6#include "halide_thread_pool.h"
19 return Internal::Call::make(t,
"input", {arg}, Internal::Call::Extern);
22 return input(
Float(16), arg);
25 return input(
BFloat(16), arg);
28 return input(
Float(32), arg);
31 return input(
Float(64), arg);
34 return input(
Int(8), arg);
37 return input(
Int(16), arg);
40 return input(
Int(32), arg);
43 return input(
Int(64), arg);
46 return input(
UInt(8), arg);
49 return input(
UInt(16), arg);
52 return input(
UInt(32), arg);
55 return input(
UInt(64), arg);
76 static constexpr int max_i32 = 0x7fffffff;
113 bool can_run_the_code =
155 can_run_the_code =
false;
158 return can_run_the_code;
162 const std::string &op,
163 const std::string &name,
165 const std::vector<Argument> &arg_types,
166 std::ostringstream &error_msg) {
167 std::string fn_name =
"test_" + name +
"_vecwidth" + std::to_string(vector_width);
171 std::map<OutputFileType, std::string> outputs = {
179 std::ifstream asm_file;
180 asm_file.open(file_name +
".s");
182 bool found_it =
false;
184 std::ostringstream msg;
185 msg << op <<
" did not generate for target=" <<
get_run_target().
to_string() <<
" vector_width=" << vector_width <<
". Instead we got:\n";
188 while (getline(asm_file, line)) {
196 error_msg <<
"Failed: " << msg.str() <<
"\n";
205 while (*p && *str && *p == *str && *p !=
'*') {
212 }
else if (*p ==
'*') {
219 }
else if (*p ==
' ') {
224 }
else if (*str ==
' ') {
250 std::ostringstream error_msg;
254 std::vector<ImageParam> image_params{
268 for (
auto &p : image_params) {
270 p.set_host_alignment(alignment_bytes);
271 const int alignment = alignment_bytes / p.type().bytes();
272 p.dim(0).set_min((p.dim(0).min() / alignment) * alignment);
275 std::vector<Argument> arg_types(image_params.begin(), image_params.end());
281 if (op->
name ==
"input") {
282 for (
auto &p : image_params) {
283 if (p.type() == op->
type) {
284 return p(mutate(op->
args[0]));
293 const std::vector<ImageParam> &image_params;
296 HookUpImageParams(
const std::vector<ImageParam> &image_params)
297 : image_params(image_params) {
299 } hook_up_image_params(image_params);
300 e = hook_up_image_params.mutate(e);
309 inline_reduction = f;
313 IRVisitor::visit(op);
319 } has_inline_reduction;
320 e.
accept(&has_inline_reduction);
323 Halide::Func f(name +
"_vecwidth" + std::to_string(vector_width));
335 if (has_inline_reduction.result) {
340 Func g{has_inline_reduction.inline_reduction};
347 .split(x, xo, xi, vector_width)
349 .vectorize(g.rvars()[0])
358 arg_types.push_back(rows);
361 RDom r_check(0,
W, 0, rows);
363 error() = Halide::cast<double>(
maximum(
absd(f(r_check.
x, r_check.
y), f_scalar(r_check.
x, r_check.
y))));
368 if (can_run_the_code) {
372 std::vector<Runtime::Buffer<>> inputs(image_params.size());
374 std::vector<Argument> args(image_params.size() + 1);
375 for (
size_t i = 0; i < image_params.size(); i++) {
376 args[i] = image_params[i];
387 assert(inputs.size() == 12);
388 (void)callable(inputs[0], inputs[1], inputs[2], inputs[3],
389 inputs[4], inputs[5], inputs[6], inputs[7],
390 inputs[8], inputs[9], inputs[10], inputs[11],
397 for (
size_t i = 0; i < inputs.size(); i++) {
398 if (inputs[i].size_in_bytes()) {
399 inputs[i].allocate();
401 Type t = inputs[i].type();
405 if (t ==
Float(32)) {
406 inputs[i].as<
float>().for_each_value([&](
float &f) { f = (rng() & 0xfff) / 8.0f - 0xff; });
407 }
else if (t ==
Float(64)) {
408 inputs[i].as<
double>().for_each_value([&](
double &f) { f = (rng() & 0xfff) / 8.0 - 0xff; });
409 }
else if (t ==
Float(16)) {
411 }
else if (t ==
BFloat(16)) {
418 const uint32_t mask = (t ==
Int(32)) ? 0x0fffffffU : 0xffffffffU;
420 ptr != (
uint32_t *)inputs[i].data() + inputs[i].size_in_bytes() / 4;
429 (void)callable(inputs[0], inputs[1], inputs[2], inputs[3],
430 inputs[4], inputs[5], inputs[6], inputs[7],
431 inputs[8], inputs[9], inputs[10], inputs[11],
434 double e = output(0);
440 error_msg <<
"The vector and scalar versions of " << name <<
" disagree. Maximum error: " << e <<
"\n";
445 std::ifstream error_file;
446 error_file.open(error_filename);
448 error_msg <<
"Error assembly: \n";
450 while (getline(error_file, line)) {
451 error_msg << line <<
"\n";
458 return {op, error_msg.str()};
463 std::string name =
"op_" + op;
464 for (
size_t i = 0; i < name.size(); i++) {
465 if (!isalnum(name[i])) name[i] =
'_';
468 name +=
"_" + std::to_string(
tasks.size());
475 tasks.emplace_back(
Task{op, name, vector_width, std::move(e)});
486 return Halide::Tools::ThreadPool<void>::num_processors_online();
495 const std::string run_target_str = run_target.
to_string();
500 std::vector<std::future<TestResult>> futures;
502 for (
size_t t = 0; t <
tasks.size(); t++) {
504 const auto &task =
tasks.at(t);
505 futures.push_back(pool.async([&]() {
506 return check_one(task.op, task.name, task.vector_width, task.expr);
510 for (
auto &f : futures) {
511 auto result = f.get();
512 constexpr int tabstop = 32;
513 const int spaces =
std::max(1, tabstop - (
int)result.op.size());
514 std::cout << result.op << std::string(spaces,
' ') <<
"(" << run_target_str <<
")\n";
515 if (!result.error_msg.empty()) {
516 std::cerr << result.error_msg;
527 template<
typename SIMDOpCheckT>
528 static int main(
int argc,
char **argv,
const std::vector<Target> &targets_to_test) {
530 std::cout <<
"host is: " << host <<
"\n";
532 const int seed = argc > 2 ?
atoi(argv[2]) : time(
nullptr);
533 std::cout <<
"simd_op_check test seed: " << seed <<
"\n";
535 for (
const auto &t : targets_to_test) {
536 if (!t.supported()) {
537 std::cout <<
"[SKIP] Unsupported target: " << t <<
"\n";
540 SIMDOpCheckT test(t);
542 if (!t.supported()) {
543 std::cout <<
"Halide was compiled without support for " << t.to_string() <<
". Skipping.\n";
548 test.filter = argv[1];
551 if (
getenv(
"HL_SIMD_OP_CHECK_FILTER")) {
552 test.filter =
getenv(
"HL_SIMD_OP_CHECK_FILTER");
564 test.output_directory = argv[2];
567 bool success = test.test_all();
577 std::cout <<
"Success!\n";
void compile_to_assembly(const std::string &filename, const std::vector< Argument > &, const std::string &fn_name, const Target &target=get_target_from_environment())
Statically compile this function to text assembly equivalent to the object file generated by compile_...
Stage update(int idx=0)
Get a handle on an update step for the purposes of scheduling it.
Func & compute_root()
Compute all of this function once ahead of time.
Callable compile_to_callable(const std::vector< Argument > &args, const Target &target=get_jit_target_from_environment())
Eagerly jit compile the function to machine code and return a callable struct that behaves like a fun...
void compile_to(const std::map< OutputFileType, std::string > &output_files, const std::vector< Argument > &args, const std::string &fn_name, const Target &target=get_target_from_environment())
Compile and generate multiple target files with single call.
Func clone_in(const Func &f)
Similar to Func::in; however, instead of replacing the call to this Func with an identity Func that r...
Func & vectorize(const VarOrRVar &var)
Mark a dimension to be computed all-at-once as a single vector.
Func & bound(const Var &var, Expr min, Expr extent)
Statically declare that the range over which a function should be evaluated is given by the second an...
Func & compute_at(const Func &f, const Var &var)
Compute this function as needed for each unique value of the given var for the given calling function...
An Image parameter to a halide pipeline.
const StageSchedule & schedule() const
Get the default (no-specialization) stage-specific schedule associated with this definition.
A reference-counted handle to Halide's internal representation of a function.
bool has_update_definition() const
Does this function have an update definition?
void mutate(IRMutator *mutator)
Accept a mutator to mutator all of the definitions and arguments of this function.
Definition & update(int idx=0)
Get a mutable handle to this function's update definition at index 'idx'.
A base class for passes over the IR which modify it (e.g.
virtual Expr visit(const IntImm *)
A base class for algorithms that need to recursively walk over the IR.
virtual void visit(const IntImm *)
const std::vector< ReductionVariable > & rvars() const
RVars of reduction domain associated with this schedule if there is any.
bool should_run(size_t task_index) const
A scalar parameter to a halide pipeline.
HALIDE_NO_USER_CODE_INLINE void set(const SOME_TYPE &val)
Set the current value of this parameter.
A multi-dimensional domain over which to iterate.
RVar x
Direct access to the first four dimensions of the reduction domain.
A reduction variable represents a single dimension of a reduction domain (RDom).
A templated Buffer class that wraps halide_buffer_t and adds functionality.
static Buffer< T, Dims, InClassDimStorage > make_scalar()
Make a zero-dimensional Buffer.
static constexpr int max_u8
std::string output_directory
static constexpr int max_i32
virtual void add_tests()=0
virtual int image_param_alignment()
bool wildcard_match(const std::string &p, const std::string &str) const
static constexpr int max_i8
static constexpr int max_u16
bool wildcard_search(const std::string &p, const std::string &str) const
bool wildcard_match(const char *p, const char *str) const
virtual ~SimdOpCheckTest()=default
SimdOpCheckTest(const Target t, int w, int h)
void check(std::string op, int vector_width, Expr e)
Target get_run_target() const
static int main(int argc, char **argv, const std::vector< Target > &targets_to_test)
static constexpr int max_i16
int num_worker_threads() const
TestResult check_one(const std::string &op, const std::string &name, int vector_width, Expr e)
virtual bool can_run_code() const
std::vector< Task > tasks
virtual void compile_and_check(Func error, const std::string &op, const std::string &name, int vector_width, const std::vector< Argument > &arg_types, std::ostringstream &error_msg)
A Halide variable, to be used when defining functions.
ConstantInterval max(const ConstantInterval &a, const ConstantInterval &b)
std::string get_env_variable(char const *env_var_name)
Get value of an environment variable.
std::map< OutputFileType, const OutputInfo > get_output_info(const Target &target)
std::string get_test_tmp_dir()
Return the path to a directory that can be safely written to when running tests; the contents directo...
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
Target get_host_target()
Return the target corresponding to the host machine.
Type BFloat(int bits, int lanes=1)
Construct a floating-point type in the bfloat format.
Type UInt(int bits, int lanes=1)
Constructing an unsigned integer type.
Type Float(int bits, int lanes=1)
Construct a floating-point type.
Expr maximum(Expr, const std::string &s="maximum")
Type Int(int bits, int lanes=1)
Constructing a signed integer type.
Expr absd(Expr a, Expr b)
Return the absolute difference between two values.
void compile_standalone_runtime(const std::string &object_filename, const Target &t)
Create an object file containing the Halide runtime for a given target.
unsigned __INT32_TYPE__ uint32_t
char * getenv(const char *)
A fragment of Halide syntax.
@ Halide
A call to a Func.
void accept(IRVisitor *v) const
Dispatch to the correct visitor method for this node.
static bool can_jit_target(const Target &target)
If the given target can be executed via the wasm executor, return true.
A struct representing a target machine and os to generate code for.
enum Halide::Target::Arch arch
bool has_feature(Feature f) const
int bits
The bit-width of the target machine.
enum Halide::Target::OS os
std::string to_string() const
Convert the Target into a string form that can be reconstituted by merge_string(),...
Target without_feature(Feature f) const
Return a copy of the target with the given feature cleared.
Feature
Optional features a target can have.
Target with_feature(Feature f) const
Return a copy of the target with the given feature set.
Types in the halide type system.
HALIDE_ALWAYS_INLINE bool is_int_or_uint() const
Is this type an integer type of any sort?
Expr max() const
Return an expression which is the maximum value of this type.
Class that provides a type that implements half precision floating point using the bfloat16 format.
Class that provides a type that implements half precision floating point (IEEE754 2008 binary16) in s...