6#include "halide_thread_pool.h"
19 return Internal::Call::make(t,
"input", {arg}, Internal::Call::Extern);
22 return input(
Float(16), arg);
25 return input(
BFloat(16), arg);
28 return input(
Float(32), arg);
31 return input(
Float(64), arg);
34 return input(
Int(8), arg);
37 return input(
Int(16), arg);
40 return input(
Int(32), arg);
43 return input(
Int(64), arg);
46 return input(
UInt(8), arg);
49 return input(
UInt(16), arg);
52 return input(
UInt(32), arg);
55 return input(
UInt(64), arg);
76 static constexpr int max_i32 = 0x7fffffff;
113 bool can_run_the_code =
152 can_run_the_code =
false;
155 return can_run_the_code;
159 const std::string &op,
160 const std::string &name,
162 const std::vector<Argument> &arg_types,
163 std::ostringstream &error_msg) {
164 std::string fn_name =
"test_" + name;
168 std::map<OutputFileType, std::string> outputs = {
175 std::ifstream asm_file;
176 asm_file.open(file_name +
".s");
178 bool found_it =
false;
180 std::ostringstream msg;
181 msg << op <<
" did not generate for target=" <<
get_run_target().
to_string() <<
" vector_width=" << vector_width <<
". Instead we got:\n";
184 while (getline(asm_file, line)) {
192 error_msg <<
"Failed: " << msg.str() <<
"\n";
201 while (*p && *str && *p == *str && *p !=
'*') {
208 }
else if (*p ==
'*') {
215 }
else if (*p ==
' ') {
220 }
else if (*str ==
' ') {
246 std::ostringstream error_msg;
250 std::vector<ImageParam> image_params{
264 for (
auto &p : image_params) {
266 p.set_host_alignment(alignment_bytes);
267 const int alignment = alignment_bytes / p.type().bytes();
268 p.dim(0).set_min((p.dim(0).min() / alignment) * alignment);
271 std::vector<Argument> arg_types(image_params.begin(), image_params.end());
277 if (op->
name ==
"input") {
278 for (
auto &p : image_params) {
279 if (p.type() == op->
type) {
280 return p(mutate(op->
args[0]));
289 const std::vector<ImageParam> &image_params;
292 HookUpImageParams(
const std::vector<ImageParam> &image_params)
293 : image_params(image_params) {
295 } hook_up_image_params(image_params);
296 e = hook_up_image_params.mutate(e);
305 inline_reduction = f;
309 IRVisitor::visit(op);
315 } has_inline_reduction;
316 e.
accept(&has_inline_reduction);
328 if (has_inline_reduction.result) {
333 Func g{has_inline_reduction.inline_reduction};
340 .split(x, xo, xi, vector_width)
342 .vectorize(g.rvars()[0])
351 arg_types.push_back(rows);
354 RDom r_check(0,
W, 0, rows);
356 error() = Halide::cast<double>(
maximum(
absd(f(r_check.
x, r_check.
y), f_scalar(r_check.
x, r_check.
y))));
361 if (can_run_the_code) {
365 std::vector<Runtime::Buffer<>> inputs(image_params.size());
367 std::vector<Argument> args(image_params.size() + 1);
368 for (
size_t i = 0; i < image_params.size(); i++) {
369 args[i] = image_params[i];
380 assert(inputs.size() == 12);
381 (void)callable(inputs[0], inputs[1], inputs[2], inputs[3],
382 inputs[4], inputs[5], inputs[6], inputs[7],
383 inputs[8], inputs[9], inputs[10], inputs[11],
390 for (
size_t i = 0; i < inputs.size(); i++) {
391 if (inputs[i].size_in_bytes()) {
392 inputs[i].allocate();
394 Type t = inputs[i].type();
398 if (t ==
Float(32)) {
399 inputs[i].as<
float>().for_each_value([&](
float &f) { f = (rng() & 0xfff) / 8.0f - 0xff; });
400 }
else if (t ==
Float(64)) {
401 inputs[i].as<
double>().for_each_value([&](
double &f) { f = (rng() & 0xfff) / 8.0 - 0xff; });
402 }
else if (t ==
Float(16)) {
407 ptr != (
uint32_t *)inputs[i].data() + inputs[i].size_in_bytes() / 4;
411 *ptr = ((
uint32_t)rng()) & 0x0fffffff;
418 (void)callable(inputs[0], inputs[1], inputs[2], inputs[3],
419 inputs[4], inputs[5], inputs[6], inputs[7],
420 inputs[8], inputs[9], inputs[10], inputs[11],
423 double e = output(0);
429 error_msg <<
"The vector and scalar versions of " << name <<
" disagree. Maximum error: " << e <<
"\n";
434 std::ifstream error_file;
435 error_file.open(error_filename);
437 error_msg <<
"Error assembly: \n";
439 while (getline(error_file, line)) {
440 error_msg << line <<
"\n";
447 return {op, error_msg.str()};
452 std::string name =
"op_" + op;
453 for (
size_t i = 0; i < name.size(); i++) {
454 if (!isalnum(name[i])) name[i] =
'_';
457 name +=
"_" + std::to_string(
tasks.size());
464 tasks.emplace_back(
Task{op, name, vector_width, e});
481 const std::string run_target_str = run_target.
to_string();
485 Halide::Tools::ThreadPool<TestResult> pool(
487 Halide::Tools::ThreadPool<TestResult>::num_processors_online() :
489 std::vector<std::future<TestResult>> futures;
491 for (
size_t t = 0; t <
tasks.size(); t++) {
493 const auto &task =
tasks.at(t);
494 futures.push_back(pool.async([&]() {
495 return check_one(task.op, task.name, task.vector_width, task.expr);
499 for (
auto &f : futures) {
500 auto result = f.get();
501 constexpr int tabstop = 32;
502 const int spaces =
std::max(1, tabstop - (
int)result.op.size());
503 std::cout << result.op << std::string(spaces,
' ') <<
"(" << run_target_str <<
")\n";
504 if (!result.error_msg.empty()) {
505 std::cerr << result.error_msg;
516 template<
typename SIMDOpCheckT>
517 static int main(
int argc,
char **argv,
const std::vector<Target> &targets_to_test) {
519 std::cout <<
"host is: " << host <<
"\n";
521 const int seed = argc > 2 ?
atoi(argv[2]) : time(
nullptr);
522 std::cout <<
"simd_op_check test seed: " << seed <<
"\n";
524 for (
const auto &t : targets_to_test) {
525 if (!t.supported()) {
526 std::cout <<
"[SKIP] Unsupported target: " << t <<
"\n";
529 SIMDOpCheckT test(t);
531 if (!t.supported()) {
532 std::cout <<
"Halide was compiled without support for " << t.to_string() <<
". Skipping.\n";
537 test.filter = argv[1];
540 if (
getenv(
"HL_SIMD_OP_CHECK_FILTER")) {
541 test.filter =
getenv(
"HL_SIMD_OP_CHECK_FILTER");
553 test.output_directory = argv[2];
556 bool success = test.test_all();
566 std::cout <<
"Success!\n";
void compile_to_assembly(const std::string &filename, const std::vector< Argument > &, const std::string &fn_name, const Target &target=get_target_from_environment())
Statically compile this function to text assembly equivalent to the object file generated by compile_...
Stage update(int idx=0)
Get a handle on an update step for the purposes of scheduling it.
Func & compute_root()
Compute all of this function once ahead of time.
Callable compile_to_callable(const std::vector< Argument > &args, const Target &target=get_jit_target_from_environment())
Eagerly jit compile the function to machine code and return a callable struct that behaves like a fun...
void compile_to(const std::map< OutputFileType, std::string > &output_files, const std::vector< Argument > &args, const std::string &fn_name, const Target &target=get_target_from_environment())
Compile and generate multiple target files with single call.
Func clone_in(const Func &f)
Similar to Func::in; however, instead of replacing the call to this Func with an identity Func that r...
Func & vectorize(const VarOrRVar &var)
Mark a dimension to be computed all-at-once as a single vector.
Func & bound(const Var &var, Expr min, Expr extent)
Statically declare that the range over which a function should be evaluated is given by the second an...
Func & compute_at(const Func &f, const Var &var)
Compute this function as needed for each unique value of the given var for the given calling function...
An Image parameter to a halide pipeline.
const StageSchedule & schedule() const
Get the default (no-specialization) stage-specific schedule associated with this definition.
A reference-counted handle to Halide's internal representation of a function.
bool has_update_definition() const
Does this function have an update definition?
void mutate(IRMutator *mutator)
Accept a mutator to mutator all of the definitions and arguments of this function.
Definition & update(int idx=0)
Get a mutable handle to this function's update definition at index 'idx'.
A base class for passes over the IR which modify it (e.g.
virtual Expr visit(const IntImm *)
A base class for algorithms that need to recursively walk over the IR.
virtual void visit(const IntImm *)
const std::vector< ReductionVariable > & rvars() const
RVars of reduction domain associated with this schedule if there is any.
bool should_run(size_t task_index) const
A scalar parameter to a halide pipeline.
HALIDE_NO_USER_CODE_INLINE void set(const SOME_TYPE &val)
Set the current value of this parameter.
A multi-dimensional domain over which to iterate.
RVar x
Direct access to the first four dimensions of the reduction domain.
A reduction variable represents a single dimension of a reduction domain (RDom).
A templated Buffer class that wraps halide_buffer_t and adds functionality.
static Buffer< T, Dims, InClassDimStorage > make_scalar()
Make a zero-dimensional Buffer.
static constexpr int max_u8
std::string output_directory
static constexpr int max_i32
virtual bool use_multiple_threads() const
virtual void add_tests()=0
virtual int image_param_alignment()
bool wildcard_match(const std::string &p, const std::string &str) const
static constexpr int max_i8
static constexpr int max_u16
bool wildcard_search(const std::string &p, const std::string &str) const
bool wildcard_match(const char *p, const char *str) const
virtual ~SimdOpCheckTest()=default
SimdOpCheckTest(const Target t, int w, int h)
void check(std::string op, int vector_width, Expr e)
Target get_run_target() const
static int main(int argc, char **argv, const std::vector< Target > &targets_to_test)
static constexpr int max_i16
TestResult check_one(const std::string &op, const std::string &name, int vector_width, Expr e)
virtual bool can_run_code() const
std::vector< Task > tasks
virtual void compile_and_check(Func error, const std::string &op, const std::string &name, int vector_width, const std::vector< Argument > &arg_types, std::ostringstream &error_msg)
A Halide variable, to be used when defining functions.
ConstantInterval max(const ConstantInterval &a, const ConstantInterval &b)
std::map< OutputFileType, const OutputInfo > get_output_info(const Target &target)
std::string get_test_tmp_dir()
Return the path to a directory that can be safely written to when running tests; the contents directo...
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
Target get_host_target()
Return the target corresponding to the host machine.
Type BFloat(int bits, int lanes=1)
Construct a floating-point type in the bfloat format.
Type UInt(int bits, int lanes=1)
Constructing an unsigned integer type.
Type Float(int bits, int lanes=1)
Construct a floating-point type.
Expr maximum(Expr, const std::string &s="maximum")
Type Int(int bits, int lanes=1)
Constructing a signed integer type.
Expr absd(Expr a, Expr b)
Return the absolute difference between two values.
void compile_standalone_runtime(const std::string &object_filename, const Target &t)
Create an object file containing the Halide runtime for a given target.
unsigned __INT32_TYPE__ uint32_t
char * getenv(const char *)
A fragment of Halide syntax.
@ Halide
A call to a Func.
void accept(IRVisitor *v) const
Dispatch to the correct visitor method for this node.
static bool can_jit_target(const Target &target)
If the given target can be executed via the wasm executor, return true.
A struct representing a target machine and os to generate code for.
enum Halide::Target::Arch arch
bool has_feature(Feature f) const
int bits
The bit-width of the target machine.
enum Halide::Target::OS os
std::string to_string() const
Convert the Target into a string form that can be reconstituted by merge_string(),...
Target without_feature(Feature f) const
Return a copy of the target with the given feature cleared.
Feature
Optional features a target can have.
Target with_feature(Feature f) const
Return a copy of the target with the given feature set.
Types in the halide type system.
Expr max() const
Return an expression which is the maximum value of this type.
Class that provides a type that implements half precision floating point (IEEE754 2008 binary16) in s...