1#ifndef HALIDE_IR_VISITOR_H
2#define HALIDE_IR_VISITOR_H
84template<
typename... Lambdas>
87 : handlers(std::move(lambdas)...) {
100 auto visit_impl(
const T *op) {
101 if constexpr (std::is_invocable_v<
decltype(handlers),
LambdaVisitor *,
const T *>) {
102 return handlers(
this, op);
110 this->visit_impl(op);
113 this->visit_impl(op);
116 this->visit_impl(op);
119 this->visit_impl(op);
122 this->visit_impl(op);
125 this->visit_impl(op);
128 this->visit_impl(op);
131 this->visit_impl(op);
134 this->visit_impl(op);
137 this->visit_impl(op);
140 this->visit_impl(op);
143 this->visit_impl(op);
146 this->visit_impl(op);
149 this->visit_impl(op);
152 this->visit_impl(op);
155 this->visit_impl(op);
158 this->visit_impl(op);
161 this->visit_impl(op);
164 this->visit_impl(op);
167 this->visit_impl(op);
170 this->visit_impl(op);
173 this->visit_impl(op);
176 this->visit_impl(op);
179 this->visit_impl(op);
182 this->visit_impl(op);
185 this->visit_impl(op);
188 this->visit_impl(op);
191 this->visit_impl(op);
194 this->visit_impl(op);
197 this->visit_impl(op);
200 this->visit_impl(op);
203 this->visit_impl(op);
206 this->visit_impl(op);
209 this->visit_impl(op);
212 this->visit_impl(op);
215 this->visit_impl(op);
218 this->visit_impl(op);
221 this->visit_impl(op);
224 this->visit_impl(op);
227 this->visit_impl(op);
230 this->visit_impl(op);
233 this->visit_impl(op);
236 this->visit_impl(op);
239 this->visit_impl(op);
242 this->visit_impl(op);
245 this->visit_impl(op);
248 this->visit_impl(op);
251 this->visit_impl(op);
255template<
typename... Lambdas>
258 constexpr bool all_take_two_args =
259 (std::is_invocable_v<Lambdas,
decltype(&visitor),
decltype(
nullptr)> && ...);
260 static_assert(all_take_two_args);
264template<
typename... Lambdas>
287 std::set<const IRNode *> visited;
349template<
typename T,
typename ExprRet,
typename StmtRet>
352 template<
typename... Args>
353 ExprRet dispatch_expr(
const BaseExprNode *node, Args &&...args) {
354 if (node ==
nullptr) {
359 return ((T *)
this)->visit((
const IntImm *)node, std::forward<Args>(args)...);
361 return ((T *)
this)->visit((
const UIntImm *)node, std::forward<Args>(args)...);
363 return ((T *)
this)->visit((
const FloatImm *)node, std::forward<Args>(args)...);
365 return ((T *)
this)->visit((
const StringImm *)node, std::forward<Args>(args)...);
367 return ((T *)
this)->visit((
const Broadcast *)node, std::forward<Args>(args)...);
369 return ((T *)
this)->visit((
const Cast *)node, std::forward<Args>(args)...);
371 return ((T *)
this)->visit((
const Reinterpret *)node, std::forward<Args>(args)...);
373 return ((T *)
this)->visit((
const Variable *)node, std::forward<Args>(args)...);
375 return ((T *)
this)->visit((
const Add *)node, std::forward<Args>(args)...);
377 return ((T *)
this)->visit((
const Sub *)node, std::forward<Args>(args)...);
379 return ((T *)
this)->visit((
const Mod *)node, std::forward<Args>(args)...);
381 return ((T *)
this)->visit((
const Mul *)node, std::forward<Args>(args)...);
383 return ((T *)
this)->visit((
const Div *)node, std::forward<Args>(args)...);
385 return ((T *)
this)->visit((
const Min *)node, std::forward<Args>(args)...);
387 return ((T *)
this)->visit((
const Max *)node, std::forward<Args>(args)...);
389 return ((T *)
this)->visit((
const EQ *)node, std::forward<Args>(args)...);
391 return ((T *)
this)->visit((
const NE *)node, std::forward<Args>(args)...);
393 return ((T *)
this)->visit((
const LT *)node, std::forward<Args>(args)...);
395 return ((T *)
this)->visit((
const LE *)node, std::forward<Args>(args)...);
397 return ((T *)
this)->visit((
const GT *)node, std::forward<Args>(args)...);
399 return ((T *)
this)->visit((
const GE *)node, std::forward<Args>(args)...);
401 return ((T *)
this)->visit((
const And *)node, std::forward<Args>(args)...);
403 return ((T *)
this)->visit((
const Or *)node, std::forward<Args>(args)...);
405 return ((T *)
this)->visit((
const Not *)node, std::forward<Args>(args)...);
407 return ((T *)
this)->visit((
const Select *)node, std::forward<Args>(args)...);
409 return ((T *)
this)->visit((
const Load *)node, std::forward<Args>(args)...);
411 return ((T *)
this)->visit((
const Ramp *)node, std::forward<Args>(args)...);
413 return ((T *)
this)->visit((
const Call *)node, std::forward<Args>(args)...);
415 return ((T *)
this)->visit((
const Let *)node, std::forward<Args>(args)...);
417 return ((T *)
this)->visit((
const Shuffle *)node, std::forward<Args>(args)...);
419 return ((T *)
this)->visit((
const VectorReduce *)node, std::forward<Args>(args)...);
445 template<
typename... Args>
446 StmtRet dispatch_stmt(
const BaseStmtNode *node, Args &&...args) {
447 if (node ==
nullptr) {
450 switch (node->node_type) {
485 return ((T *)
this)->visit((
const LetStmt *)node, std::forward<Args>(args)...);
487 return ((T *)
this)->visit((
const AssertStmt *)node, std::forward<Args>(args)...);
489 return ((T *)
this)->visit((
const ProducerConsumer *)node, std::forward<Args>(args)...);
491 return ((T *)
this)->visit((
const For *)node, std::forward<Args>(args)...);
493 return ((T *)
this)->visit((
const Acquire *)node, std::forward<Args>(args)...);
495 return ((T *)
this)->visit((
const Store *)node, std::forward<Args>(args)...);
497 return ((T *)
this)->visit((
const Provide *)node, std::forward<Args>(args)...);
499 return ((T *)
this)->visit((
const Allocate *)node, std::forward<Args>(args)...);
501 return ((T *)
this)->visit((
const Free *)node, std::forward<Args>(args)...);
503 return ((T *)
this)->visit((
const Realize *)node, std::forward<Args>(args)...);
505 return ((T *)
this)->visit((
const Block *)node, std::forward<Args>(args)...);
507 return ((T *)
this)->visit((
const Fork *)node, std::forward<Args>(args)...);
509 return ((T *)
this)->visit((
const IfThenElse *)node, std::forward<Args>(args)...);
511 return ((T *)
this)->visit((
const Evaluate *)node, std::forward<Args>(args)...);
513 return ((T *)
this)->visit((
const Prefetch *)node, std::forward<Args>(args)...);
515 return ((T *)
this)->visit((
const Atomic *)node, std::forward<Args>(args)...);
517 return ((T *)
this)->visit((
const HoistedStorage *)node, std::forward<Args>(args)...);
523 template<
typename... Args>
525 return dispatch_stmt(s.
get(), std::forward<Args>(args)...);
528 template<
typename... Args>
530 return dispatch_stmt(s.get(), std::forward<Args>(args)...);
533 template<
typename... Args>
535 return dispatch_expr(e.
get(), std::forward<Args>(args)...);
538 template<
typename... Args>
540 return dispatch_expr(e.get(), std::forward<Args>(args)...);
#define HALIDE_ALWAYS_INLINE
Subtypes for Halide expressions (Halide::Expr) and statements (Halide::Internal::Stmt)
A base class for algorithms that walk recursively over the IR without visiting the same node twice.
void visit(const Div *) override
void visit(const Shuffle *) override
void visit(const NE *) override
void visit(const Block *) override
void visit(const EQ *) override
void visit(const Let *) override
void visit(const Provide *) override
void visit(const StringImm *) override
virtual void include(const Expr &)
By default these methods add the node to the visited set, and return whether or not it was already th...
void visit(const For *) override
void visit(const HoistedStorage *) override
void visit(const Ramp *) override
void visit(const Or *) override
void visit(const UIntImm *) override
void visit(const Mul *) override
void visit(const AssertStmt *) override
void visit(const GE *) override
void visit(const Min *) override
void visit(const Free *) override
void visit(const Add *) override
void visit(const Acquire *) override
void visit(const Store *) override
void visit(const Max *) override
void visit(const IntImm *) override
These methods should call 'include' on the children to only visit them if they haven't been visited a...
void visit(const IfThenElse *) override
void visit(const LT *) override
void visit(const VectorReduce *) override
void visit(const Atomic *) override
void visit(const Sub *) override
void visit(const Not *) override
void visit(const Mod *) override
void visit(const ProducerConsumer *) override
void visit(const LetStmt *) override
void visit(const LE *) override
void visit(const Allocate *) override
void visit(const Load *) override
virtual void include(const Stmt &)
void visit(const Realize *) override
void visit(const Prefetch *) override
void visit(const FloatImm *) override
void visit(const Fork *) override
void visit(const Call *) override
void visit(const Reinterpret *) override
void visit(const And *) override
void visit(const Variable *) override
void visit(const Evaluate *) override
void visit(const Broadcast *) override
void visit(const GT *) override
void visit(const Cast *) override
void visit(const Select *) override
A base class for algorithms that need to recursively walk over the IR.
virtual void visit(const NE *)
virtual void visit(const Mul *)
virtual void visit(const Max *)
virtual void visit(const Select *)
virtual void visit(const Load *)
virtual void visit(const Div *)
virtual void visit(const Fork *)
virtual void visit(const Sub *)
virtual void visit(const LE *)
virtual ~IRVisitor()=default
virtual void visit(const ProducerConsumer *)
virtual void visit(const VectorReduce *)
virtual void visit(const GE *)
virtual void visit(const StringImm *)
virtual void visit(const Allocate *)
virtual void visit(const IfThenElse *)
virtual void visit(const For *)
virtual void visit(const Prefetch *)
virtual void visit(const Block *)
virtual void visit(const UIntImm *)
virtual void visit(const HoistedStorage *)
virtual void visit(const FloatImm *)
virtual void visit(const GT *)
virtual void visit(const Mod *)
virtual void visit(const Acquire *)
virtual void visit(const Atomic *)
virtual void visit(const Ramp *)
virtual void visit(const Free *)
virtual void visit(const IntImm *)
virtual void visit(const Or *)
virtual void visit(const EQ *)
virtual void visit(const Broadcast *)
virtual void visit(const Call *)
virtual void visit(const Min *)
virtual void visit(const Variable *)
virtual void visit(const Realize *)
virtual void visit(const Add *)
virtual void visit(const Shuffle *)
virtual void visit(const Reinterpret *)
virtual void visit(const Evaluate *)
virtual void visit(const AssertStmt *)
virtual void visit(const And *)
virtual void visit(const LetStmt *)
virtual void visit(const Store *)
virtual void visit(const Provide *)
virtual void visit(const LT *)
virtual void visit(const Cast *)
virtual void visit(const Not *)
virtual void visit(const Let *)
A visitor/mutator capable of passing arbitrary arguments to the visit methods using CRTP and returnin...
HALIDE_ALWAYS_INLINE StmtRet dispatch(const Stmt &s, Args &&...args)
HALIDE_ALWAYS_INLINE ExprRet dispatch(Expr &&e, Args &&...args)
HALIDE_ALWAYS_INLINE StmtRet dispatch(Stmt &&s, Args &&...args)
HALIDE_ALWAYS_INLINE ExprRet dispatch(const Expr &e, Args &&...args)
void visit_with(const IRNode *ir, Lambdas &&...lambdas)
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
@ Internal
Not visible externally, similar to 'static' linkage in C.
A fragment of Halide syntax.
HALIDE_ALWAYS_INLINE const Internal::BaseExprNode * get() const
Override get() to return a BaseExprNode * instead of an IRNode *.
The sum of two expressions.
Allocate a scratch area called with the given name, type, and size.
Logical and - are both expressions true.
If the 'condition' is false, then evaluate and return the message, which should be a call to an error...
Lock all the Store nodes in the body statement.
A base class for expression nodes.
A sequence of statements to be executed in-order.
A vector with 'lanes' elements, in which every element is 'value'.
The actual IR nodes begin here.
The ratio of two expressions.
Is the first expression equal to the second.
Evaluate and discard an expression, presumably because it has some side-effect.
We use the "curiously recurring template pattern" to avoid duplicated code in the IR Nodes.
Floating point constants.
A pair of statements executed concurrently.
Free the resources associated with the given buffer.
Is the first expression greater than or equal to the second.
Is the first expression greater than the second.
Represents a location where storage will be hoisted to for a Func / Realize node with a given name.
IR nodes are passed around opaque handles to them.
The abstract base classes for a node in the Halide IR.
virtual void accept(IRVisitor *v) const =0
We use the visitor pattern to traverse IR nodes throughout the compiler, so we have a virtual accept ...
IRNodeType node_type
Each IR node subclass has a unique identifier.
T * get() const
Access the raw pointer in a variety of ways.
Is the first expression less than or equal to the second.
Is the first expression less than the second.
A lambda-based IR visitor that accepts multiple lambdas for different node types.
void visit(const Atomic *op) override
void visit(const IntImm *op) override
void visit(const Mod *op) override
void visit(const HoistedStorage *op) override
void visit(const Min *op) override
void visit(const Max *op) override
void visit(const GE *op) override
void visit(const Variable *op) override
void visit(const Not *op) override
void visit(const Realize *op) override
void visit(const LT *op) override
void visit(const Reinterpret *op) override
void visit(const Prefetch *op) override
LambdaVisitor(Lambdas... lambdas)
void visit(const Fork *op) override
void visit(const Mul *op) override
void visit(const EQ *op) override
void visit(const Div *op) override
void visit(const Sub *op) override
void visit(const StringImm *op) override
void visit_base(const T *op)
Public helper to call the base visitor from lambdas.
void visit(const NE *op) override
void visit(const IfThenElse *op) override
void visit(const Provide *op) override
void visit(const Or *op) override
void visit(const LetStmt *op) override
void visit(const VectorReduce *op) override
void visit(const Free *op) override
void visit(const And *op) override
void visit(const Acquire *op) override
void visit(const Let *op) override
void visit(const For *op) override
void visit(const Allocate *op) override
void visit(const Shuffle *op) override
void visit(const ProducerConsumer *op) override
void visit(const LE *op) override
void visit(const Ramp *op) override
void visit(const Store *op) override
void visit(const Load *op) override
void visit(const AssertStmt *op) override
void visit(const GT *op) override
void visit(const FloatImm *op) override
void visit(const Evaluate *op) override
void visit(const Add *op) override
void visit(const Call *op) override
void visit(const Cast *op) override
void visit(const Select *op) override
void visit(const UIntImm *op) override
void visit(const Broadcast *op) override
void visit(const Block *op) override
A let expression, like you might find in a functional language.
The statement form of a let node.
Load a value from a named symbol if predicate is true.
The greater of two values.
The lesser of two values.
The product of two expressions.
Is the first expression not equal to the second.
Logical not - true if the expression false.
Logical or - is at least one of the expression true.
Represent a multi-dimensional region of a Func or an ImageParam that needs to be prefetched.
This node is a helpful annotation to do with permissions.
This defines the value of a function at a multi-dimensional location.
A linear ramp vector node.
Allocate a multi-dimensional buffer of the given type and size.
Reinterpret value as another type, without affecting any of the bits (on little-endian systems).
Construct a new vector by taking elements from another sequence of vectors.
A reference-counted handle to a statement node.
HALIDE_ALWAYS_INLINE const BaseStmtNode * get() const
Override get() to return a BaseStmtNode * instead of an IRNode *.
Store a 'value' to the buffer called 'name' at a given 'index' if 'predicate' is true.
The difference of two expressions.
Unsigned integer constants.
Horizontally reduce a vector to a scalar or narrower vector using the given commutative and associati...