From c62635d6633db1d1e633ecb4bb0daf352b775689 Mon Sep 17 00:00:00 2001 From: Raman Sarokin Date: Tue, 12 Jan 2021 10:25:04 -0800 Subject: [PATCH] Using common names(as in common/tasks) in Metal elementwise ops. PiperOrigin-RevId: 351397942 Change-Id: Id292f0512eeba335c34b8740f794f33b4b562ae9 --- .../lite/delegates/gpu/metal/kernels/add.cc | 4 +-- .../gpu/metal/kernels/elementwise.cc | 25 +++++++++++-------- .../lite/delegates/gpu/metal/kernels/prelu.cc | 8 +++--- .../metal/kernels/quantize_and_dequantize.cc | 6 ++--- .../lite/delegates/gpu/metal/kernels/relu.cc | 9 ++++--- .../delegates/gpu/metal/metal_arguments.cc | 8 +++--- 6 files changed, 32 insertions(+), 28 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/add.cc b/tensorflow/lite/delegates/gpu/metal/kernels/add.cc index e6e36eb362b..2664cbca430 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/add.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/add.cc @@ -39,8 +39,8 @@ ComputeTaskDescriptor Add(const OperationDef& definition) { for (int i = 1; i < definition.src_tensors.size(); ++i) { const std::string tensor_name = "src_tensor_" + std::to_string(i); desc.AddSrcTensor(tensor_name, definition.src_tensors[i]); - desc.shader_source += - " value += args." + tensor_name + ".Read(gid.x, gid.y, gid.z);\n"; + desc.shader_source += " in_out_value += args." + tensor_name + + ".Read(X_COORD, Y_COORD, S_COORD);\n"; } return desc; diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/elementwise.cc b/tensorflow/lite/delegates/gpu/metal/kernels/elementwise.cc index 1e6076e7489..cd385fce6d2 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/elementwise.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/elementwise.cc @@ -88,9 +88,9 @@ ComputeTaskDescriptor ElementwiseWithTwoInputs(const OperationDef& definition, OperationType op_type) { ComputeTaskDescriptor desc(definition); desc.is_linkable = true; - const std::string x_coord = second_shape.w == 1 ? "0" : "gid.x"; - const std::string y_coord = second_shape.h == 1 ? "0" : "gid.y"; - const std::string s_coord = second_shape.c == 1 ? "0" : "gid.z"; + const std::string x_coord = second_shape.w == 1 ? "0" : "X_COORD"; + const std::string y_coord = second_shape.h == 1 ? "0" : "Y_COORD"; + const std::string s_coord = second_shape.c == 1 ? "0" : "S_COORD"; std::string code; code = " FLT4 src_1 = args.second_tensor.Read(" + x_coord + ", " + y_coord + ", " + s_coord + ");\n"; @@ -99,7 +99,9 @@ ComputeTaskDescriptor ElementwiseWithTwoInputs(const OperationDef& definition, code += " src_1.z = src_1.x;\n"; code += " src_1.w = src_1.x;\n"; } - code += " value = " + TwoInputFunctor(op_type, "value", "src_1") + ";\n"; + code += + " in_out_value = " + TwoInputFunctor(op_type, "in_out_value", "src_1") + + ";\n"; desc.shader_source = code; @@ -112,7 +114,7 @@ ComputeTaskDescriptor ElementwiseWithOneInput(const OperationDef& definition, ComputeTaskDescriptor desc(definition); desc.is_linkable = true; desc.shader_source = - " value = " + OneInputFunctor(op_type, "value") + ";\n"; + " in_out_value = " + OneInputFunctor(op_type, "in_out_value") + ";\n"; return desc; } @@ -138,7 +140,7 @@ ComputeTaskDescriptor ElementwiseWithOneInputAndConstantArguent( linear_desc.size = linear_desc.data.size(); desc.args.AddObject( "linear", absl::make_unique(std::move(linear_desc))); - desc.shader_source += " FLT4 second_arg = args.linear.Read(gid.z);\n"; + desc.shader_source += " FLT4 second_arg = args.linear.Read(S_COORD);\n"; } else if (hwc_buf) { TensorDescriptor hwc_desc{definition.GetDataType(), TensorStorageType::BUFFER, Layout::HWC}; @@ -146,9 +148,9 @@ ComputeTaskDescriptor ElementwiseWithOneInputAndConstantArguent( desc.args.AddObject( "hwc", absl::make_unique(std::move(hwc_desc))); - const std::string x_coord = hwc_buf->shape.w == 1 ? "0" : "gid.x"; - const std::string y_coord = hwc_buf->shape.h == 1 ? "0" : "gid.y"; - const std::string s_coord = hwc_buf->shape.c == 1 ? "0" : "gid.z"; + const std::string x_coord = hwc_buf->shape.w == 1 ? "0" : "X_COORD"; + const std::string y_coord = hwc_buf->shape.h == 1 ? "0" : "Y_COORD"; + const std::string s_coord = hwc_buf->shape.c == 1 ? "0" : "S_COORD"; desc.shader_source += " FLT4 second_arg = args.hwc.Read(" + x_coord + ", " + y_coord + ", " + s_coord + ");\n"; if (hwc_buf->shape.c == 1) { @@ -157,8 +159,9 @@ ComputeTaskDescriptor ElementwiseWithOneInputAndConstantArguent( desc.shader_source += " second_arg.w = second_arg.x;\n"; } } - desc.shader_source += - " value = " + TwoInputFunctor(op_type, "value", "second_arg") + ";\n"; + desc.shader_source += " in_out_value = " + + TwoInputFunctor(op_type, "in_out_value", "second_arg") + + ";\n"; return desc; } diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/prelu.cc b/tensorflow/lite/delegates/gpu/metal/kernels/prelu.cc index 119acda3b61..bd15c0390ec 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/prelu.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/prelu.cc @@ -46,12 +46,12 @@ ComputeTaskDescriptor PReLU(const OperationDef& definition, desc.args.AddFloat("clip", attr.clip); desc.shader_source = R"( - value = FLT4(clamp(value, FLT4(0.0f), FLT4(args.clip)) + args.alpha.Read(gid.z) * min(FLT4(0.0f), value)); + in_out_value = FLT4(clamp(in_out_value, FLT4(0.0f), FLT4(args.clip)) + args.alpha.Read(S_COORD) * min(FLT4(0.0f), in_out_value)); )"; } else { desc.shader_source = R"( - value = FLT4(max(FLT4(0.0f), value) + args.alpha.Read(gid.z) * min(FLT4(0.0f), value)); + in_out_value = FLT4(max(FLT4(0.0f), in_out_value) + args.alpha.Read(S_COORD) * min(FLT4(0.0f), in_out_value)); )"; } auto data_type = DeduceDataTypeFromPrecision(definition.precision); @@ -79,12 +79,12 @@ ComputeTaskDescriptor PReLUFull(const OperationDef& definition, desc.args.AddFloat("clip", attr.clip); desc.shader_source = R"( - value = FLT4(clamp(value, FLT4(0.0f), FLT4(args.clip)) + args.alpha.Read(gid.x, gid.y, gid.z) * min(FLT4(0.0f), value)); + in_out_value = FLT4(clamp(in_out_value, FLT4(0.0f), FLT4(args.clip)) + args.alpha.Read(X_COORD, Y_COORD, S_COORD) * min(FLT4(0.0f), in_out_value)); )"; } else { desc.shader_source = R"( - value = FLT4(max(FLT4(0.0f), value) + args.alpha.Read(gid.x, gid.y, gid.z) * min(FLT4(0.0f), value)); + in_out_value = FLT4(max(FLT4(0.0f), in_out_value) + args.alpha.Read(X_COORD, Y_COORD, S_COORD) * min(FLT4(0.0f), in_out_value)); )"; } TensorDescriptor alpha_desc{definition.GetDataType(), diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/quantize_and_dequantize.cc b/tensorflow/lite/delegates/gpu/metal/kernels/quantize_and_dequantize.cc index b07e9d3c56b..dd16f96edfc 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/quantize_and_dequantize.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/quantize_and_dequantize.cc @@ -29,9 +29,9 @@ ComputeTaskDescriptor QuantizeAndDequantize( ComputeTaskDescriptor desc(definition); desc.is_linkable = true; desc.shader_source = R"( - value = clamp(value, FLT4(args.qmin), FLT4(args.qmax)); - value = (value - FLT4(args.qmin)) / FLT4(args.qscale); - value = round(value) * FLT4(args.qscale) + FLT4(args.qmin);)"; + in_out_value = clamp(in_out_value, FLT4(args.qmin), FLT4(args.qmax)); + in_out_value = (in_out_value - FLT4(args.qmin)) / FLT4(args.qscale); + in_out_value = round(in_out_value) * FLT4(args.qscale) + FLT4(args.qmin);)"; desc.args.AddFloat("qmax", attr.max); desc.args.AddFloat("qmin", attr.min); desc.args.AddFloat("qscale", attr.scale); diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/relu.cc b/tensorflow/lite/delegates/gpu/metal/kernels/relu.cc index 5da72f5a984..5dafea4b866 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/relu.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/relu.cc @@ -35,12 +35,13 @@ ComputeTaskDescriptor ReLU(const OperationDef& definition, ComputeTaskDescriptor desc(definition); desc.is_linkable = true; const std::string min_func = - attr.alpha == 0 ? "FLT4(0.0f)" : "min(value * args.alpha, 0.0f)"; + attr.alpha == 0 ? "FLT4(0.0f)" : "min(in_out_value * args.alpha, 0.0f)"; if (attr.clip != 0.0) { - desc.shader_source = - "value = FLT4(clamp(value, " + min_func + ", FLT4(args.clip)));"; + desc.shader_source = "in_out_value = FLT4(clamp(in_out_value, " + min_func + + ", FLT4(args.clip)));"; } else { - desc.shader_source = "value = FLT4(max(value, " + min_func + "));"; + desc.shader_source = + "in_out_value = FLT4(max(in_out_value, " + min_func + "));"; } desc.args.AddFloat("alpha", attr.alpha); desc.args.AddFloat("clip", attr.clip); diff --git a/tensorflow/lite/delegates/gpu/metal/metal_arguments.cc b/tensorflow/lite/delegates/gpu/metal/metal_arguments.cc index 05c226064bf..5afac90a569 100644 --- a/tensorflow/lite/delegates/gpu/metal/metal_arguments.cc +++ b/tensorflow/lite/delegates/gpu/metal/metal_arguments.cc @@ -506,10 +506,10 @@ absl::Status MetalArguments::ResolveSelector( // x_coord can have batch size property of link_object ResolveObjectNames(object_name, names, &x_coord); *result = it->second; - ReplaceAllWords("value", value_name, result); - ReplaceAllWords("gid.x", x_coord, result); - ReplaceAllWords("gid.y", y_coord, result); - ReplaceAllWords("gid.z", s_coord, result); + ReplaceAllWords("in_out_value", value_name, result); + ReplaceAllWords("X_COORD", x_coord, result); + ReplaceAllWords("Y_COORD", y_coord, result); + ReplaceAllWords("S_COORD", s_coord, result); RETURN_IF_ERROR(ResolveSelectorsPass(args, {}, result)); if (selector == "Linking") { return absl::OkStatus();