Added new compiler defines to Metal for hiding metal specific elements in kernels.

PiperOrigin-RevId: 351654818
Change-Id: If1222b61738677311c7739d3a75a9a1106fa3eef
This commit is contained in:
Raman Sarokin 2021-01-13 13:34:43 -08:00 committed by TensorFlower Gardener
parent c4aab50762
commit 1486708ff6
3 changed files with 22 additions and 12 deletions

View File

@ -94,6 +94,13 @@ absl::Status ComputeTask::Compile(CalculationsPrecision precision,
@"TO_ACCUM3_TYPE" : toAccumulatorType3,
@"TO_ACCUM4_TYPE" : toAccumulatorType4,
@"SIMDGROUP_BARRIER" : barrier,
@"MAIN_FUNCTION" : @"\"kernel void ComputeFunction\"",
@"GLOBAL_ID_0" : @"static_cast<int>(reserved_gid.x)",
@"GLOBAL_ID_1" : @"static_cast<int>(reserved_gid.y)",
@"GLOBAL_ID_2" : @"static_cast<int>(reserved_gid.z)",
@"INIT_FLT(value)" : [NSString stringWithFormat:@"%@(value)", storageType],
@"INIT_FLT4(value)" :
[NSString stringWithFormat:@"%@4(value)", storageType],
};
NSString* code =

View File

@ -33,11 +33,10 @@ namespace metal {
namespace {
std::string GetElementWiseCode(const OperationDef& op_def) {
return R"(
kernel void ComputeFunction($0
uint3 gid[[thread_position_in_grid]]) {
int X = static_cast<int>(gid.x);
int Y = static_cast<int>(gid.y);
int Z = static_cast<int>(gid.z);
MAIN_FUNCTION($0) {
int X = GLOBAL_ID_0;
int Y = GLOBAL_ID_1;
int Z = GLOBAL_ID_2;
if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || Z >= args.dst_tensor.Slices()) {
return;
}

View File

@ -155,13 +155,20 @@ absl::Status MetalArguments::Init(
std::string struct_desc = ScalarArgumentsToStructWithVec4Fields(args, code);
RETURN_IF_ERROR(SetObjectsResources(*args));
ResolveArgsPass(code);
*code = R"(
std::string header = R"(
#include <metal_stdlib>
using namespace metal;
)" + struct_desc +
"\n" + *code;
*code = absl::Substitute(*code, GetListOfArgs(/*buffer_offset*/ 0));
)";
header += struct_desc + "\n";
*code = header + *code;
std::string arguments = GetListOfArgs(/*buffer_offset*/ 0);
if (code->find("GLOBAL_ID_") != std::string::npos) {
AppendArgument("uint3 reserved_gid[[thread_position_in_grid]]", &arguments);
} else if (!arguments.empty()) {
arguments += ",\n";
}
*code = absl::Substitute(*code, arguments);
return absl::OkStatus();
}
@ -380,9 +387,6 @@ std::string MetalArguments::GetListOfArgs(int buffer_offset) {
&result);
buffer_offset++;
}
if (!result.empty()) {
result += ",\n";
}
return result;
}