Halide 22.0.0
Halide compiler and libraries
Loading...
Searching...
No Matches
IRMatch.h
Go to the documentation of this file.
1#ifndef HALIDE_IR_MATCH_H
2#define HALIDE_IR_MATCH_H
3
4/** \file
5 * Defines a method to match a fragment of IR against a pattern containing wildcards
6 */
7
8#include <map>
9#include <random>
10#include <set>
11#include <vector>
12
13#include "IR.h"
14#include "IREquality.h"
15#include "IROperator.h"
16
17namespace Halide {
18namespace Internal {
19
20/** Does the first expression have the same structure as the second?
21 * Variables in the first expression with the name * are interpreted
22 * as wildcards, and their matching equivalent in the second
23 * expression is placed in the vector give as the third argument.
24 * Wildcards require the types to match. For the type bits and width,
25 * a 0 indicates "match anything". So an Int(8, 0) will match 8-bit
26 * integer vectors of any width (including scalars), and a UInt(0, 0)
27 * will match any unsigned integer type.
28 *
29 * For example:
30 \code
31 Expr x = Variable::make(Int(32), "*");
32 match(x + x, 3 + (2*k), result)
33 \endcode
34 * should return true, and set result[0] to 3 and
35 * result[1] to 2*k.
36 */
37bool expr_match(const Expr &pattern, const Expr &expr, std::vector<Expr> &result);
38
39/** Does the first expression have the same structure as the second?
40 * Variables are matched consistently. The first time a variable is
41 * matched, it assumes the value of the matching part of the second
42 * expression. Subsequent matches must be equal to the first match.
43 *
44 * For example:
45 \code
46 Var x("x"), y("y");
47 match(x*(x + y), a*(a + b), result)
48 \endcode
49 * should return true, and set result["x"] = a, and result["y"] = b.
50 */
51bool expr_match(const Expr &pattern, const Expr &expr, std::map<std::string, Expr> &result);
52
53/** Rewrite the expression x to have `lanes` lanes. This is useful
54 * for substituting the results of expr_match into a pattern expression. */
55Expr with_lanes(const Expr &x, int lanes);
56
58
59/** An alternative template-metaprogramming approach to expression
60 * matching. Potentially more efficient. We lift the expression
61 * pattern into a type, and then use force-inlined functions to
62 * generate efficient matching and reconstruction code for any
63 * pattern. Pattern elements are either one of the classes in the
64 * namespace IRMatcher, or are non-null Exprs (represented as
65 * BaseExprNode &).
66 *
67 * Pattern elements that are fully specified by their pattern can be
68 * built into an expression using the make method. Some patterns,
69 * such as a broadcast that matches any number of lanes, don't have
70 * enough information to recreate an Expr.
71 */
72namespace IRMatcher {
73
74constexpr int max_wild = 6;
75
76static const halide_type_t i64_type = {halide_type_int, 64, 1};
77
78/** To save stack space, the matcher objects are largely stateless and
79 * immutable. This state object is built up during matching and then
80 * consumed when constructing a replacement Expr.
81 */
85
86 // values of the lanes field with special meaning.
87 static constexpr uint16_t signed_integer_overflow = 0x8000;
88 static constexpr uint16_t special_values_mask = 0x8000; // currently only one
89
91
93 void set_binding(int i, const BaseExprNode &n) noexcept {
94 bindings[i] = &n;
95 }
96
98 const BaseExprNode *get_binding(int i) const noexcept {
99 return bindings[i];
100 }
101
103 void set_bound_const(int i, int64_t s, halide_type_t t) noexcept {
104 bound_const[i].u.i64 = s;
105 bound_const_type[i] = t;
106 }
107
109 void set_bound_const(int i, uint64_t u, halide_type_t t) noexcept {
110 bound_const[i].u.u64 = u;
111 bound_const_type[i] = t;
112 }
113
115 void set_bound_const(int i, double f, halide_type_t t) noexcept {
116 bound_const[i].u.f64 = f;
117 bound_const_type[i] = t;
118 }
119
122 bound_const[i] = val;
123 bound_const_type[i] = t;
124 }
125
127 void get_bound_const(int i, halide_scalar_value_t &val, halide_type_t &type) const noexcept {
128 val = bound_const[i];
129 type = bound_const_type[i];
130 }
131
133 // NOLINTNEXTLINE(modernize-use-equals-default): Can't use `= default`; clang-tidy complains about noexcept mismatch
136};
137
138template<typename T,
139 typename = typename std::remove_reference<T>::type::pattern_tag>
141 struct type {};
142};
143
144template<typename T>
145struct bindings {
146 constexpr static uint32_t mask = std::remove_reference<T>::type::binds;
147};
148
150 const uint16_t flags = ty.lanes & MatcherState::special_values_mask;
151 ty.lanes &= ~MatcherState::special_values_mask;
154 }
155 // unreachable
156 return Expr();
157}
158
164 }
165
166 const int lanes = scalar_type.lanes;
167 scalar_type.lanes = 1;
168
169 Expr e;
170 switch (scalar_type.code) {
171 case halide_type_int:
172 e = IntImm::make(scalar_type, val.u.i64);
173 break;
174 case halide_type_uint:
175 e = UIntImm::make(scalar_type, val.u.u64);
176 break;
180 break;
181 default:
182 // Unreachable
183 return Expr();
184 }
185 if (lanes > 1) {
186 e = Broadcast::make(std::move(e), lanes);
187 }
188 return e;
189}
190
191// A pattern that matches a specific expression
193 struct pattern_tag {};
194
195 constexpr static uint32_t binds = 0;
196
197 // What is the weakest and strongest IR node this could possibly be
200 constexpr static bool canonical = true;
201
203
204 template<uint32_t bound>
205 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
206 return equal(expr, e);
207 }
208
211 return Expr(&expr);
212 }
213
214 constexpr static bool foldable = false;
215};
216
217inline std::ostream &operator<<(std::ostream &s, const SpecificExpr &e) {
218 s << Expr(&e.expr);
219 return s;
220}
221
222template<int i>
224 struct pattern_tag {};
225
226 constexpr static uint32_t binds = 1 << i;
227
230 constexpr static bool canonical = true;
231
232 template<uint32_t bound>
233 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
234 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
235 const BaseExprNode *op = &e;
236 if (op->node_type == IRNodeType::Broadcast) {
237 op = ((const Broadcast *)op)->value.get();
238 }
239 if (op->node_type != IRNodeType::IntImm) {
240 return false;
241 }
242 int64_t value = ((const IntImm *)op)->value;
243 if (bound & binds) {
245 halide_type_t type;
246 state.get_bound_const(i, val, type);
247 return (halide_type_t)e.type == type && value == val.u.i64;
248 }
249 state.set_bound_const(i, value, e.type);
250 return true;
251 }
252
253 template<uint32_t bound>
254 HALIDE_ALWAYS_INLINE bool match(int64_t value, MatcherState &state) const noexcept {
255 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
256 if (bound & binds) {
258 halide_type_t type;
259 state.get_bound_const(i, val, type);
260 return type == i64_type && value == val.u.i64;
261 }
262 state.set_bound_const(i, value, i64_type);
263 return true;
264 }
265
269 halide_type_t type;
270 state.get_bound_const(i, val, type);
271 return make_const_expr(val, type);
272 }
273
274 constexpr static bool foldable = true;
275
278 state.get_bound_const(i, val, ty);
279 }
280};
281
282template<int i>
283std::ostream &operator<<(std::ostream &s, const WildConstInt<i> &c) {
284 s << "ci" << i;
285 return s;
286}
287
288template<int i>
290 struct pattern_tag {};
291
292 constexpr static uint32_t binds = 1 << i;
293
296 constexpr static bool canonical = true;
297
298 template<uint32_t bound>
299 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
300 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
301 const BaseExprNode *op = &e;
302 if (op->node_type == IRNodeType::Broadcast) {
303 op = ((const Broadcast *)op)->value.get();
304 }
305 if (op->node_type != IRNodeType::UIntImm) {
306 return false;
307 }
308 uint64_t value = ((const UIntImm *)op)->value;
309 if (bound & binds) {
311 halide_type_t type;
312 state.get_bound_const(i, val, type);
313 return (halide_type_t)e.type == type && value == val.u.u64;
314 }
315 state.set_bound_const(i, value, e.type);
316 return true;
317 }
318
322 halide_type_t type;
323 state.get_bound_const(i, val, type);
324 return make_const_expr(val, type);
325 }
326
327 constexpr static bool foldable = true;
328
331 state.get_bound_const(i, val, ty);
332 }
333};
334
335template<int i>
336std::ostream &operator<<(std::ostream &s, const WildConstUInt<i> &c) {
337 s << "cu" << i;
338 return s;
339}
340
341template<int i>
343 struct pattern_tag {};
344
345 constexpr static uint32_t binds = 1 << i;
346
349 constexpr static bool canonical = true;
350
351 template<uint32_t bound>
352 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
353 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
354 const BaseExprNode *op = &e;
355 if (op->node_type == IRNodeType::Broadcast) {
356 op = ((const Broadcast *)op)->value.get();
357 }
358 if (op->node_type != IRNodeType::FloatImm) {
359 return false;
360 }
361 double value = ((const FloatImm *)op)->value;
362 if (bound & binds) {
364 halide_type_t type;
365 state.get_bound_const(i, val, type);
366 return (halide_type_t)e.type == type && value == val.u.f64;
367 }
368 state.set_bound_const(i, value, e.type);
369 return true;
370 }
371
375 halide_type_t type;
376 state.get_bound_const(i, val, type);
377 return make_const_expr(val, type);
378 }
379
380 constexpr static bool foldable = true;
381
384 state.get_bound_const(i, val, ty);
385 }
386};
387
388template<int i>
389std::ostream &operator<<(std::ostream &s, const WildConstFloat<i> &c) {
390 s << "cf" << i;
391 return s;
392}
393
394// Matches and binds to any constant Expr.
395template<int i>
396struct WildConst {
397 struct pattern_tag {};
398
399 constexpr static uint32_t binds = 1 << i;
400
403 constexpr static bool canonical = true;
404
405 template<uint32_t bound>
406 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
407 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
408 const BaseExprNode *op = &e;
409 if (op->node_type == IRNodeType::Broadcast) {
410 op = ((const Broadcast *)op)->value.get();
411 }
412 switch (op->node_type) {
414 return WildConstInt<i>().template match<bound>(e, state);
416 return WildConstUInt<i>().template match<bound>(e, state);
418 return WildConstFloat<i>().template match<bound>(e, state);
419 default:
420 return false;
421 }
422 }
423
424 template<uint32_t bound>
425 HALIDE_ALWAYS_INLINE bool match(int64_t e, MatcherState &state) const noexcept {
426 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
427 return WildConstInt<i>().template match<bound>(e, state);
428 }
429
433 halide_type_t type;
434 state.get_bound_const(i, val, type);
435 return make_const_expr(val, type);
436 }
437
438 constexpr static bool foldable = true;
439
442 state.get_bound_const(i, val, ty);
443 }
444};
445
446template<int i>
447std::ostream &operator<<(std::ostream &s, const WildConst<i> &c) {
448 s << "c" << i;
449 return s;
450}
451
452// Matches and binds to any Expr
453template<int i>
454struct Wild {
455 struct pattern_tag {};
456
457 constexpr static uint32_t binds = 1 << (i + 16);
458
461 constexpr static bool canonical = true;
462
463 template<uint32_t bound>
464 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
465 if (bound & binds) {
466 return equal(*state.get_binding(i), e);
467 }
468 state.set_binding(i, e);
469 return true;
470 }
471
474 return state.get_binding(i);
475 }
476
477 constexpr static bool foldable = false;
478};
479
480template<int i>
481std::ostream &operator<<(std::ostream &s, const Wild<i> &op) {
482 s << "_" << i;
483 return s;
484}
485
486// Matches a specific constant or broadcast of that constant. The
487// constant must be representable as an int64_t.
489 struct pattern_tag {};
491
492 constexpr static uint32_t binds = 0;
493
496 constexpr static bool canonical = true;
497
500 : v(v) {
501 }
502
503 template<uint32_t bound>
504 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
505 const BaseExprNode *op = &e;
506 if (e.node_type == IRNodeType::Broadcast) {
507 op = ((const Broadcast *)op)->value.get();
508 }
509 switch (op->node_type) {
511 return ((const IntImm *)op)->value == (int64_t)v;
513 return ((const UIntImm *)op)->value == (uint64_t)v;
515 return ((const FloatImm *)op)->value == (double)v;
516 default:
517 return false;
518 }
519 }
520
521 template<uint32_t bound>
522 HALIDE_ALWAYS_INLINE bool match(int64_t val, MatcherState &state) const noexcept {
523 return v == val;
524 }
525
526 template<uint32_t bound>
527 HALIDE_ALWAYS_INLINE bool match(const IntLiteral &b, MatcherState &state) const noexcept {
528 return v == b.v;
529 }
530
533 return make_const(type_hint, v);
534 }
535
536 constexpr static bool foldable = true;
537
540 // Assume type is already correct
541 switch (ty.code) {
542 case halide_type_int:
543 val.u.i64 = v;
544 break;
545 case halide_type_uint:
546 val.u.u64 = (uint64_t)v;
547 break;
550 val.u.f64 = (double)v;
551 break;
552 default:
553 // Unreachable
554 ;
555 }
556 }
557};
558
562
563// Convert a provided pattern, expr, or constant int into the internal
564// representation we use in the matcher trees.
565template<typename T,
566 typename = typename std::decay<T>::type::pattern_tag>
568 return t;
569}
572 return IntLiteral{x};
573}
574
575template<typename T>
577 static_assert(!std::is_same_v<std::decay_t<T>, Expr> || std::is_lvalue_reference_v<T>,
578 "Exprs are captured by reference by IRMatcher objects and so must be lvalues");
579}
580
582 return {*e.get()};
583}
584
585// Helpers to deref SpecificExprs to const BaseExprNode & rather than
586// passing them by value anywhere (incurring lots of refcounting)
587template<typename T,
588 // T must be a pattern node
589 typename = typename std::decay_t<T>::pattern_tag,
590 // But T may not be SpecificExpr
591 typename = std::enable_if_t<!std::is_same_v<std::decay_t<T>, SpecificExpr>>>
593 return t;
594}
595
598 return e.expr;
599}
600
601inline std::ostream &operator<<(std::ostream &s, const IntLiteral &op) {
602 s << op.v;
603 return s;
604}
605
606template<typename Op>
608
609template<typename Op>
611
612template<typename Op>
613double constant_fold_bin_op(halide_type_t &, double, double) noexcept;
614
615constexpr bool commutative(IRNodeType t) {
616 return (t == IRNodeType::Add ||
617 t == IRNodeType::Mul ||
618 t == IRNodeType::And ||
619 t == IRNodeType::Or ||
620 t == IRNodeType::Min ||
621 t == IRNodeType::Max ||
622 t == IRNodeType::EQ ||
623 t == IRNodeType::NE);
624}
625
626// Matches one of the binary operators
627template<typename Op, typename A, typename B>
628struct BinOp {
629 struct pattern_tag {};
632
634
635 constexpr static IRNodeType min_node_type = Op::_node_type;
636 constexpr static IRNodeType max_node_type = Op::_node_type;
637
638 // For commutative bin ops, we expect the weaker IR node type on
639 // the right. That is, for the rule to be canonical it must be
640 // possible that A is at least as strong as B.
641 constexpr static bool canonical =
642 A::canonical && B::canonical && (!commutative(Op::_node_type) || (A::max_node_type >= B::min_node_type));
643
644 template<uint32_t bound>
645 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
646 if (e.node_type != Op::_node_type) {
647 return false;
648 }
649 const Op &op = (const Op &)e;
650 return (a.template match<bound>(*op.a.get(), state) &&
651 b.template match<(bound | bindings<A>::mask)>(*op.b.get(), state));
652 }
653
654 template<uint32_t bound, typename Op2, typename A2, typename B2>
655 HALIDE_ALWAYS_INLINE bool match(const BinOp<Op2, A2, B2> &op, MatcherState &state) const noexcept {
656 return (std::is_same_v<Op, Op2> &&
657 a.template match<bound>(unwrap(op.a), state) &&
658 b.template match<(bound | bindings<A>::mask)>(unwrap(op.b), state));
659 }
660
661 constexpr static bool foldable = A::foldable && B::foldable;
662
666 if (std::is_same_v<A, IntLiteral>) {
667 b.make_folded_const(val_b, ty, state);
668 if ((std::is_same_v<Op, And> && val_b.u.u64 == 0) ||
669 (std::is_same_v<Op, Or> && val_b.u.u64 == 1)) {
670 // Short circuit
671 val = val_b;
672 return;
673 }
674 const uint16_t l = ty.lanes;
675 a.make_folded_const(val_a, ty, state);
676 ty.lanes |= l; // Make sure the overflow bits are sticky
677 } else {
678 a.make_folded_const(val_a, ty, state);
679 if ((std::is_same_v<Op, And> && val_a.u.u64 == 0) ||
680 (std::is_same_v<Op, Or> && val_a.u.u64 == 1)) {
681 // Short circuit
682 val = val_a;
683 return;
684 }
685 const uint16_t l = ty.lanes;
686 b.make_folded_const(val_b, ty, state);
687 ty.lanes |= l;
688 }
689 switch (ty.code) {
690 case halide_type_int:
691 val.u.i64 = constant_fold_bin_op<Op>(ty, val_a.u.i64, val_b.u.i64);
692 break;
693 case halide_type_uint:
694 val.u.u64 = constant_fold_bin_op<Op>(ty, val_a.u.u64, val_b.u.u64);
695 break;
698 val.u.f64 = constant_fold_bin_op<Op>(ty, val_a.u.f64, val_b.u.f64);
699 break;
700 default:
701 // unreachable
702 ;
703 }
704 }
705
707 Expr make(MatcherState &state, halide_type_t type_hint) const noexcept {
708 Expr ea, eb;
709 if (std::is_same_v<A, IntLiteral>) {
710 eb = b.make(state, type_hint);
711 ea = a.make(state, eb.type());
712 } else {
713 ea = a.make(state, type_hint);
714 eb = b.make(state, ea.type());
715 }
716 return Op::make(std::move(ea), std::move(eb));
717 }
718};
719
720template<typename Op>
722
723template<typename Op>
725
726template<typename Op>
727uint64_t constant_fold_cmp_op(double, double) noexcept;
728
729// Matches one of the comparison operators
730template<typename Op, typename A, typename B>
731struct CmpOp {
732 struct pattern_tag {};
735
737
738 constexpr static IRNodeType min_node_type = Op::_node_type;
739 constexpr static IRNodeType max_node_type = Op::_node_type;
740 constexpr static bool canonical = (A::canonical &&
741 B::canonical &&
742 (!commutative(Op::_node_type) || A::max_node_type >= B::min_node_type) &&
743 (Op::_node_type != IRNodeType::GE) &&
744 (Op::_node_type != IRNodeType::GT));
745
746 template<uint32_t bound>
747 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
748 if (e.node_type != Op::_node_type) {
749 return false;
750 }
751 const Op &op = (const Op &)e;
752 return (a.template match<bound>(*op.a.get(), state) &&
753 b.template match<(bound | bindings<A>::mask)>(*op.b.get(), state));
754 }
755
756 template<uint32_t bound, typename Op2, typename A2, typename B2>
757 HALIDE_ALWAYS_INLINE bool match(const CmpOp<Op2, A2, B2> &op, MatcherState &state) const noexcept {
758 return (std::is_same_v<Op, Op2> &&
759 a.template match<bound>(unwrap(op.a), state) &&
760 b.template match<(bound | bindings<A>::mask)>(unwrap(op.b), state));
761 }
762
763 constexpr static bool foldable = A::foldable && B::foldable;
764
768 // If one side is an untyped const, evaluate the other side first to get a type hint.
769 if (std::is_same_v<A, IntLiteral>) {
770 b.make_folded_const(val_b, ty, state);
771 const uint16_t l = ty.lanes;
772 a.make_folded_const(val_a, ty, state);
773 ty.lanes |= l;
774 } else {
775 a.make_folded_const(val_a, ty, state);
776 const uint16_t l = ty.lanes;
777 b.make_folded_const(val_b, ty, state);
778 ty.lanes |= l;
779 }
780
781 switch (ty.code) {
782 case halide_type_int:
783 val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.i64, val_b.u.i64);
784 break;
785 case halide_type_uint:
786 val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.u64, val_b.u.u64);
787 break;
790 val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.f64, val_b.u.f64);
791 break;
792 default:
793 // unreachable
794 ;
795 }
796 ty.code = halide_type_uint;
797 ty.bits = 1;
798 }
799
802 // If one side is an untyped const, evaluate the other side first to get a type hint.
803 Expr ea, eb;
804 if (std::is_same_v<A, IntLiteral>) {
805 eb = b.make(state, {});
806 ea = a.make(state, eb.type());
807 } else {
808 ea = a.make(state, {});
809 eb = b.make(state, ea.type());
810 }
811 return Op::make(std::move(ea), std::move(eb));
812 }
813};
814
815template<typename A, typename B>
816std::ostream &operator<<(std::ostream &s, const BinOp<Add, A, B> &op) {
817 s << "(" << op.a << " + " << op.b << ")";
818 return s;
819}
820
821template<typename A, typename B>
822std::ostream &operator<<(std::ostream &s, const BinOp<Sub, A, B> &op) {
823 s << "(" << op.a << " - " << op.b << ")";
824 return s;
825}
826
827template<typename A, typename B>
828std::ostream &operator<<(std::ostream &s, const BinOp<Mul, A, B> &op) {
829 s << "(" << op.a << " * " << op.b << ")";
830 return s;
831}
832
833template<typename A, typename B>
834std::ostream &operator<<(std::ostream &s, const BinOp<Div, A, B> &op) {
835 s << "(" << op.a << " / " << op.b << ")";
836 return s;
837}
838
839template<typename A, typename B>
840std::ostream &operator<<(std::ostream &s, const BinOp<And, A, B> &op) {
841 s << "(" << op.a << " && " << op.b << ")";
842 return s;
843}
844
845template<typename A, typename B>
846std::ostream &operator<<(std::ostream &s, const BinOp<Or, A, B> &op) {
847 s << "(" << op.a << " || " << op.b << ")";
848 return s;
849}
850
851template<typename A, typename B>
852std::ostream &operator<<(std::ostream &s, const BinOp<Min, A, B> &op) {
853 s << "min(" << op.a << ", " << op.b << ")";
854 return s;
855}
856
857template<typename A, typename B>
858std::ostream &operator<<(std::ostream &s, const BinOp<Max, A, B> &op) {
859 s << "max(" << op.a << ", " << op.b << ")";
860 return s;
861}
862
863template<typename A, typename B>
864std::ostream &operator<<(std::ostream &s, const CmpOp<LE, A, B> &op) {
865 s << "(" << op.a << " <= " << op.b << ")";
866 return s;
867}
868
869template<typename A, typename B>
870std::ostream &operator<<(std::ostream &s, const CmpOp<LT, A, B> &op) {
871 s << "(" << op.a << " < " << op.b << ")";
872 return s;
873}
874
875template<typename A, typename B>
876std::ostream &operator<<(std::ostream &s, const CmpOp<GE, A, B> &op) {
877 s << "(" << op.a << " >= " << op.b << ")";
878 return s;
879}
880
881template<typename A, typename B>
882std::ostream &operator<<(std::ostream &s, const CmpOp<GT, A, B> &op) {
883 s << "(" << op.a << " > " << op.b << ")";
884 return s;
885}
886
887template<typename A, typename B>
888std::ostream &operator<<(std::ostream &s, const CmpOp<EQ, A, B> &op) {
889 s << "(" << op.a << " == " << op.b << ")";
890 return s;
891}
892
893template<typename A, typename B>
894std::ostream &operator<<(std::ostream &s, const CmpOp<NE, A, B> &op) {
895 s << "(" << op.a << " != " << op.b << ")";
896 return s;
897}
898
899template<typename A, typename B>
900std::ostream &operator<<(std::ostream &s, const BinOp<Mod, A, B> &op) {
901 s << "(" << op.a << " % " << op.b << ")";
902 return s;
903}
904
905template<typename A, typename B>
906HALIDE_ALWAYS_INLINE auto operator+(A &&a, B &&b) noexcept -> BinOp<Add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
909 return {pattern_arg(a), pattern_arg(b)};
910}
911
912template<typename A, typename B>
918
919template<>
921 t.lanes |= ((t.bits >= 32) && add_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
922 int dead_bits = 64 - t.bits;
923 // Drop the high bits then sign-extend them back
924 return int64_t((uint64_t(a) + uint64_t(b)) << dead_bits) >> dead_bits;
925}
926
927template<>
929 uint64_t ones = (uint64_t)(-1);
930 return (a + b) & (ones >> (64 - t.bits));
931}
932
933template<>
934HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Add>(halide_type_t &t, double a, double b) noexcept {
935 return a + b;
936}
937
938template<typename A, typename B>
939HALIDE_ALWAYS_INLINE auto operator-(A &&a, B &&b) noexcept -> BinOp<Sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
942 return {pattern_arg(a), pattern_arg(b)};
943}
944
945template<typename A, typename B>
951
952template<>
954 t.lanes |= ((t.bits >= 32) && sub_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
955 // Drop the high bits then sign-extend them back
956 int dead_bits = 64 - t.bits;
957 return int64_t((uint64_t(a) - uint64_t(b)) << dead_bits) >> dead_bits;
958}
959
960template<>
962 uint64_t ones = (uint64_t)(-1);
963 return (a - b) & (ones >> (64 - t.bits));
964}
965
966template<>
967HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Sub>(halide_type_t &t, double a, double b) noexcept {
968 return a - b;
969}
970
971template<typename A, typename B>
972HALIDE_ALWAYS_INLINE auto operator*(A &&a, B &&b) noexcept -> BinOp<Mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
975 return {pattern_arg(a), pattern_arg(b)};
976}
977
978template<typename A, typename B>
984
985template<>
987 t.lanes |= ((t.bits >= 32) && mul_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
988 int dead_bits = 64 - t.bits;
989 // Drop the high bits then sign-extend them back
990 return int64_t((uint64_t(a) * uint64_t(b)) << dead_bits) >> dead_bits;
991}
992
993template<>
995 uint64_t ones = (uint64_t)(-1);
996 return (a * b) & (ones >> (64 - t.bits));
997}
998
999template<>
1000HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Mul>(halide_type_t &t, double a, double b) noexcept {
1001 return a * b;
1002}
1003
1004template<typename A, typename B>
1005HALIDE_ALWAYS_INLINE auto operator/(A &&a, B &&b) noexcept -> BinOp<Div, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1008 return {pattern_arg(a), pattern_arg(b)};
1009}
1010
1011template<typename A, typename B>
1012HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, b)) {
1013 return IRMatcher::operator/(a, b);
1014}
1015
1016template<>
1020
1021template<>
1025
1026template<>
1027HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Div>(halide_type_t &t, double a, double b) noexcept {
1028 return div_imp(a, b);
1029}
1030
1031template<typename A, typename B>
1032HALIDE_ALWAYS_INLINE auto operator%(A &&a, B &&b) noexcept -> BinOp<Mod, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1035 return {pattern_arg(a), pattern_arg(b)};
1036}
1037
1038template<typename A, typename B>
1039HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b)) {
1042 return IRMatcher::operator%(a, b);
1043}
1044
1045template<>
1049
1050template<>
1054
1055template<>
1056HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Mod>(halide_type_t &t, double a, double b) noexcept {
1057 return mod_imp(a, b);
1058}
1059
1060template<typename A, typename B>
1061HALIDE_ALWAYS_INLINE auto min(A &&a, B &&b) noexcept -> BinOp<Min, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1064 return {pattern_arg(a), pattern_arg(b)};
1065}
1066
1067template<>
1069 return std::min(a, b);
1070}
1071
1072template<>
1074 return std::min(a, b);
1075}
1076
1077template<>
1078HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Min>(halide_type_t &t, double a, double b) noexcept {
1079 return std::min(a, b);
1080}
1081
1082template<typename A, typename B>
1083HALIDE_ALWAYS_INLINE auto max(A &&a, B &&b) noexcept -> BinOp<Max, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1086 return {pattern_arg(std::forward<A>(a)), pattern_arg(std::forward<B>(b))};
1087}
1088
1089template<>
1091 return std::max(a, b);
1092}
1093
1094template<>
1096 return std::max(a, b);
1097}
1098
1099template<>
1100HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Max>(halide_type_t &t, double a, double b) noexcept {
1101 return std::max(a, b);
1102}
1103
1104template<typename A, typename B>
1105HALIDE_ALWAYS_INLINE auto operator<(A &&a, B &&b) noexcept -> CmpOp<LT, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1106 return {pattern_arg(a), pattern_arg(b)};
1107}
1108
1109template<typename A, typename B>
1110HALIDE_ALWAYS_INLINE auto lt(A &&a, B &&b) -> decltype(IRMatcher::operator<(a, b)) {
1111 return IRMatcher::operator<(a, b);
1112}
1113
1114template<>
1118
1119template<>
1123
1124template<>
1126 return a < b;
1127}
1128
1129template<typename A, typename B>
1130HALIDE_ALWAYS_INLINE auto operator>(A &&a, B &&b) noexcept -> CmpOp<GT, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1131 return {pattern_arg(a), pattern_arg(b)};
1132}
1133
1134template<typename A, typename B>
1135HALIDE_ALWAYS_INLINE auto gt(A &&a, B &&b) -> decltype(IRMatcher::operator>(a, b)) {
1136 return IRMatcher::operator>(a, b);
1137}
1138
1139template<>
1143
1144template<>
1148
1149template<>
1151 return a > b;
1152}
1153
1154template<typename A, typename B>
1155HALIDE_ALWAYS_INLINE auto operator<=(A &&a, B &&b) noexcept -> CmpOp<LE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1156 return {pattern_arg(a), pattern_arg(b)};
1157}
1158
1159template<typename A, typename B>
1160HALIDE_ALWAYS_INLINE auto le(A &&a, B &&b) -> decltype(IRMatcher::operator<=(a, b)) {
1161 return IRMatcher::operator<=(a, b);
1162}
1163
1164template<>
1166 return a <= b;
1167}
1168
1169template<>
1173
1174template<>
1176 return a <= b;
1177}
1178
1179template<typename A, typename B>
1180HALIDE_ALWAYS_INLINE auto operator>=(A &&a, B &&b) noexcept -> CmpOp<GE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1181 return {pattern_arg(a), pattern_arg(b)};
1182}
1183
1184template<typename A, typename B>
1185HALIDE_ALWAYS_INLINE auto ge(A &&a, B &&b) -> decltype(IRMatcher::operator>=(a, b)) {
1186 return IRMatcher::operator>=(a, b);
1187}
1188
1189template<>
1191 return a >= b;
1192}
1193
1194template<>
1198
1199template<>
1201 return a >= b;
1202}
1203
1204template<typename A, typename B>
1205HALIDE_ALWAYS_INLINE auto operator==(A &&a, B &&b) noexcept -> CmpOp<EQ, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1206 return {pattern_arg(a), pattern_arg(b)};
1207}
1208
1209template<typename A, typename B>
1210HALIDE_ALWAYS_INLINE auto eq(A &&a, B &&b) -> decltype(IRMatcher::operator==(a, b)) {
1211 return IRMatcher::operator==(a, b);
1212}
1213
1214template<>
1216 return a == b;
1217}
1218
1219template<>
1223
1224template<>
1226 return a == b;
1227}
1228
1229template<typename A, typename B>
1230HALIDE_ALWAYS_INLINE auto operator!=(A &&a, B &&b) noexcept -> CmpOp<NE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1231 return {pattern_arg(a), pattern_arg(b)};
1232}
1233
1234template<typename A, typename B>
1235HALIDE_ALWAYS_INLINE auto ne(A &&a, B &&b) -> decltype(IRMatcher::operator!=(a, b)) {
1236 return IRMatcher::operator!=(a, b);
1237}
1238
1239template<>
1241 return a != b;
1242}
1243
1244template<>
1248
1249template<>
1251 return a != b;
1252}
1253
1254template<typename A, typename B>
1255HALIDE_ALWAYS_INLINE auto operator||(A &&a, B &&b) noexcept -> BinOp<Or, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1256 return {pattern_arg(a), pattern_arg(b)};
1257}
1258
1259template<typename A, typename B>
1260HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||(a, b)) {
1261 return IRMatcher::operator||(a, b);
1262}
1263
1264template<>
1266 return (a | b) & 1;
1267}
1268
1269template<>
1271 return (a | b) & 1;
1272}
1273
1274template<>
1275HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Or>(halide_type_t &t, double a, double b) noexcept {
1276 // Unreachable, as it would be a type mismatch.
1277 return 0;
1278}
1279
1280template<typename A, typename B>
1281HALIDE_ALWAYS_INLINE auto operator&&(A &&a, B &&b) noexcept -> BinOp<And, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1282 return {pattern_arg(a), pattern_arg(b)};
1283}
1284
1285template<typename A, typename B>
1286HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&&(a, b)) {
1287 return IRMatcher::operator&&(a, b);
1288}
1289
1290template<>
1292 return a & b & 1;
1293}
1294
1295template<>
1299
1300template<>
1301HALIDE_ALWAYS_INLINE double constant_fold_bin_op<And>(halide_type_t &t, double a, double b) noexcept {
1302 // Unreachable
1303 return 0;
1304}
1305
1307 return 0;
1308}
1309
1310template<typename... Args>
1311constexpr uint32_t bitwise_or_reduce(uint32_t first, Args... rest) {
1312 return first | bitwise_or_reduce(rest...);
1313}
1314
1315constexpr bool and_reduce() {
1316 return true;
1317}
1318
1319template<typename... Args>
1320constexpr bool and_reduce(bool first, Args... rest) {
1321 return first && and_reduce(rest...);
1322}
1323
1324template<Call::IntrinsicOp intrin>
1326 bool check(const Type &) const {
1327 return true;
1328 }
1329};
1330
1331template<>
1334 bool check(const Type &t) const {
1335 return t == Type(type);
1336 }
1337};
1338
1339template<Call::IntrinsicOp intrin, typename... Args>
1340struct Intrin {
1341 struct pattern_tag {};
1342 std::tuple<Args...> args;
1343 // The type of the output of the intrinsic node.
1344 // Only necessary in cases where it can't be inferred
1345 // from the input types (e.g. saturating_cast).
1346
1348
1350
1353 constexpr static bool canonical = and_reduce((Args::canonical)...);
1354
1355 template<int i,
1356 uint32_t bound,
1357 typename = std::enable_if_t<(i < sizeof...(Args))>>
1358 HALIDE_ALWAYS_INLINE bool match_args(int, const Call &c, MatcherState &state) const noexcept {
1359 using T = decltype(std::get<i>(args));
1360 return (std::get<i>(args).template match<bound>(*c.args[i].get(), state) &&
1361 match_args<i + 1, (bound | bindings<T>::mask)>(0, c, state));
1362 }
1363
1364 template<int i, uint32_t binds>
1365 HALIDE_ALWAYS_INLINE bool match_args(double, const Call &c, MatcherState &state) const noexcept {
1366 return true;
1367 }
1368
1369 template<uint32_t bound>
1370 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1371 if (e.node_type != IRNodeType::Call) {
1372 return false;
1373 }
1374 const Call &c = (const Call &)e;
1375 return (c.is_intrinsic(intrin) &&
1376 optional_type_hint.check(e.type) &&
1377 match_args<0, bound>(0, c, state));
1378 }
1379
1380 template<int i,
1381 typename = std::enable_if_t<(i < sizeof...(Args))>>
1382 HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const {
1384 if (i + 1 < sizeof...(Args)) {
1385 s << ", ";
1386 }
1387 print_args<i + 1>(0, s);
1388 }
1389
1390 template<int i>
1391 HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const {
1392 }
1393
1395 void print_args(std::ostream &s) const {
1396 print_args<0>(0, s);
1397 }
1398
1401 Expr arg0 = std::get<0>(args).make(state, type_hint);
1402 if (intrin == Call::likely) {
1403 return likely(std::move(arg0));
1404 } else if (intrin == Call::likely_if_innermost) {
1405 return likely_if_innermost(std::move(arg0));
1406 } else if (intrin == Call::abs) {
1407 return abs(std::move(arg0));
1408 } else if constexpr (intrin == Call::saturating_cast) {
1409 return saturating_cast(optional_type_hint.type, std::move(arg0));
1410 }
1411
1412 Expr arg1 = std::get<std::min<size_t>(1, sizeof...(Args) - 1)>(args).make(state, type_hint);
1413 if (intrin == Call::absd) {
1414 return absd(std::move(arg0), std::move(arg1));
1415 } else if (intrin == Call::widen_right_add) {
1416 return widen_right_add(std::move(arg0), std::move(arg1));
1417 } else if (intrin == Call::widen_right_mul) {
1418 return widen_right_mul(std::move(arg0), std::move(arg1));
1419 } else if (intrin == Call::widen_right_sub) {
1420 return widen_right_sub(std::move(arg0), std::move(arg1));
1421 } else if (intrin == Call::widening_add) {
1422 return widening_add(std::move(arg0), std::move(arg1));
1423 } else if (intrin == Call::widening_sub) {
1424 return widening_sub(std::move(arg0), std::move(arg1));
1425 } else if (intrin == Call::widening_mul) {
1426 return widening_mul(std::move(arg0), std::move(arg1));
1427 } else if (intrin == Call::saturating_add) {
1428 return saturating_add(std::move(arg0), std::move(arg1));
1429 } else if (intrin == Call::saturating_sub) {
1430 return saturating_sub(std::move(arg0), std::move(arg1));
1431 } else if (intrin == Call::halving_add) {
1432 return halving_add(std::move(arg0), std::move(arg1));
1433 } else if (intrin == Call::halving_sub) {
1434 return halving_sub(std::move(arg0), std::move(arg1));
1435 } else if (intrin == Call::rounding_halving_add) {
1436 return rounding_halving_add(std::move(arg0), std::move(arg1));
1437 } else if (intrin == Call::shift_left) {
1438 return std::move(arg0) << std::move(arg1);
1439 } else if (intrin == Call::shift_right) {
1440 return std::move(arg0) >> std::move(arg1);
1441 } else if (intrin == Call::rounding_shift_left) {
1442 return rounding_shift_left(std::move(arg0), std::move(arg1));
1443 } else if (intrin == Call::rounding_shift_right) {
1444 return rounding_shift_right(std::move(arg0), std::move(arg1));
1445 }
1446
1447 Expr arg2 = std::get<std::min<size_t>(2, sizeof...(Args) - 1)>(args).make(state, type_hint);
1449 return mul_shift_right(std::move(arg0), std::move(arg1), std::move(arg2));
1450 } else if (intrin == Call::rounding_mul_shift_right) {
1451 return rounding_mul_shift_right(std::move(arg0), std::move(arg1), std::move(arg2));
1452 }
1453
1454 internal_error << "Unhandled intrinsic in IRMatcher: " << intrin;
1455 return Expr();
1456 }
1457
1458 constexpr static bool foldable = true;
1459
1462 // Assuming the args have the same type as the intrinsic is incorrect in
1463 // general. But for the intrinsics we can fold (just shifts), the LHS
1464 // has the same type as the intrinsic, and we can always treat the RHS
1465 // as a signed int, because we're using 64 bits for it.
1466 std::get<0>(args).make_folded_const(val, ty, state);
1469 // We can just directly get the second arg here, because we only want to
1470 // instantiate this method for shifts, which have two args.
1471 std::get<1>(args).make_folded_const(arg1, signed_ty, state);
1472
1473 if (intrin == Call::shift_left) {
1474 if (arg1.u.i64 < 0) {
1475 if (ty.code == halide_type_int) {
1476 // Arithmetic shift
1477 val.u.i64 >>= -arg1.u.i64;
1478 } else {
1479 // Logical shift
1480 val.u.u64 >>= -arg1.u.i64;
1481 }
1482 } else {
1483 val.u.u64 <<= arg1.u.i64;
1484 }
1485 } else if (intrin == Call::shift_right) {
1486 if (arg1.u.i64 > 0) {
1487 if (ty.code == halide_type_int) {
1488 // Arithmetic shift
1489 val.u.i64 >>= arg1.u.i64;
1490 } else {
1491 // Logical shift
1492 val.u.u64 >>= arg1.u.i64;
1493 }
1494 } else {
1495 val.u.u64 <<= -arg1.u.i64;
1496 }
1497 } else {
1498 internal_error << "Folding not implemented for intrinsic: " << intrin;
1499 }
1500 }
1501
1504 : args(args...) {
1505 }
1506};
1507
1508template<Call::IntrinsicOp intrin, typename... Args>
1509std::ostream &operator<<(std::ostream &s, const Intrin<intrin, Args...> &op) {
1510 s << intrin << "(";
1511 op.print_args(s);
1512 s << ")";
1513 return s;
1514}
1515
1516template<typename A, typename B>
1517auto widen_right_add(A &&a, B &&b) noexcept -> Intrin<Call::widen_right_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1518 return {pattern_arg(a), pattern_arg(b)};
1519}
1520template<typename A, typename B>
1521auto widen_right_mul(A &&a, B &&b) noexcept -> Intrin<Call::widen_right_mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1522 return {pattern_arg(a), pattern_arg(b)};
1523}
1524template<typename A, typename B>
1525auto widen_right_sub(A &&a, B &&b) noexcept -> Intrin<Call::widen_right_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1526 return {pattern_arg(a), pattern_arg(b)};
1527}
1528
1529template<typename A, typename B>
1530auto widening_add(A &&a, B &&b) noexcept -> Intrin<Call::widening_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1531 return {pattern_arg(a), pattern_arg(b)};
1532}
1533template<typename A, typename B>
1534auto widening_sub(A &&a, B &&b) noexcept -> Intrin<Call::widening_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1535 return {pattern_arg(a), pattern_arg(b)};
1536}
1537template<typename A, typename B>
1538auto widening_mul(A &&a, B &&b) noexcept -> Intrin<Call::widening_mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1539 return {pattern_arg(a), pattern_arg(b)};
1540}
1541template<typename A, typename B>
1542auto saturating_add(A &&a, B &&b) noexcept -> Intrin<Call::saturating_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1543 return {pattern_arg(a), pattern_arg(b)};
1544}
1545template<typename A, typename B>
1546auto saturating_sub(A &&a, B &&b) noexcept -> Intrin<Call::saturating_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1547 return {pattern_arg(a), pattern_arg(b)};
1548}
1549template<typename A>
1550auto saturating_cast(const Type &t, A &&a) noexcept -> Intrin<Call::saturating_cast, decltype(pattern_arg(a))> {
1552 p.optional_type_hint.type = t;
1553 return p;
1554}
1555template<typename A, typename B>
1556auto halving_add(A &&a, B &&b) noexcept -> Intrin<Call::halving_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1557 return {pattern_arg(a), pattern_arg(b)};
1558}
1559template<typename A, typename B>
1560auto halving_sub(A &&a, B &&b) noexcept -> Intrin<Call::halving_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1561 return {pattern_arg(a), pattern_arg(b)};
1562}
1563template<typename A, typename B>
1564auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin<Call::rounding_halving_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1565 return {pattern_arg(a), pattern_arg(b)};
1566}
1567template<typename A, typename B>
1568auto shift_left(A &&a, B &&b) noexcept -> Intrin<Call::shift_left, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1569 return {pattern_arg(a), pattern_arg(b)};
1570}
1571template<typename A, typename B>
1572auto shift_right(A &&a, B &&b) noexcept -> Intrin<Call::shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1573 return {pattern_arg(a), pattern_arg(b)};
1574}
1575template<typename A, typename B>
1576auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin<Call::rounding_shift_left, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1577 return {pattern_arg(a), pattern_arg(b)};
1578}
1579template<typename A, typename B>
1580auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin<Call::rounding_shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1581 return {pattern_arg(a), pattern_arg(b)};
1582}
1583template<typename A, typename B, typename C>
1584auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<Call::mul_shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1585 return {pattern_arg(a), pattern_arg(b), pattern_arg(c)};
1586}
1587template<typename A, typename B, typename C>
1588auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<Call::rounding_mul_shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1589 return {pattern_arg(a), pattern_arg(b), pattern_arg(c)};
1590}
1591
1592template<typename A>
1593auto abs(A &&a) noexcept -> Intrin<Call::abs, decltype(pattern_arg(a))> {
1594 return {pattern_arg(a)};
1595}
1596
1597template<typename A, typename B>
1598auto absd(A &&a, B &&b) noexcept -> Intrin<Call::absd, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1599 return {pattern_arg(a), pattern_arg(b)};
1600}
1601
1602template<typename A>
1603auto likely(A &&a) noexcept -> Intrin<Call::likely, decltype(pattern_arg(a))> {
1604 return {pattern_arg(a)};
1605}
1606
1607template<typename A>
1609 return {pattern_arg(a)};
1610}
1611
1612template<typename A>
1613struct NotOp {
1614 struct pattern_tag {};
1616
1618
1621 constexpr static bool canonical = A::canonical;
1622
1623 template<uint32_t bound>
1624 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1625 if (e.node_type != IRNodeType::Not) {
1626 return false;
1627 }
1628 const Not &op = (const Not &)e;
1629 return (a.template match<bound>(*op.a.get(), state));
1630 }
1631
1632 template<uint32_t bound, typename A2>
1633 HALIDE_ALWAYS_INLINE bool match(const NotOp<A2> &op, MatcherState &state) const noexcept {
1634 return a.template match<bound>(unwrap(op.a), state);
1635 }
1636
1639 return Not::make(a.make(state, type_hint));
1640 }
1641
1642 constexpr static bool foldable = A::foldable;
1643
1644 template<typename A1 = A>
1646 a.make_folded_const(val, ty, state);
1647 val.u.u64 = ~val.u.u64;
1648 val.u.u64 &= 1;
1649 }
1650};
1651
1652template<typename A>
1653HALIDE_ALWAYS_INLINE auto operator!(A &&a) noexcept -> NotOp<decltype(pattern_arg(a))> {
1655 return {pattern_arg(a)};
1656}
1657
1658template<typename A>
1663
1664template<typename A>
1665inline std::ostream &operator<<(std::ostream &s, const NotOp<A> &op) {
1666 s << "!(" << op.a << ")";
1667 return s;
1668}
1669
1670// The simplified negation of some already-bound boolean wildcard. So if x
1671// matched to v < 3, this will bind to 3 <= v. If x matched to v == 3, this will
1672// bind to v != 3, etc. Will also bind to !x.
1673template<int i>
1675 struct pattern_tag {};
1676
1677 constexpr static uint32_t binds = 0;
1678
1681 constexpr static bool canonical = true;
1682 constexpr static bool foldable = false;
1683
1684 template<uint32_t bound>
1685 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1686 static_assert(bound & Wild<i>::binds, "neg must be applied to an already-bound expr");
1687 const BaseExprNode &b = *state.get_binding(i);
1688
1689 switch (b.node_type) {
1690 case IRNodeType::EQ:
1691 return (e.node_type == IRNodeType::NE &&
1692 ((equal(*((const NE &)e).a.get(), *((const EQ &)b).a.get()) &&
1693 equal(*((const NE &)e).b.get(), *((const EQ &)b).b.get())) ||
1694 (equal(*((const NE &)e).a.get(), *((const EQ &)b).b.get()) &&
1695 equal(*((const NE &)e).b.get(), *((const EQ &)b).a.get()))));
1696 case IRNodeType::NE:
1697 return (e.node_type == IRNodeType::EQ &&
1698 ((equal(*((const EQ &)e).a.get(), *((const NE &)b).a.get()) &&
1699 equal(*((const EQ &)e).b.get(), *((const NE &)b).b.get())) ||
1700 (equal(*((const EQ &)e).a.get(), *((const NE &)b).b.get()) &&
1701 equal(*((const EQ &)e).b.get(), *((const NE &)b).a.get()))));
1702 case IRNodeType::LT:
1703 return (e.node_type == IRNodeType::LE &&
1704 equal(*((const LE &)e).a.get(), *((const LT &)b).b.get()) &&
1705 equal(*((const LE &)e).b.get(), *((const LT &)b).a.get()));
1706 case IRNodeType::LE:
1707 return (e.node_type == IRNodeType::LT &&
1708 equal(*((const LT &)e).a.get(), *((const LE &)b).b.get()) &&
1709 equal(*((const LT &)e).b.get(), *((const LE &)b).a.get()));
1710 case IRNodeType::Not:
1711 return equal(e, *((const Not &)b).a.get());
1712 default:
1713 return (e.node_type == IRNodeType::Not &&
1714 equal(*((const Not &)e).a.get(), b));
1715 }
1716 }
1717};
1718
1719template<int i>
1723
1724template<int i>
1725inline std::ostream &operator<<(std::ostream &s, const SimplifiedNegateOp<i> &op) {
1726 s << "neg(" << Wild<i>{} << ")";
1727 return s;
1728}
1729
1730template<typename C, typename T, typename F>
1731struct SelectOp {
1732 struct pattern_tag {};
1734 T t;
1736
1738
1741
1742 constexpr static bool canonical = C::canonical && T::canonical && F::canonical;
1743
1744 template<uint32_t bound>
1745 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1746 if (e.node_type != Select::_node_type) {
1747 return false;
1748 }
1749 const Select &op = (const Select &)e;
1750 return (c.template match<bound>(*op.condition.get(), state) &&
1751 t.template match<(bound | bindings<C>::mask)>(*op.true_value.get(), state) &&
1752 f.template match<(bound | bindings<C>::mask | bindings<T>::mask)>(*op.false_value.get(), state));
1753 }
1754 template<uint32_t bound, typename C2, typename T2, typename F2>
1755 HALIDE_ALWAYS_INLINE bool match(const SelectOp<C2, T2, F2> &instance, MatcherState &state) const noexcept {
1756 return (c.template match<bound>(unwrap(instance.c), state) &&
1757 t.template match<(bound | bindings<C>::mask)>(unwrap(instance.t), state) &&
1758 f.template match<(bound | bindings<C>::mask | bindings<T>::mask)>(unwrap(instance.f), state));
1759 }
1760
1763 return Select::make(c.make(state, {}), t.make(state, type_hint), f.make(state, type_hint));
1764 }
1765
1766 constexpr static bool foldable = C::foldable && T::foldable && F::foldable;
1767
1768 template<typename C1 = C>
1772 c.make_folded_const(c_val, c_ty, state);
1773 if ((c_val.u.u64 & 1) == 1) {
1774 t.make_folded_const(val, ty, state);
1775 } else {
1776 f.make_folded_const(val, ty, state);
1777 }
1779 }
1780};
1781
1782template<typename C, typename T, typename F>
1783std::ostream &operator<<(std::ostream &s, const SelectOp<C, T, F> &op) {
1784 s << "select(" << op.c << ", " << op.t << ", " << op.f << ")";
1785 return s;
1786}
1787
1788template<typename C, typename T, typename F>
1789HALIDE_ALWAYS_INLINE auto select(C &&c, T &&t, F &&f) noexcept -> SelectOp<decltype(pattern_arg(c)), decltype(pattern_arg(t)), decltype(pattern_arg(f))> {
1793 return {pattern_arg(c), pattern_arg(t), pattern_arg(f)};
1794}
1795
1796template<typename A, typename B>
1798 struct pattern_tag {};
1801
1803
1806
1807 constexpr static bool canonical = A::canonical && B::canonical;
1808
1809 template<uint32_t bound>
1810 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1811 if (e.node_type == Broadcast::_node_type) {
1812 const Broadcast &op = (const Broadcast &)e;
1813 if (a.template match<bound>(*op.value.get(), state) &&
1814 lanes.template match<bound>(op.lanes, state)) {
1815 return true;
1816 }
1817 }
1818 return false;
1819 }
1820
1821 template<uint32_t bound, typename A2, typename B2>
1822 HALIDE_ALWAYS_INLINE bool match(const BroadcastOp<A2, B2> &op, MatcherState &state) const noexcept {
1823 return (a.template match<bound>(unwrap(op.a), state) &&
1824 lanes.template match<(bound | bindings<A>::mask)>(unwrap(op.lanes), state));
1825 }
1826
1831 lanes.make_folded_const(lanes_val, ty, state);
1832 int32_t l = (int32_t)lanes_val.u.i64;
1833 type_hint.lanes /= l;
1834 Expr val = a.make(state, type_hint);
1835 if (l == 1) {
1836 return val;
1837 } else {
1838 return Broadcast::make(std::move(val), l);
1839 }
1840 }
1841
1842 constexpr static bool foldable = false;
1843
1844 template<typename A1 = A>
1848 lanes.make_folded_const(lanes_val, lanes_ty, state);
1849 uint16_t l = (uint16_t)lanes_val.u.i64;
1850 a.make_folded_const(val, ty, state);
1851 ty.lanes = l | (ty.lanes & MatcherState::special_values_mask);
1852 }
1853};
1854
1855template<typename A, typename B>
1856inline std::ostream &operator<<(std::ostream &s, const BroadcastOp<A, B> &op) {
1857 s << "broadcast(" << op.a << ", " << op.lanes << ")";
1858 return s;
1859}
1860
1861template<typename A, typename B>
1862HALIDE_ALWAYS_INLINE auto broadcast(A &&a, B lanes) noexcept -> BroadcastOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes))> {
1864 return {pattern_arg(a), pattern_arg(lanes)};
1865}
1866
1867template<typename A, typename B, typename C>
1868struct RampOp {
1869 struct pattern_tag {};
1873
1875
1878
1879 constexpr static bool canonical = A::canonical && B::canonical && C::canonical;
1880
1881 template<uint32_t bound>
1882 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1883 if (e.node_type != Ramp::_node_type) {
1884 return false;
1885 }
1886 const Ramp &op = (const Ramp &)e;
1887 if (a.template match<bound>(*op.base.get(), state) &&
1888 b.template match<(bound | bindings<A>::mask)>(*op.stride.get(), state) &&
1889 lanes.template match<(bound | bindings<A>::mask | bindings<B>::mask)>(op.lanes, state)) {
1890 return true;
1891 } else {
1892 return false;
1893 }
1894 }
1895
1896 template<uint32_t bound, typename A2, typename B2, typename C2>
1897 HALIDE_ALWAYS_INLINE bool match(const RampOp<A2, B2, C2> &op, MatcherState &state) const noexcept {
1898 return (a.template match<bound>(unwrap(op.a), state) &&
1899 b.template match<(bound | bindings<A>::mask)>(unwrap(op.b), state) &&
1900 lanes.template match<(bound | bindings<A>::mask | bindings<B>::mask)>(unwrap(op.lanes), state));
1901 }
1902
1907 lanes.make_folded_const(lanes_val, ty, state);
1908 int32_t l = (int32_t)lanes_val.u.i64;
1909 type_hint.lanes /= l;
1910 Expr ea, eb;
1911 eb = b.make(state, type_hint);
1912 ea = a.make(state, eb.type());
1913 return Ramp::make(std::move(ea), std::move(eb), l);
1914 }
1915
1916 constexpr static bool foldable = false;
1917};
1918
1919template<typename A, typename B, typename C>
1920std::ostream &operator<<(std::ostream &s, const RampOp<A, B, C> &op) {
1921 s << "ramp(" << op.a << ", " << op.b << ", " << op.lanes << ")";
1922 return s;
1923}
1924
1925template<typename A, typename B, typename C>
1926HALIDE_ALWAYS_INLINE auto ramp(A &&a, B &&b, C &&c) noexcept -> RampOp<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1930 return {pattern_arg(a), pattern_arg(b), pattern_arg(c)};
1931}
1932
1933template<typename A, typename B, VectorReduce::Operator reduce_op>
1935 struct pattern_tag {};
1938
1940
1943 constexpr static bool canonical = A::canonical;
1944
1945 template<uint32_t bound>
1946 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1947 if (e.node_type == VectorReduce::_node_type) {
1948 const VectorReduce &op = (const VectorReduce &)e;
1949 if (op.op == reduce_op &&
1950 a.template match<bound>(*op.value.get(), state) &&
1951 lanes.template match<(bound | bindings<A>::mask)>(op.type.lanes(), state)) {
1952 return true;
1953 }
1954 }
1955 return false;
1956 }
1957
1958 template<uint32_t bound, typename A2, typename B2, VectorReduce::Operator reduce_op_2>
1960 return (reduce_op == reduce_op_2 &&
1961 a.template match<bound>(unwrap(op.a), state) &&
1962 lanes.template match<(bound | bindings<A>::mask)>(unwrap(op.lanes), state));
1963 }
1964
1969 lanes.make_folded_const(lanes_val, ty, state);
1970 int l = (int)lanes_val.u.i64;
1971 return VectorReduce::make(reduce_op, a.make(state, type_hint), l);
1972 }
1973
1974 constexpr static bool foldable = false;
1975};
1976
1977template<typename A, typename B, VectorReduce::Operator reduce_op>
1978inline std::ostream &operator<<(std::ostream &s, const VectorReduceOp<A, B, reduce_op> &op) {
1979 s << "vector_reduce(" << reduce_op << ", " << op.a << ", " << op.lanes << ")";
1980 return s;
1981}
1982
1983template<typename A, typename B>
1984HALIDE_ALWAYS_INLINE auto h_add(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Add> {
1986 return {pattern_arg(a), pattern_arg(lanes)};
1987}
1988
1989template<typename A, typename B>
1990HALIDE_ALWAYS_INLINE auto h_min(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Min> {
1992 return {pattern_arg(a), pattern_arg(lanes)};
1993}
1994
1995template<typename A, typename B>
1996HALIDE_ALWAYS_INLINE auto h_max(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Max> {
1998 return {pattern_arg(a), pattern_arg(lanes)};
1999}
2000
2001template<typename A, typename B>
2002HALIDE_ALWAYS_INLINE auto h_and(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::And> {
2004 return {pattern_arg(a), pattern_arg(lanes)};
2005}
2006
2007template<typename A, typename B>
2008HALIDE_ALWAYS_INLINE auto h_or(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Or> {
2010 return {pattern_arg(a), pattern_arg(lanes)};
2011}
2012
2013template<typename A>
2014struct NegateOp {
2015 struct pattern_tag {};
2017
2019
2022
2023 constexpr static bool canonical = A::canonical;
2024
2025 template<uint32_t bound>
2026 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2027 if (e.node_type != Sub::_node_type) {
2028 return false;
2029 }
2030 const Sub &op = (const Sub &)e;
2031 return (a.template match<bound>(*op.b.get(), state) &&
2032 is_const_zero(op.a));
2033 }
2034
2035 template<uint32_t bound, typename A2>
2036 HALIDE_ALWAYS_INLINE bool match(NegateOp<A2> &&p, MatcherState &state) const noexcept {
2037 return a.template match<bound>(unwrap(p.a), state);
2038 }
2039
2042 Expr ea = a.make(state, type_hint);
2043 Expr z = make_zero(ea.type());
2044 return Sub::make(std::move(z), std::move(ea));
2045 }
2046
2047 constexpr static bool foldable = A::foldable;
2048
2049 template<typename A1 = A>
2051 a.make_folded_const(val, ty, state);
2052 int dead_bits = 64 - ty.bits;
2053 switch (ty.code) {
2054 case halide_type_int:
2055 if (ty.bits >= 32 && val.u.u64 && (val.u.u64 << (65 - ty.bits)) == 0) {
2056 // Trying to negate the most negative signed int for a no-overflow type.
2058 } else {
2059 // Negate, drop the high bits, and then sign-extend them back
2060 val.u.i64 = int64_t(uint64_t(-val.u.i64) << dead_bits) >> dead_bits;
2061 }
2062 break;
2063 case halide_type_uint:
2064 val.u.u64 = ((-val.u.u64) << dead_bits) >> dead_bits;
2065 break;
2066 case halide_type_float:
2067 case halide_type_bfloat:
2068 val.u.f64 = -val.u.f64;
2069 break;
2070 default:
2071 // unreachable
2072 ;
2073 }
2074 }
2075};
2076
2077template<typename A>
2078std::ostream &operator<<(std::ostream &s, const NegateOp<A> &op) {
2079 s << "-" << op.a;
2080 return s;
2081}
2082
2083template<typename A>
2084HALIDE_ALWAYS_INLINE auto operator-(A &&a) noexcept -> NegateOp<decltype(pattern_arg(a))> {
2086 return {pattern_arg(a)};
2087}
2088
2089template<typename A>
2094
2095template<typename A>
2096struct CastOp {
2097 struct pattern_tag {};
2100
2102
2105 constexpr static bool canonical = A::canonical;
2106
2107 template<uint32_t bound>
2108 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2109 if (e.node_type != Cast::_node_type) {
2110 return false;
2111 }
2112 const Cast &op = (const Cast &)e;
2113 return (e.type == t &&
2114 a.template match<bound>(*op.value.get(), state));
2115 }
2116 template<uint32_t bound, typename A2>
2117 HALIDE_ALWAYS_INLINE bool match(const CastOp<A2> &op, MatcherState &state) const noexcept {
2118 return t == op.t && a.template match<bound>(unwrap(op.a), state);
2119 }
2120
2123 return cast(t, a.make(state, {}));
2124 }
2125
2126 constexpr static bool foldable = false;
2127};
2128
2129template<typename A>
2130std::ostream &operator<<(std::ostream &s, const CastOp<A> &op) {
2131 s << "cast(" << op.t << ", " << op.a << ")";
2132 return s;
2133}
2134
2135template<typename A>
2136HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp<decltype(pattern_arg(a))> {
2138 return {t, pattern_arg(a)};
2139}
2140
2141template<typename A>
2142struct WidenOp {
2143 struct pattern_tag {};
2145
2147
2150 constexpr static bool canonical = A::canonical;
2151
2152 template<uint32_t bound>
2153 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2154 if (e.node_type != Cast::_node_type) {
2155 return false;
2156 }
2157 const Cast &op = (const Cast &)e;
2158 return (e.type == op.value.type().widen() &&
2159 a.template match<bound>(*op.value.get(), state));
2160 }
2161 template<uint32_t bound, typename A2>
2162 HALIDE_ALWAYS_INLINE bool match(const WidenOp<A2> &op, MatcherState &state) const noexcept {
2163 return a.template match<bound>(unwrap(op.a), state);
2164 }
2165
2168 Expr e = a.make(state, {});
2169 Type w = e.type().widen();
2170 return cast(w, std::move(e));
2171 }
2172
2173 constexpr static bool foldable = false;
2174};
2175
2176template<typename A>
2177std::ostream &operator<<(std::ostream &s, const WidenOp<A> &op) {
2178 s << "widen(" << op.a << ")";
2179 return s;
2180}
2181
2182template<typename A>
2183HALIDE_ALWAYS_INLINE auto widen(A &&a) noexcept -> WidenOp<decltype(pattern_arg(a))> {
2185 return {pattern_arg(a)};
2186}
2187
2188template<typename Vec, typename Base, typename Stride, typename Lanes>
2189struct SliceOp {
2190 struct pattern_tag {};
2195
2196 static constexpr uint32_t binds = Vec::binds | Base::binds | Stride::binds | Lanes::binds;
2197
2200 constexpr static bool canonical = Vec::canonical && Base::canonical && Stride::canonical && Lanes::canonical;
2201
2202 template<uint32_t bound>
2203 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2204 if (e.node_type != IRNodeType::Shuffle) {
2205 return false;
2206 }
2207 const Shuffle &v = (const Shuffle &)e;
2208 return v.vectors.size() == 1 &&
2209 v.is_slice() &&
2210 vec.template match<bound>(*v.vectors[0].get(), state) &&
2211 base.template match<(bound | bindings<Vec>::mask)>(v.slice_begin(), state) &&
2212 stride.template match<(bound | bindings<Vec>::mask | bindings<Base>::mask)>(v.slice_stride(), state) &&
2214 }
2215
2220 base.make_folded_const(base_val, ty, state);
2221 int b = (int)base_val.u.i64;
2222 stride.make_folded_const(stride_val, ty, state);
2223 int s = (int)stride_val.u.i64;
2224 lanes.make_folded_const(lanes_val, ty, state);
2225 int l = (int)lanes_val.u.i64;
2226 return Shuffle::make_slice(vec.make(state, type_hint), b, s, l);
2227 }
2228
2229 constexpr static bool foldable = false;
2230
2233 : vec(v), base(b), stride(s), lanes(l) {
2234 static_assert(Base::foldable, "Base of slice should consist only of operations that constant-fold");
2235 static_assert(Stride::foldable, "Stride of slice should consist only of operations that constant-fold");
2236 static_assert(Lanes::foldable, "Lanes of slice should consist only of operations that constant-fold");
2237 }
2238};
2239
2240template<typename Vec, typename Base, typename Stride, typename Lanes>
2241std::ostream &operator<<(std::ostream &s, const SliceOp<Vec, Base, Stride, Lanes> &op) {
2242 s << "slice(" << op.vec << ", " << op.base << ", " << op.stride << ", " << op.lanes << ")";
2243 return s;
2244}
2245
2246template<typename Vec, typename Base, typename Stride, typename Lanes>
2247HALIDE_ALWAYS_INLINE auto slice(Vec vec, Base base, Stride stride, Lanes lanes) noexcept
2248 -> SliceOp<decltype(pattern_arg(vec)), decltype(pattern_arg(base)), decltype(pattern_arg(stride)), decltype(pattern_arg(lanes))> {
2249 return {pattern_arg(vec), pattern_arg(base), pattern_arg(stride), pattern_arg(lanes)};
2250}
2251
2252template<typename A>
2253struct Fold {
2254 struct pattern_tag {};
2256
2258
2261 constexpr static bool canonical = true;
2262
2267 a.make_folded_const(c, ty, state);
2268
2269 // The result of the fold may have an underspecified type
2270 // (e.g. because it's from an int literal). Make the type code
2271 // and bits match the required type, if there is one (we can
2272 // tell from the bits field).
2273 if (type_hint.bits) {
2274 if (((int)ty.code == (int)halide_type_int) &&
2275 ((int)type_hint.code == (int)halide_type_float)) {
2276 int64_t x = c.u.i64;
2277 c.u.f64 = (double)x;
2278 }
2279 ty.code = type_hint.code;
2280 ty.bits = type_hint.bits;
2281 }
2282
2283 return make_const_expr(c, ty);
2284 }
2285
2286 constexpr static bool foldable = A::foldable;
2287
2288 template<typename A1 = A>
2290 a.make_folded_const(val, ty, state);
2291 }
2292};
2293
2294template<typename A>
2295HALIDE_ALWAYS_INLINE auto fold(A &&a) noexcept -> Fold<decltype(pattern_arg(a))> {
2297 return {pattern_arg(a)};
2298}
2299
2300template<typename A>
2301std::ostream &operator<<(std::ostream &s, const Fold<A> &op) {
2302 s << "fold(" << op.a << ")";
2303 return s;
2304}
2305
2306template<typename A>
2308 struct pattern_tag {};
2310
2312
2313 // This rule is a predicate, so it always evaluates to a boolean,
2314 // which has IRNodeType UIntImm
2317 constexpr static bool canonical = true;
2318
2319 constexpr static bool foldable = A::foldable;
2320
2321 template<typename A1 = A>
2323 a.make_folded_const(val, ty, state);
2324 ty.code = halide_type_uint;
2325 ty.bits = 64;
2326 val.u.u64 = (ty.lanes & MatcherState::special_values_mask) != 0;
2327 ty.lanes = 1;
2328 }
2329};
2330
2331template<typename A>
2332HALIDE_ALWAYS_INLINE auto overflows(A &&a) noexcept -> Overflows<decltype(pattern_arg(a))> {
2334 return {pattern_arg(a)};
2335}
2336
2337template<typename A>
2338std::ostream &operator<<(std::ostream &s, const Overflows<A> &op) {
2339 s << "overflows(" << op.a << ")";
2340 return s;
2341}
2342
2343struct Overflow {
2344 struct pattern_tag {};
2345
2346 constexpr static uint32_t binds = 0;
2347
2348 // Overflow is an intrinsic, represented as a Call node
2351 constexpr static bool canonical = true;
2352
2353 template<uint32_t bound>
2354 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2355 if (e.node_type != Call::_node_type) {
2356 return false;
2357 }
2358 const Call &op = (const Call &)e;
2360 }
2361
2367
2368 constexpr static bool foldable = true;
2369
2372 val.u.u64 = 0;
2374 }
2375};
2376
2377inline std::ostream &operator<<(std::ostream &s, const Overflow &op) {
2378 s << "overflow()";
2379 return s;
2380}
2381
2382template<typename A>
2383struct IsConst {
2384 struct pattern_tag {};
2385
2387
2388 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2391 constexpr static bool canonical = true;
2392
2396
2397 constexpr static bool foldable = true;
2398
2399 template<typename A1 = A>
2401 Expr e = a.make(state, {});
2402 ty.code = halide_type_uint;
2403 ty.bits = 64;
2404 ty.lanes = 1;
2405 if (check_v) {
2406 val.u.u64 = ::Halide::Internal::is_const(e, v) ? 1 : 0;
2407 } else {
2408 val.u.u64 = ::Halide::Internal::is_const(e) ? 1 : 0;
2409 }
2410 }
2411};
2412
2413template<typename A>
2414HALIDE_ALWAYS_INLINE auto is_const(A &&a) noexcept -> IsConst<decltype(pattern_arg(a))> {
2416 return {pattern_arg(a), false, 0};
2417}
2418
2419template<typename A>
2420HALIDE_ALWAYS_INLINE auto is_const(A &&a, int64_t value) noexcept -> IsConst<decltype(pattern_arg(a))> {
2422 return {pattern_arg(a), true, value};
2423}
2424
2425template<typename A>
2426std::ostream &operator<<(std::ostream &s, const IsConst<A> &op) {
2427 if (op.check_v) {
2428 s << "is_const(" << op.a << ")";
2429 } else {
2430 s << "is_const(" << op.a << ", " << op.v << ")";
2431 }
2432 return s;
2433}
2434
2435template<typename A, typename Prover>
2436struct CanProve {
2437 struct pattern_tag {};
2439 Prover *prover; // An existing simplifying mutator
2440
2442
2443 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2446 constexpr static bool canonical = true;
2447
2448 constexpr static bool foldable = true;
2449
2450 // Includes a raw call to an inlined make method, so don't inline.
2452 Expr condition = a.make(state, {});
2453 condition = prover->mutate(condition, nullptr);
2454 val.u.u64 = is_const_one(condition);
2455 ty.code = halide_type_uint;
2456 ty.bits = 1;
2457 ty.lanes = condition.type().lanes();
2458 }
2459};
2460
2461template<typename A, typename Prover>
2462HALIDE_ALWAYS_INLINE auto can_prove(A &&a, Prover *p) noexcept -> CanProve<decltype(pattern_arg(a)), Prover> {
2464 return {pattern_arg(a), p};
2465}
2466
2467template<typename A, typename Prover>
2468std::ostream &operator<<(std::ostream &s, const CanProve<A, Prover> &op) {
2469 s << "can_prove(" << op.a << ")";
2470 return s;
2471}
2472
2473template<typename A>
2474struct IsFloat {
2475 struct pattern_tag {};
2477
2479
2480 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2483 constexpr static bool canonical = true;
2484
2485 constexpr static bool foldable = true;
2486
2489 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2490 Type t = a.make(state, {}).type();
2491 val.u.u64 = t.is_float();
2492 ty.code = halide_type_uint;
2493 ty.bits = 1;
2494 ty.lanes = t.lanes();
2495 }
2496};
2497
2498template<typename A>
2499HALIDE_ALWAYS_INLINE auto is_float(A &&a) noexcept -> IsFloat<decltype(pattern_arg(a))> {
2501 return {pattern_arg(a)};
2502}
2503
2504template<typename A>
2505std::ostream &operator<<(std::ostream &s, const IsFloat<A> &op) {
2506 s << "is_float(" << op.a << ")";
2507 return s;
2508}
2509
2510template<typename A>
2511struct IsInt {
2512 struct pattern_tag {};
2516
2518
2519 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2522 constexpr static bool canonical = true;
2523
2524 constexpr static bool foldable = true;
2525
2528 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2529 Type t = a.make(state, {}).type();
2530 val.u.u64 = t.is_int() && (bits == 0 || t.bits() == bits) && (lanes == 0 || t.lanes() == lanes);
2531 ty.code = halide_type_uint;
2532 ty.bits = 1;
2533 ty.lanes = t.lanes();
2534 }
2535};
2536
2537template<typename A>
2540 return {pattern_arg(a), bits, lanes};
2541}
2542
2543template<typename A>
2544std::ostream &operator<<(std::ostream &s, const IsInt<A> &op) {
2545 s << "is_int(" << op.a;
2546 if (op.bits > 0) {
2547 s << ", " << op.bits;
2548 }
2549 if (op.lanes > 0) {
2550 s << ", " << op.lanes;
2551 }
2552 s << ")";
2553 return s;
2554}
2555
2556template<typename A>
2557struct IsUInt {
2558 struct pattern_tag {};
2562
2564
2565 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2568 constexpr static bool canonical = true;
2569
2570 constexpr static bool foldable = true;
2571
2574 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2575 Type t = a.make(state, {}).type();
2576 val.u.u64 = t.is_uint() && (bits == 0 || t.bits() == bits) && (lanes == 0 || t.lanes() == lanes);
2577 ty.code = halide_type_uint;
2578 ty.bits = 1;
2579 ty.lanes = t.lanes();
2580 }
2581};
2582
2583template<typename A>
2586 return {pattern_arg(a), bits, lanes};
2587}
2588
2589template<typename A>
2590std::ostream &operator<<(std::ostream &s, const IsUInt<A> &op) {
2591 s << "is_uint(" << op.a;
2592 if (op.bits > 0) {
2593 s << ", " << op.bits;
2594 }
2595 if (op.lanes > 0) {
2596 s << ", " << op.lanes;
2597 }
2598 s << ")";
2599 return s;
2600}
2601
2602template<typename A>
2603struct IsScalar {
2604 struct pattern_tag {};
2606
2608
2609 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2612 constexpr static bool canonical = true;
2613
2614 constexpr static bool foldable = true;
2615
2618 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2619 Type t = a.make(state, {}).type();
2620 val.u.u64 = t.is_scalar();
2621 ty.code = halide_type_uint;
2622 ty.bits = 1;
2623 ty.lanes = t.lanes();
2624 }
2625};
2626
2627template<typename A>
2628HALIDE_ALWAYS_INLINE auto is_scalar(A &&a) noexcept -> IsScalar<decltype(pattern_arg(a))> {
2630 return {pattern_arg(a)};
2631}
2632
2633template<typename A>
2634std::ostream &operator<<(std::ostream &s, const IsScalar<A> &op) {
2635 s << "is_scalar(" << op.a << ")";
2636 return s;
2637}
2638
2639template<typename A>
2641 struct pattern_tag {};
2643
2645
2646 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2649 constexpr static bool canonical = true;
2650
2651 constexpr static bool foldable = true;
2652
2655 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2656 a.make_folded_const(val, ty, state);
2657 const uint64_t max_bits = (uint64_t)(-1) >> (64 - ty.bits + (ty.code == halide_type_int));
2658 if (ty.code == halide_type_uint || ty.code == halide_type_int) {
2659 val.u.u64 = (val.u.u64 == max_bits);
2660 } else {
2661 val.u.u64 = 0;
2662 }
2663 ty.code = halide_type_uint;
2664 ty.bits = 1;
2665 }
2666};
2667
2668template<typename A>
2669HALIDE_ALWAYS_INLINE auto is_max_value(A &&a) noexcept -> IsMaxValue<decltype(pattern_arg(a))> {
2671 return {pattern_arg(a)};
2672}
2673
2674template<typename A>
2675std::ostream &operator<<(std::ostream &s, const IsMaxValue<A> &op) {
2676 s << "is_max_value(" << op.a << ")";
2677 return s;
2678}
2679
2680template<typename A>
2682 struct pattern_tag {};
2684
2686
2687 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2690 constexpr static bool canonical = true;
2691
2692 constexpr static bool foldable = true;
2693
2696 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2697 a.make_folded_const(val, ty, state);
2698 if (ty.code == halide_type_int) {
2699 const uint64_t min_bits = (uint64_t)(-1) << (ty.bits - 1);
2700 val.u.u64 = (val.u.u64 == min_bits);
2701 } else if (ty.code == halide_type_uint) {
2702 val.u.u64 = (val.u.u64 == 0);
2703 } else {
2704 val.u.u64 = 0;
2705 }
2706 ty.code = halide_type_uint;
2707 ty.bits = 1;
2708 }
2709};
2710
2711template<typename A>
2712HALIDE_ALWAYS_INLINE auto is_min_value(A &&a) noexcept -> IsMinValue<decltype(pattern_arg(a))> {
2714 return {pattern_arg(a)};
2715}
2716
2717template<typename A>
2718std::ostream &operator<<(std::ostream &s, const IsMinValue<A> &op) {
2719 s << "is_min_value(" << op.a << ")";
2720 return s;
2721}
2722
2723template<typename A>
2724struct LanesOf {
2725 struct pattern_tag {};
2727
2729
2730 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2733 constexpr static bool canonical = true;
2734
2735 constexpr static bool foldable = true;
2736
2739 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2740 Type t = a.make(state, {}).type();
2741 val.u.u64 = t.lanes();
2742 ty.code = halide_type_uint;
2743 ty.bits = 32;
2744 ty.lanes = 1;
2745 }
2746};
2747
2748template<typename A>
2749HALIDE_ALWAYS_INLINE auto lanes_of(A &&a) noexcept -> LanesOf<decltype(pattern_arg(a))> {
2751 return {pattern_arg(a)};
2752}
2753
2754template<typename A>
2755std::ostream &operator<<(std::ostream &s, const LanesOf<A> &op) {
2756 s << "lanes_of(" << op.a << ")";
2757 return s;
2758}
2759
2760// Verify properties of each rewrite rule. Currently just fuzz tests them.
2761template<typename Before,
2762 typename After,
2763 typename Predicate,
2764 typename = std::enable_if_t<std::decay_t<Before>::foldable &&
2765 std::decay_t<After>::foldable>>
2767 halide_type_t wildcard_type, halide_type_t output_type) noexcept {
2768
2769 // We only validate the rules in the scalar case
2770 wildcard_type.lanes = output_type.lanes = 1;
2771
2772 // Track which types this rule has been tested for before
2773 static std::set<uint32_t> tested;
2774
2775 if (!tested.insert(reinterpret_bits<uint32_t>(wildcard_type)).second) {
2776 return;
2777 }
2778
2779 // Print it in a form where it can be piped into a python/z3 validator
2780 debug(0) << "validate('" << before << "', '" << after << "', '" << pred << "', " << Type(wildcard_type) << ", " << Type(output_type) << ")\n";
2781
2782 // Substitute some random constants into the before and after
2783 // expressions and see if the rule holds true. This should catch
2784 // silly errors, but not necessarily corner cases.
2785 static std::mt19937_64 rng(0);
2786 MatcherState state;
2787
2788 Expr exprs[max_wild];
2789
2790 for (int trials = 0; trials < 100; trials++) {
2791 // We want to test small constants more frequently than
2792 // large ones, otherwise we'll just get coverage of
2793 // overflow rules.
2794 int shift = (int)(rng() & (wildcard_type.bits - 1));
2795
2796 for (int i = 0; i < max_wild; i++) {
2797 // Bind all the exprs and constants
2798 switch (wildcard_type.code) {
2799 case halide_type_uint: {
2800 // Normalize to the type's range by adding zero
2801 uint64_t val = constant_fold_bin_op<Add>(wildcard_type, (uint64_t)rng() >> shift, 0);
2802 state.set_bound_const(i, val, wildcard_type);
2803 val = constant_fold_bin_op<Add>(wildcard_type, (uint64_t)rng() >> shift, 0);
2804 exprs[i] = make_const(wildcard_type, val);
2805 state.set_binding(i, *exprs[i].get());
2806 } break;
2807 case halide_type_int: {
2808 int64_t val = constant_fold_bin_op<Add>(wildcard_type, (int64_t)rng() >> shift, 0);
2809 state.set_bound_const(i, val, wildcard_type);
2810 val = constant_fold_bin_op<Add>(wildcard_type, (int64_t)rng() >> shift, 0);
2811 exprs[i] = make_const(wildcard_type, val);
2812 } break;
2813 case halide_type_float:
2814 case halide_type_bfloat: {
2815 // Use a very narrow range of precise floats, so
2816 // that none of the rules a human is likely to
2817 // write have instabilities.
2818 double val = ((int64_t)(rng() & 15) - 8) / 2.0;
2819 state.set_bound_const(i, val, wildcard_type);
2820 val = ((int64_t)(rng() & 15) - 8) / 2.0;
2821 exprs[i] = make_const(wildcard_type, val);
2822 } break;
2823 default:
2824 return; // Don't care about handles
2825 }
2826 state.set_binding(i, *exprs[i].get());
2827 }
2828
2830 halide_type_t type = output_type;
2831 if (!evaluate_predicate(pred, state)) {
2832 continue;
2833 }
2834 before.make_folded_const(val_before, type, state);
2835 uint16_t lanes = type.lanes;
2836 after.make_folded_const(val_after, type, state);
2837 lanes |= type.lanes;
2838
2840 continue;
2841 }
2842
2843 bool ok = true;
2844 switch (output_type.code) {
2845 case halide_type_uint:
2846 // Compare normalized representations
2847 ok &= (constant_fold_bin_op<Add>(output_type, val_before.u.u64, 0) ==
2848 constant_fold_bin_op<Add>(output_type, val_after.u.u64, 0));
2849 break;
2850 case halide_type_int:
2851 ok &= (constant_fold_bin_op<Add>(output_type, val_before.u.i64, 0) ==
2852 constant_fold_bin_op<Add>(output_type, val_after.u.i64, 0));
2853 break;
2854 case halide_type_float:
2855 case halide_type_bfloat: {
2856 double error = std::abs(val_before.u.f64 - val_after.u.f64);
2857 // We accept an equal bit pattern (e.g. inf vs inf),
2858 // a small floating point difference, or turning a nan into not-a-nan.
2859 ok &= (error < 0.01 ||
2860 val_before.u.u64 == val_after.u.u64 ||
2861 std::isnan(val_before.u.f64));
2862 break;
2863 }
2864 default:
2865 return;
2866 }
2867
2868 if (!ok) {
2869 debug(0) << "Fails with values:\n";
2870 for (int i = 0; i < max_wild; i++) {
2872 state.get_bound_const(i, val, wildcard_type);
2873 debug(0) << " c" << i << ": " << make_const_expr(val, wildcard_type) << "\n";
2874 }
2875 for (int i = 0; i < max_wild; i++) {
2876 debug(0) << " _" << i << ": " << Expr(state.get_binding(i)) << "\n";
2877 }
2878 debug(0) << " Before: " << make_const_expr(val_before, output_type) << "\n";
2879 debug(0) << " After: " << make_const_expr(val_after, output_type) << "\n";
2880 debug(0) << val_before.u.u64 << " " << val_after.u.u64 << "\n";
2882 }
2883 }
2884}
2885
2886template<typename Before,
2887 typename After,
2888 typename Predicate,
2889 typename = std::enable_if_t<!(std::decay_t<Before>::foldable &&
2890 std::decay_t<After>::foldable)>>
2892 halide_type_t, halide_type_t, int dummy = 0) noexcept {
2893 // We can't verify rewrite rules that can't be constant-folded.
2894}
2895
2897bool evaluate_predicate(bool x, MatcherState &) noexcept {
2898 return x;
2899}
2900
2901template<typename Pattern,
2902 typename = typename enable_if_pattern<Pattern>::type>
2906 p.make_folded_const(c, ty, state);
2907 // Overflow counts as a failed predicate
2908 return (c.u.u64 != 0) && ((ty.lanes & MatcherState::special_values_mask) == 0);
2909}
2910
2911// #defines for testing
2912
2913// Print all successful or failed matches
2914#define HALIDE_DEBUG_MATCHED_RULES 0
2915#define HALIDE_DEBUG_UNMATCHED_RULES 0
2916
2917// Set to true if you want to fuzz test every rewrite passed to
2918// operator() to ensure the input and the output have the same value
2919// for lots of random values of the wildcards. Run
2920// correctness_simplify with this on.
2921#define HALIDE_FUZZ_TEST_RULES 0
2922
2923template<typename Instance>
2924struct Rewriter {
2930
2935
2936 template<typename After>
2938#if HALIDE_DEBUG_MATCHED_RULES
2939 debug(0) << instance << " -> " << after << "\n";
2940#endif
2941 result = after.make(state, output_type);
2942 }
2943
2944 template<typename Before,
2945 typename After,
2946 typename = typename enable_if_pattern<Before>::type,
2947 typename = typename enable_if_pattern<After>::type>
2949 static_assert((Before::binds & After::binds) == After::binds, "Rule result uses unbound values");
2950 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2951 static_assert(After::canonical, "RHS of rewrite rule should be in canonical form");
2952#if HALIDE_FUZZ_TEST_RULES
2954#endif
2955 if (before.template match<0>(unwrap(instance), state)) {
2957#if HALIDE_DEBUG_MATCHED_RULES
2958 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2959#endif
2960 return true;
2961 } else {
2962#if HALIDE_DEBUG_UNMATCHED_RULES
2963 debug(0) << instance << " does not match " << before << "\n";
2964#endif
2965 return false;
2966 }
2967 }
2968
2969 template<typename Before,
2970 typename = typename enable_if_pattern<Before>::type>
2972 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2973 if (before.template match<0>(unwrap(instance), state)) {
2974 result = after;
2975#if HALIDE_DEBUG_MATCHED_RULES
2976 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2977#endif
2978 return true;
2979 } else {
2980#if HALIDE_DEBUG_UNMATCHED_RULES
2981 debug(0) << instance << " does not match " << before << "\n";
2982#endif
2983 return false;
2984 }
2985 }
2986
2987 template<typename Before,
2988 typename = typename enable_if_pattern<Before>::type>
2990 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2991#if HALIDE_FUZZ_TEST_RULES
2993#endif
2994 if (before.template match<0>(unwrap(instance), state)) {
2996#if HALIDE_DEBUG_MATCHED_RULES
2997 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2998#endif
2999 return true;
3000 } else {
3001#if HALIDE_DEBUG_UNMATCHED_RULES
3002 debug(0) << instance << " does not match " << before << "\n";
3003#endif
3004 return false;
3005 }
3006 }
3007
3008 template<typename Before,
3009 typename After,
3010 typename Predicate,
3011 typename = typename enable_if_pattern<Before>::type,
3012 typename = typename enable_if_pattern<After>::type,
3013 typename = typename enable_if_pattern<Predicate>::type>
3015 static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
3016 static_assert((Before::binds & After::binds) == After::binds, "Rule result uses unbound values");
3017 static_assert((Before::binds & Predicate::binds) == Predicate::binds, "Rule predicate uses unbound values");
3018 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
3019 static_assert(After::canonical, "RHS of rewrite rule should be in canonical form");
3020
3021#if HALIDE_FUZZ_TEST_RULES
3023#endif
3024 if (before.template match<0>(unwrap(instance), state) &&
3027#if HALIDE_DEBUG_MATCHED_RULES
3028 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
3029#endif
3030 return true;
3031 } else {
3032#if HALIDE_DEBUG_UNMATCHED_RULES
3033 debug(0) << instance << " does not match " << before << "\n";
3034#endif
3035 return false;
3036 }
3037 }
3038
3039 template<typename Before,
3040 typename Predicate,
3041 typename = typename enable_if_pattern<Before>::type,
3042 typename = typename enable_if_pattern<Predicate>::type>
3044 static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
3045 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
3046
3047 if (before.template match<0>(unwrap(instance), state) &&
3049 result = after;
3050#if HALIDE_DEBUG_MATCHED_RULES
3051 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
3052#endif
3053 return true;
3054 } else {
3055#if HALIDE_DEBUG_UNMATCHED_RULES
3056 debug(0) << instance << " does not match " << before << "\n";
3057#endif
3058 return false;
3059 }
3060 }
3061
3062 template<typename Before,
3063 typename Predicate,
3064 typename = typename enable_if_pattern<Before>::type,
3065 typename = typename enable_if_pattern<Predicate>::type>
3067 static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
3068 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
3069#if HALIDE_FUZZ_TEST_RULES
3071#endif
3072 if (before.template match<0>(unwrap(instance), state) &&
3075#if HALIDE_DEBUG_MATCHED_RULES
3076 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
3077#endif
3078 return true;
3079 } else {
3080#if HALIDE_DEBUG_UNMATCHED_RULES
3081 debug(0) << instance << " does not match " << before << "\n";
3082#endif
3083 return false;
3084 }
3085 }
3086};
3087
3088/** Construct a rewriter for the given instance, which may be a pattern
3089 * with concrete expressions as leaves, or just an expression. The
3090 * second optional argument (wildcard_type) is a hint as to what the
3091 * type of the wildcards is likely to be. If omitted it uses the same
3092 * type as the expression itself. They are not required to be this
3093 * type, but the rule will only be tested for wildcards of that type
3094 * when testing is enabled.
3095 *
3096 * The rewriter can be used to check to see if the instance is one of
3097 * some number of patterns and if so rewrite it into another form,
3098 * using its operator() method. See Simplify.cpp for a bunch of
3099 * example usage.
3100 *
3101 * Important: Any Exprs in patterns are captured by reference, not by
3102 * value, so ensure they outlive the rewriter.
3103 */
3104// @{
3105template<typename Instance,
3106 typename = typename enable_if_pattern<Instance>::type>
3107HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type, halide_type_t wildcard_type) noexcept -> Rewriter<decltype(pattern_arg(instance))> {
3108 return {pattern_arg(instance), output_type, wildcard_type};
3109}
3110
3111template<typename Instance,
3112 typename = typename enable_if_pattern<Instance>::type>
3113HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type) noexcept -> Rewriter<decltype(pattern_arg(instance))> {
3114 return {pattern_arg(instance), output_type, output_type};
3115}
3116
3118auto rewriter(const Expr &e, halide_type_t wildcard_type) noexcept -> Rewriter<decltype(pattern_arg(e))> {
3119 return {pattern_arg(e), e.type(), wildcard_type};
3120}
3121
3123auto rewriter(const Expr &e) noexcept -> Rewriter<decltype(pattern_arg(e))> {
3124 return {pattern_arg(e), e.type(), e.type()};
3125}
3126// @}
3127
3128} // namespace IRMatcher
3129
3130} // namespace Internal
3131} // namespace Halide
3132
3133#endif
#define debug(n)
For optional debugging during codegen, use the debug macro as follows:
Definition Debug.h:52
#define internal_error
Definition Error.h:229
@ halide_type_float
IEEE floating point numbers.
@ halide_type_bfloat
floating point numbers in the bfloat format
@ halide_type_int
signed integers
@ halide_type_uint
unsigned integers
#define HALIDE_NEVER_INLINE
#define HALIDE_ALWAYS_INLINE
Subtypes for Halide expressions (Halide::Expr) and statements (Halide::Internal::Stmt)
Methods to test Exprs and Stmts for equality of value.
Defines various operator overloads and utility functions that make it more pleasant to work with Hali...
HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type, halide_type_t wildcard_type) noexcept -> Rewriter< decltype(pattern_arg(instance))>
Construct a rewriter for the given instance, which may be a pattern with concrete expressions as leav...
Definition IRMatch.h:3107
HALIDE_ALWAYS_INLINE T pattern_arg(T t)
Definition IRMatch.h:567
auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin< Call::rounding_halving_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1564
HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||(a, b))
Definition IRMatch.h:1260
HALIDE_ALWAYS_INLINE auto operator!(A &&a) noexcept -> NotOp< decltype(pattern_arg(a))>
Definition IRMatch.h:1653
auto shift_right(A &&a, B &&b) noexcept -> Intrin< Call::shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1572
auto widening_add(A &&a, B &&b) noexcept -> Intrin< Call::widening_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1530
HALIDE_ALWAYS_INLINE auto is_int(A &&a, uint8_t bits=0, uint16_t lanes=0) noexcept -> IsInt< decltype(pattern_arg(a))>
Definition IRMatch.h:2538
HALIDE_ALWAYS_INLINE bool evaluate_predicate(bool x, MatcherState &) noexcept
Definition IRMatch.h:2897
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Div >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1017
auto abs(A &&a) noexcept -> Intrin< Call::abs, decltype(pattern_arg(a))>
Definition IRMatch.h:1593
HALIDE_ALWAYS_INLINE auto ne(A &&a, B &&b) -> decltype(IRMatcher::operator!=(a, b))
Definition IRMatch.h:1235
HALIDE_ALWAYS_INLINE auto is_uint(A &&a, uint8_t bits=0, uint16_t lanes=0) noexcept -> IsUInt< decltype(pattern_arg(a))>
Definition IRMatch.h:2584
HALIDE_ALWAYS_INLINE auto negate(A &&a) -> decltype(IRMatcher::operator-(a))
Definition IRMatch.h:2090
uint64_t constant_fold_cmp_op(int64_t, int64_t) noexcept
HALIDE_ALWAYS_INLINE auto operator<=(A &&a, B &&b) noexcept -> CmpOp< LE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1155
HALIDE_ALWAYS_INLINE auto operator+(A &&a, B &&b) noexcept -> BinOp< Add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:906
HALIDE_ALWAYS_INLINE auto is_max_value(A &&a) noexcept -> IsMaxValue< decltype(pattern_arg(a))>
Definition IRMatch.h:2669
std::ostream & operator<<(std::ostream &s, const SpecificExpr &e)
Definition IRMatch.h:217
HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&&(a, b))
Definition IRMatch.h:1286
HALIDE_ALWAYS_INLINE auto h_and(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::And >
Definition IRMatch.h:2002
HALIDE_ALWAYS_INLINE auto gt(A &&a, B &&b) -> decltype(IRMatcher::operator>(a, b))
Definition IRMatch.h:1135
HALIDE_ALWAYS_INLINE auto is_const(A &&a) noexcept -> IsConst< decltype(pattern_arg(a))>
Definition IRMatch.h:2414
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LE >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1165
HALIDE_ALWAYS_INLINE auto operator*(A &&a, B &&b) noexcept -> BinOp< Mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:972
HALIDE_ALWAYS_INLINE auto add(A &&a, B &&b) -> decltype(IRMatcher::operator+(a, b))
Definition IRMatch.h:913
auto widen_right_add(A &&a, B &&b) noexcept -> Intrin< Call::widen_right_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1517
HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, b))
Definition IRMatch.h:1012
auto widen_right_mul(A &&a, B &&b) noexcept -> Intrin< Call::widen_right_mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1521
HALIDE_ALWAYS_INLINE auto mul(A &&a, B &&b) -> decltype(IRMatcher::operator*(a, b))
Definition IRMatch.h:979
auto absd(A &&a, B &&b) noexcept -> Intrin< Call::absd, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1598
HALIDE_ALWAYS_INLINE auto slice(Vec vec, Base base, Stride stride, Lanes lanes) noexcept -> SliceOp< decltype(pattern_arg(vec)), decltype(pattern_arg(base)), decltype(pattern_arg(stride)), decltype(pattern_arg(lanes))>
Definition IRMatch.h:2247
HALIDE_ALWAYS_INLINE auto ramp(A &&a, B &&b, C &&c) noexcept -> RampOp< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition IRMatch.h:1926
HALIDE_ALWAYS_INLINE auto operator/(A &&a, B &&b) noexcept -> BinOp< Div, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1005
HALIDE_ALWAYS_INLINE auto widen(A &&a) noexcept -> WidenOp< decltype(pattern_arg(a))>
Definition IRMatch.h:2183
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mod >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1046
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< And >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1291
HALIDE_ALWAYS_INLINE int64_t unwrap(IntLiteral t)
Definition IRMatch.h:559
auto widening_mul(A &&a, B &&b) noexcept -> Intrin< Call::widening_mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1538
HALIDE_ALWAYS_INLINE auto operator>(A &&a, B &&b) noexcept -> CmpOp< GT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1130
HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp< decltype(pattern_arg(a))>
Definition IRMatch.h:2136
HALIDE_ALWAYS_INLINE auto overflows(A &&a) noexcept -> Overflows< decltype(pattern_arg(a))>
Definition IRMatch.h:2332
auto saturating_cast(const Type &t, A &&a) noexcept -> Intrin< Call::saturating_cast, decltype(pattern_arg(a))>
Definition IRMatch.h:1550
HALIDE_ALWAYS_INLINE void assert_is_lvalue_if_expr()
Definition IRMatch.h:576
HALIDE_ALWAYS_INLINE auto operator%(A &&a, B &&b) noexcept -> BinOp< Mod, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1032
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Sub >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:953
auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin< Call::rounding_shift_left, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1576
HALIDE_ALWAYS_INLINE auto is_scalar(A &&a) noexcept -> IsScalar< decltype(pattern_arg(a))>
Definition IRMatch.h:2628
HALIDE_ALWAYS_INLINE auto fold(A &&a) noexcept -> Fold< decltype(pattern_arg(a))>
Definition IRMatch.h:2295
HALIDE_ALWAYS_INLINE auto not_op(A &&a) -> decltype(IRMatcher::operator!(a))
Definition IRMatch.h:1659
auto likely(A &&a) noexcept -> Intrin< Call::likely, decltype(pattern_arg(a))>
Definition IRMatch.h:1603
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Max >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1090
constexpr bool and_reduce()
Definition IRMatch.h:1315
HALIDE_ALWAYS_INLINE auto operator||(A &&a, B &&b) noexcept -> BinOp< Or, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1255
constexpr int max_wild
Definition IRMatch.h:74
HALIDE_ALWAYS_INLINE auto operator!=(A &&a, B &&b) noexcept -> CmpOp< NE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1230
auto halving_add(A &&a, B &&b) noexcept -> Intrin< Call::halving_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1556
auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< Call::mul_shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition IRMatch.h:1584
HALIDE_ALWAYS_INLINE auto is_float(A &&a) noexcept -> IsFloat< decltype(pattern_arg(a))>
Definition IRMatch.h:2499
auto widening_sub(A &&a, B &&b) noexcept -> Intrin< Call::widening_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1534
HALIDE_ALWAYS_INLINE auto operator>=(A &&a, B &&b) noexcept -> CmpOp< GE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1180
HALIDE_ALWAYS_INLINE auto operator<(A &&a, B &&b) noexcept -> CmpOp< LT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1105
HALIDE_ALWAYS_INLINE auto operator&&(A &&a, B &&b) noexcept -> BinOp< And, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1281
HALIDE_ALWAYS_INLINE auto h_or(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Or >
Definition IRMatch.h:2008
constexpr bool commutative(IRNodeType t)
Definition IRMatch.h:615
HALIDE_ALWAYS_INLINE auto sub(A &&a, B &&b) -> decltype(IRMatcher::operator-(a, b))
Definition IRMatch.h:946
auto likely_if_innermost(A &&a) noexcept -> Intrin< Call::likely_if_innermost, decltype(pattern_arg(a))>
Definition IRMatch.h:1608
HALIDE_ALWAYS_INLINE auto h_max(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Max >
Definition IRMatch.h:1996
HALIDE_ALWAYS_INLINE auto broadcast(A &&a, B lanes) noexcept -> BroadcastOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes))>
Definition IRMatch.h:1862
HALIDE_ALWAYS_INLINE auto select(C &&c, T &&t, F &&f) noexcept -> SelectOp< decltype(pattern_arg(c)), decltype(pattern_arg(t)), decltype(pattern_arg(f))>
Definition IRMatch.h:1789
HALIDE_ALWAYS_INLINE auto neg(const Wild< i > &a) -> SimplifiedNegateOp< i >
Definition IRMatch.h:1720
auto saturating_sub(A &&a, B &&b) noexcept -> Intrin< Call::saturating_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1546
HALIDE_ALWAYS_INLINE auto is_min_value(A &&a) noexcept -> IsMinValue< decltype(pattern_arg(a))>
Definition IRMatch.h:2712
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Min >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1068
HALIDE_NEVER_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred, halide_type_t wildcard_type, halide_type_t output_type) noexcept
Definition IRMatch.h:2766
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GT >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1140
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mul >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:986
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GE >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1190
HALIDE_ALWAYS_INLINE auto operator-(A &&a, B &&b) noexcept -> BinOp< Sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:939
auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< Call::rounding_mul_shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition IRMatch.h:1588
auto saturating_add(A &&a, B &&b) noexcept -> Intrin< Call::saturating_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1542
HALIDE_ALWAYS_INLINE auto le(A &&a, B &&b) -> decltype(IRMatcher::operator<=(a, b))
Definition IRMatch.h:1160
HALIDE_ALWAYS_INLINE auto lt(A &&a, B &&b) -> decltype(IRMatcher::operator<(a, b))
Definition IRMatch.h:1110
auto shift_left(A &&a, B &&b) noexcept -> Intrin< Call::shift_left, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1568
HALIDE_ALWAYS_INLINE auto lanes_of(A &&a) noexcept -> LanesOf< decltype(pattern_arg(a))>
Definition IRMatch.h:2749
auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin< Call::rounding_shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1580
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LT >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1115
HALIDE_ALWAYS_INLINE auto h_min(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Min >
Definition IRMatch.h:1990
HALIDE_ALWAYS_INLINE auto h_add(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Add >
Definition IRMatch.h:1984
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Or >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1265
HALIDE_ALWAYS_INLINE Expr make_const_expr(halide_scalar_value_t val, halide_type_t ty)
Definition IRMatch.h:160
constexpr uint32_t bitwise_or_reduce()
Definition IRMatch.h:1306
int64_t constant_fold_bin_op(halide_type_t &, int64_t, int64_t) noexcept
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< EQ >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1215
HALIDE_NEVER_INLINE Expr make_const_special_expr(halide_type_t ty)
Definition IRMatch.h:149
auto halving_sub(A &&a, B &&b) noexcept -> Intrin< Call::halving_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1560
HALIDE_ALWAYS_INLINE auto ge(A &&a, B &&b) -> decltype(IRMatcher::operator>=(a, b))
Definition IRMatch.h:1185
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< NE >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1240
HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b))
Definition IRMatch.h:1039
HALIDE_ALWAYS_INLINE auto operator==(A &&a, B &&b) noexcept -> CmpOp< EQ, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1205
auto widen_right_sub(A &&a, B &&b) noexcept -> Intrin< Call::widen_right_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1525
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Add >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:920
HALIDE_ALWAYS_INLINE auto can_prove(A &&a, Prover *p) noexcept -> CanProve< decltype(pattern_arg(a)), Prover >
Definition IRMatch.h:2462
HALIDE_ALWAYS_INLINE auto eq(A &&a, B &&b) -> decltype(IRMatcher::operator==(a, b))
Definition IRMatch.h:1210
T div_imp(T a, T b)
Definition IROperator.h:278
bool is_const_zero(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to zero (in all lanes,...
Expr make_zero(Type t)
Construct the representation of zero in the given type.
void expr_match_test()
bool is_const_one(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to one (in all lanes,...
ConstantInterval min(const ConstantInterval &a, const ConstantInterval &b)
bool equal(const RDom &bounds0, const RDom &bounds1)
Return true if bounds0 and bounds1 represent the same bounds.
constexpr IRNodeType StrongestExprNodeType
Definition Expr.h:81
Expr make_const(Type t, int64_t val)
Construct an immediate of the given type from any numeric C++ type.
ConstantInterval max(const ConstantInterval &a, const ConstantInterval &b)
T mod_imp(T a, T b)
Implementations of division and mod that are specific to Halide.
Definition IROperator.h:257
bool sub_would_overflow(int bits, int64_t a, int64_t b)
bool add_would_overflow(int bits, int64_t a, int64_t b)
Routines to test if math would overflow for signed integers with the given number of bits.
bool mul_would_overflow(int bits, int64_t a, int64_t b)
Expr with_lanes(const Expr &x, int lanes)
Rewrite the expression x to have lanes lanes.
bool expr_match(const Expr &pattern, const Expr &expr, std::vector< Expr > &result)
Does the first expression have the same structure as the second? Variables in the first expression wi...
Expr make_signed_integer_overflow(Type type)
Construct a unique signed_integer_overflow Expr.
IRNodeType
All our IR node types get unique IDs for the purposes of RTTI.
Definition Expr.h:25
bool is_const(const Expr &e)
Is the expression either an IntImm, a FloatImm, a StringImm, or a Cast of the same,...
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
@ Internal
Not visible externally, similar to 'static' linkage in C.
@ Predicate
Guard the loads and stores in the loop with an if statement that prevents evaluation beyond the origi...
@ C
No name mangling.
unsigned __INT64_TYPE__ uint64_t
signed __INT64_TYPE__ int64_t
signed __INT32_TYPE__ int32_t
unsigned __INT8_TYPE__ uint8_t
unsigned __INT16_TYPE__ uint16_t
unsigned __INT32_TYPE__ uint32_t
A fragment of Halide syntax.
Definition Expr.h:258
HALIDE_ALWAYS_INLINE Type type() const
Get the type of this expression node.
Definition Expr.h:327
HALIDE_ALWAYS_INLINE const Internal::BaseExprNode * get() const
Override get() to return a BaseExprNode * instead of an IRNode *.
Definition Expr.h:321
The sum of two expressions.
Definition IR.h:56
Logical and - are both expressions true.
Definition IR.h:175
A base class for expression nodes.
Definition Expr.h:143
A vector with 'lanes' elements, in which every element is 'value'.
Definition IR.h:259
static Expr make(Expr value, int lanes)
static const IRNodeType _node_type
Definition IR.h:265
A function call.
Definition IR.h:490
bool is_intrinsic() const
Definition IR.h:737
static const IRNodeType _node_type
Definition IR.h:795
The actual IR nodes begin here.
Definition IR.h:30
static const IRNodeType _node_type
Definition IR.h:35
The ratio of two expressions.
Definition IR.h:83
Is the first expression equal to the second.
Definition IR.h:121
Floating point constants.
Definition Expr.h:236
static const FloatImm * make(Type t, double value)
Is the first expression greater than or equal to the second.
Definition IR.h:166
Is the first expression greater than the second.
Definition IR.h:157
static constexpr bool canonical
Definition IRMatch.h:641
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:664
static constexpr uint32_t binds
Definition IRMatch.h:633
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:645
static constexpr bool foldable
Definition IRMatch.h:661
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
Definition IRMatch.h:707
HALIDE_ALWAYS_INLINE bool match(const BinOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:655
static constexpr IRNodeType max_node_type
Definition IRMatch.h:636
static constexpr IRNodeType min_node_type
Definition IRMatch.h:635
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1804
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1828
HALIDE_ALWAYS_INLINE bool match(const BroadcastOp< A2, B2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:1822
static constexpr uint32_t binds
Definition IRMatch.h:1802
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1810
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1805
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:1845
HALIDE_NEVER_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2451
static constexpr uint32_t binds
Definition IRMatch.h:2441
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2444
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2445
static constexpr bool foldable
Definition IRMatch.h:2448
static constexpr bool canonical
Definition IRMatch.h:2446
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2104
static constexpr bool foldable
Definition IRMatch.h:2126
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:2108
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2103
static constexpr uint32_t binds
Definition IRMatch.h:2101
static constexpr bool canonical
Definition IRMatch.h:2105
HALIDE_ALWAYS_INLINE bool match(const CastOp< A2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:2117
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:2122
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:801
static constexpr IRNodeType max_node_type
Definition IRMatch.h:739
static constexpr uint32_t binds
Definition IRMatch.h:736
static constexpr bool canonical
Definition IRMatch.h:740
static constexpr bool foldable
Definition IRMatch.h:763
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:747
static constexpr IRNodeType min_node_type
Definition IRMatch.h:738
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:766
HALIDE_ALWAYS_INLINE bool match(const CmpOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:757
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2260
static constexpr uint32_t binds
Definition IRMatch.h:2257
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2259
static constexpr bool canonical
Definition IRMatch.h:2261
static constexpr bool foldable
Definition IRMatch.h:2286
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
Definition IRMatch.h:2264
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:2289
static constexpr IRNodeType max_node_type
Definition IRMatch.h:495
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:504
HALIDE_ALWAYS_INLINE IntLiteral(int64_t v)
Definition IRMatch.h:499
HALIDE_ALWAYS_INLINE bool match(const IntLiteral &b, MatcherState &state) const noexcept
Definition IRMatch.h:527
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:539
static constexpr IRNodeType min_node_type
Definition IRMatch.h:494
static constexpr bool canonical
Definition IRMatch.h:496
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:532
HALIDE_ALWAYS_INLINE bool match(int64_t val, MatcherState &state) const noexcept
Definition IRMatch.h:522
static constexpr uint32_t binds
Definition IRMatch.h:492
HALIDE_ALWAYS_INLINE Intrin(Args... args) noexcept
Definition IRMatch.h:1503
HALIDE_ALWAYS_INLINE void print_args(std::ostream &s) const
Definition IRMatch.h:1395
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1370
static constexpr bool foldable
Definition IRMatch.h:1458
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:1460
std::tuple< Args... > args
Definition IRMatch.h:1342
static constexpr uint32_t binds
Definition IRMatch.h:1349
HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const
Definition IRMatch.h:1391
HALIDE_ALWAYS_INLINE bool match_args(int, const Call &c, MatcherState &state) const noexcept
Definition IRMatch.h:1358
static constexpr bool canonical
Definition IRMatch.h:1353
HALIDE_ALWAYS_INLINE bool match_args(double, const Call &c, MatcherState &state) const noexcept
Definition IRMatch.h:1365
HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const
Definition IRMatch.h:1382
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1352
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1400
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1351
OptionalIntrinType< intrin > optional_type_hint
Definition IRMatch.h:1347
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2389
static constexpr bool canonical
Definition IRMatch.h:2391
static constexpr bool foldable
Definition IRMatch.h:2397
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2390
static constexpr uint32_t binds
Definition IRMatch.h:2386
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:2400
static constexpr bool foldable
Definition IRMatch.h:2485
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2488
static constexpr bool canonical
Definition IRMatch.h:2483
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2481
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2482
static constexpr uint32_t binds
Definition IRMatch.h:2478
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2521
static constexpr bool foldable
Definition IRMatch.h:2524
static constexpr uint32_t binds
Definition IRMatch.h:2517
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2527
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2520
static constexpr bool canonical
Definition IRMatch.h:2522
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2647
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2648
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2654
static constexpr uint32_t binds
Definition IRMatch.h:2644
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2688
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2695
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2689
static constexpr uint32_t binds
Definition IRMatch.h:2685
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2611
static constexpr uint32_t binds
Definition IRMatch.h:2607
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2617
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2610
static constexpr bool foldable
Definition IRMatch.h:2614
static constexpr bool canonical
Definition IRMatch.h:2612
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2573
static constexpr bool foldable
Definition IRMatch.h:2570
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2566
static constexpr bool canonical
Definition IRMatch.h:2568
static constexpr uint32_t binds
Definition IRMatch.h:2563
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2567
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2732
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2738
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2731
static constexpr bool foldable
Definition IRMatch.h:2735
static constexpr uint32_t binds
Definition IRMatch.h:2728
static constexpr bool canonical
Definition IRMatch.h:2733
To save stack space, the matcher objects are largely stateless and immutable.
Definition IRMatch.h:82
HALIDE_ALWAYS_INLINE void get_bound_const(int i, halide_scalar_value_t &val, halide_type_t &type) const noexcept
Definition IRMatch.h:127
HALIDE_ALWAYS_INLINE void set_bound_const(int i, int64_t s, halide_type_t t) noexcept
Definition IRMatch.h:103
HALIDE_ALWAYS_INLINE void set_bound_const(int i, double f, halide_type_t t) noexcept
Definition IRMatch.h:115
static constexpr uint16_t special_values_mask
Definition IRMatch.h:88
HALIDE_ALWAYS_INLINE void set_bound_const(int i, halide_scalar_value_t val, halide_type_t t) noexcept
Definition IRMatch.h:121
halide_type_t bound_const_type[max_wild]
Definition IRMatch.h:90
HALIDE_ALWAYS_INLINE void set_binding(int i, const BaseExprNode &n) noexcept
Definition IRMatch.h:93
HALIDE_ALWAYS_INLINE MatcherState() noexcept
Definition IRMatch.h:134
HALIDE_ALWAYS_INLINE const BaseExprNode * get_binding(int i) const noexcept
Definition IRMatch.h:98
halide_scalar_value_t bound_const[max_wild]
Definition IRMatch.h:84
HALIDE_ALWAYS_INLINE void set_bound_const(int i, uint64_t u, halide_type_t t) noexcept
Definition IRMatch.h:109
static constexpr uint16_t signed_integer_overflow
Definition IRMatch.h:87
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:2026
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:2041
HALIDE_ALWAYS_INLINE bool match(NegateOp< A2 > &&p, MatcherState &state) const noexcept
Definition IRMatch.h:2036
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:2050
static constexpr uint32_t binds
Definition IRMatch.h:2018
static constexpr bool canonical
Definition IRMatch.h:2023
static constexpr bool foldable
Definition IRMatch.h:2047
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2021
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2020
static constexpr uint32_t binds
Definition IRMatch.h:1617
static constexpr bool foldable
Definition IRMatch.h:1642
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1624
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1620
static constexpr bool canonical
Definition IRMatch.h:1621
HALIDE_ALWAYS_INLINE bool match(const NotOp< A2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:1633
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1638
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:1645
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1619
static constexpr uint32_t binds
Definition IRMatch.h:2346
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2350
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:2354
static constexpr bool canonical
Definition IRMatch.h:2351
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:2363
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:2371
static constexpr bool foldable
Definition IRMatch.h:2368
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2349
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:2322
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2315
static constexpr uint32_t binds
Definition IRMatch.h:2311
static constexpr bool canonical
Definition IRMatch.h:2317
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2316
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1904
static constexpr bool canonical
Definition IRMatch.h:1879
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1877
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1876
static constexpr uint32_t binds
Definition IRMatch.h:1874
HALIDE_ALWAYS_INLINE bool match(const RampOp< A2, B2, C2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:1897
static constexpr bool foldable
Definition IRMatch.h:1916
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1882
HALIDE_NEVER_INLINE void build_replacement(After after)
Definition IRMatch.h:2937
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after, Predicate pred)
Definition IRMatch.h:3014
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after) noexcept
Definition IRMatch.h:2989
HALIDE_ALWAYS_INLINE Rewriter(Instance instance, halide_type_t ot, halide_type_t wt)
Definition IRMatch.h:2932
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after, Predicate pred)
Definition IRMatch.h:3043
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after) noexcept
Definition IRMatch.h:2971
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after, Predicate pred)
Definition IRMatch.h:3066
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after)
Definition IRMatch.h:2948
static constexpr uint32_t binds
Definition IRMatch.h:1737
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:1769
static constexpr bool foldable
Definition IRMatch.h:1766
static constexpr bool canonical
Definition IRMatch.h:1742
HALIDE_ALWAYS_INLINE bool match(const SelectOp< C2, T2, F2 > &instance, MatcherState &state) const noexcept
Definition IRMatch.h:1755
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1745
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1762
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1740
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1739
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1679
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1685
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1680
static constexpr bool canonical
Definition IRMatch.h:2200
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2199
static constexpr bool foldable
Definition IRMatch.h:2229
HALIDE_ALWAYS_INLINE SliceOp(Vec v, Base b, Stride s, Lanes l)
Definition IRMatch.h:2232
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2198
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:2203
static constexpr uint32_t binds
Definition IRMatch.h:2196
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:2217
static constexpr IRNodeType min_node_type
Definition IRMatch.h:198
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:205
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:210
static constexpr IRNodeType max_node_type
Definition IRMatch.h:199
static constexpr uint32_t binds
Definition IRMatch.h:195
HALIDE_ALWAYS_INLINE bool match(const VectorReduceOp< A2, B2, reduce_op_2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:1959
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1941
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1946
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1942
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1966
static constexpr uint32_t binds
Definition IRMatch.h:2146
static constexpr bool canonical
Definition IRMatch.h:2150
static constexpr bool foldable
Definition IRMatch.h:2173
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:2167
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2149
HALIDE_ALWAYS_INLINE bool match(const WidenOp< A2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:2162
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2148
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:2153
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:352
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:373
static constexpr IRNodeType max_node_type
Definition IRMatch.h:348
static constexpr IRNodeType min_node_type
Definition IRMatch.h:347
static constexpr uint32_t binds
Definition IRMatch.h:345
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:383
static constexpr bool canonical
Definition IRMatch.h:403
static constexpr IRNodeType max_node_type
Definition IRMatch.h:402
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:431
static constexpr uint32_t binds
Definition IRMatch.h:399
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:406
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:441
HALIDE_ALWAYS_INLINE bool match(int64_t e, MatcherState &state) const noexcept
Definition IRMatch.h:425
static constexpr IRNodeType min_node_type
Definition IRMatch.h:401
static constexpr bool foldable
Definition IRMatch.h:438
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:267
static constexpr uint32_t binds
Definition IRMatch.h:226
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:277
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:233
HALIDE_ALWAYS_INLINE bool match(int64_t value, MatcherState &state) const noexcept
Definition IRMatch.h:254
static constexpr IRNodeType min_node_type
Definition IRMatch.h:228
static constexpr IRNodeType max_node_type
Definition IRMatch.h:229
static constexpr uint32_t binds
Definition IRMatch.h:292
static constexpr IRNodeType max_node_type
Definition IRMatch.h:295
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:299
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:330
static constexpr IRNodeType min_node_type
Definition IRMatch.h:294
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:320
static constexpr IRNodeType min_node_type
Definition IRMatch.h:459
static constexpr uint32_t binds
Definition IRMatch.h:457
static constexpr IRNodeType max_node_type
Definition IRMatch.h:460
static constexpr bool canonical
Definition IRMatch.h:461
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:473
static constexpr bool foldable
Definition IRMatch.h:477
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:464
static constexpr uint32_t mask
Definition IRMatch.h:146
IRNodeType node_type
Each IR node subclass has a unique identifier.
Definition Expr.h:113
Integer constants.
Definition Expr.h:218
static const IntImm * make(Type t, int64_t value)
Is the first expression less than or equal to the second.
Definition IR.h:148
Is the first expression less than the second.
Definition IR.h:139
The greater of two values.
Definition IR.h:112
The lesser of two values.
Definition IR.h:103
The remainder of a / b.
Definition IR.h:94
The product of two expressions.
Definition IR.h:74
Is the first expression not equal to the second.
Definition IR.h:130
Logical not - true if the expression false.
Definition IR.h:193
static Expr make(Expr a)
Logical or - is at least one of the expression true.
Definition IR.h:184
A linear ramp vector node.
Definition IR.h:247
static const IRNodeType _node_type
Definition IR.h:253
static Expr make(Expr base, Expr stride, int lanes)
A ternary operator.
Definition IR.h:204
static Expr make(Expr condition, Expr true_value, Expr false_value)
static const IRNodeType _node_type
Definition IR.h:209
Construct a new vector by taking elements from another sequence of vectors.
Definition IR.h:884
static Expr make_slice(Expr vector, int begin, int stride, int size)
Convenience constructor for making a shuffle representing a contiguous subset of a vector.
std::vector< Expr > vectors
Definition IR.h:885
bool is_slice() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so,...
int slice_stride() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so,...
Definition IR.h:938
int slice_begin() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so,...
Definition IR.h:935
The difference of two expressions.
Definition IR.h:65
static const IRNodeType _node_type
Definition IR.h:70
static Expr make(Expr a, Expr b)
Unsigned integer constants.
Definition Expr.h:227
static const UIntImm * make(Type t, uint64_t value)
Horizontally reduce a vector to a scalar or narrower vector using the given commutative and associati...
Definition IR.h:1012
static const IRNodeType _node_type
Definition IR.h:1031
static Expr make(Operator op, Expr vec, int lanes)
Types in the halide type system.
Definition Type.h:283
Type widen() const
Return Type with the same type code and number of lanes, but with at least twice as many bits.
Definition Type.h:378
HALIDE_ALWAYS_INLINE bool is_int() const
Is this type a signed integer type?
Definition Type.h:435
HALIDE_ALWAYS_INLINE int lanes() const
Return the number of vector elements in this type.
Definition Type.h:355
HALIDE_ALWAYS_INLINE bool is_uint() const
Is this type an unsigned integer type?
Definition Type.h:441
HALIDE_ALWAYS_INLINE int bits() const
Return the bit size of a single element of this type.
Definition Type.h:349
HALIDE_ALWAYS_INLINE bool is_scalar() const
Is this type a scalar type? (lanes() == 1).
Definition Type.h:417
HALIDE_ALWAYS_INLINE bool is_float() const
Is this type a floating point type (float or double).
Definition Type.h:423
halide_scalar_value_t is a simple union able to represent all the well-known scalar values in a filte...
union halide_scalar_value_t::@2 u
A runtime tag for a type in the halide type system.
uint16_t lanes
How many elements in a vector.
uint8_t code
The basic type code: signed integer, unsigned integer, or floating point.