1#ifndef HALIDE_IR_MATCH_H
2#define HALIDE_IR_MATCH_H
139 typename =
typename std::remove_reference<T>::type::pattern_tag>
146 constexpr static uint32_t mask = std::remove_reference<T>::type::binds;
166 const int lanes = scalar_type.
lanes;
167 scalar_type.
lanes = 1;
170 switch (scalar_type.
code) {
204 template<u
int32_t bound>
232 template<u
int32_t bound>
234 static_assert(i >= 0 && i <
max_wild,
"Wild with out-of-range index");
237 op = ((
const Broadcast *)op)->value.get();
246 state.get_bound_const(i, val, type);
249 state.set_bound_const(i, value, e.type);
253 template<u
int32_t bound>
255 static_assert(i >= 0 && i <
max_wild,
"Wild with out-of-range index");
259 state.get_bound_const(i, val, type);
260 return type == i64_type && value == val.
u.
i64;
262 state.set_bound_const(i, value, i64_type);
298 template<u
int32_t bound>
300 static_assert(i >= 0 && i <
max_wild,
"Wild with out-of-range index");
303 op = ((
const Broadcast *)op)->value.get();
312 state.get_bound_const(i, val, type);
315 state.set_bound_const(i, value, e.type);
331 state.get_bound_const(i, val, ty);
351 template<u
int32_t bound>
353 static_assert(i >= 0 && i <
max_wild,
"Wild with out-of-range index");
356 op = ((
const Broadcast *)op)->value.get();
361 double value = ((
const FloatImm *)op)->value;
365 state.get_bound_const(i, val, type);
368 state.set_bound_const(i, value, e.type);
384 state.get_bound_const(i, val, ty);
405 template<u
int32_t bound>
407 static_assert(i >= 0 && i <
max_wild,
"Wild with out-of-range index");
410 op = ((
const Broadcast *)op)->value.get();
424 template<u
int32_t bound>
426 static_assert(i >= 0 && i <
max_wild,
"Wild with out-of-range index");
442 state.get_bound_const(i, val, ty);
463 template<u
int32_t bound>
466 return equal(*state.get_binding(i), e);
468 state.set_binding(i, e);
503 template<u
int32_t bound>
507 op = ((
const Broadcast *)op)->value.get();
515 return ((
const FloatImm *)op)->value == (
double)
v;
521 template<u
int32_t bound>
526 template<u
int32_t bound>
550 val.u.f64 = (double)
v;
566 typename =
typename std::decay<T>::type::pattern_tag>
577 static_assert(!std::is_same_v<std::decay_t<T>,
Expr> || std::is_lvalue_reference_v<T>,
578 "Exprs are captured by reference by IRMatcher objects and so must be lvalues");
589 typename =
typename std::decay_t<T>::pattern_tag,
591 typename = std::enable_if_t<!std::is_same_v<std::decay_t<T>, SpecificExpr>>>
627template<
typename Op,
typename A,
typename B>
642 A::canonical && B::canonical && (!
commutative(Op::_node_type) || (A::max_node_type >= B::min_node_type));
644 template<u
int32_t bound>
646 if (e.node_type != Op::_node_type) {
649 const Op &op = (
const Op &)e;
650 return (
a.template match<bound>(*op.a.get(), state) &&
654 template<u
int32_t bound,
typename Op2,
typename A2,
typename B2>
656 return (std::is_same_v<Op, Op2> &&
657 a.template match<bound>(
unwrap(op.a), state) &&
661 constexpr static bool foldable = A::foldable && B::foldable;
666 if (std::is_same_v<A, IntLiteral>) {
667 b.make_folded_const(val_b, ty, state);
668 if ((std::is_same_v<Op, And> && val_b.
u.
u64 == 0) ||
669 (std::is_same_v<Op, Or> && val_b.
u.
u64 == 1)) {
675 a.make_folded_const(val_a, ty, state);
678 a.make_folded_const(val_a, ty, state);
679 if ((std::is_same_v<Op, And> && val_a.
u.
u64 == 0) ||
680 (std::is_same_v<Op, Or> && val_a.
u.
u64 == 1)) {
686 b.make_folded_const(val_b, ty, state);
691 val.u.i64 = constant_fold_bin_op<Op>(ty, val_a.
u.
i64, val_b.
u.
i64);
694 val.u.u64 = constant_fold_bin_op<Op>(ty, val_a.
u.
u64, val_b.
u.
u64);
698 val.u.f64 = constant_fold_bin_op<Op>(ty, val_a.
u.
f64, val_b.
u.
f64);
709 if (std::is_same_v<A, IntLiteral>) {
710 eb =
b.make(state, type_hint);
711 ea =
a.make(state, eb.
type());
713 ea =
a.make(state, type_hint);
714 eb =
b.make(state, ea.
type());
716 return Op::make(std::move(ea), std::move(eb));
730template<
typename Op,
typename A,
typename B>
742 (!
commutative(Op::_node_type) || A::max_node_type >= B::min_node_type) &&
746 template<u
int32_t bound>
748 if (e.node_type != Op::_node_type) {
751 const Op &op = (
const Op &)e;
752 return (
a.template match<bound>(*op.a.get(), state) &&
756 template<u
int32_t bound,
typename Op2,
typename A2,
typename B2>
758 return (std::is_same_v<Op, Op2> &&
759 a.template match<bound>(
unwrap(op.a), state) &&
763 constexpr static bool foldable = A::foldable && B::foldable;
769 if (std::is_same_v<A, IntLiteral>) {
770 b.make_folded_const(val_b, ty, state);
772 a.make_folded_const(val_a, ty, state);
775 a.make_folded_const(val_a, ty, state);
777 b.make_folded_const(val_b, ty, state);
783 val.u.u64 = constant_fold_cmp_op<Op>(val_a.
u.
i64, val_b.
u.
i64);
786 val.u.u64 = constant_fold_cmp_op<Op>(val_a.
u.
u64, val_b.
u.
u64);
790 val.u.u64 = constant_fold_cmp_op<Op>(val_a.
u.
f64, val_b.
u.
f64);
804 if (std::is_same_v<A, IntLiteral>) {
805 eb =
b.make(state, {});
806 ea =
a.make(state, eb.
type());
808 ea =
a.make(state, {});
809 eb =
b.make(state, ea.
type());
811 return Op::make(std::move(ea), std::move(eb));
815template<
typename A,
typename B>
817 s <<
"(" << op.
a <<
" + " << op.
b <<
")";
821template<
typename A,
typename B>
823 s <<
"(" << op.
a <<
" - " << op.
b <<
")";
827template<
typename A,
typename B>
829 s <<
"(" << op.
a <<
" * " << op.
b <<
")";
833template<
typename A,
typename B>
835 s <<
"(" << op.
a <<
" / " << op.
b <<
")";
839template<
typename A,
typename B>
841 s <<
"(" << op.
a <<
" && " << op.
b <<
")";
845template<
typename A,
typename B>
847 s <<
"(" << op.
a <<
" || " << op.
b <<
")";
851template<
typename A,
typename B>
853 s <<
"min(" << op.
a <<
", " << op.
b <<
")";
857template<
typename A,
typename B>
859 s <<
"max(" << op.
a <<
", " << op.
b <<
")";
863template<
typename A,
typename B>
865 s <<
"(" << op.
a <<
" <= " << op.
b <<
")";
869template<
typename A,
typename B>
871 s <<
"(" << op.
a <<
" < " << op.
b <<
")";
875template<
typename A,
typename B>
877 s <<
"(" << op.
a <<
" >= " << op.
b <<
")";
881template<
typename A,
typename B>
883 s <<
"(" << op.
a <<
" > " << op.
b <<
")";
887template<
typename A,
typename B>
889 s <<
"(" << op.
a <<
" == " << op.
b <<
")";
893template<
typename A,
typename B>
895 s <<
"(" << op.
a <<
" != " << op.
b <<
")";
899template<
typename A,
typename B>
901 s <<
"(" << op.
a <<
" % " << op.
b <<
")";
905template<
typename A,
typename B>
907 assert_is_lvalue_if_expr<A>();
908 assert_is_lvalue_if_expr<B>();
912template<
typename A,
typename B>
914 assert_is_lvalue_if_expr<A>();
915 assert_is_lvalue_if_expr<B>();
922 int dead_bits = 64 - t.bits;
930 return (a + b) & (ones >> (64 - t.bits));
938template<
typename A,
typename B>
940 assert_is_lvalue_if_expr<A>();
941 assert_is_lvalue_if_expr<B>();
945template<
typename A,
typename B>
947 assert_is_lvalue_if_expr<A>();
948 assert_is_lvalue_if_expr<B>();
956 int dead_bits = 64 - t.bits;
963 return (a - b) & (ones >> (64 - t.bits));
971template<
typename A,
typename B>
973 assert_is_lvalue_if_expr<A>();
974 assert_is_lvalue_if_expr<B>();
978template<
typename A,
typename B>
980 assert_is_lvalue_if_expr<A>();
981 assert_is_lvalue_if_expr<B>();
988 int dead_bits = 64 - t.bits;
996 return (a * b) & (ones >> (64 - t.bits));
1004template<
typename A,
typename B>
1006 assert_is_lvalue_if_expr<A>();
1007 assert_is_lvalue_if_expr<B>();
1011template<
typename A,
typename B>
1031template<
typename A,
typename B>
1033 assert_is_lvalue_if_expr<A>();
1034 assert_is_lvalue_if_expr<B>();
1038template<
typename A,
typename B>
1040 assert_is_lvalue_if_expr<A>();
1041 assert_is_lvalue_if_expr<B>();
1060template<
typename A,
typename B>
1062 assert_is_lvalue_if_expr<A>();
1063 assert_is_lvalue_if_expr<B>();
1082template<
typename A,
typename B>
1084 assert_is_lvalue_if_expr<A>();
1085 assert_is_lvalue_if_expr<B>();
1104template<
typename A,
typename B>
1109template<
typename A,
typename B>
1129template<
typename A,
typename B>
1134template<
typename A,
typename B>
1154template<
typename A,
typename B>
1159template<
typename A,
typename B>
1179template<
typename A,
typename B>
1184template<
typename A,
typename B>
1204template<
typename A,
typename B>
1209template<
typename A,
typename B>
1229template<
typename A,
typename B>
1234template<
typename A,
typename B>
1254template<
typename A,
typename B>
1259template<
typename A,
typename B>
1280template<
typename A,
typename B>
1285template<
typename A,
typename B>
1310template<
typename... Args>
1319template<
typename... Args>
1324template<Call::IntrinsicOp
intrin>
1335 return t ==
Type(type);
1357 typename = std::enable_if_t<(i <
sizeof...(Args))>>
1359 using T =
decltype(std::get<i>(
args));
1360 return (std::get<i>(
args).template match<bound>(*c.args[i].get(), state) &&
1364 template<
int i, u
int32_t binds>
1369 template<u
int32_t bound>
1377 match_args<0, bound>(0, c, state));
1381 typename = std::enable_if_t<(i <
sizeof...(Args))>>
1383 s << std::get<i>(
args);
1384 if (i + 1 <
sizeof...(Args)) {
1387 print_args<i + 1>(0, s);
1396 print_args<0>(0, s);
1401 Expr arg0 = std::get<0>(
args).make(state, type_hint);
1403 return likely(std::move(arg0));
1407 return abs(std::move(arg0));
1412 Expr arg1 = std::get<std::min<size_t>(1,
sizeof...(Args) - 1)>(
args).
make(state, type_hint);
1414 return absd(std::move(arg0), std::move(arg1));
1432 return halving_add(std::move(arg0), std::move(arg1));
1434 return halving_sub(std::move(arg0), std::move(arg1));
1438 return std::move(arg0) << std::move(arg1);
1440 return std::move(arg0) >> std::move(arg1);
1447 Expr arg2 = std::get<std::min<size_t>(2,
sizeof...(Args) - 1)>(
args).
make(state, type_hint);
1449 return mul_shift_right(std::move(arg0), std::move(arg1), std::move(arg2));
1454 internal_error <<
"Unhandled intrinsic in IRMatcher: " << intrin;
1466 std::get<0>(
args).make_folded_const(val, ty, state);
1471 std::get<1>(
args).make_folded_const(arg1, signed_ty, state);
1474 if (arg1.
u.
i64 < 0) {
1477 val.u.i64 >>= -arg1.
u.
i64;
1480 val.u.u64 >>= -arg1.
u.
i64;
1483 val.u.u64 <<= arg1.
u.
i64;
1486 if (arg1.
u.
i64 > 0) {
1489 val.u.i64 >>= arg1.
u.
i64;
1492 val.u.u64 >>= arg1.
u.
i64;
1495 val.u.u64 <<= -arg1.
u.
i64;
1498 internal_error <<
"Folding not implemented for intrinsic: " << intrin;
1516template<
typename A,
typename B>
1520template<
typename A,
typename B>
1524template<
typename A,
typename B>
1529template<
typename A,
typename B>
1533template<
typename A,
typename B>
1537template<
typename A,
typename B>
1541template<
typename A,
typename B>
1545template<
typename A,
typename B>
1552 p.optional_type_hint.type = t;
1555template<
typename A,
typename B>
1559template<
typename A,
typename B>
1563template<
typename A,
typename B>
1567template<
typename A,
typename B>
1571template<
typename A,
typename B>
1575template<
typename A,
typename B>
1579template<
typename A,
typename B>
1583template<
typename A,
typename B,
typename C>
1587template<
typename A,
typename B,
typename C>
1597template<
typename A,
typename B>
1623 template<u
int32_t bound>
1628 const Not &op = (
const Not &)e;
1629 return (
a.template match<bound>(*op.
a.
get(), state));
1632 template<u
int32_t bound,
typename A2>
1634 return a.template match<bound>(
unwrap(op.a), state);
1644 template<
typename A1 = A>
1646 a.make_folded_const(val, ty, state);
1647 val.u.u64 = ~val.u.u64;
1654 assert_is_lvalue_if_expr<A>();
1660 assert_is_lvalue_if_expr<A>();
1666 s <<
"!(" << op.
a <<
")";
1684 template<u
int32_t bound>
1686 static_assert(bound &
Wild<i>::binds,
"neg must be applied to an already-bound expr");
1692 ((
equal(*((
const NE &)e).a.get(), *((
const EQ &)b).a.get()) &&
1693 equal(*((
const NE &)e).b.get(), *((
const EQ &)b).
b.
get())) ||
1694 (
equal(*((
const NE &)e).a.get(), *((
const EQ &)b).b.get()) &&
1695 equal(*((
const NE &)e).b.get(), *((
const EQ &)b).
a.
get()))));
1698 ((
equal(*((
const EQ &)e).a.get(), *((
const NE &)b).a.get()) &&
1699 equal(*((
const EQ &)e).b.get(), *((
const NE &)b).
b.
get())) ||
1700 (
equal(*((
const EQ &)e).a.get(), *((
const NE &)b).b.get()) &&
1701 equal(*((
const EQ &)e).b.get(), *((
const NE &)b).
a.
get()))));
1704 equal(*((
const LE &)e).a.get(), *((
const LT &)b).b.get()) &&
1708 equal(*((
const LT &)e).a.get(), *((
const LE &)b).b.get()) &&
1711 return equal(e, *((
const Not &)b).a.get());
1714 equal(*((
const Not &)e).a.get(), b));
1726 s <<
"neg(" <<
Wild<i>{} <<
")";
1730template<
typename C,
typename T,
typename F>
1742 constexpr static bool canonical = C::canonical && T::canonical && F::canonical;
1744 template<u
int32_t bound>
1750 return (
c.template match<bound>(*op.
condition.
get(), state) &&
1754 template<u
int32_t bound,
typename C2,
typename T2,
typename F2>
1756 return (
c.template match<bound>(
unwrap(instance.c), state) &&
1763 return Select::make(
c.make(state, {}),
t.make(state, type_hint),
f.make(state, type_hint));
1766 constexpr static bool foldable = C::foldable && T::foldable && F::foldable;
1768 template<
typename C1 = C>
1772 c.make_folded_const(c_val, c_ty, state);
1773 if ((c_val.
u.
u64 & 1) == 1) {
1774 t.make_folded_const(val, ty, state);
1776 f.make_folded_const(val, ty, state);
1782template<
typename C,
typename T,
typename F>
1784 s <<
"select(" << op.
c <<
", " << op.
t <<
", " << op.
f <<
")";
1788template<
typename C,
typename T,
typename F>
1790 assert_is_lvalue_if_expr<C>();
1791 assert_is_lvalue_if_expr<T>();
1792 assert_is_lvalue_if_expr<F>();
1796template<
typename A,
typename B>
1807 constexpr static bool canonical = A::canonical && B::canonical;
1809 template<u
int32_t bound>
1813 if (
a.template match<bound>(*op.
value.
get(), state) &&
1814 lanes.template match<bound>(op.
lanes, state)) {
1821 template<u
int32_t bound,
typename A2,
typename B2>
1823 return (
a.template match<bound>(
unwrap(op.a), state) &&
1831 lanes.make_folded_const(lanes_val, ty, state);
1833 type_hint.
lanes /= l;
1834 Expr val =
a.make(state, type_hint);
1844 template<
typename A1 = A>
1848 lanes.make_folded_const(lanes_val, lanes_ty, state);
1850 a.make_folded_const(val, ty, state);
1855template<
typename A,
typename B>
1857 s <<
"broadcast(" << op.
a <<
", " << op.
lanes <<
")";
1861template<
typename A,
typename B>
1863 assert_is_lvalue_if_expr<A>();
1867template<
typename A,
typename B,
typename C>
1879 constexpr static bool canonical = A::canonical && B::canonical && C::canonical;
1881 template<u
int32_t bound>
1887 if (
a.template match<bound>(*op.
base.
get(), state) &&
1896 template<u
int32_t bound,
typename A2,
typename B2,
typename C2>
1898 return (
a.template match<bound>(
unwrap(op.a), state) &&
1907 lanes.make_folded_const(lanes_val, ty, state);
1909 type_hint.
lanes /= l;
1911 eb =
b.make(state, type_hint);
1912 ea =
a.make(state, eb.type());
1913 return Ramp::make(std::move(ea), std::move(eb), l);
1919template<
typename A,
typename B,
typename C>
1921 s <<
"ramp(" << op.
a <<
", " << op.
b <<
", " << op.
lanes <<
")";
1925template<
typename A,
typename B,
typename C>
1927 assert_is_lvalue_if_expr<A>();
1928 assert_is_lvalue_if_expr<B>();
1929 assert_is_lvalue_if_expr<C>();
1933template<
typename A,
typename B, VectorReduce::Operator reduce_op>
1945 template<u
int32_t bound>
1949 if (op.
op == reduce_op &&
1950 a.template match<bound>(*op.
value.
get(), state) &&
1958 template<u
int32_t bound,
typename A2,
typename B2, VectorReduce::Operator reduce_op_2>
1960 return (reduce_op == reduce_op_2 &&
1961 a.template match<bound>(
unwrap(op.a), state) &&
1969 lanes.make_folded_const(lanes_val, ty, state);
1970 int l = (int)lanes_val.
u.
i64;
1977template<
typename A,
typename B, VectorReduce::Operator reduce_op>
1979 s <<
"vector_reduce(" << reduce_op <<
", " << op.
a <<
", " << op.
lanes <<
")";
1983template<
typename A,
typename B>
1985 assert_is_lvalue_if_expr<A>();
1989template<
typename A,
typename B>
1991 assert_is_lvalue_if_expr<A>();
1995template<
typename A,
typename B>
1997 assert_is_lvalue_if_expr<A>();
2001template<
typename A,
typename B>
2003 assert_is_lvalue_if_expr<A>();
2007template<
typename A,
typename B>
2009 assert_is_lvalue_if_expr<A>();
2025 template<u
int32_t bound>
2030 const Sub &op = (
const Sub &)e;
2031 return (
a.template match<bound>(*op.
b.
get(), state) &&
2035 template<u
int32_t bound,
typename A2>
2037 return a.template match<bound>(
unwrap(p.a), state);
2042 Expr ea =
a.make(state, type_hint);
2044 return Sub::make(std::move(z), std::move(ea));
2049 template<
typename A1 = A>
2051 a.make_folded_const(val, ty, state);
2052 int dead_bits = 64 - ty.bits;
2055 if (ty.bits >= 32 && val.u.u64 && (val.u.u64 << (65 - ty.bits)) == 0) {
2064 val.u.u64 = ((-val.u.u64) << dead_bits) >> dead_bits;
2068 val.u.f64 = -val.u.f64;
2085 assert_is_lvalue_if_expr<A>();
2091 assert_is_lvalue_if_expr<A>();
2107 template<u
int32_t bound>
2113 return (e.type ==
t &&
2114 a.template match<bound>(*op.
value.
get(), state));
2116 template<u
int32_t bound,
typename A2>
2118 return t == op.t &&
a.template match<bound>(
unwrap(op.a), state);
2123 return cast(
t,
a.make(state, {}));
2131 s <<
"cast(" << op.
t <<
", " << op.
a <<
")";
2137 assert_is_lvalue_if_expr<A>();
2152 template<u
int32_t bound>
2159 a.template match<bound>(*op.
value.
get(), state));
2161 template<u
int32_t bound,
typename A2>
2163 return a.template match<bound>(
unwrap(op.a), state);
2168 Expr e =
a.make(state, {});
2170 return cast(w, std::move(e));
2178 s <<
"widen(" << op.
a <<
")";
2184 assert_is_lvalue_if_expr<A>();
2188template<
typename Vec,
typename Base,
typename Str
ide,
typename Lanes>
2196 static constexpr uint32_t binds = Vec::binds | Base::binds | Stride::binds | Lanes::binds;
2200 constexpr static bool canonical = Vec::canonical && Base::canonical && Stride::canonical && Lanes::canonical;
2202 template<u
int32_t bound>
2208 return v.
vectors.size() == 1 &&
2210 vec.template match<bound>(*v.
vectors[0].get(), state) &&
2220 base.make_folded_const(base_val, ty, state);
2221 int b = (int)base_val.
u.
i64;
2222 stride.make_folded_const(stride_val, ty, state);
2223 int s = (int)stride_val.
u.
i64;
2224 lanes.make_folded_const(lanes_val, ty, state);
2225 int l = (int)lanes_val.
u.
i64;
2234 static_assert(Base::foldable,
"Base of slice should consist only of operations that constant-fold");
2235 static_assert(Stride::foldable,
"Stride of slice should consist only of operations that constant-fold");
2236 static_assert(Lanes::foldable,
"Lanes of slice should consist only of operations that constant-fold");
2240template<
typename Vec,
typename Base,
typename Str
ide,
typename Lanes>
2242 s <<
"slice(" << op.
vec <<
", " << op.
base <<
", " << op.
stride <<
", " << op.
lanes <<
")";
2246template<
typename Vec,
typename Base,
typename Str
ide,
typename Lanes>
2267 a.make_folded_const(c, ty, state);
2273 if (type_hint.bits) {
2277 c.
u.
f64 = (double)x;
2279 ty.
code = type_hint.code;
2280 ty.
bits = type_hint.bits;
2288 template<
typename A1 = A>
2290 a.make_folded_const(val, ty, state);
2296 assert_is_lvalue_if_expr<A>();
2302 s <<
"fold(" << op.
a <<
")";
2321 template<
typename A1 = A>
2323 a.make_folded_const(val, ty, state);
2333 assert_is_lvalue_if_expr<A>();
2339 s <<
"overflows(" << op.
a <<
")";
2353 template<u
int32_t bound>
2399 template<
typename A1 = A>
2401 Expr e =
a.make(state, {});
2415 assert_is_lvalue_if_expr<A>();
2421 assert_is_lvalue_if_expr<A>();
2428 s <<
"is_const(" << op.
a <<
")";
2430 s <<
"is_const(" << op.
a <<
", " << op.
v <<
")";
2435template<
typename A,
typename Prover>
2452 Expr condition =
a.make(state, {});
2453 condition =
prover->mutate(condition,
nullptr);
2461template<
typename A,
typename Prover>
2463 assert_is_lvalue_if_expr<A>();
2467template<
typename A,
typename Prover>
2469 s <<
"can_prove(" << op.
a <<
")";
2490 Type t =
a.make(state, {}).type();
2500 assert_is_lvalue_if_expr<A>();
2506 s <<
"is_float(" << op.
a <<
")";
2529 Type t =
a.make(state, {}).type();
2539 assert_is_lvalue_if_expr<A>();
2545 s <<
"is_int(" << op.
a;
2547 s <<
", " << op.
bits;
2550 s <<
", " << op.
lanes;
2575 Type t =
a.make(state, {}).type();
2585 assert_is_lvalue_if_expr<A>();
2591 s <<
"is_uint(" << op.
a;
2593 s <<
", " << op.
bits;
2596 s <<
", " << op.
lanes;
2619 Type t =
a.make(state, {}).type();
2629 assert_is_lvalue_if_expr<A>();
2635 s <<
"is_scalar(" << op.
a <<
")";
2656 a.make_folded_const(val, ty, state);
2659 val.
u.
u64 = (val.
u.
u64 == max_bits);
2670 assert_is_lvalue_if_expr<A>();
2676 s <<
"is_max_value(" << op.
a <<
")";
2697 a.make_folded_const(val, ty, state);
2700 val.
u.
u64 = (val.
u.
u64 == min_bits);
2713 assert_is_lvalue_if_expr<A>();
2719 s <<
"is_min_value(" << op.
a <<
")";
2740 Type t =
a.make(state, {}).type();
2750 assert_is_lvalue_if_expr<A>();
2756 s <<
"lanes_of(" << op.
a <<
")";
2761template<
typename Before,
2764 typename = std::enable_if_t<std::decay_t<Before>::foldable &&
2765 std::decay_t<After>::foldable>>
2770 wildcard_type.lanes = output_type.lanes = 1;
2773 static std::set<uint32_t> tested;
2775 if (!tested.insert(reinterpret_bits<uint32_t>(wildcard_type)).second) {
2780 debug(0) <<
"validate('" << before <<
"', '" << after <<
"', '" << pred <<
"', " <<
Type(wildcard_type) <<
", " <<
Type(output_type) <<
")\n";
2785 static std::mt19937_64 rng(0);
2790 for (
int trials = 0; trials < 100; trials++) {
2794 int shift = (int)(rng() & (wildcard_type.bits - 1));
2796 for (
int i = 0; i <
max_wild; i++) {
2798 switch (wildcard_type.code) {
2818 double val = ((
int64_t)(rng() & 15) - 8) / 2.0;
2820 val = ((
int64_t)(rng() & 15) - 8) / 2.0;
2834 before.make_folded_const(val_before, type, state);
2836 after.make_folded_const(val_after, type, state);
2837 lanes |= type.
lanes;
2844 switch (output_type.code) {
2859 ok &= (error < 0.01 ||
2860 val_before.
u.
u64 == val_after.
u.
u64 ||
2861 std::isnan(val_before.
u.
f64));
2869 debug(0) <<
"Fails with values:\n";
2870 for (
int i = 0; i <
max_wild; i++) {
2875 for (
int i = 0; i <
max_wild; i++) {
2880 debug(0) << val_before.
u.
u64 <<
" " << val_after.
u.
u64 <<
"\n";
2886template<
typename Before,
2889 typename = std::enable_if_t<!(std::decay_t<Before>::foldable &&
2890 std::decay_t<After>::foldable)>>
2901template<
typename Pattern,
2902 typename =
typename enable_if_pattern<Pattern>::type>
2906 p.make_folded_const(c, ty, state);
2914#define HALIDE_DEBUG_MATCHED_RULES 0
2915#define HALIDE_DEBUG_UNMATCHED_RULES 0
2921#define HALIDE_FUZZ_TEST_RULES 0
2923template<
typename Instance>
2936 template<
typename After>
2938#if HALIDE_DEBUG_MATCHED_RULES
2944 template<
typename Before,
2949 static_assert((Before::binds & After::binds) == After::binds,
"Rule result uses unbound values");
2950 static_assert(Before::canonical,
"LHS of rewrite rule should be in canonical form");
2951 static_assert(After::canonical,
"RHS of rewrite rule should be in canonical form");
2952#if HALIDE_FUZZ_TEST_RULES
2957#if HALIDE_DEBUG_MATCHED_RULES
2962#if HALIDE_DEBUG_UNMATCHED_RULES
2963 debug(0) <<
instance <<
" does not match " << before <<
"\n";
2969 template<
typename Before,
2972 static_assert(Before::canonical,
"LHS of rewrite rule should be in canonical form");
2975#if HALIDE_DEBUG_MATCHED_RULES
2980#if HALIDE_DEBUG_UNMATCHED_RULES
2981 debug(0) <<
instance <<
" does not match " << before <<
"\n";
2987 template<
typename Before,
2990 static_assert(Before::canonical,
"LHS of rewrite rule should be in canonical form");
2991#if HALIDE_FUZZ_TEST_RULES
2996#if HALIDE_DEBUG_MATCHED_RULES
3001#if HALIDE_DEBUG_UNMATCHED_RULES
3002 debug(0) <<
instance <<
" does not match " << before <<
"\n";
3008 template<
typename Before,
3015 static_assert(Predicate::foldable,
"Predicates must consist only of operations that can constant-fold");
3016 static_assert((Before::binds & After::binds) == After::binds,
"Rule result uses unbound values");
3017 static_assert((Before::binds & Predicate::binds) == Predicate::binds,
"Rule predicate uses unbound values");
3018 static_assert(Before::canonical,
"LHS of rewrite rule should be in canonical form");
3019 static_assert(After::canonical,
"RHS of rewrite rule should be in canonical form");
3021#if HALIDE_FUZZ_TEST_RULES
3027#if HALIDE_DEBUG_MATCHED_RULES
3028 debug(0) <<
instance <<
" -> " <<
result <<
" via " << before <<
" -> " << after <<
" when " << pred <<
"\n";
3032#if HALIDE_DEBUG_UNMATCHED_RULES
3033 debug(0) <<
instance <<
" does not match " << before <<
"\n";
3039 template<
typename Before,
3044 static_assert(Predicate::foldable,
"Predicates must consist only of operations that can constant-fold");
3045 static_assert(Before::canonical,
"LHS of rewrite rule should be in canonical form");
3050#if HALIDE_DEBUG_MATCHED_RULES
3051 debug(0) <<
instance <<
" -> " <<
result <<
" via " << before <<
" -> " << after <<
" when " << pred <<
"\n";
3055#if HALIDE_DEBUG_UNMATCHED_RULES
3056 debug(0) <<
instance <<
" does not match " << before <<
"\n";
3062 template<
typename Before,
3067 static_assert(Predicate::foldable,
"Predicates must consist only of operations that can constant-fold");
3068 static_assert(Before::canonical,
"LHS of rewrite rule should be in canonical form");
3069#if HALIDE_FUZZ_TEST_RULES
3075#if HALIDE_DEBUG_MATCHED_RULES
3076 debug(0) <<
instance <<
" -> " <<
result <<
" via " << before <<
" -> " << after <<
" when " << pred <<
"\n";
3080#if HALIDE_DEBUG_UNMATCHED_RULES
3081 debug(0) <<
instance <<
" does not match " << before <<
"\n";
3105template<
typename Instance,
3106 typename =
typename enable_if_pattern<Instance>::type>
3108 return {
pattern_arg(instance), output_type, wildcard_type};
3111template<
typename Instance,
3112 typename =
typename enable_if_pattern<Instance>::type>
3114 return {
pattern_arg(instance), output_type, output_type};
#define debug(n)
For optional debugging during codegen, use the debug macro as follows:
@ halide_type_float
IEEE floating point numbers.
@ halide_type_bfloat
floating point numbers in the bfloat format
@ halide_type_int
signed integers
@ halide_type_uint
unsigned integers
#define HALIDE_NEVER_INLINE
#define HALIDE_ALWAYS_INLINE
Subtypes for Halide expressions (Halide::Expr) and statements (Halide::Internal::Stmt)
Methods to test Exprs and Stmts for equality of value.
Defines various operator overloads and utility functions that make it more pleasant to work with Hali...
HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type, halide_type_t wildcard_type) noexcept -> Rewriter< decltype(pattern_arg(instance))>
Construct a rewriter for the given instance, which may be a pattern with concrete expressions as leav...
HALIDE_ALWAYS_INLINE T pattern_arg(T t)
auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin< Call::rounding_halving_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||(a, b))
HALIDE_ALWAYS_INLINE auto operator!(A &&a) noexcept -> NotOp< decltype(pattern_arg(a))>
auto shift_right(A &&a, B &&b) noexcept -> Intrin< Call::shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto min(A &&a, B &&b) noexcept -> BinOp< Min, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
auto widening_add(A &&a, B &&b) noexcept -> Intrin< Call::widening_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto is_int(A &&a, uint8_t bits=0, uint16_t lanes=0) noexcept -> IsInt< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE bool evaluate_predicate(bool x, MatcherState &) noexcept
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Div >(halide_type_t &t, int64_t a, int64_t b) noexcept
auto abs(A &&a) noexcept -> Intrin< Call::abs, decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE auto ne(A &&a, B &&b) -> decltype(IRMatcher::operator!=(a, b))
HALIDE_ALWAYS_INLINE auto is_uint(A &&a, uint8_t bits=0, uint16_t lanes=0) noexcept -> IsUInt< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE auto negate(A &&a) -> decltype(IRMatcher::operator-(a))
uint64_t constant_fold_cmp_op(int64_t, int64_t) noexcept
HALIDE_ALWAYS_INLINE auto operator<=(A &&a, B &&b) noexcept -> CmpOp< LE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto operator+(A &&a, B &&b) noexcept -> BinOp< Add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto is_max_value(A &&a) noexcept -> IsMaxValue< decltype(pattern_arg(a))>
std::ostream & operator<<(std::ostream &s, const SpecificExpr &e)
HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&&(a, b))
HALIDE_ALWAYS_INLINE auto h_and(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::And >
HALIDE_ALWAYS_INLINE auto gt(A &&a, B &&b) -> decltype(IRMatcher::operator>(a, b))
HALIDE_ALWAYS_INLINE auto is_const(A &&a) noexcept -> IsConst< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LE >(int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE auto operator*(A &&a, B &&b) noexcept -> BinOp< Mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto add(A &&a, B &&b) -> decltype(IRMatcher::operator+(a, b))
auto widen_right_add(A &&a, B &&b) noexcept -> Intrin< Call::widen_right_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, b))
auto widen_right_mul(A &&a, B &&b) noexcept -> Intrin< Call::widen_right_mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto mul(A &&a, B &&b) -> decltype(IRMatcher::operator*(a, b))
HALIDE_ALWAYS_INLINE auto max(A &&a, B &&b) noexcept -> BinOp< Max, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
auto absd(A &&a, B &&b) noexcept -> Intrin< Call::absd, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto slice(Vec vec, Base base, Stride stride, Lanes lanes) noexcept -> SliceOp< decltype(pattern_arg(vec)), decltype(pattern_arg(base)), decltype(pattern_arg(stride)), decltype(pattern_arg(lanes))>
HALIDE_ALWAYS_INLINE auto ramp(A &&a, B &&b, C &&c) noexcept -> RampOp< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
HALIDE_ALWAYS_INLINE auto operator/(A &&a, B &&b) noexcept -> BinOp< Div, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto widen(A &&a) noexcept -> WidenOp< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mod >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< And >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE int64_t unwrap(IntLiteral t)
auto widening_mul(A &&a, B &&b) noexcept -> Intrin< Call::widening_mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto operator>(A &&a, B &&b) noexcept -> CmpOp< GT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE auto overflows(A &&a) noexcept -> Overflows< decltype(pattern_arg(a))>
auto saturating_cast(const Type &t, A &&a) noexcept -> Intrin< Call::saturating_cast, decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE void assert_is_lvalue_if_expr()
HALIDE_ALWAYS_INLINE auto operator%(A &&a, B &&b) noexcept -> BinOp< Mod, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Sub >(halide_type_t &t, int64_t a, int64_t b) noexcept
auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin< Call::rounding_shift_left, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto is_scalar(A &&a) noexcept -> IsScalar< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE auto fold(A &&a) noexcept -> Fold< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE auto not_op(A &&a) -> decltype(IRMatcher::operator!(a))
auto likely(A &&a) noexcept -> Intrin< Call::likely, decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Max >(halide_type_t &t, int64_t a, int64_t b) noexcept
constexpr bool and_reduce()
HALIDE_ALWAYS_INLINE auto operator||(A &&a, B &&b) noexcept -> BinOp< Or, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto operator!=(A &&a, B &&b) noexcept -> CmpOp< NE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
auto halving_add(A &&a, B &&b) noexcept -> Intrin< Call::halving_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< Call::mul_shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
HALIDE_ALWAYS_INLINE auto is_float(A &&a) noexcept -> IsFloat< decltype(pattern_arg(a))>
auto widening_sub(A &&a, B &&b) noexcept -> Intrin< Call::widening_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto operator>=(A &&a, B &&b) noexcept -> CmpOp< GE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto operator<(A &&a, B &&b) noexcept -> CmpOp< LT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto operator&&(A &&a, B &&b) noexcept -> BinOp< And, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto h_or(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Or >
constexpr bool commutative(IRNodeType t)
HALIDE_ALWAYS_INLINE auto sub(A &&a, B &&b) -> decltype(IRMatcher::operator-(a, b))
auto likely_if_innermost(A &&a) noexcept -> Intrin< Call::likely_if_innermost, decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE auto h_max(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Max >
HALIDE_ALWAYS_INLINE auto broadcast(A &&a, B lanes) noexcept -> BroadcastOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes))>
HALIDE_ALWAYS_INLINE auto select(C &&c, T &&t, F &&f) noexcept -> SelectOp< decltype(pattern_arg(c)), decltype(pattern_arg(t)), decltype(pattern_arg(f))>
HALIDE_ALWAYS_INLINE auto neg(const Wild< i > &a) -> SimplifiedNegateOp< i >
auto saturating_sub(A &&a, B &&b) noexcept -> Intrin< Call::saturating_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto is_min_value(A &&a) noexcept -> IsMinValue< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Min >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_NEVER_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred, halide_type_t wildcard_type, halide_type_t output_type) noexcept
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GT >(int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mul >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GE >(int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE auto operator-(A &&a, B &&b) noexcept -> BinOp< Sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< Call::rounding_mul_shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
auto saturating_add(A &&a, B &&b) noexcept -> Intrin< Call::saturating_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto le(A &&a, B &&b) -> decltype(IRMatcher::operator<=(a, b))
HALIDE_ALWAYS_INLINE auto lt(A &&a, B &&b) -> decltype(IRMatcher::operator<(a, b))
auto shift_left(A &&a, B &&b) noexcept -> Intrin< Call::shift_left, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto is_const(A &&a, int64_t value) noexcept -> IsConst< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE auto lanes_of(A &&a) noexcept -> LanesOf< decltype(pattern_arg(a))>
auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin< Call::rounding_shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LT >(int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE auto h_min(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Min >
HALIDE_ALWAYS_INLINE auto h_add(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Add >
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Or >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE Expr make_const_expr(halide_scalar_value_t val, halide_type_t ty)
constexpr uint32_t bitwise_or_reduce()
int64_t constant_fold_bin_op(halide_type_t &, int64_t, int64_t) noexcept
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< EQ >(int64_t a, int64_t b) noexcept
HALIDE_NEVER_INLINE Expr make_const_special_expr(halide_type_t ty)
auto halving_sub(A &&a, B &&b) noexcept -> Intrin< Call::halving_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto ge(A &&a, B &&b) -> decltype(IRMatcher::operator>=(a, b))
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< NE >(int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b))
HALIDE_ALWAYS_INLINE auto operator==(A &&a, B &&b) noexcept -> CmpOp< EQ, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
auto widen_right_sub(A &&a, B &&b) noexcept -> Intrin< Call::widen_right_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Add >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE auto can_prove(A &&a, Prover *p) noexcept -> CanProve< decltype(pattern_arg(a)), Prover >
HALIDE_ALWAYS_INLINE auto eq(A &&a, B &&b) -> decltype(IRMatcher::operator==(a, b))
bool is_const_zero(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to zero (in all lanes,...
Expr make_zero(Type t)
Construct the representation of zero in the given type.
bool is_const_one(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to one (in all lanes,...
bool equal(const RDom &bounds0, const RDom &bounds1)
Return true if bounds0 and bounds1 represent the same bounds.
constexpr IRNodeType StrongestExprNodeType
Expr make_const(Type t, int64_t val)
Construct an immediate of the given type from any numeric C++ type.
T mod_imp(T a, T b)
Implementations of division and mod that are specific to Halide.
bool sub_would_overflow(int bits, int64_t a, int64_t b)
bool add_would_overflow(int bits, int64_t a, int64_t b)
Routines to test if math would overflow for signed integers with the given number of bits.
bool mul_would_overflow(int bits, int64_t a, int64_t b)
Expr with_lanes(const Expr &x, int lanes)
Rewrite the expression x to have lanes lanes.
bool expr_match(const Expr &pattern, const Expr &expr, std::vector< Expr > &result)
Does the first expression have the same structure as the second? Variables in the first expression wi...
Expr make_signed_integer_overflow(Type type)
Construct a unique signed_integer_overflow Expr.
IRNodeType
All our IR node types get unique IDs for the purposes of RTTI.
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
@ Internal
Not visible externally, similar to 'static' linkage in C.
@ Predicate
Guard the loads and stores in the loop with an if statement that prevents evaluation beyond the origi...
unsigned __INT64_TYPE__ uint64_t
signed __INT64_TYPE__ int64_t
signed __INT32_TYPE__ int32_t
unsigned __INT8_TYPE__ uint8_t
unsigned __INT16_TYPE__ uint16_t
unsigned __INT32_TYPE__ uint32_t
A fragment of Halide syntax.
HALIDE_ALWAYS_INLINE Type type() const
Get the type of this expression node.
HALIDE_ALWAYS_INLINE const Internal::BaseExprNode * get() const
Override get() to return a BaseExprNode * instead of an IRNode *.
The sum of two expressions.
Logical and - are both expressions true.
A base class for expression nodes.
A vector with 'lanes' elements, in which every element is 'value'.
static Expr make(Expr value, int lanes)
static const IRNodeType _node_type
@ signed_integer_overflow
@ rounding_mul_shift_right
bool is_intrinsic() const
static const IRNodeType _node_type
The actual IR nodes begin here.
static const IRNodeType _node_type
The ratio of two expressions.
Is the first expression equal to the second.
Floating point constants.
static const FloatImm * make(Type t, double value)
Is the first expression greater than or equal to the second.
Is the first expression greater than the second.
static constexpr bool canonical
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
static constexpr uint32_t binds
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
static constexpr bool foldable
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
HALIDE_ALWAYS_INLINE bool match(const BinOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
static constexpr IRNodeType max_node_type
static constexpr IRNodeType min_node_type
static constexpr IRNodeType min_node_type
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
HALIDE_ALWAYS_INLINE bool match(const BroadcastOp< A2, B2 > &op, MatcherState &state) const noexcept
static constexpr bool foldable
static constexpr uint32_t binds
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
static constexpr bool canonical
static constexpr IRNodeType max_node_type
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
HALIDE_NEVER_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
static constexpr uint32_t binds
static constexpr IRNodeType min_node_type
static constexpr IRNodeType max_node_type
static constexpr bool foldable
static constexpr bool canonical
static constexpr IRNodeType max_node_type
static constexpr bool foldable
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
static constexpr IRNodeType min_node_type
static constexpr uint32_t binds
static constexpr bool canonical
HALIDE_ALWAYS_INLINE bool match(const CastOp< A2 > &op, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
static constexpr IRNodeType max_node_type
static constexpr uint32_t binds
static constexpr bool canonical
static constexpr bool foldable
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
static constexpr IRNodeType min_node_type
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE bool match(const CmpOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
static constexpr IRNodeType max_node_type
static constexpr uint32_t binds
static constexpr IRNodeType min_node_type
static constexpr bool canonical
static constexpr bool foldable
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
static constexpr IRNodeType max_node_type
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE IntLiteral(int64_t v)
HALIDE_ALWAYS_INLINE bool match(const IntLiteral &b, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
static constexpr IRNodeType min_node_type
static constexpr bool canonical
static constexpr bool foldable
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
HALIDE_ALWAYS_INLINE bool match(int64_t val, MatcherState &state) const noexcept
static constexpr uint32_t binds
HALIDE_ALWAYS_INLINE Intrin(Args... args) noexcept
HALIDE_ALWAYS_INLINE void print_args(std::ostream &s) const
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
static constexpr bool foldable
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
std::tuple< Args... > args
static constexpr uint32_t binds
HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const
HALIDE_ALWAYS_INLINE bool match_args(int, const Call &c, MatcherState &state) const noexcept
static constexpr bool canonical
HALIDE_ALWAYS_INLINE bool match_args(double, const Call &c, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const
static constexpr IRNodeType max_node_type
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
static constexpr IRNodeType min_node_type
OptionalIntrinType< intrin > optional_type_hint
static constexpr IRNodeType min_node_type
static constexpr bool canonical
static constexpr bool foldable
static constexpr IRNodeType max_node_type
static constexpr uint32_t binds
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
static constexpr bool foldable
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
static constexpr bool canonical
static constexpr IRNodeType min_node_type
static constexpr IRNodeType max_node_type
static constexpr uint32_t binds
static constexpr IRNodeType max_node_type
static constexpr bool foldable
static constexpr uint32_t binds
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
static constexpr IRNodeType min_node_type
static constexpr bool canonical
static constexpr bool foldable
static constexpr IRNodeType min_node_type
static constexpr bool canonical
static constexpr IRNodeType max_node_type
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
static constexpr uint32_t binds
static constexpr IRNodeType min_node_type
static constexpr bool canonical
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
static constexpr bool foldable
static constexpr IRNodeType max_node_type
static constexpr uint32_t binds
static constexpr IRNodeType max_node_type
static constexpr uint32_t binds
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
static constexpr IRNodeType min_node_type
static constexpr bool foldable
static constexpr bool canonical
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
static constexpr bool foldable
static constexpr IRNodeType min_node_type
static constexpr bool canonical
static constexpr uint32_t binds
static constexpr IRNodeType max_node_type
static constexpr IRNodeType max_node_type
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
static constexpr IRNodeType min_node_type
static constexpr bool foldable
static constexpr uint32_t binds
static constexpr bool canonical
To save stack space, the matcher objects are largely stateless and immutable.
HALIDE_ALWAYS_INLINE void get_bound_const(int i, halide_scalar_value_t &val, halide_type_t &type) const noexcept
HALIDE_ALWAYS_INLINE void set_bound_const(int i, int64_t s, halide_type_t t) noexcept
HALIDE_ALWAYS_INLINE void set_bound_const(int i, double f, halide_type_t t) noexcept
static constexpr uint16_t special_values_mask
HALIDE_ALWAYS_INLINE void set_bound_const(int i, halide_scalar_value_t val, halide_type_t t) noexcept
halide_type_t bound_const_type[max_wild]
HALIDE_ALWAYS_INLINE void set_binding(int i, const BaseExprNode &n) noexcept
HALIDE_ALWAYS_INLINE MatcherState() noexcept
HALIDE_ALWAYS_INLINE const BaseExprNode * get_binding(int i) const noexcept
halide_scalar_value_t bound_const[max_wild]
HALIDE_ALWAYS_INLINE void set_bound_const(int i, uint64_t u, halide_type_t t) noexcept
static constexpr uint16_t signed_integer_overflow
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
HALIDE_ALWAYS_INLINE bool match(NegateOp< A2 > &&p, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
static constexpr uint32_t binds
static constexpr bool canonical
static constexpr bool foldable
static constexpr IRNodeType max_node_type
static constexpr IRNodeType min_node_type
static constexpr uint32_t binds
static constexpr bool foldable
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
static constexpr IRNodeType max_node_type
static constexpr bool canonical
HALIDE_ALWAYS_INLINE bool match(const NotOp< A2 > &op, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
static constexpr IRNodeType min_node_type
bool check(const Type &t) const
bool check(const Type &) const
static constexpr uint32_t binds
static constexpr IRNodeType max_node_type
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
static constexpr bool canonical
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
static constexpr bool foldable
static constexpr IRNodeType min_node_type
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
static constexpr IRNodeType min_node_type
static constexpr uint32_t binds
static constexpr bool canonical
static constexpr IRNodeType max_node_type
static constexpr bool foldable
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
static constexpr bool canonical
static constexpr IRNodeType max_node_type
static constexpr IRNodeType min_node_type
static constexpr uint32_t binds
HALIDE_ALWAYS_INLINE bool match(const RampOp< A2, B2, C2 > &op, MatcherState &state) const noexcept
static constexpr bool foldable
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_NEVER_INLINE void build_replacement(After after)
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after, Predicate pred)
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after) noexcept
HALIDE_ALWAYS_INLINE Rewriter(Instance instance, halide_type_t ot, halide_type_t wt)
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after, Predicate pred)
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after) noexcept
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after, Predicate pred)
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after)
halide_type_t wildcard_type
halide_type_t output_type
static constexpr uint32_t binds
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
static constexpr bool foldable
static constexpr bool canonical
HALIDE_ALWAYS_INLINE bool match(const SelectOp< C2, T2, F2 > &instance, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
static constexpr IRNodeType max_node_type
static constexpr IRNodeType min_node_type
static constexpr uint32_t binds
static constexpr bool canonical
static constexpr IRNodeType min_node_type
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
static constexpr bool foldable
static constexpr IRNodeType max_node_type
static constexpr bool canonical
static constexpr IRNodeType max_node_type
static constexpr bool foldable
HALIDE_ALWAYS_INLINE SliceOp(Vec v, Base b, Stride s, Lanes l)
static constexpr IRNodeType min_node_type
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
static constexpr uint32_t binds
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
static constexpr IRNodeType min_node_type
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
static constexpr bool canonical
static constexpr IRNodeType max_node_type
const BaseExprNode & expr
static constexpr uint32_t binds
static constexpr bool foldable
static constexpr bool canonical
HALIDE_ALWAYS_INLINE bool match(const VectorReduceOp< A2, B2, reduce_op_2 > &op, MatcherState &state) const noexcept
static constexpr uint32_t binds
static constexpr IRNodeType min_node_type
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
static constexpr IRNodeType max_node_type
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
static constexpr bool foldable
static constexpr uint32_t binds
static constexpr bool canonical
static constexpr bool foldable
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
static constexpr IRNodeType max_node_type
HALIDE_ALWAYS_INLINE bool match(const WidenOp< A2 > &op, MatcherState &state) const noexcept
static constexpr IRNodeType min_node_type
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
static constexpr bool canonical
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
static constexpr IRNodeType max_node_type
static constexpr IRNodeType min_node_type
static constexpr uint32_t binds
static constexpr bool foldable
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
static constexpr bool canonical
static constexpr IRNodeType max_node_type
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
static constexpr uint32_t binds
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE bool match(int64_t e, MatcherState &state) const noexcept
static constexpr IRNodeType min_node_type
static constexpr bool foldable
static constexpr bool canonical
static constexpr bool foldable
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
static constexpr uint32_t binds
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE bool match(int64_t value, MatcherState &state) const noexcept
static constexpr IRNodeType min_node_type
static constexpr IRNodeType max_node_type
static constexpr uint32_t binds
static constexpr bool foldable
static constexpr IRNodeType max_node_type
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
static constexpr IRNodeType min_node_type
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
static constexpr bool canonical
static constexpr IRNodeType min_node_type
static constexpr uint32_t binds
static constexpr IRNodeType max_node_type
static constexpr bool canonical
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
static constexpr bool foldable
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
static constexpr uint32_t mask
IRNodeType node_type
Each IR node subclass has a unique identifier.
static const IntImm * make(Type t, int64_t value)
Is the first expression less than or equal to the second.
Is the first expression less than the second.
The greater of two values.
The lesser of two values.
The product of two expressions.
Is the first expression not equal to the second.
Logical not - true if the expression false.
Logical or - is at least one of the expression true.
A linear ramp vector node.
static const IRNodeType _node_type
static Expr make(Expr base, Expr stride, int lanes)
static Expr make(Expr condition, Expr true_value, Expr false_value)
static const IRNodeType _node_type
Construct a new vector by taking elements from another sequence of vectors.
static Expr make_slice(Expr vector, int begin, int stride, int size)
Convenience constructor for making a shuffle representing a contiguous subset of a vector.
std::vector< Expr > vectors
bool is_slice() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so,...
int slice_stride() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so,...
int slice_begin() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so,...
The difference of two expressions.
static const IRNodeType _node_type
static Expr make(Expr a, Expr b)
Unsigned integer constants.
static const UIntImm * make(Type t, uint64_t value)
Horizontally reduce a vector to a scalar or narrower vector using the given commutative and associati...
static const IRNodeType _node_type
static Expr make(Operator op, Expr vec, int lanes)
Types in the halide type system.
Type widen() const
Return Type with the same type code and number of lanes, but with at least twice as many bits.
HALIDE_ALWAYS_INLINE bool is_int() const
Is this type a signed integer type?
HALIDE_ALWAYS_INLINE int lanes() const
Return the number of vector elements in this type.
HALIDE_ALWAYS_INLINE bool is_uint() const
Is this type an unsigned integer type?
HALIDE_ALWAYS_INLINE int bits() const
Return the bit size of a single element of this type.
HALIDE_ALWAYS_INLINE bool is_scalar() const
Is this type a scalar type? (lanes() == 1).
HALIDE_ALWAYS_INLINE bool is_float() const
Is this type a floating point type (float or double).
halide_scalar_value_t is a simple union able to represent all the well-known scalar values in a filte...
union halide_scalar_value_t::@2 u
A runtime tag for a type in the halide type system.
uint8_t bits
The number of bits of precision of a single scalar value of this type.
uint16_t lanes
How many elements in a vector.
uint8_t code
The basic type code: signed integer, unsigned integer, or floating point.