Halide 22.0.0
Halide compiler and libraries
IRMutator.h
Go to the documentation of this file.
1#ifndef HALIDE_IR_MUTATOR_H
2#define HALIDE_IR_MUTATOR_H
3
4/** \file
5 * Defines a base class for passes over the IR that modify it
6 */
7
8#include <map>
9#include <type_traits>
10#include <utility>
11
12#include "IR.h"
13
14namespace Halide {
15namespace Internal {
16
17/** A base class for passes over the IR which modify it
18 * (e.g. replacing a variable with a value (Substitute.h), or
19 * constant-folding).
20 *
21 * Your mutator should override the visit() methods you care about and return
22 * the new expression or stmt. The default implementations recursively
23 * mutate their children. To mutate sub-expressions and sub-statements you
24 * should override the mutate() method, which will dispatch to
25 * the appropriate visit() method and then return the value of expr or
26 * stmt after the call to visit.
27 */
28class IRMutator {
29public:
30 IRMutator() = default;
31 virtual ~IRMutator() = default;
32
33 /** This is the main interface for using a mutator. Also call
34 * these in your subclass to mutate sub-expressions and
35 * sub-statements.
36 */
37 virtual Expr mutate(const Expr &expr);
38 virtual Stmt mutate(const Stmt &stmt);
39
40 // Mutate all the Exprs and return the new list in ret, along with
41 // a flag that is true iff at least one item in the list changed.
42 std::pair<std::vector<Expr>, bool> mutate_with_changes(const std::vector<Expr> &);
43
44 // Like mutate_with_changes, but discard the changes flag.
45 std::vector<Expr> mutate(const std::vector<Expr> &exprs) {
46 return mutate_with_changes(exprs).first;
47 }
48
49protected:
50 // ExprNode<> and StmtNode<> are allowed to call visit (to implement mutate_expr/mutate_stmt())
51 template<typename T>
52 friend struct ExprNode;
53 template<typename T>
54 friend struct StmtNode;
55
56 virtual Expr visit(const IntImm *);
57 virtual Expr visit(const UIntImm *);
58 virtual Expr visit(const FloatImm *);
59 virtual Expr visit(const StringImm *);
60 virtual Expr visit(const Cast *);
61 virtual Expr visit(const Reinterpret *);
62 virtual Expr visit(const Add *);
63 virtual Expr visit(const Sub *);
64 virtual Expr visit(const Mul *);
65 virtual Expr visit(const Div *);
66 virtual Expr visit(const Mod *);
67 virtual Expr visit(const Min *);
68 virtual Expr visit(const Max *);
69 virtual Expr visit(const EQ *);
70 virtual Expr visit(const NE *);
71 virtual Expr visit(const LT *);
72 virtual Expr visit(const LE *);
73 virtual Expr visit(const GT *);
74 virtual Expr visit(const GE *);
75 virtual Expr visit(const And *);
76 virtual Expr visit(const Or *);
77 virtual Expr visit(const Not *);
78 virtual Expr visit(const Select *);
79 virtual Expr visit(const Load *);
80 virtual Expr visit(const Ramp *);
81 virtual Expr visit(const Broadcast *);
82 virtual Expr visit(const Let *);
83 virtual Stmt visit(const LetStmt *);
84 virtual Stmt visit(const AssertStmt *);
85 virtual Stmt visit(const ProducerConsumer *);
86 virtual Stmt visit(const Store *);
87 virtual Stmt visit(const Provide *);
88 virtual Stmt visit(const Allocate *);
89 virtual Stmt visit(const Free *);
90 virtual Stmt visit(const Realize *);
91 virtual Stmt visit(const Block *);
92 virtual Stmt visit(const Fork *);
93 virtual Stmt visit(const IfThenElse *);
94 virtual Stmt visit(const Evaluate *);
95 virtual Expr visit(const Call *);
96 virtual Expr visit(const Variable *);
97 virtual Stmt visit(const For *);
98 virtual Stmt visit(const Acquire *);
99 virtual Expr visit(const Shuffle *);
100 virtual Stmt visit(const Prefetch *);
101 virtual Stmt visit(const HoistedStorage *);
102 virtual Stmt visit(const Atomic *);
103 virtual Expr visit(const VectorReduce *);
104};
105
106/** A mutator that caches and reapplies previously done mutations so
107 * that it can handle graphs of IR that have not had CSE done to
108 * them. */
109class IRGraphMutator : public IRMutator {
110protected:
111 std::map<Expr, Expr, ExprCompare> expr_replacements;
112 std::map<Stmt, Stmt, Stmt::Compare> stmt_replacements;
113
114public:
115 using IRMutator::mutate;
116 Stmt mutate(const Stmt &s) override;
117 Expr mutate(const Expr &e) override;
118};
119
120/** A lambda-based IR mutator that accepts multiple lambdas for different
121 * node types. */
122template<typename... Lambdas>
123struct LambdaMutator final : IRMutator {
124 explicit LambdaMutator(Lambdas... lambdas)
125 : handlers(std::move(lambdas)...) {
126 }
127
128 /** Public helper to call the base visitor from lambdas. */
129 template<typename T>
130 auto visit_base(const T *op) {
131 return IRMutator::visit(op);
132 }
133
134private:
135 LambdaOverloads<Lambdas...> handlers;
136
137 template<typename T>
138 auto visit_impl(const T *op) {
139 if constexpr (std::is_invocable_v<decltype(handlers), LambdaMutator *, const T *>) {
140 return handlers(this, op);
141 } else {
142 return this->visit_base(op);
143 }
144 }
145
146protected:
147 Expr visit(const IntImm *op) override {
148 return this->visit_impl(op);
149 }
150 Expr visit(const UIntImm *op) override {
151 return this->visit_impl(op);
152 }
153 Expr visit(const FloatImm *op) override {
154 return this->visit_impl(op);
155 }
156 Expr visit(const StringImm *op) override {
157 return this->visit_impl(op);
158 }
159 Expr visit(const Cast *op) override {
160 return this->visit_impl(op);
161 }
162 Expr visit(const Reinterpret *op) override {
163 return this->visit_impl(op);
164 }
165 Expr visit(const Add *op) override {
166 return this->visit_impl(op);
167 }
168 Expr visit(const Sub *op) override {
169 return this->visit_impl(op);
170 }
171 Expr visit(const Mul *op) override {
172 return this->visit_impl(op);
173 }
174 Expr visit(const Div *op) override {
175 return this->visit_impl(op);
176 }
177 Expr visit(const Mod *op) override {
178 return this->visit_impl(op);
179 }
180 Expr visit(const Min *op) override {
181 return this->visit_impl(op);
182 }
183 Expr visit(const Max *op) override {
184 return this->visit_impl(op);
185 }
186 Expr visit(const EQ *op) override {
187 return this->visit_impl(op);
188 }
189 Expr visit(const NE *op) override {
190 return this->visit_impl(op);
191 }
192 Expr visit(const LT *op) override {
193 return this->visit_impl(op);
194 }
195 Expr visit(const LE *op) override {
196 return this->visit_impl(op);
197 }
198 Expr visit(const GT *op) override {
199 return this->visit_impl(op);
200 }
201 Expr visit(const GE *op) override {
202 return this->visit_impl(op);
203 }
204 Expr visit(const And *op) override {
205 return this->visit_impl(op);
206 }
207 Expr visit(const Or *op) override {
208 return this->visit_impl(op);
209 }
210 Expr visit(const Not *op) override {
211 return this->visit_impl(op);
212 }
213 Expr visit(const Select *op) override {
214 return this->visit_impl(op);
215 }
216 Expr visit(const Load *op) override {
217 return this->visit_impl(op);
218 }
219 Expr visit(const Ramp *op) override {
220 return this->visit_impl(op);
221 }
222 Expr visit(const Broadcast *op) override {
223 return this->visit_impl(op);
224 }
225 Expr visit(const Let *op) override {
226 return this->visit_impl(op);
227 }
228 Stmt visit(const LetStmt *op) override {
229 return this->visit_impl(op);
230 }
231 Stmt visit(const AssertStmt *op) override {
232 return this->visit_impl(op);
233 }
234 Stmt visit(const ProducerConsumer *op) override {
235 return this->visit_impl(op);
236 }
237 Stmt visit(const Store *op) override {
238 return this->visit_impl(op);
239 }
240 Stmt visit(const Provide *op) override {
241 return this->visit_impl(op);
242 }
243 Stmt visit(const Allocate *op) override {
244 return this->visit_impl(op);
245 }
246 Stmt visit(const Free *op) override {
247 return this->visit_impl(op);
248 }
249 Stmt visit(const Realize *op) override {
250 return this->visit_impl(op);
251 }
252 Stmt visit(const Block *op) override {
253 return this->visit_impl(op);
254 }
255 Stmt visit(const Fork *op) override {
256 return this->visit_impl(op);
257 }
258 Stmt visit(const IfThenElse *op) override {
259 return this->visit_impl(op);
260 }
261 Stmt visit(const Evaluate *op) override {
262 return this->visit_impl(op);
263 }
264 Expr visit(const Call *op) override {
265 return this->visit_impl(op);
266 }
267 Expr visit(const Variable *op) override {
268 return this->visit_impl(op);
269 }
270 Stmt visit(const For *op) override {
271 return this->visit_impl(op);
272 }
273 Stmt visit(const Acquire *op) override {
274 return this->visit_impl(op);
275 }
276 Expr visit(const Shuffle *op) override {
277 return this->visit_impl(op);
278 }
279 Stmt visit(const Prefetch *op) override {
280 return this->visit_impl(op);
281 }
282 Stmt visit(const HoistedStorage *op) override {
283 return this->visit_impl(op);
284 }
285 Stmt visit(const Atomic *op) override {
286 return this->visit_impl(op);
287 }
288 Expr visit(const VectorReduce *op) override {
289 return this->visit_impl(op);
290 }
291};
292
293/** A lambda-based IR mutator that accepts multiple lambdas for overloading
294 * the base mutate() method. */
295template<typename... Lambdas>
297 explicit LambdaMutatorGeneric(Lambdas... lambdas)
298 : handlers(std::move(lambdas)...) {
299 }
300
301 /** Public helper to call the base mutator from lambdas. */
302 // Note: C++26 introduces variadic friends: https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2024/p2893r3.html
303 // So the mutate_base API could be replaced with:
304 // friend Lambdas...;
305 template<typename T>
306 auto mutate_base(const T &op) {
307 return IRMutator::mutate(op);
308 }
309
310 Expr mutate(const Expr &e) override {
311 if constexpr (std::is_invocable_v<decltype(handlers), LambdaMutatorGeneric *, const Expr &>) {
312 return handlers(this, e);
313 } else {
314 return this->mutate_base(e);
315 }
316 }
317
318 Stmt mutate(const Stmt &e) override {
319 if constexpr (std::is_invocable_v<decltype(handlers), LambdaMutatorGeneric *, const Stmt &>) {
320 return handlers(this, e);
321 } else {
322 return this->mutate_base(e);
323 }
324 }
325
326private:
327 LambdaOverloads<Lambdas...> handlers;
328};
329
330template<typename T, typename... Lambdas>
331auto mutate_with(const T &ir, Lambdas &&...lambdas) {
332 using Overloads = LambdaOverloads<Lambdas...>;
333 using Generic = LambdaMutatorGeneric<Overloads>;
334 if constexpr (std::is_invocable_v<Overloads, Generic *, const Expr &> ||
335 std::is_invocable_v<Overloads, Generic *, const Stmt &>) {
336 return LambdaMutatorGeneric{std::forward<Lambdas>(lambdas)...}.mutate(ir);
337 } else {
338 LambdaMutator mutator{std::forward<Lambdas>(lambdas)...};
339 constexpr bool all_take_two_args =
340 (std::is_invocable_v<Lambdas, decltype(&mutator), decltype(nullptr)> && ...);
341 static_assert(all_take_two_args);
342 return mutator.mutate(ir);
343 }
344}
345
346/** A helper function for mutator-like things to mutate regions */
347template<typename Mutator, typename... Args>
348std::pair<Region, bool> mutate_region(Mutator *mutator, const Region &bounds, Args &&...args) {
349 Region new_bounds(bounds.size());
350 bool bounds_changed = false;
351
352 for (size_t i = 0; i < bounds.size(); i++) {
353 Expr old_min = bounds[i].min;
354 Expr old_extent = bounds[i].extent;
355 Expr new_min = mutator->mutate(old_min, args...);
356 Expr new_extent = mutator->mutate(old_extent, args...);
357 if (!new_min.same_as(old_min)) {
358 bounds_changed = true;
359 }
360 if (!new_extent.same_as(old_extent)) {
361 bounds_changed = true;
362 }
363 new_bounds[i] = Range(new_min, new_extent);
364 }
365 return {new_bounds, bounds_changed};
366}
367
368} // namespace Internal
369} // namespace Halide
370
371#endif
Subtypes for Halide expressions (Halide::Expr) and statements (Halide::Internal::Stmt)
A mutator that caches and reapplies previously done mutations so that it can handle graphs of IR that...
Definition: IRMutator.h:109
std::map< Expr, Expr, ExprCompare > expr_replacements
Definition: IRMutator.h:111
Expr mutate(const Expr &e) override
This is the main interface for using a mutator.
Stmt mutate(const Stmt &s) override
std::map< Stmt, Stmt, Stmt::Compare > stmt_replacements
Definition: IRMutator.h:112
A base class for passes over the IR which modify it (e.g.
Definition: IRMutator.h:28
std::vector< Expr > mutate(const std::vector< Expr > &exprs)
Definition: IRMutator.h:45
virtual Stmt visit(const Free *)
virtual Expr visit(const Broadcast *)
virtual Expr visit(const VectorReduce *)
virtual Expr visit(const Call *)
virtual Stmt visit(const Atomic *)
virtual Expr visit(const Let *)
virtual Stmt visit(const Provide *)
virtual Expr visit(const Ramp *)
virtual Expr visit(const Mod *)
virtual Expr visit(const LE *)
virtual Expr visit(const Reinterpret *)
virtual Stmt visit(const Store *)
virtual Expr visit(const Max *)
virtual Stmt visit(const HoistedStorage *)
virtual Expr visit(const Cast *)
virtual Stmt mutate(const Stmt &stmt)
virtual Stmt visit(const Block *)
virtual Expr visit(const Load *)
virtual Stmt visit(const Fork *)
virtual Stmt visit(const Allocate *)
virtual Stmt visit(const LetStmt *)
virtual Expr visit(const Or *)
virtual Expr visit(const And *)
virtual Stmt visit(const Acquire *)
std::pair< std::vector< Expr >, bool > mutate_with_changes(const std::vector< Expr > &)
virtual Stmt visit(const ProducerConsumer *)
virtual Expr visit(const Variable *)
virtual Expr visit(const Shuffle *)
virtual Expr visit(const LT *)
virtual Stmt visit(const For *)
virtual ~IRMutator()=default
virtual Expr visit(const Min *)
virtual Stmt visit(const AssertStmt *)
virtual Expr visit(const EQ *)
virtual Expr visit(const GT *)
virtual Expr visit(const Not *)
virtual Expr visit(const Add *)
virtual Expr visit(const NE *)
virtual Expr visit(const Div *)
virtual Expr visit(const Sub *)
virtual Stmt visit(const Evaluate *)
virtual Stmt visit(const Prefetch *)
virtual Stmt visit(const Realize *)
virtual Expr visit(const IntImm *)
virtual Expr mutate(const Expr &expr)
This is the main interface for using a mutator.
virtual Expr visit(const GE *)
virtual Expr visit(const UIntImm *)
virtual Stmt visit(const IfThenElse *)
virtual Expr visit(const FloatImm *)
virtual Expr visit(const Select *)
virtual Expr visit(const Mul *)
virtual Expr visit(const StringImm *)
auto mutate_with(const T &ir, Lambdas &&...lambdas)
Definition: IRMutator.h:331
std::pair< Region, bool > mutate_region(Mutator *mutator, const Region &bounds, Args &&...args)
A helper function for mutator-like things to mutate regions.
Definition: IRMutator.h:348
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
@ Internal
Not visible externally, similar to 'static' linkage in C.
std::vector< Range > Region
A multi-dimensional box.
Definition: Expr.h:350
A fragment of Halide syntax.
Definition: Expr.h:258
The sum of two expressions.
Definition: IR.h:66
Allocate a scratch area called with the given name, type, and size.
Definition: IR.h:381
Logical and - are both expressions true.
Definition: IR.h:185
If the 'condition' is false, then evaluate and return the message, which should be a call to an error...
Definition: IR.h:304
Lock all the Store nodes in the body statement.
Definition: IR.h:1008
A sequence of statements to be executed in-order.
Definition: IR.h:452
A vector with 'lanes' elements, in which every element is 'value'.
Definition: IR.h:269
A function call.
Definition: IR.h:500
The actual IR nodes begin here.
Definition: IR.h:40
The ratio of two expressions.
Definition: IR.h:93
Is the first expression equal to the second.
Definition: IR.h:131
Evaluate and discard an expression, presumably because it has some side-effect.
Definition: IR.h:486
We use the "curiously recurring template pattern" to avoid duplicated code in the IR Nodes.
Definition: Expr.h:158
Floating point constants.
Definition: Expr.h:236
A for loop.
Definition: IR.h:858
A pair of statements executed concurrently.
Definition: IR.h:467
Free the resources associated with the given buffer.
Definition: IR.h:423
Is the first expression greater than or equal to the second.
Definition: IR.h:176
Is the first expression greater than the second.
Definition: IR.h:167
Represents a location where storage will be hoisted to for a Func / Realize node with a given name.
Definition: IR.h:992
An if-then-else block.
Definition: IR.h:476
Integer constants.
Definition: Expr.h:218
HALIDE_ALWAYS_INLINE bool same_as(const IntrusivePtr &other) const
Definition: IntrusivePtr.h:171
Is the first expression less than or equal to the second.
Definition: IR.h:158
Is the first expression less than the second.
Definition: IR.h:149
A lambda-based IR mutator that accepts multiple lambdas for overloading the base mutate() method.
Definition: IRMutator.h:296
LambdaMutatorGeneric(Lambdas... lambdas)
Definition: IRMutator.h:297
auto mutate_base(const T &op)
Public helper to call the base mutator from lambdas.
Definition: IRMutator.h:306
Stmt mutate(const Stmt &e) override
Definition: IRMutator.h:318
Expr mutate(const Expr &e) override
This is the main interface for using a mutator.
Definition: IRMutator.h:310
A lambda-based IR mutator that accepts multiple lambdas for different node types.
Definition: IRMutator.h:123
Expr visit(const UIntImm *op) override
Definition: IRMutator.h:150
Stmt visit(const Block *op) override
Definition: IRMutator.h:252
Expr visit(const Shuffle *op) override
Definition: IRMutator.h:276
Expr visit(const Call *op) override
Definition: IRMutator.h:264
Expr visit(const Cast *op) override
Definition: IRMutator.h:159
Expr visit(const And *op) override
Definition: IRMutator.h:204
Stmt visit(const Acquire *op) override
Definition: IRMutator.h:273
LambdaMutator(Lambdas... lambdas)
Definition: IRMutator.h:124
Expr visit(const Max *op) override
Definition: IRMutator.h:183
Stmt visit(const Atomic *op) override
Definition: IRMutator.h:285
Stmt visit(const HoistedStorage *op) override
Definition: IRMutator.h:282
Stmt visit(const Evaluate *op) override
Definition: IRMutator.h:261
Stmt visit(const Allocate *op) override
Definition: IRMutator.h:243
Expr visit(const Load *op) override
Definition: IRMutator.h:216
Stmt visit(const Provide *op) override
Definition: IRMutator.h:240
Stmt visit(const Realize *op) override
Definition: IRMutator.h:249
Expr visit(const Reinterpret *op) override
Definition: IRMutator.h:162
Expr visit(const NE *op) override
Definition: IRMutator.h:189
Expr visit(const GE *op) override
Definition: IRMutator.h:201
Expr visit(const GT *op) override
Definition: IRMutator.h:198
Expr visit(const Not *op) override
Definition: IRMutator.h:210
Expr visit(const Variable *op) override
Definition: IRMutator.h:267
Expr visit(const VectorReduce *op) override
Definition: IRMutator.h:288
Stmt visit(const Prefetch *op) override
Definition: IRMutator.h:279
Expr visit(const Sub *op) override
Definition: IRMutator.h:168
Stmt visit(const Fork *op) override
Definition: IRMutator.h:255
Stmt visit(const Store *op) override
Definition: IRMutator.h:237
Expr visit(const LE *op) override
Definition: IRMutator.h:195
Stmt visit(const LetStmt *op) override
Definition: IRMutator.h:228
Expr visit(const Ramp *op) override
Definition: IRMutator.h:219
Expr visit(const FloatImm *op) override
Definition: IRMutator.h:153
Stmt visit(const Free *op) override
Definition: IRMutator.h:246
Expr visit(const Min *op) override
Definition: IRMutator.h:180
Expr visit(const IntImm *op) override
Definition: IRMutator.h:147
Stmt visit(const IfThenElse *op) override
Definition: IRMutator.h:258
Stmt visit(const For *op) override
Definition: IRMutator.h:270
Stmt visit(const ProducerConsumer *op) override
Definition: IRMutator.h:234
Stmt visit(const AssertStmt *op) override
Definition: IRMutator.h:231
Expr visit(const EQ *op) override
Definition: IRMutator.h:186
Expr visit(const Mul *op) override
Definition: IRMutator.h:171
Expr visit(const Or *op) override
Definition: IRMutator.h:207
Expr visit(const Select *op) override
Definition: IRMutator.h:213
auto visit_base(const T *op)
Public helper to call the base visitor from lambdas.
Definition: IRMutator.h:130
Expr visit(const Mod *op) override
Definition: IRMutator.h:177
Expr visit(const Div *op) override
Definition: IRMutator.h:174
Expr visit(const Add *op) override
Definition: IRMutator.h:165
Expr visit(const Broadcast *op) override
Definition: IRMutator.h:222
Expr visit(const LT *op) override
Definition: IRMutator.h:192
Expr visit(const Let *op) override
Definition: IRMutator.h:225
Expr visit(const StringImm *op) override
Definition: IRMutator.h:156
A let expression, like you might find in a functional language.
Definition: IR.h:281
The statement form of a let node.
Definition: IR.h:292
Load a value from a named symbol if predicate is true.
Definition: IR.h:227
The greater of two values.
Definition: IR.h:122
The lesser of two values.
Definition: IR.h:113
The remainder of a / b.
Definition: IR.h:104
The product of two expressions.
Definition: IR.h:84
Is the first expression not equal to the second.
Definition: IR.h:140
Logical not - true if the expression false.
Definition: IR.h:203
Logical or - is at least one of the expression true.
Definition: IR.h:194
Represent a multi-dimensional region of a Func or an ImageParam that needs to be prefetched.
Definition: IR.h:970
This node is a helpful annotation to do with permissions.
Definition: IR.h:325
This defines the value of a function at a multi-dimensional location.
Definition: IR.h:364
A linear ramp vector node.
Definition: IR.h:257
Allocate a multi-dimensional buffer of the given type and size.
Definition: IR.h:437
Reinterpret value as another type, without affecting any of the bits (on little-endian systems).
Definition: IR.h:57
A ternary operator.
Definition: IR.h:214
Construct a new vector by taking elements from another sequence of vectors.
Definition: IR.h:898
A reference-counted handle to a statement node.
Definition: Expr.h:427
Store a 'value' to the buffer called 'name' at a given 'index' if 'predicate' is true.
Definition: IR.h:343
String constants.
Definition: Expr.h:245
The difference of two expressions.
Definition: IR.h:75
Unsigned integer constants.
Definition: Expr.h:227
A named variable.
Definition: IR.h:813
Horizontally reduce a vector to a scalar or narrower vector using the given commutative and associati...
Definition: IR.h:1026
A single-dimensional span.
Definition: Expr.h:342