1#ifndef HALIDE_IR_VISITOR_H
2#define HALIDE_IR_VISITOR_H
84template<
typename... Lambdas>
87 : handlers(std::move(lambdas)...) {
100 auto visit_impl(
const T *op) {
101 if constexpr (std::is_invocable_v<
decltype(handlers),
LambdaVisitor *,
const T *>) {
102 return handlers(
this, op);
110 this->visit_impl(op);
113 this->visit_impl(op);
116 this->visit_impl(op);
119 this->visit_impl(op);
122 this->visit_impl(op);
125 this->visit_impl(op);
128 this->visit_impl(op);
131 this->visit_impl(op);
134 this->visit_impl(op);
137 this->visit_impl(op);
140 this->visit_impl(op);
143 this->visit_impl(op);
146 this->visit_impl(op);
149 this->visit_impl(op);
152 this->visit_impl(op);
155 this->visit_impl(op);
158 this->visit_impl(op);
161 this->visit_impl(op);
164 this->visit_impl(op);
167 this->visit_impl(op);
170 this->visit_impl(op);
173 this->visit_impl(op);
176 this->visit_impl(op);
179 this->visit_impl(op);
182 this->visit_impl(op);
185 this->visit_impl(op);
188 this->visit_impl(op);
191 this->visit_impl(op);
194 this->visit_impl(op);
197 this->visit_impl(op);
200 this->visit_impl(op);
203 this->visit_impl(op);
206 this->visit_impl(op);
209 this->visit_impl(op);
212 this->visit_impl(op);
215 this->visit_impl(op);
218 this->visit_impl(op);
221 this->visit_impl(op);
224 this->visit_impl(op);
227 this->visit_impl(op);
230 this->visit_impl(op);
233 this->visit_impl(op);
236 this->visit_impl(op);
239 this->visit_impl(op);
242 this->visit_impl(op);
245 this->visit_impl(op);
248 this->visit_impl(op);
251 this->visit_impl(op);
255template<
typename T,
typename... Lambdas>
279 std::set<const IRNode *> visited;
341template<
typename T,
typename ExprRet,
typename StmtRet>
344 template<
typename... Args>
345 ExprRet dispatch_expr(
const BaseExprNode *node, Args &&...args) {
346 if (node ==
nullptr) {
351 return ((T *)
this)->visit((
const IntImm *)node, std::forward<Args>(args)...);
353 return ((T *)
this)->visit((
const UIntImm *)node, std::forward<Args>(args)...);
355 return ((T *)
this)->visit((
const FloatImm *)node, std::forward<Args>(args)...);
357 return ((T *)
this)->visit((
const StringImm *)node, std::forward<Args>(args)...);
359 return ((T *)
this)->visit((
const Broadcast *)node, std::forward<Args>(args)...);
361 return ((T *)
this)->visit((
const Cast *)node, std::forward<Args>(args)...);
363 return ((T *)
this)->visit((
const Reinterpret *)node, std::forward<Args>(args)...);
365 return ((T *)
this)->visit((
const Variable *)node, std::forward<Args>(args)...);
367 return ((T *)
this)->visit((
const Add *)node, std::forward<Args>(args)...);
369 return ((T *)
this)->visit((
const Sub *)node, std::forward<Args>(args)...);
371 return ((T *)
this)->visit((
const Mod *)node, std::forward<Args>(args)...);
373 return ((T *)
this)->visit((
const Mul *)node, std::forward<Args>(args)...);
375 return ((T *)
this)->visit((
const Div *)node, std::forward<Args>(args)...);
377 return ((T *)
this)->visit((
const Min *)node, std::forward<Args>(args)...);
379 return ((T *)
this)->visit((
const Max *)node, std::forward<Args>(args)...);
381 return ((T *)
this)->visit((
const EQ *)node, std::forward<Args>(args)...);
383 return ((T *)
this)->visit((
const NE *)node, std::forward<Args>(args)...);
385 return ((T *)
this)->visit((
const LT *)node, std::forward<Args>(args)...);
387 return ((T *)
this)->visit((
const LE *)node, std::forward<Args>(args)...);
389 return ((T *)
this)->visit((
const GT *)node, std::forward<Args>(args)...);
391 return ((T *)
this)->visit((
const GE *)node, std::forward<Args>(args)...);
393 return ((T *)
this)->visit((
const And *)node, std::forward<Args>(args)...);
395 return ((T *)
this)->visit((
const Or *)node, std::forward<Args>(args)...);
397 return ((T *)
this)->visit((
const Not *)node, std::forward<Args>(args)...);
399 return ((T *)
this)->visit((
const Select *)node, std::forward<Args>(args)...);
401 return ((T *)
this)->visit((
const Load *)node, std::forward<Args>(args)...);
403 return ((T *)
this)->visit((
const Ramp *)node, std::forward<Args>(args)...);
405 return ((T *)
this)->visit((
const Call *)node, std::forward<Args>(args)...);
407 return ((T *)
this)->visit((
const Let *)node, std::forward<Args>(args)...);
409 return ((T *)
this)->visit((
const Shuffle *)node, std::forward<Args>(args)...);
411 return ((T *)
this)->visit((
const VectorReduce *)node, std::forward<Args>(args)...);
437 template<
typename... Args>
438 StmtRet dispatch_stmt(
const BaseStmtNode *node, Args &&...args) {
439 if (node ==
nullptr) {
442 switch (node->node_type) {
477 return ((T *)
this)->visit((
const LetStmt *)node, std::forward<Args>(args)...);
479 return ((T *)
this)->visit((
const AssertStmt *)node, std::forward<Args>(args)...);
481 return ((T *)
this)->visit((
const ProducerConsumer *)node, std::forward<Args>(args)...);
483 return ((T *)
this)->visit((
const For *)node, std::forward<Args>(args)...);
485 return ((T *)
this)->visit((
const Acquire *)node, std::forward<Args>(args)...);
487 return ((T *)
this)->visit((
const Store *)node, std::forward<Args>(args)...);
489 return ((T *)
this)->visit((
const Provide *)node, std::forward<Args>(args)...);
491 return ((T *)
this)->visit((
const Allocate *)node, std::forward<Args>(args)...);
493 return ((T *)
this)->visit((
const Free *)node, std::forward<Args>(args)...);
495 return ((T *)
this)->visit((
const Realize *)node, std::forward<Args>(args)...);
497 return ((T *)
this)->visit((
const Block *)node, std::forward<Args>(args)...);
499 return ((T *)
this)->visit((
const Fork *)node, std::forward<Args>(args)...);
501 return ((T *)
this)->visit((
const IfThenElse *)node, std::forward<Args>(args)...);
503 return ((T *)
this)->visit((
const Evaluate *)node, std::forward<Args>(args)...);
505 return ((T *)
this)->visit((
const Prefetch *)node, std::forward<Args>(args)...);
507 return ((T *)
this)->visit((
const Atomic *)node, std::forward<Args>(args)...);
509 return ((T *)
this)->visit((
const HoistedStorage *)node, std::forward<Args>(args)...);
515 template<
typename... Args>
517 return dispatch_stmt(s.
get(), std::forward<Args>(args)...);
520 template<
typename... Args>
522 return dispatch_stmt(s.get(), std::forward<Args>(args)...);
525 template<
typename... Args>
527 return dispatch_expr(e.
get(), std::forward<Args>(args)...);
530 template<
typename... Args>
532 return dispatch_expr(e.get(), std::forward<Args>(args)...);
#define HALIDE_ALWAYS_INLINE
Subtypes for Halide expressions (Halide::Expr) and statements (Halide::Internal::Stmt)
A base class for algorithms that walk recursively over the IR without visiting the same node twice.
void visit(const Div *) override
void visit(const Shuffle *) override
void visit(const NE *) override
void visit(const Block *) override
void visit(const EQ *) override
void visit(const Let *) override
void visit(const Provide *) override
void visit(const StringImm *) override
virtual void include(const Expr &)
By default these methods add the node to the visited set, and return whether or not it was already th...
void visit(const For *) override
void visit(const HoistedStorage *) override
void visit(const Ramp *) override
void visit(const Or *) override
void visit(const UIntImm *) override
void visit(const Mul *) override
void visit(const AssertStmt *) override
void visit(const GE *) override
void visit(const Min *) override
void visit(const Free *) override
void visit(const Add *) override
void visit(const Acquire *) override
void visit(const Store *) override
void visit(const Max *) override
void visit(const IntImm *) override
These methods should call 'include' on the children to only visit them if they haven't been visited a...
void visit(const IfThenElse *) override
void visit(const LT *) override
void visit(const VectorReduce *) override
void visit(const Atomic *) override
void visit(const Sub *) override
void visit(const Not *) override
void visit(const Mod *) override
void visit(const ProducerConsumer *) override
void visit(const LetStmt *) override
void visit(const LE *) override
void visit(const Allocate *) override
void visit(const Load *) override
virtual void include(const Stmt &)
void visit(const Realize *) override
void visit(const Prefetch *) override
void visit(const FloatImm *) override
void visit(const Fork *) override
void visit(const Call *) override
void visit(const Reinterpret *) override
void visit(const And *) override
void visit(const Variable *) override
void visit(const Evaluate *) override
void visit(const Broadcast *) override
void visit(const GT *) override
void visit(const Cast *) override
void visit(const Select *) override
A base class for algorithms that need to recursively walk over the IR.
virtual void visit(const NE *)
virtual void visit(const Mul *)
virtual void visit(const Max *)
virtual void visit(const Select *)
virtual void visit(const Load *)
virtual void visit(const Div *)
virtual void visit(const Fork *)
virtual void visit(const Sub *)
virtual void visit(const LE *)
virtual ~IRVisitor()=default
virtual void visit(const ProducerConsumer *)
virtual void visit(const VectorReduce *)
virtual void visit(const GE *)
virtual void visit(const StringImm *)
virtual void visit(const Allocate *)
virtual void visit(const IfThenElse *)
virtual void visit(const For *)
virtual void visit(const Prefetch *)
virtual void visit(const Block *)
virtual void visit(const UIntImm *)
virtual void visit(const HoistedStorage *)
virtual void visit(const FloatImm *)
virtual void visit(const GT *)
virtual void visit(const Mod *)
virtual void visit(const Acquire *)
virtual void visit(const Atomic *)
virtual void visit(const Ramp *)
virtual void visit(const Free *)
virtual void visit(const IntImm *)
virtual void visit(const Or *)
virtual void visit(const EQ *)
virtual void visit(const Broadcast *)
virtual void visit(const Call *)
virtual void visit(const Min *)
virtual void visit(const Variable *)
virtual void visit(const Realize *)
virtual void visit(const Add *)
virtual void visit(const Shuffle *)
virtual void visit(const Reinterpret *)
virtual void visit(const Evaluate *)
virtual void visit(const AssertStmt *)
virtual void visit(const And *)
virtual void visit(const LetStmt *)
virtual void visit(const Store *)
virtual void visit(const Provide *)
virtual void visit(const LT *)
virtual void visit(const Cast *)
virtual void visit(const Not *)
virtual void visit(const Let *)
A visitor/mutator capable of passing arbitrary arguments to the visit methods using CRTP and returnin...
HALIDE_ALWAYS_INLINE StmtRet dispatch(const Stmt &s, Args &&...args)
HALIDE_ALWAYS_INLINE ExprRet dispatch(Expr &&e, Args &&...args)
HALIDE_ALWAYS_INLINE StmtRet dispatch(Stmt &&s, Args &&...args)
HALIDE_ALWAYS_INLINE ExprRet dispatch(const Expr &e, Args &&...args)
void visit_with(const T &ir, Lambdas &&...lambdas)
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
@ Internal
Not visible externally, similar to 'static' linkage in C.
A fragment of Halide syntax.
HALIDE_ALWAYS_INLINE const Internal::BaseExprNode * get() const
Override get() to return a BaseExprNode * instead of an IRNode *.
The sum of two expressions.
Allocate a scratch area called with the given name, type, and size.
Logical and - are both expressions true.
If the 'condition' is false, then evaluate and return the message, which should be a call to an error...
Lock all the Store nodes in the body statement.
A base class for expression nodes.
A sequence of statements to be executed in-order.
A vector with 'lanes' elements, in which every element is 'value'.
The actual IR nodes begin here.
The ratio of two expressions.
Is the first expression equal to the second.
Evaluate and discard an expression, presumably because it has some side-effect.
We use the "curiously recurring template pattern" to avoid duplicated code in the IR Nodes.
Floating point constants.
A pair of statements executed concurrently.
Free the resources associated with the given buffer.
Is the first expression greater than or equal to the second.
Is the first expression greater than the second.
Represents a location where storage will be hoisted to for a Func / Realize node with a given name.
IRNodeType node_type
Each IR node subclass has a unique identifier.
Is the first expression less than or equal to the second.
Is the first expression less than the second.
A lambda-based IR visitor that accepts multiple lambdas for different node types.
void visit(const Atomic *op) override
void visit(const IntImm *op) override
void visit(const Mod *op) override
void visit(const HoistedStorage *op) override
void visit(const Min *op) override
void visit(const Max *op) override
void visit(const GE *op) override
void visit(const Variable *op) override
void visit(const Not *op) override
void visit(const Realize *op) override
void visit(const LT *op) override
void visit(const Reinterpret *op) override
void visit(const Prefetch *op) override
LambdaVisitor(Lambdas... lambdas)
void visit(const Fork *op) override
void visit(const Mul *op) override
void visit(const EQ *op) override
void visit(const Div *op) override
void visit(const Sub *op) override
void visit(const StringImm *op) override
void visit_base(const T *op)
Public helper to call the base visitor from lambdas.
void visit(const NE *op) override
void visit(const IfThenElse *op) override
void visit(const Provide *op) override
void visit(const Or *op) override
void visit(const LetStmt *op) override
void visit(const VectorReduce *op) override
void visit(const Free *op) override
void visit(const And *op) override
void visit(const Acquire *op) override
void visit(const Let *op) override
void visit(const For *op) override
void visit(const Allocate *op) override
void visit(const Shuffle *op) override
void visit(const ProducerConsumer *op) override
void visit(const LE *op) override
void visit(const Ramp *op) override
void visit(const Store *op) override
void visit(const Load *op) override
void visit(const AssertStmt *op) override
void visit(const GT *op) override
void visit(const FloatImm *op) override
void visit(const Evaluate *op) override
void visit(const Add *op) override
void visit(const Call *op) override
void visit(const Cast *op) override
void visit(const Select *op) override
void visit(const UIntImm *op) override
void visit(const Broadcast *op) override
void visit(const Block *op) override
A let expression, like you might find in a functional language.
The statement form of a let node.
Load a value from a named symbol if predicate is true.
The greater of two values.
The lesser of two values.
The product of two expressions.
Is the first expression not equal to the second.
Logical not - true if the expression false.
Logical or - is at least one of the expression true.
Represent a multi-dimensional region of a Func or an ImageParam that needs to be prefetched.
This node is a helpful annotation to do with permissions.
This defines the value of a function at a multi-dimensional location.
A linear ramp vector node.
Allocate a multi-dimensional buffer of the given type and size.
Reinterpret value as another type, without affecting any of the bits (on little-endian systems).
Construct a new vector by taking elements from another sequence of vectors.
A reference-counted handle to a statement node.
HALIDE_ALWAYS_INLINE const BaseStmtNode * get() const
Override get() to return a BaseStmtNode * instead of an IRNode *.
Store a 'value' to the buffer called 'name' at a given 'index' if 'predicate' is true.
The difference of two expressions.
Unsigned integer constants.
Horizontally reduce a vector to a scalar or narrower vector using the given commutative and associati...