Added new compiler defines to Metal for hiding metal specific elements in kernels.
PiperOrigin-RevId: 351654818 Change-Id: If1222b61738677311c7739d3a75a9a1106fa3eef
This commit is contained in:
parent
c4aab50762
commit
1486708ff6
@ -94,6 +94,13 @@ absl::Status ComputeTask::Compile(CalculationsPrecision precision,
|
||||
@"TO_ACCUM3_TYPE" : toAccumulatorType3,
|
||||
@"TO_ACCUM4_TYPE" : toAccumulatorType4,
|
||||
@"SIMDGROUP_BARRIER" : barrier,
|
||||
@"MAIN_FUNCTION" : @"\"kernel void ComputeFunction\"",
|
||||
@"GLOBAL_ID_0" : @"static_cast<int>(reserved_gid.x)",
|
||||
@"GLOBAL_ID_1" : @"static_cast<int>(reserved_gid.y)",
|
||||
@"GLOBAL_ID_2" : @"static_cast<int>(reserved_gid.z)",
|
||||
@"INIT_FLT(value)" : [NSString stringWithFormat:@"%@(value)", storageType],
|
||||
@"INIT_FLT4(value)" :
|
||||
[NSString stringWithFormat:@"%@4(value)", storageType],
|
||||
};
|
||||
|
||||
NSString* code =
|
||||
|
@ -33,11 +33,10 @@ namespace metal {
|
||||
namespace {
|
||||
std::string GetElementWiseCode(const OperationDef& op_def) {
|
||||
return R"(
|
||||
kernel void ComputeFunction($0
|
||||
uint3 gid[[thread_position_in_grid]]) {
|
||||
int X = static_cast<int>(gid.x);
|
||||
int Y = static_cast<int>(gid.y);
|
||||
int Z = static_cast<int>(gid.z);
|
||||
MAIN_FUNCTION($0) {
|
||||
int X = GLOBAL_ID_0;
|
||||
int Y = GLOBAL_ID_1;
|
||||
int Z = GLOBAL_ID_2;
|
||||
if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || Z >= args.dst_tensor.Slices()) {
|
||||
return;
|
||||
}
|
||||
|
@ -155,13 +155,20 @@ absl::Status MetalArguments::Init(
|
||||
std::string struct_desc = ScalarArgumentsToStructWithVec4Fields(args, code);
|
||||
RETURN_IF_ERROR(SetObjectsResources(*args));
|
||||
ResolveArgsPass(code);
|
||||
*code = R"(
|
||||
std::string header = R"(
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
)" + struct_desc +
|
||||
"\n" + *code;
|
||||
*code = absl::Substitute(*code, GetListOfArgs(/*buffer_offset*/ 0));
|
||||
)";
|
||||
header += struct_desc + "\n";
|
||||
*code = header + *code;
|
||||
std::string arguments = GetListOfArgs(/*buffer_offset*/ 0);
|
||||
if (code->find("GLOBAL_ID_") != std::string::npos) {
|
||||
AppendArgument("uint3 reserved_gid[[thread_position_in_grid]]", &arguments);
|
||||
} else if (!arguments.empty()) {
|
||||
arguments += ",\n";
|
||||
}
|
||||
*code = absl::Substitute(*code, arguments);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
@ -380,9 +387,6 @@ std::string MetalArguments::GetListOfArgs(int buffer_offset) {
|
||||
&result);
|
||||
buffer_offset++;
|
||||
}
|
||||
if (!result.empty()) {
|
||||
result += ",\n";
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user