Halide 19.0.0
Halide compiler and libraries
Loading...
Searching...
No Matches
IRMatch.h
Go to the documentation of this file.
1#ifndef HALIDE_IR_MATCH_H
2#define HALIDE_IR_MATCH_H
3
4/** \file
5 * Defines a method to match a fragment of IR against a pattern containing wildcards
6 */
7
8#include <map>
9#include <random>
10#include <set>
11#include <vector>
12
13#include "IR.h"
14#include "IREquality.h"
15#include "IROperator.h"
16
17namespace Halide {
18namespace Internal {
19
20/** Does the first expression have the same structure as the second?
21 * Variables in the first expression with the name * are interpreted
22 * as wildcards, and their matching equivalent in the second
23 * expression is placed in the vector give as the third argument.
24 * Wildcards require the types to match. For the type bits and width,
25 * a 0 indicates "match anything". So an Int(8, 0) will match 8-bit
26 * integer vectors of any width (including scalars), and a UInt(0, 0)
27 * will match any unsigned integer type.
28 *
29 * For example:
30 \code
31 Expr x = Variable::make(Int(32), "*");
32 match(x + x, 3 + (2*k), result)
33 \endcode
34 * should return true, and set result[0] to 3 and
35 * result[1] to 2*k.
36 */
37bool expr_match(const Expr &pattern, const Expr &expr, std::vector<Expr> &result);
38
39/** Does the first expression have the same structure as the second?
40 * Variables are matched consistently. The first time a variable is
41 * matched, it assumes the value of the matching part of the second
42 * expression. Subsequent matches must be equal to the first match.
43 *
44 * For example:
45 \code
46 Var x("x"), y("y");
47 match(x*(x + y), a*(a + b), result)
48 \endcode
49 * should return true, and set result["x"] = a, and result["y"] = b.
50 */
51bool expr_match(const Expr &pattern, const Expr &expr, std::map<std::string, Expr> &result);
52
53/** Rewrite the expression x to have `lanes` lanes. This is useful
54 * for substituting the results of expr_match into a pattern expression. */
55Expr with_lanes(const Expr &x, int lanes);
56
58
59/** An alternative template-metaprogramming approach to expression
60 * matching. Potentially more efficient. We lift the expression
61 * pattern into a type, and then use force-inlined functions to
62 * generate efficient matching and reconstruction code for any
63 * pattern. Pattern elements are either one of the classes in the
64 * namespace IRMatcher, or are non-null Exprs (represented as
65 * BaseExprNode &).
66 *
67 * Pattern elements that are fully specified by their pattern can be
68 * built into an expression using the make method. Some patterns,
69 * such as a broadcast that matches any number of lanes, don't have
70 * enough information to recreate an Expr.
71 */
72namespace IRMatcher {
73
74constexpr int max_wild = 6;
75
76static const halide_type_t i64_type = {halide_type_int, 64, 1};
77
78/** To save stack space, the matcher objects are largely stateless and
79 * immutable. This state object is built up during matching and then
80 * consumed when constructing a replacement Expr.
81 */
85
86 // values of the lanes field with special meaning.
87 static constexpr uint16_t signed_integer_overflow = 0x8000;
88 static constexpr uint16_t special_values_mask = 0x8000; // currently only one
89
91
93 void set_binding(int i, const BaseExprNode &n) noexcept {
94 bindings[i] = &n;
95 }
96
98 const BaseExprNode *get_binding(int i) const noexcept {
99 return bindings[i];
100 }
101
103 void set_bound_const(int i, int64_t s, halide_type_t t) noexcept {
104 bound_const[i].u.i64 = s;
105 bound_const_type[i] = t;
106 }
107
109 void set_bound_const(int i, uint64_t u, halide_type_t t) noexcept {
110 bound_const[i].u.u64 = u;
111 bound_const_type[i] = t;
112 }
113
115 void set_bound_const(int i, double f, halide_type_t t) noexcept {
116 bound_const[i].u.f64 = f;
117 bound_const_type[i] = t;
118 }
119
122 bound_const[i] = val;
123 bound_const_type[i] = t;
124 }
125
127 void get_bound_const(int i, halide_scalar_value_t &val, halide_type_t &type) const noexcept {
128 val = bound_const[i];
129 type = bound_const_type[i];
130 }
131
133 // NOLINTNEXTLINE(modernize-use-equals-default): Can't use `= default`; clang-tidy complains about noexcept mismatch
136};
137
138template<typename T,
139 typename = typename std::remove_reference<T>::type::pattern_tag>
141 struct type {};
142};
143
144template<typename T>
145struct bindings {
146 constexpr static uint32_t mask = std::remove_reference<T>::type::binds;
147};
148
150 const uint16_t flags = ty.lanes & MatcherState::special_values_mask;
151 ty.lanes &= ~MatcherState::special_values_mask;
154 }
155 // unreachable
156 return Expr();
157}
158
164 }
165
166 const int lanes = scalar_type.lanes;
167 scalar_type.lanes = 1;
168
169 Expr e;
170 switch (scalar_type.code) {
171 case halide_type_int:
172 e = IntImm::make(scalar_type, val.u.i64);
173 break;
174 case halide_type_uint:
175 e = UIntImm::make(scalar_type, val.u.u64);
176 break;
180 break;
181 default:
182 // Unreachable
183 return Expr();
184 }
185 if (lanes > 1) {
186 e = Broadcast::make(std::move(e), lanes);
187 }
188 return e;
189}
190
191// A pattern that matches a specific expression
193 struct pattern_tag {};
194
195 constexpr static uint32_t binds = 0;
196
197 // What is the weakest and strongest IR node this could possibly be
200 constexpr static bool canonical = true;
201
203
204 template<uint32_t bound>
205 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
206 return equal(expr, e);
207 }
208
211 return Expr(&expr);
212 }
213
214 constexpr static bool foldable = false;
215};
216
217inline std::ostream &operator<<(std::ostream &s, const SpecificExpr &e) {
218 s << Expr(&e.expr);
219 return s;
220}
221
222template<int i>
224 struct pattern_tag {};
225
226 constexpr static uint32_t binds = 1 << i;
227
230 constexpr static bool canonical = true;
231
232 template<uint32_t bound>
233 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
234 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
235 const BaseExprNode *op = &e;
236 if (op->node_type == IRNodeType::Broadcast) {
237 op = ((const Broadcast *)op)->value.get();
238 }
239 if (op->node_type != IRNodeType::IntImm) {
240 return false;
241 }
242 int64_t value = ((const IntImm *)op)->value;
243 if (bound & binds) {
245 halide_type_t type;
246 state.get_bound_const(i, val, type);
247 return (halide_type_t)e.type == type && value == val.u.i64;
248 }
249 state.set_bound_const(i, value, e.type);
250 return true;
251 }
252
253 template<uint32_t bound>
254 HALIDE_ALWAYS_INLINE bool match(int64_t value, MatcherState &state) const noexcept {
255 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
256 if (bound & binds) {
258 halide_type_t type;
259 state.get_bound_const(i, val, type);
260 return type == i64_type && value == val.u.i64;
261 }
262 state.set_bound_const(i, value, i64_type);
263 return true;
264 }
265
269 halide_type_t type;
270 state.get_bound_const(i, val, type);
271 return make_const_expr(val, type);
272 }
273
274 constexpr static bool foldable = true;
275
278 state.get_bound_const(i, val, ty);
279 }
280};
281
282template<int i>
283std::ostream &operator<<(std::ostream &s, const WildConstInt<i> &c) {
284 s << "ci" << i;
285 return s;
286}
287
288template<int i>
290 struct pattern_tag {};
291
292 constexpr static uint32_t binds = 1 << i;
293
296 constexpr static bool canonical = true;
297
298 template<uint32_t bound>
299 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
300 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
301 const BaseExprNode *op = &e;
302 if (op->node_type == IRNodeType::Broadcast) {
303 op = ((const Broadcast *)op)->value.get();
304 }
305 if (op->node_type != IRNodeType::UIntImm) {
306 return false;
307 }
308 uint64_t value = ((const UIntImm *)op)->value;
309 if (bound & binds) {
311 halide_type_t type;
312 state.get_bound_const(i, val, type);
313 return (halide_type_t)e.type == type && value == val.u.u64;
314 }
315 state.set_bound_const(i, value, e.type);
316 return true;
317 }
318
322 halide_type_t type;
323 state.get_bound_const(i, val, type);
324 return make_const_expr(val, type);
325 }
326
327 constexpr static bool foldable = true;
328
331 state.get_bound_const(i, val, ty);
332 }
333};
334
335template<int i>
336std::ostream &operator<<(std::ostream &s, const WildConstUInt<i> &c) {
337 s << "cu" << i;
338 return s;
339}
340
341template<int i>
343 struct pattern_tag {};
344
345 constexpr static uint32_t binds = 1 << i;
346
349 constexpr static bool canonical = true;
350
351 template<uint32_t bound>
352 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
353 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
354 const BaseExprNode *op = &e;
355 if (op->node_type == IRNodeType::Broadcast) {
356 op = ((const Broadcast *)op)->value.get();
357 }
358 if (op->node_type != IRNodeType::FloatImm) {
359 return false;
360 }
361 double value = ((const FloatImm *)op)->value;
362 if (bound & binds) {
364 halide_type_t type;
365 state.get_bound_const(i, val, type);
366 return (halide_type_t)e.type == type && value == val.u.f64;
367 }
368 state.set_bound_const(i, value, e.type);
369 return true;
370 }
371
375 halide_type_t type;
376 state.get_bound_const(i, val, type);
377 return make_const_expr(val, type);
378 }
379
380 constexpr static bool foldable = true;
381
384 state.get_bound_const(i, val, ty);
385 }
386};
387
388template<int i>
389std::ostream &operator<<(std::ostream &s, const WildConstFloat<i> &c) {
390 s << "cf" << i;
391 return s;
392}
393
394// Matches and binds to any constant Expr. Does not support constant-folding.
395template<int i>
396struct WildConst {
397 struct pattern_tag {};
398
399 constexpr static uint32_t binds = 1 << i;
400
403 constexpr static bool canonical = true;
404
405 template<uint32_t bound>
406 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
407 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
408 const BaseExprNode *op = &e;
409 if (op->node_type == IRNodeType::Broadcast) {
410 op = ((const Broadcast *)op)->value.get();
411 }
412 switch (op->node_type) {
414 return WildConstInt<i>().template match<bound>(e, state);
416 return WildConstUInt<i>().template match<bound>(e, state);
418 return WildConstFloat<i>().template match<bound>(e, state);
419 default:
420 return false;
421 }
422 }
423
424 template<uint32_t bound>
425 HALIDE_ALWAYS_INLINE bool match(int64_t e, MatcherState &state) const noexcept {
426 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
427 return WildConstInt<i>().template match<bound>(e, state);
428 }
429
433 halide_type_t type;
434 state.get_bound_const(i, val, type);
435 return make_const_expr(val, type);
436 }
437
438 constexpr static bool foldable = true;
439
442 state.get_bound_const(i, val, ty);
443 }
444};
445
446template<int i>
447std::ostream &operator<<(std::ostream &s, const WildConst<i> &c) {
448 s << "c" << i;
449 return s;
450}
451
452// Matches and binds to any Expr
453template<int i>
454struct Wild {
455 struct pattern_tag {};
456
457 constexpr static uint32_t binds = 1 << (i + 16);
458
461 constexpr static bool canonical = true;
462
463 template<uint32_t bound>
464 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
465 if (bound & binds) {
466 return equal(*state.get_binding(i), e);
467 }
468 state.set_binding(i, e);
469 return true;
470 }
471
474 return state.get_binding(i);
475 }
476
477 constexpr static bool foldable = false;
478};
479
480template<int i>
481std::ostream &operator<<(std::ostream &s, const Wild<i> &op) {
482 s << "_" << i;
483 return s;
484}
485
486// Matches a specific constant or broadcast of that constant. The
487// constant must be representable as an int64_t.
489 struct pattern_tag {};
491
492 constexpr static uint32_t binds = 0;
493
496 constexpr static bool canonical = true;
497
500 : v(v) {
501 }
502
503 template<uint32_t bound>
504 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
505 const BaseExprNode *op = &e;
506 if (e.node_type == IRNodeType::Broadcast) {
507 op = ((const Broadcast *)op)->value.get();
508 }
509 switch (op->node_type) {
511 return ((const IntImm *)op)->value == (int64_t)v;
513 return ((const UIntImm *)op)->value == (uint64_t)v;
515 return ((const FloatImm *)op)->value == (double)v;
516 default:
517 return false;
518 }
519 }
520
521 template<uint32_t bound>
522 HALIDE_ALWAYS_INLINE bool match(int64_t val, MatcherState &state) const noexcept {
523 return v == val;
524 }
525
526 template<uint32_t bound>
527 HALIDE_ALWAYS_INLINE bool match(const IntLiteral &b, MatcherState &state) const noexcept {
528 return v == b.v;
529 }
530
533 return make_const(type_hint, v);
534 }
535
536 constexpr static bool foldable = true;
537
540 // Assume type is already correct
541 switch (ty.code) {
542 case halide_type_int:
543 val.u.i64 = v;
544 break;
545 case halide_type_uint:
546 val.u.u64 = (uint64_t)v;
547 break;
550 val.u.f64 = (double)v;
551 break;
552 default:
553 // Unreachable
554 ;
555 }
556 }
557};
558
562
563// Convert a provided pattern, expr, or constant int into the internal
564// representation we use in the matcher trees.
565template<typename T,
566 typename = typename std::decay<T>::type::pattern_tag>
568 return t;
569}
572 return IntLiteral{x};
573}
574
575template<typename T>
577 static_assert(!std::is_same<typename std::decay<T>::type, Expr>::value || std::is_lvalue_reference<T>::value,
578 "Exprs are captured by reference by IRMatcher objects and so must be lvalues");
579}
580
582 return {*e.get()};
583}
584
585// Helpers to deref SpecificExprs to const BaseExprNode & rather than
586// passing them by value anywhere (incurring lots of refcounting)
587template<typename T,
588 // T must be a pattern node
589 typename = typename std::decay<T>::type::pattern_tag,
590 // But T may not be SpecificExpr
591 typename = typename std::enable_if<!std::is_same<typename std::decay<T>::type, SpecificExpr>::value>::type>
593 return t;
594}
595
598 return e.expr;
599}
600
601inline std::ostream &operator<<(std::ostream &s, const IntLiteral &op) {
602 s << op.v;
603 return s;
604}
605
606template<typename Op>
608
609template<typename Op>
611
612template<typename Op>
613double constant_fold_bin_op(halide_type_t &, double, double) noexcept;
614
615constexpr bool commutative(IRNodeType t) {
616 return (t == IRNodeType::Add ||
617 t == IRNodeType::Mul ||
618 t == IRNodeType::And ||
619 t == IRNodeType::Or ||
620 t == IRNodeType::Min ||
621 t == IRNodeType::Max ||
622 t == IRNodeType::EQ ||
623 t == IRNodeType::NE);
624}
625
626// Matches one of the binary operators
627template<typename Op, typename A, typename B>
628struct BinOp {
629 struct pattern_tag {};
632
634
635 constexpr static IRNodeType min_node_type = Op::_node_type;
636 constexpr static IRNodeType max_node_type = Op::_node_type;
637
638 // For commutative bin ops, we expect the weaker IR node type on
639 // the right. That is, for the rule to be canonical it must be
640 // possible that A is at least as strong as B.
641 constexpr static bool canonical =
642 A::canonical && B::canonical && (!commutative(Op::_node_type) || (A::max_node_type >= B::min_node_type));
643
644 template<uint32_t bound>
645 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
646 if (e.node_type != Op::_node_type) {
647 return false;
648 }
649 const Op &op = (const Op &)e;
650 return (a.template match<bound>(*op.a.get(), state) &&
651 b.template match<bound | bindings<A>::mask>(*op.b.get(), state));
652 }
653
654 template<uint32_t bound, typename Op2, typename A2, typename B2>
655 HALIDE_ALWAYS_INLINE bool match(const BinOp<Op2, A2, B2> &op, MatcherState &state) const noexcept {
656 return (std::is_same<Op, Op2>::value &&
657 a.template match<bound>(unwrap(op.a), state) &&
658 b.template match<bound | bindings<A>::mask>(unwrap(op.b), state));
659 }
660
661 constexpr static bool foldable = A::foldable && B::foldable;
662
666 if (std::is_same<A, IntLiteral>::value) {
667 b.make_folded_const(val_b, ty, state);
668 if ((std::is_same<Op, And>::value && val_b.u.u64 == 0) ||
669 (std::is_same<Op, Or>::value && val_b.u.u64 == 1)) {
670 // Short circuit
671 val = val_b;
672 return;
673 }
674 const uint16_t l = ty.lanes;
675 a.make_folded_const(val_a, ty, state);
676 ty.lanes |= l; // Make sure the overflow bits are sticky
677 } else {
678 a.make_folded_const(val_a, ty, state);
679 if ((std::is_same<Op, And>::value && val_a.u.u64 == 0) ||
680 (std::is_same<Op, Or>::value && val_a.u.u64 == 1)) {
681 // Short circuit
682 val = val_a;
683 return;
684 }
685 const uint16_t l = ty.lanes;
686 b.make_folded_const(val_b, ty, state);
687 ty.lanes |= l;
688 }
689 switch (ty.code) {
690 case halide_type_int:
691 val.u.i64 = constant_fold_bin_op<Op>(ty, val_a.u.i64, val_b.u.i64);
692 break;
693 case halide_type_uint:
694 val.u.u64 = constant_fold_bin_op<Op>(ty, val_a.u.u64, val_b.u.u64);
695 break;
698 val.u.f64 = constant_fold_bin_op<Op>(ty, val_a.u.f64, val_b.u.f64);
699 break;
700 default:
701 // unreachable
702 ;
703 }
704 }
705
707 Expr make(MatcherState &state, halide_type_t type_hint) const noexcept {
708 Expr ea, eb;
709 if (std::is_same<A, IntLiteral>::value) {
710 eb = b.make(state, type_hint);
711 ea = a.make(state, eb.type());
712 } else {
713 ea = a.make(state, type_hint);
714 eb = b.make(state, ea.type());
715 }
716 return Op::make(std::move(ea), std::move(eb));
717 }
718};
719
720template<typename Op>
722
723template<typename Op>
725
726template<typename Op>
727uint64_t constant_fold_cmp_op(double, double) noexcept;
728
729// Matches one of the comparison operators
730template<typename Op, typename A, typename B>
731struct CmpOp {
732 struct pattern_tag {};
735
737
738 constexpr static IRNodeType min_node_type = Op::_node_type;
739 constexpr static IRNodeType max_node_type = Op::_node_type;
740 constexpr static bool canonical = (A::canonical &&
741 B::canonical &&
742 (!commutative(Op::_node_type) || A::max_node_type >= B::min_node_type) &&
743 (Op::_node_type != IRNodeType::GE) &&
744 (Op::_node_type != IRNodeType::GT));
745
746 template<uint32_t bound>
747 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
748 if (e.node_type != Op::_node_type) {
749 return false;
750 }
751 const Op &op = (const Op &)e;
752 return (a.template match<bound>(*op.a.get(), state) &&
753 b.template match<bound | bindings<A>::mask>(*op.b.get(), state));
754 }
755
756 template<uint32_t bound, typename Op2, typename A2, typename B2>
757 HALIDE_ALWAYS_INLINE bool match(const CmpOp<Op2, A2, B2> &op, MatcherState &state) const noexcept {
758 return (std::is_same<Op, Op2>::value &&
759 a.template match<bound>(unwrap(op.a), state) &&
760 b.template match<bound | bindings<A>::mask>(unwrap(op.b), state));
761 }
762
763 constexpr static bool foldable = A::foldable && B::foldable;
764
768 // If one side is an untyped const, evaluate the other side first to get a type hint.
769 if (std::is_same<A, IntLiteral>::value) {
770 b.make_folded_const(val_b, ty, state);
771 const uint16_t l = ty.lanes;
772 a.make_folded_const(val_a, ty, state);
773 ty.lanes |= l;
774 } else {
775 a.make_folded_const(val_a, ty, state);
776 const uint16_t l = ty.lanes;
777 b.make_folded_const(val_b, ty, state);
778 ty.lanes |= l;
779 }
780 switch (ty.code) {
781 case halide_type_int:
782 val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.i64, val_b.u.i64);
783 break;
784 case halide_type_uint:
785 val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.u64, val_b.u.u64);
786 break;
789 val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.f64, val_b.u.f64);
790 break;
791 default:
792 // unreachable
793 ;
794 }
795 ty.code = halide_type_uint;
796 ty.bits = 1;
797 }
798
801 // If one side is an untyped const, evaluate the other side first to get a type hint.
802 Expr ea, eb;
803 if (std::is_same<A, IntLiteral>::value) {
804 eb = b.make(state, {});
805 ea = a.make(state, eb.type());
806 } else {
807 ea = a.make(state, {});
808 eb = b.make(state, ea.type());
809 }
810 return Op::make(std::move(ea), std::move(eb));
811 }
812};
813
814template<typename A, typename B>
815std::ostream &operator<<(std::ostream &s, const BinOp<Add, A, B> &op) {
816 s << "(" << op.a << " + " << op.b << ")";
817 return s;
818}
819
820template<typename A, typename B>
821std::ostream &operator<<(std::ostream &s, const BinOp<Sub, A, B> &op) {
822 s << "(" << op.a << " - " << op.b << ")";
823 return s;
824}
825
826template<typename A, typename B>
827std::ostream &operator<<(std::ostream &s, const BinOp<Mul, A, B> &op) {
828 s << "(" << op.a << " * " << op.b << ")";
829 return s;
830}
831
832template<typename A, typename B>
833std::ostream &operator<<(std::ostream &s, const BinOp<Div, A, B> &op) {
834 s << "(" << op.a << " / " << op.b << ")";
835 return s;
836}
837
838template<typename A, typename B>
839std::ostream &operator<<(std::ostream &s, const BinOp<And, A, B> &op) {
840 s << "(" << op.a << " && " << op.b << ")";
841 return s;
842}
843
844template<typename A, typename B>
845std::ostream &operator<<(std::ostream &s, const BinOp<Or, A, B> &op) {
846 s << "(" << op.a << " || " << op.b << ")";
847 return s;
848}
849
850template<typename A, typename B>
851std::ostream &operator<<(std::ostream &s, const BinOp<Min, A, B> &op) {
852 s << "min(" << op.a << ", " << op.b << ")";
853 return s;
854}
855
856template<typename A, typename B>
857std::ostream &operator<<(std::ostream &s, const BinOp<Max, A, B> &op) {
858 s << "max(" << op.a << ", " << op.b << ")";
859 return s;
860}
861
862template<typename A, typename B>
863std::ostream &operator<<(std::ostream &s, const CmpOp<LE, A, B> &op) {
864 s << "(" << op.a << " <= " << op.b << ")";
865 return s;
866}
867
868template<typename A, typename B>
869std::ostream &operator<<(std::ostream &s, const CmpOp<LT, A, B> &op) {
870 s << "(" << op.a << " < " << op.b << ")";
871 return s;
872}
873
874template<typename A, typename B>
875std::ostream &operator<<(std::ostream &s, const CmpOp<GE, A, B> &op) {
876 s << "(" << op.a << " >= " << op.b << ")";
877 return s;
878}
879
880template<typename A, typename B>
881std::ostream &operator<<(std::ostream &s, const CmpOp<GT, A, B> &op) {
882 s << "(" << op.a << " > " << op.b << ")";
883 return s;
884}
885
886template<typename A, typename B>
887std::ostream &operator<<(std::ostream &s, const CmpOp<EQ, A, B> &op) {
888 s << "(" << op.a << " == " << op.b << ")";
889 return s;
890}
891
892template<typename A, typename B>
893std::ostream &operator<<(std::ostream &s, const CmpOp<NE, A, B> &op) {
894 s << "(" << op.a << " != " << op.b << ")";
895 return s;
896}
897
898template<typename A, typename B>
899std::ostream &operator<<(std::ostream &s, const BinOp<Mod, A, B> &op) {
900 s << "(" << op.a << " % " << op.b << ")";
901 return s;
902}
903
904template<typename A, typename B>
905HALIDE_ALWAYS_INLINE auto operator+(A &&a, B &&b) noexcept -> BinOp<Add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
908 return {pattern_arg(a), pattern_arg(b)};
909}
910
911template<typename A, typename B>
917
918template<>
920 t.lanes |= ((t.bits >= 32) && add_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
921 int dead_bits = 64 - t.bits;
922 // Drop the high bits then sign-extend them back
923 return int64_t((uint64_t(a) + uint64_t(b)) << dead_bits) >> dead_bits;
924}
925
926template<>
928 uint64_t ones = (uint64_t)(-1);
929 return (a + b) & (ones >> (64 - t.bits));
930}
931
932template<>
933HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Add>(halide_type_t &t, double a, double b) noexcept {
934 return a + b;
935}
936
937template<typename A, typename B>
938HALIDE_ALWAYS_INLINE auto operator-(A &&a, B &&b) noexcept -> BinOp<Sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
941 return {pattern_arg(a), pattern_arg(b)};
942}
943
944template<typename A, typename B>
950
951template<>
953 t.lanes |= ((t.bits >= 32) && sub_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
954 // Drop the high bits then sign-extend them back
955 int dead_bits = 64 - t.bits;
956 return int64_t((uint64_t(a) - uint64_t(b)) << dead_bits) >> dead_bits;
957}
958
959template<>
961 uint64_t ones = (uint64_t)(-1);
962 return (a - b) & (ones >> (64 - t.bits));
963}
964
965template<>
966HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Sub>(halide_type_t &t, double a, double b) noexcept {
967 return a - b;
968}
969
970template<typename A, typename B>
971HALIDE_ALWAYS_INLINE auto operator*(A &&a, B &&b) noexcept -> BinOp<Mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
974 return {pattern_arg(a), pattern_arg(b)};
975}
976
977template<typename A, typename B>
983
984template<>
986 t.lanes |= ((t.bits >= 32) && mul_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
987 int dead_bits = 64 - t.bits;
988 // Drop the high bits then sign-extend them back
989 return int64_t((uint64_t(a) * uint64_t(b)) << dead_bits) >> dead_bits;
990}
991
992template<>
994 uint64_t ones = (uint64_t)(-1);
995 return (a * b) & (ones >> (64 - t.bits));
996}
997
998template<>
999HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Mul>(halide_type_t &t, double a, double b) noexcept {
1000 return a * b;
1001}
1002
1003template<typename A, typename B>
1004HALIDE_ALWAYS_INLINE auto operator/(A &&a, B &&b) noexcept -> BinOp<Div, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1007 return {pattern_arg(a), pattern_arg(b)};
1008}
1009
1010template<typename A, typename B>
1011HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, b)) {
1012 return IRMatcher::operator/(a, b);
1013}
1014
1015template<>
1019
1020template<>
1024
1025template<>
1026HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Div>(halide_type_t &t, double a, double b) noexcept {
1027 return div_imp(a, b);
1028}
1029
1030template<typename A, typename B>
1031HALIDE_ALWAYS_INLINE auto operator%(A &&a, B &&b) noexcept -> BinOp<Mod, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1034 return {pattern_arg(a), pattern_arg(b)};
1035}
1036
1037template<typename A, typename B>
1038HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b)) {
1041 return IRMatcher::operator%(a, b);
1042}
1043
1044template<>
1048
1049template<>
1053
1054template<>
1055HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Mod>(halide_type_t &t, double a, double b) noexcept {
1056 return mod_imp(a, b);
1057}
1058
1059template<typename A, typename B>
1060HALIDE_ALWAYS_INLINE auto min(A &&a, B &&b) noexcept -> BinOp<Min, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1063 return {pattern_arg(a), pattern_arg(b)};
1064}
1065
1066template<>
1068 return std::min(a, b);
1069}
1070
1071template<>
1073 return std::min(a, b);
1074}
1075
1076template<>
1077HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Min>(halide_type_t &t, double a, double b) noexcept {
1078 return std::min(a, b);
1079}
1080
1081template<typename A, typename B>
1082HALIDE_ALWAYS_INLINE auto max(A &&a, B &&b) noexcept -> BinOp<Max, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1085 return {pattern_arg(std::forward<A>(a)), pattern_arg(std::forward<B>(b))};
1086}
1087
1088template<>
1090 return std::max(a, b);
1091}
1092
1093template<>
1095 return std::max(a, b);
1096}
1097
1098template<>
1099HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Max>(halide_type_t &t, double a, double b) noexcept {
1100 return std::max(a, b);
1101}
1102
1103template<typename A, typename B>
1104HALIDE_ALWAYS_INLINE auto operator<(A &&a, B &&b) noexcept -> CmpOp<LT, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1105 return {pattern_arg(a), pattern_arg(b)};
1106}
1107
1108template<typename A, typename B>
1109HALIDE_ALWAYS_INLINE auto lt(A &&a, B &&b) -> decltype(IRMatcher::operator<(a, b)) {
1110 return IRMatcher::operator<(a, b);
1111}
1112
1113template<>
1117
1118template<>
1122
1123template<>
1125 return a < b;
1126}
1127
1128template<typename A, typename B>
1129HALIDE_ALWAYS_INLINE auto operator>(A &&a, B &&b) noexcept -> CmpOp<GT, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1130 return {pattern_arg(a), pattern_arg(b)};
1131}
1132
1133template<typename A, typename B>
1134HALIDE_ALWAYS_INLINE auto gt(A &&a, B &&b) -> decltype(IRMatcher::operator>(a, b)) {
1135 return IRMatcher::operator>(a, b);
1136}
1137
1138template<>
1142
1143template<>
1147
1148template<>
1150 return a > b;
1151}
1152
1153template<typename A, typename B>
1154HALIDE_ALWAYS_INLINE auto operator<=(A &&a, B &&b) noexcept -> CmpOp<LE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1155 return {pattern_arg(a), pattern_arg(b)};
1156}
1157
1158template<typename A, typename B>
1159HALIDE_ALWAYS_INLINE auto le(A &&a, B &&b) -> decltype(IRMatcher::operator<=(a, b)) {
1160 return IRMatcher::operator<=(a, b);
1161}
1162
1163template<>
1165 return a <= b;
1166}
1167
1168template<>
1172
1173template<>
1175 return a <= b;
1176}
1177
1178template<typename A, typename B>
1179HALIDE_ALWAYS_INLINE auto operator>=(A &&a, B &&b) noexcept -> CmpOp<GE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1180 return {pattern_arg(a), pattern_arg(b)};
1181}
1182
1183template<typename A, typename B>
1184HALIDE_ALWAYS_INLINE auto ge(A &&a, B &&b) -> decltype(IRMatcher::operator>=(a, b)) {
1185 return IRMatcher::operator>=(a, b);
1186}
1187
1188template<>
1190 return a >= b;
1191}
1192
1193template<>
1197
1198template<>
1200 return a >= b;
1201}
1202
1203template<typename A, typename B>
1204HALIDE_ALWAYS_INLINE auto operator==(A &&a, B &&b) noexcept -> CmpOp<EQ, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1205 return {pattern_arg(a), pattern_arg(b)};
1206}
1207
1208template<typename A, typename B>
1209HALIDE_ALWAYS_INLINE auto eq(A &&a, B &&b) -> decltype(IRMatcher::operator==(a, b)) {
1210 return IRMatcher::operator==(a, b);
1211}
1212
1213template<>
1215 return a == b;
1216}
1217
1218template<>
1222
1223template<>
1225 return a == b;
1226}
1227
1228template<typename A, typename B>
1229HALIDE_ALWAYS_INLINE auto operator!=(A &&a, B &&b) noexcept -> CmpOp<NE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1230 return {pattern_arg(a), pattern_arg(b)};
1231}
1232
1233template<typename A, typename B>
1234HALIDE_ALWAYS_INLINE auto ne(A &&a, B &&b) -> decltype(IRMatcher::operator!=(a, b)) {
1235 return IRMatcher::operator!=(a, b);
1236}
1237
1238template<>
1240 return a != b;
1241}
1242
1243template<>
1247
1248template<>
1250 return a != b;
1251}
1252
1253template<typename A, typename B>
1254HALIDE_ALWAYS_INLINE auto operator||(A &&a, B &&b) noexcept -> BinOp<Or, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1255 return {pattern_arg(a), pattern_arg(b)};
1256}
1257
1258template<typename A, typename B>
1259HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||(a, b)) {
1260 return IRMatcher::operator||(a, b);
1261}
1262
1263template<>
1265 return (a | b) & 1;
1266}
1267
1268template<>
1270 return (a | b) & 1;
1271}
1272
1273template<>
1274HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Or>(halide_type_t &t, double a, double b) noexcept {
1275 // Unreachable, as it would be a type mismatch.
1276 return 0;
1277}
1278
1279template<typename A, typename B>
1280HALIDE_ALWAYS_INLINE auto operator&&(A &&a, B &&b) noexcept -> BinOp<And, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1281 return {pattern_arg(a), pattern_arg(b)};
1282}
1283
1284template<typename A, typename B>
1285HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&&(a, b)) {
1286 return IRMatcher::operator&&(a, b);
1287}
1288
1289template<>
1291 return a & b & 1;
1292}
1293
1294template<>
1298
1299template<>
1300HALIDE_ALWAYS_INLINE double constant_fold_bin_op<And>(halide_type_t &t, double a, double b) noexcept {
1301 // Unreachable
1302 return 0;
1303}
1304
1306 return 0;
1307}
1308
1309template<typename... Args>
1310constexpr uint32_t bitwise_or_reduce(uint32_t first, Args... rest) {
1311 return first | bitwise_or_reduce(rest...);
1312}
1313
1314constexpr bool and_reduce() {
1315 return true;
1316}
1317
1318template<typename... Args>
1319constexpr bool and_reduce(bool first, Args... rest) {
1320 return first && and_reduce(rest...);
1321}
1322
1323// TODO: this can be replaced with std::min() once we require C++14 or later
1324constexpr int const_min(int a, int b) {
1325 return a < b ? a : b;
1326}
1327
1328template<typename... Args>
1329struct Intrin {
1330 struct pattern_tag {};
1332 std::tuple<Args...> args;
1333 // The type of the output of the intrinsic node.
1334 // Only necessary in cases where it can't be inferred
1335 // from the input types (e.g. saturating_cast).
1337
1339
1342 constexpr static bool canonical = and_reduce((Args::canonical)...);
1343
1344 template<int i,
1345 uint32_t bound,
1346 typename = typename std::enable_if<(i < sizeof...(Args))>::type>
1347 HALIDE_ALWAYS_INLINE bool match_args(int, const Call &c, MatcherState &state) const noexcept {
1348 using T = decltype(std::get<i>(args));
1349 return (std::get<i>(args).template match<bound>(*c.args[i].get(), state) &&
1351 }
1352
1353 template<int i, uint32_t binds>
1354 HALIDE_ALWAYS_INLINE bool match_args(double, const Call &c, MatcherState &state) const noexcept {
1355 return true;
1356 }
1357
1358 template<uint32_t bound>
1359 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1360 if (e.node_type != IRNodeType::Call) {
1361 return false;
1362 }
1363 const Call &c = (const Call &)e;
1364 return (c.is_intrinsic(intrin) &&
1365 ((optional_type_hint == Type()) || optional_type_hint == e.type) &&
1366 match_args<0, bound>(0, c, state));
1367 }
1368
1369 template<int i,
1370 typename = typename std::enable_if<(i < sizeof...(Args))>::type>
1371 HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const {
1373 if (i + 1 < sizeof...(Args)) {
1374 s << ", ";
1375 }
1376 print_args<i + 1>(0, s);
1377 }
1378
1379 template<int i>
1380 HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const {
1381 }
1382
1384 void print_args(std::ostream &s) const {
1385 print_args<0>(0, s);
1386 }
1387
1390 Expr arg0 = std::get<0>(args).make(state, type_hint);
1391 if (intrin == Call::likely) {
1392 return likely(std::move(arg0));
1393 } else if (intrin == Call::likely_if_innermost) {
1394 return likely_if_innermost(std::move(arg0));
1395 } else if (intrin == Call::abs) {
1396 return abs(std::move(arg0));
1397 } else if (intrin == Call::saturating_cast) {
1398 return saturating_cast(optional_type_hint, std::move(arg0));
1399 }
1400
1401 Expr arg1 = std::get<const_min(1, sizeof...(Args) - 1)>(args).make(state, type_hint);
1402 if (intrin == Call::absd) {
1403 return absd(std::move(arg0), std::move(arg1));
1404 } else if (intrin == Call::widen_right_add) {
1405 return widen_right_add(std::move(arg0), std::move(arg1));
1406 } else if (intrin == Call::widen_right_mul) {
1407 return widen_right_mul(std::move(arg0), std::move(arg1));
1408 } else if (intrin == Call::widen_right_sub) {
1409 return widen_right_sub(std::move(arg0), std::move(arg1));
1410 } else if (intrin == Call::widening_add) {
1411 return widening_add(std::move(arg0), std::move(arg1));
1412 } else if (intrin == Call::widening_sub) {
1413 return widening_sub(std::move(arg0), std::move(arg1));
1414 } else if (intrin == Call::widening_mul) {
1415 return widening_mul(std::move(arg0), std::move(arg1));
1416 } else if (intrin == Call::saturating_add) {
1417 return saturating_add(std::move(arg0), std::move(arg1));
1418 } else if (intrin == Call::saturating_sub) {
1419 return saturating_sub(std::move(arg0), std::move(arg1));
1420 } else if (intrin == Call::halving_add) {
1421 return halving_add(std::move(arg0), std::move(arg1));
1422 } else if (intrin == Call::halving_sub) {
1423 return halving_sub(std::move(arg0), std::move(arg1));
1424 } else if (intrin == Call::rounding_halving_add) {
1425 return rounding_halving_add(std::move(arg0), std::move(arg1));
1426 } else if (intrin == Call::shift_left) {
1427 return std::move(arg0) << std::move(arg1);
1428 } else if (intrin == Call::shift_right) {
1429 return std::move(arg0) >> std::move(arg1);
1430 } else if (intrin == Call::rounding_shift_left) {
1431 return rounding_shift_left(std::move(arg0), std::move(arg1));
1432 } else if (intrin == Call::rounding_shift_right) {
1433 return rounding_shift_right(std::move(arg0), std::move(arg1));
1434 }
1435
1436 Expr arg2 = std::get<const_min(2, sizeof...(Args) - 1)>(args).make(state, type_hint);
1438 return mul_shift_right(std::move(arg0), std::move(arg1), std::move(arg2));
1439 } else if (intrin == Call::rounding_mul_shift_right) {
1440 return rounding_mul_shift_right(std::move(arg0), std::move(arg1), std::move(arg2));
1441 }
1442
1443 internal_error << "Unhandled intrinsic in IRMatcher: " << intrin;
1444 return Expr();
1445 }
1446
1447 constexpr static bool foldable = true;
1448
1451 // Assuming the args have the same type as the intrinsic is incorrect in
1452 // general. But for the intrinsics we can fold (just shifts), the LHS
1453 // has the same type as the intrinsic, and we can always treat the RHS
1454 // as a signed int, because we're using 64 bits for it.
1455 std::get<0>(args).make_folded_const(val, ty, state);
1458 // We can just directly get the second arg here, because we only want to
1459 // instantiate this method for shifts, which have two args.
1460 std::get<1>(args).make_folded_const(arg1, signed_ty, state);
1461
1462 if (intrin == Call::shift_left) {
1463 if (arg1.u.i64 < 0) {
1464 if (ty.code == halide_type_int) {
1465 // Arithmetic shift
1466 val.u.i64 >>= -arg1.u.i64;
1467 } else {
1468 // Logical shift
1469 val.u.u64 >>= -arg1.u.i64;
1470 }
1471 } else {
1472 val.u.u64 <<= arg1.u.i64;
1473 }
1474 } else if (intrin == Call::shift_right) {
1475 if (arg1.u.i64 > 0) {
1476 if (ty.code == halide_type_int) {
1477 // Arithmetic shift
1478 val.u.i64 >>= arg1.u.i64;
1479 } else {
1480 // Logical shift
1481 val.u.u64 >>= arg1.u.i64;
1482 }
1483 } else {
1484 val.u.u64 <<= -arg1.u.i64;
1485 }
1486 } else {
1487 internal_error << "Folding not implemented for intrinsic: " << intrin;
1488 }
1489 }
1490
1495};
1496
1497template<typename... Args>
1498std::ostream &operator<<(std::ostream &s, const Intrin<Args...> &op) {
1499 s << op.intrin << "(";
1500 op.print_args(s);
1501 s << ")";
1502 return s;
1503}
1504
1505template<typename... Args>
1506HALIDE_ALWAYS_INLINE auto intrin(Call::IntrinsicOp intrinsic_op, Args... args) noexcept -> Intrin<decltype(pattern_arg(args))...> {
1507 return {intrinsic_op, pattern_arg(args)...};
1508}
1509
1510template<typename A, typename B>
1511auto widen_right_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1513}
1514template<typename A, typename B>
1515auto widen_right_mul(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1517}
1518template<typename A, typename B>
1519auto widen_right_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1521}
1522
1523template<typename A, typename B>
1524auto widening_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1526}
1527template<typename A, typename B>
1528auto widening_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1530}
1531template<typename A, typename B>
1532auto widening_mul(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1534}
1535template<typename A, typename B>
1536auto saturating_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1538}
1539template<typename A, typename B>
1540auto saturating_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1542}
1543template<typename A>
1544auto saturating_cast(const Type &t, A &&a) noexcept -> Intrin<decltype(pattern_arg(a))> {
1547 return p;
1548}
1549template<typename A, typename B>
1550auto halving_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1551 return {Call::halving_add, pattern_arg(a), pattern_arg(b)};
1552}
1553template<typename A, typename B>
1554auto halving_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1555 return {Call::halving_sub, pattern_arg(a), pattern_arg(b)};
1556}
1557template<typename A, typename B>
1558auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1560}
1561template<typename A, typename B>
1562auto shift_left(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1563 return {Call::shift_left, pattern_arg(a), pattern_arg(b)};
1564}
1565template<typename A, typename B>
1566auto shift_right(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1567 return {Call::shift_right, pattern_arg(a), pattern_arg(b)};
1568}
1569template<typename A, typename B>
1570auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1572}
1573template<typename A, typename B>
1574auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1576}
1577template<typename A, typename B, typename C>
1578auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1580}
1581template<typename A, typename B, typename C>
1582auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1584}
1585
1586template<typename A>
1587struct NotOp {
1588 struct pattern_tag {};
1590
1592
1595 constexpr static bool canonical = A::canonical;
1596
1597 template<uint32_t bound>
1598 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1599 if (e.node_type != IRNodeType::Not) {
1600 return false;
1601 }
1602 const Not &op = (const Not &)e;
1603 return (a.template match<bound>(*op.a.get(), state));
1604 }
1605
1606 template<uint32_t bound, typename A2>
1607 HALIDE_ALWAYS_INLINE bool match(const NotOp<A2> &op, MatcherState &state) const noexcept {
1608 return a.template match<bound>(unwrap(op.a), state);
1609 }
1610
1613 return Not::make(a.make(state, type_hint));
1614 }
1615
1616 constexpr static bool foldable = A::foldable;
1617
1618 template<typename A1 = A>
1620 a.make_folded_const(val, ty, state);
1621 val.u.u64 = ~val.u.u64;
1622 val.u.u64 &= 1;
1623 }
1624};
1625
1626template<typename A>
1627HALIDE_ALWAYS_INLINE auto operator!(A &&a) noexcept -> NotOp<decltype(pattern_arg(a))> {
1629 return {pattern_arg(a)};
1630}
1631
1632template<typename A>
1637
1638template<typename A>
1639inline std::ostream &operator<<(std::ostream &s, const NotOp<A> &op) {
1640 s << "!(" << op.a << ")";
1641 return s;
1642}
1643
1644template<typename C, typename T, typename F>
1645struct SelectOp {
1646 struct pattern_tag {};
1648 T t;
1650
1652
1655
1656 constexpr static bool canonical = C::canonical && T::canonical && F::canonical;
1657
1658 template<uint32_t bound>
1659 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1660 if (e.node_type != Select::_node_type) {
1661 return false;
1662 }
1663 const Select &op = (const Select &)e;
1664 return (c.template match<bound>(*op.condition.get(), state) &&
1665 t.template match<bound | bindings<C>::mask>(*op.true_value.get(), state) &&
1666 f.template match<bound | bindings<C>::mask | bindings<T>::mask>(*op.false_value.get(), state));
1667 }
1668 template<uint32_t bound, typename C2, typename T2, typename F2>
1669 HALIDE_ALWAYS_INLINE bool match(const SelectOp<C2, T2, F2> &instance, MatcherState &state) const noexcept {
1670 return (c.template match<bound>(unwrap(instance.c), state) &&
1671 t.template match<bound | bindings<C>::mask>(unwrap(instance.t), state) &&
1672 f.template match<bound | bindings<C>::mask | bindings<T>::mask>(unwrap(instance.f), state));
1673 }
1674
1677 return Select::make(c.make(state, {}), t.make(state, type_hint), f.make(state, type_hint));
1678 }
1679
1680 constexpr static bool foldable = C::foldable && T::foldable && F::foldable;
1681
1682 template<typename C1 = C>
1686 c.make_folded_const(c_val, c_ty, state);
1687 if ((c_val.u.u64 & 1) == 1) {
1688 t.make_folded_const(val, ty, state);
1689 } else {
1690 f.make_folded_const(val, ty, state);
1691 }
1693 }
1694};
1695
1696template<typename C, typename T, typename F>
1697std::ostream &operator<<(std::ostream &s, const SelectOp<C, T, F> &op) {
1698 s << "select(" << op.c << ", " << op.t << ", " << op.f << ")";
1699 return s;
1700}
1701
1702template<typename C, typename T, typename F>
1703HALIDE_ALWAYS_INLINE auto select(C &&c, T &&t, F &&f) noexcept -> SelectOp<decltype(pattern_arg(c)), decltype(pattern_arg(t)), decltype(pattern_arg(f))> {
1707 return {pattern_arg(c), pattern_arg(t), pattern_arg(f)};
1708}
1709
1710template<typename A, typename B>
1712 struct pattern_tag {};
1715
1717
1720
1721 constexpr static bool canonical = A::canonical && B::canonical;
1722
1723 template<uint32_t bound>
1724 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1725 if (e.node_type == Broadcast::_node_type) {
1726 const Broadcast &op = (const Broadcast &)e;
1727 if (a.template match<bound>(*op.value.get(), state) &&
1728 lanes.template match<bound>(op.lanes, state)) {
1729 return true;
1730 }
1731 }
1732 return false;
1733 }
1734
1735 template<uint32_t bound, typename A2, typename B2>
1736 HALIDE_ALWAYS_INLINE bool match(const BroadcastOp<A2, B2> &op, MatcherState &state) const noexcept {
1737 return (a.template match<bound>(unwrap(op.a), state) &&
1738 lanes.template match<bound | bindings<A>::mask>(unwrap(op.lanes), state));
1739 }
1740
1745 lanes.make_folded_const(lanes_val, ty, state);
1746 int32_t l = (int32_t)lanes_val.u.i64;
1747 type_hint.lanes /= l;
1748 Expr val = a.make(state, type_hint);
1749 if (l == 1) {
1750 return val;
1751 } else {
1752 return Broadcast::make(std::move(val), l);
1753 }
1754 }
1755
1756 constexpr static bool foldable = false;
1757
1758 template<typename A1 = A>
1762 lanes.make_folded_const(lanes_val, lanes_ty, state);
1763 uint16_t l = (uint16_t)lanes_val.u.i64;
1764 a.make_folded_const(val, ty, state);
1765 ty.lanes = l | (ty.lanes & MatcherState::special_values_mask);
1766 }
1767};
1768
1769template<typename A, typename B>
1770inline std::ostream &operator<<(std::ostream &s, const BroadcastOp<A, B> &op) {
1771 s << "broadcast(" << op.a << ", " << op.lanes << ")";
1772 return s;
1773}
1774
1775template<typename A, typename B>
1776HALIDE_ALWAYS_INLINE auto broadcast(A &&a, B lanes) noexcept -> BroadcastOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes))> {
1778 return {pattern_arg(a), pattern_arg(lanes)};
1779}
1780
1781template<typename A, typename B, typename C>
1782struct RampOp {
1783 struct pattern_tag {};
1787
1789
1792
1793 constexpr static bool canonical = A::canonical && B::canonical && C::canonical;
1794
1795 template<uint32_t bound>
1796 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1797 if (e.node_type != Ramp::_node_type) {
1798 return false;
1799 }
1800 const Ramp &op = (const Ramp &)e;
1801 if (a.template match<bound>(*op.base.get(), state) &&
1802 b.template match<bound | bindings<A>::mask>(*op.stride.get(), state) &&
1803 lanes.template match<bound | bindings<A>::mask | bindings<B>::mask>(op.lanes, state)) {
1804 return true;
1805 } else {
1806 return false;
1807 }
1808 }
1809
1810 template<uint32_t bound, typename A2, typename B2, typename C2>
1811 HALIDE_ALWAYS_INLINE bool match(const RampOp<A2, B2, C2> &op, MatcherState &state) const noexcept {
1812 return (a.template match<bound>(unwrap(op.a), state) &&
1813 b.template match<bound | bindings<A>::mask>(unwrap(op.b), state) &&
1814 lanes.template match<bound | bindings<A>::mask | bindings<B>::mask>(unwrap(op.lanes), state));
1815 }
1816
1821 lanes.make_folded_const(lanes_val, ty, state);
1822 int32_t l = (int32_t)lanes_val.u.i64;
1823 type_hint.lanes /= l;
1824 Expr ea, eb;
1825 eb = b.make(state, type_hint);
1826 ea = a.make(state, eb.type());
1827 return Ramp::make(std::move(ea), std::move(eb), l);
1828 }
1829
1830 constexpr static bool foldable = false;
1831};
1832
1833template<typename A, typename B, typename C>
1834std::ostream &operator<<(std::ostream &s, const RampOp<A, B, C> &op) {
1835 s << "ramp(" << op.a << ", " << op.b << ", " << op.lanes << ")";
1836 return s;
1837}
1838
1839template<typename A, typename B, typename C>
1840HALIDE_ALWAYS_INLINE auto ramp(A &&a, B &&b, C &&c) noexcept -> RampOp<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1844 return {pattern_arg(a), pattern_arg(b), pattern_arg(c)};
1845}
1846
1847template<typename A, typename B, VectorReduce::Operator reduce_op>
1849 struct pattern_tag {};
1852
1854
1857 constexpr static bool canonical = A::canonical;
1858
1859 template<uint32_t bound>
1860 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1861 if (e.node_type == VectorReduce::_node_type) {
1862 const VectorReduce &op = (const VectorReduce &)e;
1863 if (op.op == reduce_op &&
1864 a.template match<bound>(*op.value.get(), state) &&
1865 lanes.template match<bound | bindings<A>::mask>(op.type.lanes(), state)) {
1866 return true;
1867 }
1868 }
1869 return false;
1870 }
1871
1872 template<uint32_t bound, typename A2, typename B2, VectorReduce::Operator reduce_op_2>
1874 return (reduce_op == reduce_op_2 &&
1875 a.template match<bound>(unwrap(op.a), state) &&
1876 lanes.template match<bound | bindings<A>::mask>(unwrap(op.lanes), state));
1877 }
1878
1883 lanes.make_folded_const(lanes_val, ty, state);
1884 int l = (int)lanes_val.u.i64;
1885 return VectorReduce::make(reduce_op, a.make(state, type_hint), l);
1886 }
1887
1888 constexpr static bool foldable = false;
1889};
1890
1891template<typename A, typename B, VectorReduce::Operator reduce_op>
1892inline std::ostream &operator<<(std::ostream &s, const VectorReduceOp<A, B, reduce_op> &op) {
1893 s << "vector_reduce(" << reduce_op << ", " << op.a << ", " << op.lanes << ")";
1894 return s;
1895}
1896
1897template<typename A, typename B>
1898HALIDE_ALWAYS_INLINE auto h_add(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Add> {
1900 return {pattern_arg(a), pattern_arg(lanes)};
1901}
1902
1903template<typename A, typename B>
1904HALIDE_ALWAYS_INLINE auto h_min(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Min> {
1906 return {pattern_arg(a), pattern_arg(lanes)};
1907}
1908
1909template<typename A, typename B>
1910HALIDE_ALWAYS_INLINE auto h_max(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Max> {
1912 return {pattern_arg(a), pattern_arg(lanes)};
1913}
1914
1915template<typename A, typename B>
1916HALIDE_ALWAYS_INLINE auto h_and(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::And> {
1918 return {pattern_arg(a), pattern_arg(lanes)};
1919}
1920
1921template<typename A, typename B>
1922HALIDE_ALWAYS_INLINE auto h_or(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Or> {
1924 return {pattern_arg(a), pattern_arg(lanes)};
1925}
1926
1927template<typename A>
1928struct NegateOp {
1929 struct pattern_tag {};
1931
1933
1936
1937 constexpr static bool canonical = A::canonical;
1938
1939 template<uint32_t bound>
1940 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1941 if (e.node_type != Sub::_node_type) {
1942 return false;
1943 }
1944 const Sub &op = (const Sub &)e;
1945 return (a.template match<bound>(*op.b.get(), state) &&
1946 is_const_zero(op.a));
1947 }
1948
1949 template<uint32_t bound, typename A2>
1950 HALIDE_ALWAYS_INLINE bool match(NegateOp<A2> &&p, MatcherState &state) const noexcept {
1951 return a.template match<bound>(unwrap(p.a), state);
1952 }
1953
1956 Expr ea = a.make(state, type_hint);
1957 Expr z = make_zero(ea.type());
1958 return Sub::make(std::move(z), std::move(ea));
1959 }
1960
1961 constexpr static bool foldable = A::foldable;
1962
1963 template<typename A1 = A>
1965 a.make_folded_const(val, ty, state);
1966 int dead_bits = 64 - ty.bits;
1967 switch (ty.code) {
1968 case halide_type_int:
1969 if (ty.bits >= 32 && val.u.u64 && (val.u.u64 << (65 - ty.bits)) == 0) {
1970 // Trying to negate the most negative signed int for a no-overflow type.
1972 } else {
1973 // Negate, drop the high bits, and then sign-extend them back
1974 val.u.i64 = int64_t(uint64_t(-val.u.i64) << dead_bits) >> dead_bits;
1975 }
1976 break;
1977 case halide_type_uint:
1978 val.u.u64 = ((-val.u.u64) << dead_bits) >> dead_bits;
1979 break;
1980 case halide_type_float:
1981 case halide_type_bfloat:
1982 val.u.f64 = -val.u.f64;
1983 break;
1984 default:
1985 // unreachable
1986 ;
1987 }
1988 }
1989};
1990
1991template<typename A>
1992std::ostream &operator<<(std::ostream &s, const NegateOp<A> &op) {
1993 s << "-" << op.a;
1994 return s;
1995}
1996
1997template<typename A>
1998HALIDE_ALWAYS_INLINE auto operator-(A &&a) noexcept -> NegateOp<decltype(pattern_arg(a))> {
2000 return {pattern_arg(a)};
2001}
2002
2003template<typename A>
2008
2009template<typename A>
2010struct CastOp {
2011 struct pattern_tag {};
2014
2016
2019 constexpr static bool canonical = A::canonical;
2020
2021 template<uint32_t bound>
2022 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2023 if (e.node_type != Cast::_node_type) {
2024 return false;
2025 }
2026 const Cast &op = (const Cast &)e;
2027 return (e.type == t &&
2028 a.template match<bound>(*op.value.get(), state));
2029 }
2030 template<uint32_t bound, typename A2>
2031 HALIDE_ALWAYS_INLINE bool match(const CastOp<A2> &op, MatcherState &state) const noexcept {
2032 return t == op.t && a.template match<bound>(unwrap(op.a), state);
2033 }
2034
2037 return cast(t, a.make(state, {}));
2038 }
2039
2040 constexpr static bool foldable = false;
2041};
2042
2043template<typename A>
2044std::ostream &operator<<(std::ostream &s, const CastOp<A> &op) {
2045 s << "cast(" << op.t << ", " << op.a << ")";
2046 return s;
2047}
2048
2049template<typename A>
2050HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp<decltype(pattern_arg(a))> {
2052 return {t, pattern_arg(a)};
2053}
2054
2055template<typename A>
2056struct WidenOp {
2057 struct pattern_tag {};
2059
2061
2064 constexpr static bool canonical = A::canonical;
2065
2066 template<uint32_t bound>
2067 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2068 if (e.node_type != Cast::_node_type) {
2069 return false;
2070 }
2071 const Cast &op = (const Cast &)e;
2072 return (e.type == op.value.type().widen() &&
2073 a.template match<bound>(*op.value.get(), state));
2074 }
2075 template<uint32_t bound, typename A2>
2076 HALIDE_ALWAYS_INLINE bool match(const WidenOp<A2> &op, MatcherState &state) const noexcept {
2077 return a.template match<bound>(unwrap(op.a), state);
2078 }
2079
2082 Expr e = a.make(state, {});
2083 Type w = e.type().widen();
2084 return cast(w, std::move(e));
2085 }
2086
2087 constexpr static bool foldable = false;
2088};
2089
2090template<typename A>
2091std::ostream &operator<<(std::ostream &s, const WidenOp<A> &op) {
2092 s << "widen(" << op.a << ")";
2093 return s;
2094}
2095
2096template<typename A>
2097HALIDE_ALWAYS_INLINE auto widen(A &&a) noexcept -> WidenOp<decltype(pattern_arg(a))> {
2099 return {pattern_arg(a)};
2100}
2101
2102template<typename Vec, typename Base, typename Stride, typename Lanes>
2103struct SliceOp {
2104 struct pattern_tag {};
2109
2110 static constexpr uint32_t binds = Vec::binds | Base::binds | Stride::binds | Lanes::binds;
2111
2114 constexpr static bool canonical = Vec::canonical && Base::canonical && Stride::canonical && Lanes::canonical;
2115
2116 template<uint32_t bound>
2117 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2118 if (e.node_type != IRNodeType::Shuffle) {
2119 return false;
2120 }
2121 const Shuffle &v = (const Shuffle &)e;
2122 return v.vectors.size() == 1 &&
2123 v.is_slice() &&
2124 vec.template match<bound>(*v.vectors[0].get(), state) &&
2125 base.template match<bound | bindings<Vec>::mask>(v.slice_begin(), state) &&
2126 stride.template match<bound | bindings<Vec>::mask | bindings<Base>::mask>(v.slice_stride(), state) &&
2128 }
2129
2134 base.make_folded_const(base_val, ty, state);
2135 int b = (int)base_val.u.i64;
2136 stride.make_folded_const(stride_val, ty, state);
2137 int s = (int)stride_val.u.i64;
2138 lanes.make_folded_const(lanes_val, ty, state);
2139 int l = (int)lanes_val.u.i64;
2140 return Shuffle::make_slice(vec.make(state, type_hint), b, s, l);
2141 }
2142
2143 constexpr static bool foldable = false;
2144
2147 : vec(v), base(b), stride(s), lanes(l) {
2148 static_assert(Base::foldable, "Base of slice should consist only of operations that constant-fold");
2149 static_assert(Stride::foldable, "Stride of slice should consist only of operations that constant-fold");
2150 static_assert(Lanes::foldable, "Lanes of slice should consist only of operations that constant-fold");
2151 }
2152};
2153
2154template<typename Vec, typename Base, typename Stride, typename Lanes>
2155std::ostream &operator<<(std::ostream &s, const SliceOp<Vec, Base, Stride, Lanes> &op) {
2156 s << "slice(" << op.vec << ", " << op.base << ", " << op.stride << ", " << op.lanes << ")";
2157 return s;
2158}
2159
2160template<typename Vec, typename Base, typename Stride, typename Lanes>
2161HALIDE_ALWAYS_INLINE auto slice(Vec vec, Base base, Stride stride, Lanes lanes) noexcept
2162 -> SliceOp<decltype(pattern_arg(vec)), decltype(pattern_arg(base)), decltype(pattern_arg(stride)), decltype(pattern_arg(lanes))> {
2163 return {pattern_arg(vec), pattern_arg(base), pattern_arg(stride), pattern_arg(lanes)};
2164}
2165
2166template<typename A>
2167struct Fold {
2168 struct pattern_tag {};
2170
2172
2175 constexpr static bool canonical = true;
2176
2181 a.make_folded_const(c, ty, state);
2182
2183 // The result of the fold may have an underspecified type
2184 // (e.g. because it's from an int literal). Make the type code
2185 // and bits match the required type, if there is one (we can
2186 // tell from the bits field).
2187 if (type_hint.bits) {
2188 if (((int)ty.code == (int)halide_type_int) &&
2189 ((int)type_hint.code == (int)halide_type_float)) {
2190 int64_t x = c.u.i64;
2191 c.u.f64 = (double)x;
2192 }
2193 ty.code = type_hint.code;
2194 ty.bits = type_hint.bits;
2195 }
2196
2197 return make_const_expr(c, ty);
2198 }
2199
2200 constexpr static bool foldable = A::foldable;
2201
2202 template<typename A1 = A>
2204 a.make_folded_const(val, ty, state);
2205 }
2206};
2207
2208template<typename A>
2209HALIDE_ALWAYS_INLINE auto fold(A &&a) noexcept -> Fold<decltype(pattern_arg(a))> {
2211 return {pattern_arg(a)};
2212}
2213
2214template<typename A>
2215std::ostream &operator<<(std::ostream &s, const Fold<A> &op) {
2216 s << "fold(" << op.a << ")";
2217 return s;
2218}
2219
2220template<typename A>
2222 struct pattern_tag {};
2224
2226
2227 // This rule is a predicate, so it always evaluates to a boolean,
2228 // which has IRNodeType UIntImm
2231 constexpr static bool canonical = true;
2232
2233 constexpr static bool foldable = A::foldable;
2234
2235 template<typename A1 = A>
2237 a.make_folded_const(val, ty, state);
2238 ty.code = halide_type_uint;
2239 ty.bits = 64;
2240 val.u.u64 = (ty.lanes & MatcherState::special_values_mask) != 0;
2241 ty.lanes = 1;
2242 }
2243};
2244
2245template<typename A>
2246HALIDE_ALWAYS_INLINE auto overflows(A &&a) noexcept -> Overflows<decltype(pattern_arg(a))> {
2248 return {pattern_arg(a)};
2249}
2250
2251template<typename A>
2252std::ostream &operator<<(std::ostream &s, const Overflows<A> &op) {
2253 s << "overflows(" << op.a << ")";
2254 return s;
2255}
2256
2257struct Overflow {
2258 struct pattern_tag {};
2259
2260 constexpr static uint32_t binds = 0;
2261
2262 // Overflow is an intrinsic, represented as a Call node
2265 constexpr static bool canonical = true;
2266
2267 template<uint32_t bound>
2268 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2269 if (e.node_type != Call::_node_type) {
2270 return false;
2271 }
2272 const Call &op = (const Call &)e;
2274 }
2275
2281
2282 constexpr static bool foldable = true;
2283
2286 val.u.u64 = 0;
2288 }
2289};
2290
2291inline std::ostream &operator<<(std::ostream &s, const Overflow &op) {
2292 s << "overflow()";
2293 return s;
2294}
2295
2296template<typename A>
2297struct IsConst {
2298 struct pattern_tag {};
2299
2301
2302 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2305 constexpr static bool canonical = true;
2306
2310
2311 constexpr static bool foldable = true;
2312
2313 template<typename A1 = A>
2315 Expr e = a.make(state, {});
2316 ty.code = halide_type_uint;
2317 ty.bits = 64;
2318 ty.lanes = 1;
2319 if (check_v) {
2320 val.u.u64 = ::Halide::Internal::is_const(e, v) ? 1 : 0;
2321 } else {
2322 val.u.u64 = ::Halide::Internal::is_const(e) ? 1 : 0;
2323 }
2324 }
2325};
2326
2327template<typename A>
2328HALIDE_ALWAYS_INLINE auto is_const(A &&a) noexcept -> IsConst<decltype(pattern_arg(a))> {
2330 return {pattern_arg(a), false, 0};
2331}
2332
2333template<typename A>
2334HALIDE_ALWAYS_INLINE auto is_const(A &&a, int64_t value) noexcept -> IsConst<decltype(pattern_arg(a))> {
2336 return {pattern_arg(a), true, value};
2337}
2338
2339template<typename A>
2340std::ostream &operator<<(std::ostream &s, const IsConst<A> &op) {
2341 if (op.check_v) {
2342 s << "is_const(" << op.a << ")";
2343 } else {
2344 s << "is_const(" << op.a << ", " << op.v << ")";
2345 }
2346 return s;
2347}
2348
2349template<typename A, typename Prover>
2350struct CanProve {
2351 struct pattern_tag {};
2353 Prover *prover; // An existing simplifying mutator
2354
2356
2357 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2360 constexpr static bool canonical = true;
2361
2362 constexpr static bool foldable = true;
2363
2364 // Includes a raw call to an inlined make method, so don't inline.
2366 Expr condition = a.make(state, {});
2367 condition = prover->mutate(condition, nullptr);
2368 val.u.u64 = is_const_one(condition);
2369 ty.code = halide_type_uint;
2370 ty.bits = 1;
2371 ty.lanes = condition.type().lanes();
2372 }
2373};
2374
2375template<typename A, typename Prover>
2376HALIDE_ALWAYS_INLINE auto can_prove(A &&a, Prover *p) noexcept -> CanProve<decltype(pattern_arg(a)), Prover> {
2378 return {pattern_arg(a), p};
2379}
2380
2381template<typename A, typename Prover>
2382std::ostream &operator<<(std::ostream &s, const CanProve<A, Prover> &op) {
2383 s << "can_prove(" << op.a << ")";
2384 return s;
2385}
2386
2387template<typename A>
2388struct IsFloat {
2389 struct pattern_tag {};
2391
2393
2394 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2397 constexpr static bool canonical = true;
2398
2399 constexpr static bool foldable = true;
2400
2403 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2404 Type t = a.make(state, {}).type();
2405 val.u.u64 = t.is_float();
2406 ty.code = halide_type_uint;
2407 ty.bits = 1;
2408 ty.lanes = t.lanes();
2409 }
2410};
2411
2412template<typename A>
2413HALIDE_ALWAYS_INLINE auto is_float(A &&a) noexcept -> IsFloat<decltype(pattern_arg(a))> {
2415 return {pattern_arg(a)};
2416}
2417
2418template<typename A>
2419std::ostream &operator<<(std::ostream &s, const IsFloat<A> &op) {
2420 s << "is_float(" << op.a << ")";
2421 return s;
2422}
2423
2424template<typename A>
2425struct IsInt {
2426 struct pattern_tag {};
2429
2431
2432 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2435 constexpr static bool canonical = true;
2436
2437 constexpr static bool foldable = true;
2438
2441 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2442 Type t = a.make(state, {}).type();
2443 val.u.u64 = t.is_int() && (bits == 0 || t.bits() == bits) && (lanes == 0 || t.lanes() == lanes);
2444 ty.code = halide_type_uint;
2445 ty.bits = 1;
2446 ty.lanes = t.lanes();
2447 }
2448};
2449
2450template<typename A>
2451HALIDE_ALWAYS_INLINE auto is_int(A &&a, int bits = 0, int lanes = 0) noexcept -> IsInt<decltype(pattern_arg(a))> {
2453 return {pattern_arg(a), bits, lanes};
2454}
2455
2456template<typename A>
2457std::ostream &operator<<(std::ostream &s, const IsInt<A> &op) {
2458 s << "is_int(" << op.a;
2459 if (op.bits > 0) {
2460 s << ", " << op.bits;
2461 }
2462 if (op.lanes > 0) {
2463 s << ", " << op.lanes;
2464 }
2465 s << ")";
2466 return s;
2467}
2468
2469template<typename A>
2470struct IsUInt {
2471 struct pattern_tag {};
2474
2476
2477 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2480 constexpr static bool canonical = true;
2481
2482 constexpr static bool foldable = true;
2483
2486 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2487 Type t = a.make(state, {}).type();
2488 val.u.u64 = t.is_uint() && (bits == 0 || t.bits() == bits) && (lanes == 0 || t.lanes() == lanes);
2489 ty.code = halide_type_uint;
2490 ty.bits = 1;
2491 ty.lanes = t.lanes();
2492 }
2493};
2494
2495template<typename A>
2496HALIDE_ALWAYS_INLINE auto is_uint(A &&a, int bits = 0, int lanes = 0) noexcept -> IsUInt<decltype(pattern_arg(a))> {
2498 return {pattern_arg(a), bits, lanes};
2499}
2500
2501template<typename A>
2502std::ostream &operator<<(std::ostream &s, const IsUInt<A> &op) {
2503 s << "is_uint(" << op.a;
2504 if (op.bits > 0) {
2505 s << ", " << op.bits;
2506 }
2507 if (op.lanes > 0) {
2508 s << ", " << op.lanes;
2509 }
2510 s << ")";
2511 return s;
2512}
2513
2514template<typename A>
2515struct IsScalar {
2516 struct pattern_tag {};
2518
2520
2521 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2524 constexpr static bool canonical = true;
2525
2526 constexpr static bool foldable = true;
2527
2530 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2531 Type t = a.make(state, {}).type();
2532 val.u.u64 = t.is_scalar();
2533 ty.code = halide_type_uint;
2534 ty.bits = 1;
2535 ty.lanes = t.lanes();
2536 }
2537};
2538
2539template<typename A>
2540HALIDE_ALWAYS_INLINE auto is_scalar(A &&a) noexcept -> IsScalar<decltype(pattern_arg(a))> {
2542 return {pattern_arg(a)};
2543}
2544
2545template<typename A>
2546std::ostream &operator<<(std::ostream &s, const IsScalar<A> &op) {
2547 s << "is_scalar(" << op.a << ")";
2548 return s;
2549}
2550
2551template<typename A>
2553 struct pattern_tag {};
2555
2557
2558 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2561 constexpr static bool canonical = true;
2562
2563 constexpr static bool foldable = true;
2564
2567 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2568 a.make_folded_const(val, ty, state);
2569 const uint64_t max_bits = (uint64_t)(-1) >> (64 - ty.bits + (ty.code == halide_type_int));
2570 if (ty.code == halide_type_uint || ty.code == halide_type_int) {
2571 val.u.u64 = (val.u.u64 == max_bits);
2572 } else {
2573 val.u.u64 = 0;
2574 }
2575 ty.code = halide_type_uint;
2576 ty.bits = 1;
2577 }
2578};
2579
2580template<typename A>
2581HALIDE_ALWAYS_INLINE auto is_max_value(A &&a) noexcept -> IsMaxValue<decltype(pattern_arg(a))> {
2583 return {pattern_arg(a)};
2584}
2585
2586template<typename A>
2587std::ostream &operator<<(std::ostream &s, const IsMaxValue<A> &op) {
2588 s << "is_max_value(" << op.a << ")";
2589 return s;
2590}
2591
2592template<typename A>
2594 struct pattern_tag {};
2596
2598
2599 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2602 constexpr static bool canonical = true;
2603
2604 constexpr static bool foldable = true;
2605
2608 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2609 a.make_folded_const(val, ty, state);
2610 if (ty.code == halide_type_int) {
2611 const uint64_t min_bits = (uint64_t)(-1) << (ty.bits - 1);
2612 val.u.u64 = (val.u.u64 == min_bits);
2613 } else if (ty.code == halide_type_uint) {
2614 val.u.u64 = (val.u.u64 == 0);
2615 } else {
2616 val.u.u64 = 0;
2617 }
2618 ty.code = halide_type_uint;
2619 ty.bits = 1;
2620 }
2621};
2622
2623template<typename A>
2624HALIDE_ALWAYS_INLINE auto is_min_value(A &&a) noexcept -> IsMinValue<decltype(pattern_arg(a))> {
2626 return {pattern_arg(a)};
2627}
2628
2629template<typename A>
2630std::ostream &operator<<(std::ostream &s, const IsMinValue<A> &op) {
2631 s << "is_min_value(" << op.a << ")";
2632 return s;
2633}
2634
2635template<typename A>
2636struct LanesOf {
2637 struct pattern_tag {};
2639
2641
2642 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2645 constexpr static bool canonical = true;
2646
2647 constexpr static bool foldable = true;
2648
2651 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2652 Type t = a.make(state, {}).type();
2653 val.u.u64 = t.lanes();
2654 ty.code = halide_type_uint;
2655 ty.bits = 32;
2656 ty.lanes = 1;
2657 }
2658};
2659
2660template<typename A>
2661HALIDE_ALWAYS_INLINE auto lanes_of(A &&a) noexcept -> LanesOf<decltype(pattern_arg(a))> {
2663 return {pattern_arg(a)};
2664}
2665
2666template<typename A>
2667std::ostream &operator<<(std::ostream &s, const LanesOf<A> &op) {
2668 s << "lanes_of(" << op.a << ")";
2669 return s;
2670}
2671
2672// Verify properties of each rewrite rule. Currently just fuzz tests them.
2673template<typename Before,
2674 typename After,
2675 typename Predicate,
2676 typename = typename std::enable_if<std::decay<Before>::type::foldable &&
2677 std::decay<After>::type::foldable>::type>
2679 halide_type_t wildcard_type, halide_type_t output_type) noexcept {
2680
2681 // We only validate the rules in the scalar case
2682 wildcard_type.lanes = output_type.lanes = 1;
2683
2684 // Track which types this rule has been tested for before
2685 static std::set<uint32_t> tested;
2686
2687 if (!tested.insert(reinterpret_bits<uint32_t>(wildcard_type)).second) {
2688 return;
2689 }
2690
2691 // Print it in a form where it can be piped into a python/z3 validator
2692 debug(0) << "validate('" << before << "', '" << after << "', '" << pred << "', " << Type(wildcard_type) << ", " << Type(output_type) << ")\n";
2693
2694 // Substitute some random constants into the before and after
2695 // expressions and see if the rule holds true. This should catch
2696 // silly errors, but not necessarily corner cases.
2697 static std::mt19937_64 rng(0);
2698 MatcherState state;
2699
2700 Expr exprs[max_wild];
2701
2702 for (int trials = 0; trials < 100; trials++) {
2703 // We want to test small constants more frequently than
2704 // large ones, otherwise we'll just get coverage of
2705 // overflow rules.
2706 int shift = (int)(rng() & (wildcard_type.bits - 1));
2707
2708 for (int i = 0; i < max_wild; i++) {
2709 // Bind all the exprs and constants
2710 switch (wildcard_type.code) {
2711 case halide_type_uint: {
2712 // Normalize to the type's range by adding zero
2713 uint64_t val = constant_fold_bin_op<Add>(wildcard_type, (uint64_t)rng() >> shift, 0);
2714 state.set_bound_const(i, val, wildcard_type);
2715 val = constant_fold_bin_op<Add>(wildcard_type, (uint64_t)rng() >> shift, 0);
2716 exprs[i] = make_const(wildcard_type, val);
2717 state.set_binding(i, *exprs[i].get());
2718 } break;
2719 case halide_type_int: {
2720 int64_t val = constant_fold_bin_op<Add>(wildcard_type, (int64_t)rng() >> shift, 0);
2721 state.set_bound_const(i, val, wildcard_type);
2722 val = constant_fold_bin_op<Add>(wildcard_type, (int64_t)rng() >> shift, 0);
2723 exprs[i] = make_const(wildcard_type, val);
2724 } break;
2725 case halide_type_float:
2726 case halide_type_bfloat: {
2727 // Use a very narrow range of precise floats, so
2728 // that none of the rules a human is likely to
2729 // write have instabilities.
2730 double val = ((int64_t)(rng() & 15) - 8) / 2.0;
2731 state.set_bound_const(i, val, wildcard_type);
2732 val = ((int64_t)(rng() & 15) - 8) / 2.0;
2733 exprs[i] = make_const(wildcard_type, val);
2734 } break;
2735 default:
2736 return; // Don't care about handles
2737 }
2738 state.set_binding(i, *exprs[i].get());
2739 }
2740
2742 halide_type_t type = output_type;
2743 if (!evaluate_predicate(pred, state)) {
2744 continue;
2745 }
2746 before.make_folded_const(val_before, type, state);
2747 uint16_t lanes = type.lanes;
2748 after.make_folded_const(val_after, type, state);
2749 lanes |= type.lanes;
2750
2752 continue;
2753 }
2754
2755 bool ok = true;
2756 switch (output_type.code) {
2757 case halide_type_uint:
2758 // Compare normalized representations
2759 ok &= (constant_fold_bin_op<Add>(output_type, val_before.u.u64, 0) ==
2760 constant_fold_bin_op<Add>(output_type, val_after.u.u64, 0));
2761 break;
2762 case halide_type_int:
2763 ok &= (constant_fold_bin_op<Add>(output_type, val_before.u.i64, 0) ==
2764 constant_fold_bin_op<Add>(output_type, val_after.u.i64, 0));
2765 break;
2766 case halide_type_float:
2767 case halide_type_bfloat: {
2768 double error = std::abs(val_before.u.f64 - val_after.u.f64);
2769 // We accept an equal bit pattern (e.g. inf vs inf),
2770 // a small floating point difference, or turning a nan into not-a-nan.
2771 ok &= (error < 0.01 ||
2772 val_before.u.u64 == val_after.u.u64 ||
2773 std::isnan(val_before.u.f64));
2774 break;
2775 }
2776 default:
2777 return;
2778 }
2779
2780 if (!ok) {
2781 debug(0) << "Fails with values:\n";
2782 for (int i = 0; i < max_wild; i++) {
2784 state.get_bound_const(i, val, wildcard_type);
2785 debug(0) << " c" << i << ": " << make_const_expr(val, wildcard_type) << "\n";
2786 }
2787 for (int i = 0; i < max_wild; i++) {
2788 debug(0) << " _" << i << ": " << Expr(state.get_binding(i)) << "\n";
2789 }
2790 debug(0) << " Before: " << make_const_expr(val_before, output_type) << "\n";
2791 debug(0) << " After: " << make_const_expr(val_after, output_type) << "\n";
2792 debug(0) << val_before.u.u64 << " " << val_after.u.u64 << "\n";
2794 }
2795 }
2796}
2797
2798template<typename Before,
2799 typename After,
2800 typename Predicate,
2801 typename = typename std::enable_if<!(std::decay<Before>::type::foldable &&
2802 std::decay<After>::type::foldable)>::type>
2804 halide_type_t, halide_type_t, int dummy = 0) noexcept {
2805 // We can't verify rewrite rules that can't be constant-folded.
2806}
2807
2809bool evaluate_predicate(bool x, MatcherState &) noexcept {
2810 return x;
2811}
2812
2813template<typename Pattern,
2814 typename = typename enable_if_pattern<Pattern>::type>
2818 p.make_folded_const(c, ty, state);
2819 // Overflow counts as a failed predicate
2820 return (c.u.u64 != 0) && ((ty.lanes & MatcherState::special_values_mask) == 0);
2821}
2822
2823// #defines for testing
2824
2825// Print all successful or failed matches
2826#define HALIDE_DEBUG_MATCHED_RULES 0
2827#define HALIDE_DEBUG_UNMATCHED_RULES 0
2828
2829// Set to true if you want to fuzz test every rewrite passed to
2830// operator() to ensure the input and the output have the same value
2831// for lots of random values of the wildcards. Run
2832// correctness_simplify with this on.
2833#define HALIDE_FUZZ_TEST_RULES 0
2834
2835template<typename Instance>
2836struct Rewriter {
2842
2847
2848 template<typename After>
2850#if HALIDE_DEBUG_MATCHED_RULES
2851 debug(0) << instance << " -> " << after << "\n";
2852#endif
2853 result = after.make(state, output_type);
2854 }
2855
2856 template<typename Before,
2857 typename After,
2858 typename = typename enable_if_pattern<Before>::type,
2859 typename = typename enable_if_pattern<After>::type>
2861 static_assert((Before::binds & After::binds) == After::binds, "Rule result uses unbound values");
2862 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2863 static_assert(After::canonical, "RHS of rewrite rule should be in canonical form");
2864#if HALIDE_FUZZ_TEST_RULES
2866#endif
2867 if (before.template match<0>(unwrap(instance), state)) {
2869#if HALIDE_DEBUG_MATCHED_RULES
2870 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2871#endif
2872 return true;
2873 } else {
2874#if HALIDE_DEBUG_UNMATCHED_RULES
2875 debug(0) << instance << " does not match " << before << "\n";
2876#endif
2877 return false;
2878 }
2879 }
2880
2881 template<typename Before,
2882 typename = typename enable_if_pattern<Before>::type>
2884 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2885 if (before.template match<0>(unwrap(instance), state)) {
2886 result = after;
2887#if HALIDE_DEBUG_MATCHED_RULES
2888 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2889#endif
2890 return true;
2891 } else {
2892#if HALIDE_DEBUG_UNMATCHED_RULES
2893 debug(0) << instance << " does not match " << before << "\n";
2894#endif
2895 return false;
2896 }
2897 }
2898
2899 template<typename Before,
2900 typename = typename enable_if_pattern<Before>::type>
2902 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2903#if HALIDE_FUZZ_TEST_RULES
2905#endif
2906 if (before.template match<0>(unwrap(instance), state)) {
2908#if HALIDE_DEBUG_MATCHED_RULES
2909 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2910#endif
2911 return true;
2912 } else {
2913#if HALIDE_DEBUG_UNMATCHED_RULES
2914 debug(0) << instance << " does not match " << before << "\n";
2915#endif
2916 return false;
2917 }
2918 }
2919
2920 template<typename Before,
2921 typename After,
2922 typename Predicate,
2923 typename = typename enable_if_pattern<Before>::type,
2924 typename = typename enable_if_pattern<After>::type,
2925 typename = typename enable_if_pattern<Predicate>::type>
2927 static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
2928 static_assert((Before::binds & After::binds) == After::binds, "Rule result uses unbound values");
2929 static_assert((Before::binds & Predicate::binds) == Predicate::binds, "Rule predicate uses unbound values");
2930 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2931 static_assert(After::canonical, "RHS of rewrite rule should be in canonical form");
2932
2933#if HALIDE_FUZZ_TEST_RULES
2935#endif
2936 if (before.template match<0>(unwrap(instance), state) &&
2939#if HALIDE_DEBUG_MATCHED_RULES
2940 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
2941#endif
2942 return true;
2943 } else {
2944#if HALIDE_DEBUG_UNMATCHED_RULES
2945 debug(0) << instance << " does not match " << before << "\n";
2946#endif
2947 return false;
2948 }
2949 }
2950
2951 template<typename Before,
2952 typename Predicate,
2953 typename = typename enable_if_pattern<Before>::type,
2954 typename = typename enable_if_pattern<Predicate>::type>
2956 static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
2957 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2958
2959 if (before.template match<0>(unwrap(instance), state) &&
2961 result = after;
2962#if HALIDE_DEBUG_MATCHED_RULES
2963 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
2964#endif
2965 return true;
2966 } else {
2967#if HALIDE_DEBUG_UNMATCHED_RULES
2968 debug(0) << instance << " does not match " << before << "\n";
2969#endif
2970 return false;
2971 }
2972 }
2973
2974 template<typename Before,
2975 typename Predicate,
2976 typename = typename enable_if_pattern<Before>::type,
2977 typename = typename enable_if_pattern<Predicate>::type>
2979 static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
2980 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2981#if HALIDE_FUZZ_TEST_RULES
2983#endif
2984 if (before.template match<0>(unwrap(instance), state) &&
2987#if HALIDE_DEBUG_MATCHED_RULES
2988 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
2989#endif
2990 return true;
2991 } else {
2992#if HALIDE_DEBUG_UNMATCHED_RULES
2993 debug(0) << instance << " does not match " << before << "\n";
2994#endif
2995 return false;
2996 }
2997 }
2998};
2999
3000/** Construct a rewriter for the given instance, which may be a pattern
3001 * with concrete expressions as leaves, or just an expression. The
3002 * second optional argument (wildcard_type) is a hint as to what the
3003 * type of the wildcards is likely to be. If omitted it uses the same
3004 * type as the expression itself. They are not required to be this
3005 * type, but the rule will only be tested for wildcards of that type
3006 * when testing is enabled.
3007 *
3008 * The rewriter can be used to check to see if the instance is one of
3009 * some number of patterns and if so rewrite it into another form,
3010 * using its operator() method. See Simplify.cpp for a bunch of
3011 * example usage.
3012 *
3013 * Important: Any Exprs in patterns are captured by reference, not by
3014 * value, so ensure they outlive the rewriter.
3015 */
3016// @{
3017template<typename Instance,
3018 typename = typename enable_if_pattern<Instance>::type>
3019HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type, halide_type_t wildcard_type) noexcept -> Rewriter<decltype(pattern_arg(instance))> {
3020 return {pattern_arg(instance), output_type, wildcard_type};
3021}
3022
3023template<typename Instance,
3024 typename = typename enable_if_pattern<Instance>::type>
3025HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type) noexcept -> Rewriter<decltype(pattern_arg(instance))> {
3026 return {pattern_arg(instance), output_type, output_type};
3027}
3028
3030auto rewriter(const Expr &e, halide_type_t wildcard_type) noexcept -> Rewriter<decltype(pattern_arg(e))> {
3031 return {pattern_arg(e), e.type(), wildcard_type};
3032}
3033
3035auto rewriter(const Expr &e) noexcept -> Rewriter<decltype(pattern_arg(e))> {
3036 return {pattern_arg(e), e.type(), e.type()};
3037}
3038// @}
3039
3040} // namespace IRMatcher
3041
3042} // namespace Internal
3043} // namespace Halide
3044
3045#endif
#define internal_error
Definition Errors.h:23
@ halide_type_float
IEEE floating point numbers.
@ halide_type_bfloat
floating point numbers in the bfloat format
@ halide_type_int
signed integers
@ halide_type_uint
unsigned integers
#define HALIDE_NEVER_INLINE
#define HALIDE_ALWAYS_INLINE
Subtypes for Halide expressions (Halide::Expr) and statements (Halide::Internal::Stmt)
Methods to test Exprs and Stmts for equality of value.
Defines various operator overloads and utility functions that make it more pleasant to work with Hali...
For optional debugging during codegen, use the debug class as follows:
Definition Debug.h:49
auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1570
auto shift_left(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1562
HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type, halide_type_t wildcard_type) noexcept -> Rewriter< decltype(pattern_arg(instance))>
Construct a rewriter for the given instance, which may be a pattern with concrete expressions as leav...
Definition IRMatch.h:3019
HALIDE_ALWAYS_INLINE T pattern_arg(T t)
Definition IRMatch.h:567
auto widen_right_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1511
HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||(a, b))
Definition IRMatch.h:1259
HALIDE_ALWAYS_INLINE auto operator!(A &&a) noexcept -> NotOp< decltype(pattern_arg(a))>
Definition IRMatch.h:1627
HALIDE_ALWAYS_INLINE bool evaluate_predicate(bool x, MatcherState &) noexcept
Definition IRMatch.h:2809
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Div >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1016
HALIDE_ALWAYS_INLINE auto ne(A &&a, B &&b) -> decltype(IRMatcher::operator!=(a, b))
Definition IRMatch.h:1234
HALIDE_ALWAYS_INLINE auto negate(A &&a) -> decltype(IRMatcher::operator-(a))
Definition IRMatch.h:2004
uint64_t constant_fold_cmp_op(int64_t, int64_t) noexcept
HALIDE_ALWAYS_INLINE auto operator<=(A &&a, B &&b) noexcept -> CmpOp< LE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1154
HALIDE_ALWAYS_INLINE auto operator+(A &&a, B &&b) noexcept -> BinOp< Add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:905
HALIDE_ALWAYS_INLINE auto is_max_value(A &&a) noexcept -> IsMaxValue< decltype(pattern_arg(a))>
Definition IRMatch.h:2581
std::ostream & operator<<(std::ostream &s, const SpecificExpr &e)
Definition IRMatch.h:217
HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&&(a, b))
Definition IRMatch.h:1285
HALIDE_ALWAYS_INLINE auto h_and(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::And >
Definition IRMatch.h:1916
HALIDE_ALWAYS_INLINE auto gt(A &&a, B &&b) -> decltype(IRMatcher::operator>(a, b))
Definition IRMatch.h:1134
HALIDE_ALWAYS_INLINE auto is_const(A &&a) noexcept -> IsConst< decltype(pattern_arg(a))>
Definition IRMatch.h:2328
HALIDE_ALWAYS_INLINE auto intrin(Call::IntrinsicOp intrinsic_op, Args... args) noexcept -> Intrin< decltype(pattern_arg(args))... >
Definition IRMatch.h:1506
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LE >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1164
HALIDE_ALWAYS_INLINE auto operator*(A &&a, B &&b) noexcept -> BinOp< Mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:971
auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1558
auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1574
auto widen_right_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1519
HALIDE_ALWAYS_INLINE auto add(A &&a, B &&b) -> decltype(IRMatcher::operator+(a, b))
Definition IRMatch.h:912
HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, b))
Definition IRMatch.h:1011
auto saturating_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1536
HALIDE_ALWAYS_INLINE auto mul(A &&a, B &&b) -> decltype(IRMatcher::operator*(a, b))
Definition IRMatch.h:978
HALIDE_ALWAYS_INLINE auto slice(Vec vec, Base base, Stride stride, Lanes lanes) noexcept -> SliceOp< decltype(pattern_arg(vec)), decltype(pattern_arg(base)), decltype(pattern_arg(stride)), decltype(pattern_arg(lanes))>
Definition IRMatch.h:2161
HALIDE_ALWAYS_INLINE auto ramp(A &&a, B &&b, C &&c) noexcept -> RampOp< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition IRMatch.h:1840
HALIDE_ALWAYS_INLINE auto operator/(A &&a, B &&b) noexcept -> BinOp< Div, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1004
HALIDE_ALWAYS_INLINE auto widen(A &&a) noexcept -> WidenOp< decltype(pattern_arg(a))>
Definition IRMatch.h:2097
auto widening_mul(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1532
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mod >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1045
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< And >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1290
HALIDE_ALWAYS_INLINE int64_t unwrap(IntLiteral t)
Definition IRMatch.h:559
HALIDE_ALWAYS_INLINE auto operator>(A &&a, B &&b) noexcept -> CmpOp< GT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1129
HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp< decltype(pattern_arg(a))>
Definition IRMatch.h:2050
HALIDE_ALWAYS_INLINE auto overflows(A &&a) noexcept -> Overflows< decltype(pattern_arg(a))>
Definition IRMatch.h:2246
auto widening_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1524
HALIDE_ALWAYS_INLINE void assert_is_lvalue_if_expr()
Definition IRMatch.h:576
HALIDE_ALWAYS_INLINE auto operator%(A &&a, B &&b) noexcept -> BinOp< Mod, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1031
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Sub >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:952
HALIDE_ALWAYS_INLINE auto is_scalar(A &&a) noexcept -> IsScalar< decltype(pattern_arg(a))>
Definition IRMatch.h:2540
HALIDE_ALWAYS_INLINE auto fold(A &&a) noexcept -> Fold< decltype(pattern_arg(a))>
Definition IRMatch.h:2209
HALIDE_ALWAYS_INLINE auto not_op(A &&a) -> decltype(IRMatcher::operator!(a))
Definition IRMatch.h:1633
auto halving_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1550
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Max >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1089
constexpr bool and_reduce()
Definition IRMatch.h:1314
HALIDE_ALWAYS_INLINE auto operator||(A &&a, B &&b) noexcept -> BinOp< Or, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1254
auto widening_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1528
constexpr int max_wild
Definition IRMatch.h:74
HALIDE_ALWAYS_INLINE auto operator!=(A &&a, B &&b) noexcept -> CmpOp< NE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1229
HALIDE_ALWAYS_INLINE auto is_float(A &&a) noexcept -> IsFloat< decltype(pattern_arg(a))>
Definition IRMatch.h:2413
HALIDE_ALWAYS_INLINE auto operator>=(A &&a, B &&b) noexcept -> CmpOp< GE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1179
HALIDE_ALWAYS_INLINE auto operator<(A &&a, B &&b) noexcept -> CmpOp< LT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1104
HALIDE_ALWAYS_INLINE auto operator&&(A &&a, B &&b) noexcept -> BinOp< And, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1280
HALIDE_ALWAYS_INLINE auto is_uint(A &&a, int bits=0, int lanes=0) noexcept -> IsUInt< decltype(pattern_arg(a))>
Definition IRMatch.h:2496
HALIDE_ALWAYS_INLINE auto h_or(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Or >
Definition IRMatch.h:1922
constexpr bool commutative(IRNodeType t)
Definition IRMatch.h:615
auto widen_right_mul(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1515
HALIDE_ALWAYS_INLINE auto sub(A &&a, B &&b) -> decltype(IRMatcher::operator-(a, b))
Definition IRMatch.h:945
HALIDE_ALWAYS_INLINE auto h_max(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Max >
Definition IRMatch.h:1910
HALIDE_ALWAYS_INLINE auto broadcast(A &&a, B lanes) noexcept -> BroadcastOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes))>
Definition IRMatch.h:1776
HALIDE_ALWAYS_INLINE auto is_int(A &&a, int bits=0, int lanes=0) noexcept -> IsInt< decltype(pattern_arg(a))>
Definition IRMatch.h:2451
HALIDE_ALWAYS_INLINE auto select(C &&c, T &&t, F &&f) noexcept -> SelectOp< decltype(pattern_arg(c)), decltype(pattern_arg(t)), decltype(pattern_arg(f))>
Definition IRMatch.h:1703
HALIDE_ALWAYS_INLINE auto is_min_value(A &&a) noexcept -> IsMinValue< decltype(pattern_arg(a))>
Definition IRMatch.h:2624
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Min >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1067
HALIDE_NEVER_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred, halide_type_t wildcard_type, halide_type_t output_type) noexcept
Definition IRMatch.h:2678
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GT >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1139
auto halving_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1554
auto saturating_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1540
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mul >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:985
auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition IRMatch.h:1578
auto shift_right(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1566
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GE >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1189
HALIDE_ALWAYS_INLINE auto operator-(A &&a, B &&b) noexcept -> BinOp< Sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:938
HALIDE_ALWAYS_INLINE auto le(A &&a, B &&b) -> decltype(IRMatcher::operator<=(a, b))
Definition IRMatch.h:1159
HALIDE_ALWAYS_INLINE auto lt(A &&a, B &&b) -> decltype(IRMatcher::operator<(a, b))
Definition IRMatch.h:1109
HALIDE_ALWAYS_INLINE auto lanes_of(A &&a) noexcept -> LanesOf< decltype(pattern_arg(a))>
Definition IRMatch.h:2661
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LT >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1114
HALIDE_ALWAYS_INLINE auto h_min(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Min >
Definition IRMatch.h:1904
HALIDE_ALWAYS_INLINE auto h_add(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Add >
Definition IRMatch.h:1898
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Or >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1264
HALIDE_ALWAYS_INLINE Expr make_const_expr(halide_scalar_value_t val, halide_type_t ty)
Definition IRMatch.h:160
constexpr uint32_t bitwise_or_reduce()
Definition IRMatch.h:1305
auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition IRMatch.h:1582
int64_t constant_fold_bin_op(halide_type_t &, int64_t, int64_t) noexcept
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< EQ >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1214
HALIDE_NEVER_INLINE Expr make_const_special_expr(halide_type_t ty)
Definition IRMatch.h:149
HALIDE_ALWAYS_INLINE auto ge(A &&a, B &&b) -> decltype(IRMatcher::operator>=(a, b))
Definition IRMatch.h:1184
auto saturating_cast(const Type &t, A &&a) noexcept -> Intrin< decltype(pattern_arg(a))>
Definition IRMatch.h:1544
constexpr int const_min(int a, int b)
Definition IRMatch.h:1324
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< NE >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1239
HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b))
Definition IRMatch.h:1038
HALIDE_ALWAYS_INLINE auto operator==(A &&a, B &&b) noexcept -> CmpOp< EQ, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1204
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Add >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:919
HALIDE_ALWAYS_INLINE auto can_prove(A &&a, Prover *p) noexcept -> CanProve< decltype(pattern_arg(a)), Prover >
Definition IRMatch.h:2376
HALIDE_ALWAYS_INLINE auto eq(A &&a, B &&b) -> decltype(IRMatcher::operator==(a, b))
Definition IRMatch.h:1209
T div_imp(T a, T b)
Definition IROperator.h:268
bool is_const_zero(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to zero (in all lanes,...
Expr make_zero(Type t)
Construct the representation of zero in the given type.
void expr_match_test()
bool is_const_one(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to one (in all lanes,...
ConstantInterval min(const ConstantInterval &a, const ConstantInterval &b)
bool equal(const RDom &bounds0, const RDom &bounds1)
Return true if bounds0 and bounds1 represent the same bounds.
constexpr IRNodeType StrongestExprNodeType
Definition Expr.h:81
Expr make_const(Type t, int64_t val)
Construct an immediate of the given type from any numeric C++ type.
ConstantInterval max(const ConstantInterval &a, const ConstantInterval &b)
T mod_imp(T a, T b)
Implementations of division and mod that are specific to Halide.
Definition IROperator.h:247
bool sub_would_overflow(int bits, int64_t a, int64_t b)
bool add_would_overflow(int bits, int64_t a, int64_t b)
Routines to test if math would overflow for signed integers with the given number of bits.
bool mul_would_overflow(int bits, int64_t a, int64_t b)
Expr with_lanes(const Expr &x, int lanes)
Rewrite the expression x to have lanes lanes.
bool expr_match(const Expr &pattern, const Expr &expr, std::vector< Expr > &result)
Does the first expression have the same structure as the second? Variables in the first expression wi...
ConstantInterval abs(const ConstantInterval &a)
Expr make_signed_integer_overflow(Type type)
Construct a unique signed_integer_overflow Expr.
IRNodeType
All our IR node types get unique IDs for the purposes of RTTI.
Definition Expr.h:25
bool is_const(const Expr &e)
Is the expression either an IntImm, a FloatImm, a StringImm, or a Cast of the same,...
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
@ Internal
Not visible externally, similar to 'static' linkage in C.
@ Predicate
Guard the loads and stores in the loop with an if statement that prevents evaluation beyond the origi...
Expr absd(Expr a, Expr b)
Return the absolute difference between two values.
@ C
No name mangling.
Expr likely_if_innermost(Expr e)
Equivalent to likely, but only triggers a loop partitioning if found in an innermost loop.
Expr likely(Expr e)
Expressions tagged with this intrinsic are considered to be part of the steady state of some loop wit...
unsigned __INT64_TYPE__ uint64_t
signed __INT64_TYPE__ int64_t
signed __INT32_TYPE__ int32_t
unsigned __INT16_TYPE__ uint16_t
unsigned __INT32_TYPE__ uint32_t
A fragment of Halide syntax.
Definition Expr.h:258
HALIDE_ALWAYS_INLINE Type type() const
Get the type of this expression node.
Definition Expr.h:327
HALIDE_ALWAYS_INLINE const Internal::BaseExprNode * get() const
Override get() to return a BaseExprNode * instead of an IRNode *.
Definition Expr.h:321
The sum of two expressions.
Definition IR.h:56
Logical and - are both expressions true.
Definition IR.h:175
A base class for expression nodes.
Definition Expr.h:143
A vector with 'lanes' elements, in which every element is 'value'.
Definition IR.h:259
static Expr make(Expr value, int lanes)
static const IRNodeType _node_type
Definition IR.h:265
A function call.
Definition IR.h:490
bool is_intrinsic() const
Definition IR.h:721
static const IRNodeType _node_type
Definition IR.h:766
The actual IR nodes begin here.
Definition IR.h:30
static const IRNodeType _node_type
Definition IR.h:35
The ratio of two expressions.
Definition IR.h:83
Is the first expression equal to the second.
Definition IR.h:121
Floating point constants.
Definition Expr.h:236
static const FloatImm * make(Type t, double value)
Is the first expression greater than or equal to the second.
Definition IR.h:166
Is the first expression greater than the second.
Definition IR.h:157
static constexpr bool canonical
Definition IRMatch.h:641
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:664
static constexpr uint32_t binds
Definition IRMatch.h:633
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:645
static constexpr bool foldable
Definition IRMatch.h:661
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
Definition IRMatch.h:707
HALIDE_ALWAYS_INLINE bool match(const BinOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:655
static constexpr IRNodeType max_node_type
Definition IRMatch.h:636
static constexpr IRNodeType min_node_type
Definition IRMatch.h:635
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1718
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1742
HALIDE_ALWAYS_INLINE bool match(const BroadcastOp< A2, B2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:1736
static constexpr uint32_t binds
Definition IRMatch.h:1716
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1724
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1719
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:1759
HALIDE_NEVER_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2365
static constexpr uint32_t binds
Definition IRMatch.h:2355
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2358
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2359
static constexpr bool foldable
Definition IRMatch.h:2362
static constexpr bool canonical
Definition IRMatch.h:2360
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2018
static constexpr bool foldable
Definition IRMatch.h:2040
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:2022
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2017
static constexpr uint32_t binds
Definition IRMatch.h:2015
static constexpr bool canonical
Definition IRMatch.h:2019
HALIDE_ALWAYS_INLINE bool match(const CastOp< A2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:2031
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:2036
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:800
static constexpr IRNodeType max_node_type
Definition IRMatch.h:739
static constexpr uint32_t binds
Definition IRMatch.h:736
static constexpr bool canonical
Definition IRMatch.h:740
static constexpr bool foldable
Definition IRMatch.h:763
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:747
static constexpr IRNodeType min_node_type
Definition IRMatch.h:738
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:766
HALIDE_ALWAYS_INLINE bool match(const CmpOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:757
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2174
static constexpr uint32_t binds
Definition IRMatch.h:2171
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2173
static constexpr bool canonical
Definition IRMatch.h:2175
static constexpr bool foldable
Definition IRMatch.h:2200
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
Definition IRMatch.h:2178
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:2203
static constexpr IRNodeType max_node_type
Definition IRMatch.h:495
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:504
HALIDE_ALWAYS_INLINE IntLiteral(int64_t v)
Definition IRMatch.h:499
HALIDE_ALWAYS_INLINE bool match(const IntLiteral &b, MatcherState &state) const noexcept
Definition IRMatch.h:527
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:539
static constexpr IRNodeType min_node_type
Definition IRMatch.h:494
static constexpr bool canonical
Definition IRMatch.h:496
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:532
HALIDE_ALWAYS_INLINE bool match(int64_t val, MatcherState &state) const noexcept
Definition IRMatch.h:522
static constexpr uint32_t binds
Definition IRMatch.h:492
HALIDE_ALWAYS_INLINE bool match_args(double, const Call &c, MatcherState &state) const noexcept
Definition IRMatch.h:1354
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1341
static constexpr bool canonical
Definition IRMatch.h:1342
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1389
HALIDE_ALWAYS_INLINE void print_args(std::ostream &s) const
Definition IRMatch.h:1384
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:1449
static constexpr uint32_t binds
Definition IRMatch.h:1338
static constexpr bool foldable
Definition IRMatch.h:1447
HALIDE_ALWAYS_INLINE bool match_args(int, const Call &c, MatcherState &state) const noexcept
Definition IRMatch.h:1347
HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const
Definition IRMatch.h:1371
std::tuple< Args... > args
Definition IRMatch.h:1332
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1359
HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const
Definition IRMatch.h:1380
HALIDE_ALWAYS_INLINE Intrin(Call::IntrinsicOp intrin, Args... args) noexcept
Definition IRMatch.h:1492
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1340
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2303
static constexpr bool canonical
Definition IRMatch.h:2305
static constexpr bool foldable
Definition IRMatch.h:2311
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2304
static constexpr uint32_t binds
Definition IRMatch.h:2300
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:2314
static constexpr bool foldable
Definition IRMatch.h:2399
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2402
static constexpr bool canonical
Definition IRMatch.h:2397
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2395
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2396
static constexpr uint32_t binds
Definition IRMatch.h:2392
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2434
static constexpr bool foldable
Definition IRMatch.h:2437
static constexpr uint32_t binds
Definition IRMatch.h:2430
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2440
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2433
static constexpr bool canonical
Definition IRMatch.h:2435
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2559
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2560
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2566
static constexpr uint32_t binds
Definition IRMatch.h:2556
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2600
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2607
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2601
static constexpr uint32_t binds
Definition IRMatch.h:2597
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2523
static constexpr uint32_t binds
Definition IRMatch.h:2519
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2529
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2522
static constexpr bool foldable
Definition IRMatch.h:2526
static constexpr bool canonical
Definition IRMatch.h:2524
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2485
static constexpr bool foldable
Definition IRMatch.h:2482
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2478
static constexpr bool canonical
Definition IRMatch.h:2480
static constexpr uint32_t binds
Definition IRMatch.h:2475
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2479
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2644
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2650
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2643
static constexpr bool foldable
Definition IRMatch.h:2647
static constexpr uint32_t binds
Definition IRMatch.h:2640
static constexpr bool canonical
Definition IRMatch.h:2645
To save stack space, the matcher objects are largely stateless and immutable.
Definition IRMatch.h:82
HALIDE_ALWAYS_INLINE void get_bound_const(int i, halide_scalar_value_t &val, halide_type_t &type) const noexcept
Definition IRMatch.h:127
HALIDE_ALWAYS_INLINE void set_bound_const(int i, int64_t s, halide_type_t t) noexcept
Definition IRMatch.h:103
HALIDE_ALWAYS_INLINE void set_bound_const(int i, double f, halide_type_t t) noexcept
Definition IRMatch.h:115
static constexpr uint16_t special_values_mask
Definition IRMatch.h:88
HALIDE_ALWAYS_INLINE void set_bound_const(int i, halide_scalar_value_t val, halide_type_t t) noexcept
Definition IRMatch.h:121
halide_type_t bound_const_type[max_wild]
Definition IRMatch.h:90
HALIDE_ALWAYS_INLINE void set_binding(int i, const BaseExprNode &n) noexcept
Definition IRMatch.h:93
HALIDE_ALWAYS_INLINE MatcherState() noexcept
Definition IRMatch.h:134
HALIDE_ALWAYS_INLINE const BaseExprNode * get_binding(int i) const noexcept
Definition IRMatch.h:98
halide_scalar_value_t bound_const[max_wild]
Definition IRMatch.h:84
HALIDE_ALWAYS_INLINE void set_bound_const(int i, uint64_t u, halide_type_t t) noexcept
Definition IRMatch.h:109
static constexpr uint16_t signed_integer_overflow
Definition IRMatch.h:87
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1940
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1955
HALIDE_ALWAYS_INLINE bool match(NegateOp< A2 > &&p, MatcherState &state) const noexcept
Definition IRMatch.h:1950
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:1964
static constexpr uint32_t binds
Definition IRMatch.h:1932
static constexpr bool canonical
Definition IRMatch.h:1937
static constexpr bool foldable
Definition IRMatch.h:1961
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1935
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1934
static constexpr uint32_t binds
Definition IRMatch.h:1591
static constexpr bool foldable
Definition IRMatch.h:1616
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1598
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1594
static constexpr bool canonical
Definition IRMatch.h:1595
HALIDE_ALWAYS_INLINE bool match(const NotOp< A2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:1607
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1612
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:1619
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1593
static constexpr uint32_t binds
Definition IRMatch.h:2260
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2264
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:2268
static constexpr bool canonical
Definition IRMatch.h:2265
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:2277
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:2285
static constexpr bool foldable
Definition IRMatch.h:2282
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2263
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:2236
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2229
static constexpr uint32_t binds
Definition IRMatch.h:2225
static constexpr bool canonical
Definition IRMatch.h:2231
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2230
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1818
static constexpr bool canonical
Definition IRMatch.h:1793
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1791
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1790
static constexpr uint32_t binds
Definition IRMatch.h:1788
HALIDE_ALWAYS_INLINE bool match(const RampOp< A2, B2, C2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:1811
static constexpr bool foldable
Definition IRMatch.h:1830
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1796
HALIDE_NEVER_INLINE void build_replacement(After after)
Definition IRMatch.h:2849
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after, Predicate pred)
Definition IRMatch.h:2926
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after) noexcept
Definition IRMatch.h:2901
HALIDE_ALWAYS_INLINE Rewriter(Instance instance, halide_type_t ot, halide_type_t wt)
Definition IRMatch.h:2844
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after, Predicate pred)
Definition IRMatch.h:2955
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after) noexcept
Definition IRMatch.h:2883
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after, Predicate pred)
Definition IRMatch.h:2978
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after)
Definition IRMatch.h:2860
static constexpr uint32_t binds
Definition IRMatch.h:1651
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:1683
static constexpr bool foldable
Definition IRMatch.h:1680
static constexpr bool canonical
Definition IRMatch.h:1656
HALIDE_ALWAYS_INLINE bool match(const SelectOp< C2, T2, F2 > &instance, MatcherState &state) const noexcept
Definition IRMatch.h:1669
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1659
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1676
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1654
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1653
static constexpr bool canonical
Definition IRMatch.h:2114
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2113
static constexpr bool foldable
Definition IRMatch.h:2143
HALIDE_ALWAYS_INLINE SliceOp(Vec v, Base b, Stride s, Lanes l)
Definition IRMatch.h:2146
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2112
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:2117
static constexpr uint32_t binds
Definition IRMatch.h:2110
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:2131
static constexpr IRNodeType min_node_type
Definition IRMatch.h:198
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:205
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:210
static constexpr IRNodeType max_node_type
Definition IRMatch.h:199
static constexpr uint32_t binds
Definition IRMatch.h:195
HALIDE_ALWAYS_INLINE bool match(const VectorReduceOp< A2, B2, reduce_op_2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:1873
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1855
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1860
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1856
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1880
static constexpr uint32_t binds
Definition IRMatch.h:2060
static constexpr bool canonical
Definition IRMatch.h:2064
static constexpr bool foldable
Definition IRMatch.h:2087
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:2081
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2063
HALIDE_ALWAYS_INLINE bool match(const WidenOp< A2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:2076
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2062
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:2067
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:352
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:373
static constexpr IRNodeType max_node_type
Definition IRMatch.h:348
static constexpr IRNodeType min_node_type
Definition IRMatch.h:347
static constexpr uint32_t binds
Definition IRMatch.h:345
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:383
static constexpr bool canonical
Definition IRMatch.h:403
static constexpr IRNodeType max_node_type
Definition IRMatch.h:402
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:431
static constexpr uint32_t binds
Definition IRMatch.h:399
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:406
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:441
HALIDE_ALWAYS_INLINE bool match(int64_t e, MatcherState &state) const noexcept
Definition IRMatch.h:425
static constexpr IRNodeType min_node_type
Definition IRMatch.h:401
static constexpr bool foldable
Definition IRMatch.h:438
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:267
static constexpr uint32_t binds
Definition IRMatch.h:226
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:277
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:233
HALIDE_ALWAYS_INLINE bool match(int64_t value, MatcherState &state) const noexcept
Definition IRMatch.h:254
static constexpr IRNodeType min_node_type
Definition IRMatch.h:228
static constexpr IRNodeType max_node_type
Definition IRMatch.h:229
static constexpr uint32_t binds
Definition IRMatch.h:292
static constexpr IRNodeType max_node_type
Definition IRMatch.h:295
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:299
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:330
static constexpr IRNodeType min_node_type
Definition IRMatch.h:294
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:320
static constexpr IRNodeType min_node_type
Definition IRMatch.h:459
static constexpr uint32_t binds
Definition IRMatch.h:457
static constexpr IRNodeType max_node_type
Definition IRMatch.h:460
static constexpr bool canonical
Definition IRMatch.h:461
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:473
static constexpr bool foldable
Definition IRMatch.h:477
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:464
static constexpr uint32_t mask
Definition IRMatch.h:146
IRNodeType node_type
Each IR node subclass has a unique identifier.
Definition Expr.h:113
Integer constants.
Definition Expr.h:218
static const IntImm * make(Type t, int64_t value)
Is the first expression less than or equal to the second.
Definition IR.h:148
Is the first expression less than the second.
Definition IR.h:139
The greater of two values.
Definition IR.h:112
The lesser of two values.
Definition IR.h:103
The remainder of a / b.
Definition IR.h:94
The product of two expressions.
Definition IR.h:74
Is the first expression not equal to the second.
Definition IR.h:130
Logical not - true if the expression false.
Definition IR.h:193
static Expr make(Expr a)
Logical or - is at least one of the expression true.
Definition IR.h:184
A linear ramp vector node.
Definition IR.h:247
static const IRNodeType _node_type
Definition IR.h:253
static Expr make(Expr base, Expr stride, int lanes)
A ternary operator.
Definition IR.h:204
static Expr make(Expr condition, Expr true_value, Expr false_value)
static const IRNodeType _node_type
Definition IR.h:209
Construct a new vector by taking elements from another sequence of vectors.
Definition IR.h:855
static Expr make_slice(Expr vector, int begin, int stride, int size)
Convenience constructor for making a shuffle representing a contiguous subset of a vector.
std::vector< Expr > vectors
Definition IR.h:856
bool is_slice() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so,...
int slice_stride() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so,...
Definition IR.h:909
int slice_begin() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so,...
Definition IR.h:906
The difference of two expressions.
Definition IR.h:65
static const IRNodeType _node_type
Definition IR.h:70
static Expr make(Expr a, Expr b)
Unsigned integer constants.
Definition Expr.h:227
static const UIntImm * make(Type t, uint64_t value)
Horizontally reduce a vector to a scalar or narrower vector using the given commutative and associati...
Definition IR.h:979
static const IRNodeType _node_type
Definition IR.h:998
static Expr make(Operator op, Expr vec, int lanes)
Types in the halide type system.
Definition Type.h:283
Type widen() const
Return Type with the same type code and number of lanes, but with at least twice as many bits.
Definition Type.h:378
HALIDE_ALWAYS_INLINE bool is_int() const
Is this type a signed integer type?
Definition Type.h:435
HALIDE_ALWAYS_INLINE int lanes() const
Return the number of vector elements in this type.
Definition Type.h:355
HALIDE_ALWAYS_INLINE bool is_uint() const
Is this type an unsigned integer type?
Definition Type.h:441
HALIDE_ALWAYS_INLINE int bits() const
Return the bit size of a single element of this type.
Definition Type.h:349
HALIDE_ALWAYS_INLINE bool is_scalar() const
Is this type a scalar type? (lanes() == 1).
Definition Type.h:417
HALIDE_ALWAYS_INLINE bool is_float() const
Is this type a floating point type (float or double).
Definition Type.h:423
halide_scalar_value_t is a simple union able to represent all the well-known scalar values in a filte...
union halide_scalar_value_t::@3 u
A runtime tag for a type in the halide type system.
uint16_t lanes
How many elements in a vector.
uint8_t code
The basic type code: signed integer, unsigned integer, or floating point.