From 1486708ff67ff98e2bfe49ccddfdd65e284e70f7 Mon Sep 17 00:00:00 2001 From: Raman Sarokin Date: Wed, 13 Jan 2021 13:34:43 -0800 Subject: [PATCH] Added new compiler defines to Metal for hiding metal specific elements in kernels. PiperOrigin-RevId: 351654818 Change-Id: If1222b61738677311c7739d3a75a9a1106fa3eef --- .../lite/delegates/gpu/metal/compute_task.cc | 7 +++++++ .../gpu/metal/compute_task_descriptor.cc | 9 ++++----- .../delegates/gpu/metal/metal_arguments.cc | 18 +++++++++++------- 3 files changed, 22 insertions(+), 12 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/metal/compute_task.cc b/tensorflow/lite/delegates/gpu/metal/compute_task.cc index 15336108a9b..59aa7f96f46 100644 --- a/tensorflow/lite/delegates/gpu/metal/compute_task.cc +++ b/tensorflow/lite/delegates/gpu/metal/compute_task.cc @@ -94,6 +94,13 @@ absl::Status ComputeTask::Compile(CalculationsPrecision precision, @"TO_ACCUM3_TYPE" : toAccumulatorType3, @"TO_ACCUM4_TYPE" : toAccumulatorType4, @"SIMDGROUP_BARRIER" : barrier, + @"MAIN_FUNCTION" : @"\"kernel void ComputeFunction\"", + @"GLOBAL_ID_0" : @"static_cast(reserved_gid.x)", + @"GLOBAL_ID_1" : @"static_cast(reserved_gid.y)", + @"GLOBAL_ID_2" : @"static_cast(reserved_gid.z)", + @"INIT_FLT(value)" : [NSString stringWithFormat:@"%@(value)", storageType], + @"INIT_FLT4(value)" : + [NSString stringWithFormat:@"%@4(value)", storageType], }; NSString* code = diff --git a/tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.cc b/tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.cc index 16f8478f066..c7c98894a33 100644 --- a/tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.cc +++ b/tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.cc @@ -33,11 +33,10 @@ namespace metal { namespace { std::string GetElementWiseCode(const OperationDef& op_def) { return R"( -kernel void ComputeFunction($0 - uint3 gid[[thread_position_in_grid]]) { - int X = static_cast(gid.x); - int Y = static_cast(gid.y); - int Z = static_cast(gid.z); +MAIN_FUNCTION($0) { + int X = GLOBAL_ID_0; + int Y = GLOBAL_ID_1; + int Z = GLOBAL_ID_2; if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || Z >= args.dst_tensor.Slices()) { return; } diff --git a/tensorflow/lite/delegates/gpu/metal/metal_arguments.cc b/tensorflow/lite/delegates/gpu/metal/metal_arguments.cc index 1c6afb36f52..68d73f70c02 100644 --- a/tensorflow/lite/delegates/gpu/metal/metal_arguments.cc +++ b/tensorflow/lite/delegates/gpu/metal/metal_arguments.cc @@ -155,13 +155,20 @@ absl::Status MetalArguments::Init( std::string struct_desc = ScalarArgumentsToStructWithVec4Fields(args, code); RETURN_IF_ERROR(SetObjectsResources(*args)); ResolveArgsPass(code); - *code = R"( + std::string header = R"( #include using namespace metal; -)" + struct_desc + - "\n" + *code; - *code = absl::Substitute(*code, GetListOfArgs(/*buffer_offset*/ 0)); +)"; + header += struct_desc + "\n"; + *code = header + *code; + std::string arguments = GetListOfArgs(/*buffer_offset*/ 0); + if (code->find("GLOBAL_ID_") != std::string::npos) { + AppendArgument("uint3 reserved_gid[[thread_position_in_grid]]", &arguments); + } else if (!arguments.empty()) { + arguments += ",\n"; + } + *code = absl::Substitute(*code, arguments); return absl::OkStatus(); } @@ -380,9 +387,6 @@ std::string MetalArguments::GetListOfArgs(int buffer_offset) { &result); buffer_offset++; } - if (!result.empty()) { - result += ",\n"; - } return result; }