From ec5c9bebf8526ea96fba9c3d1459594ab6727ab7 Mon Sep 17 00:00:00 2001 From: Raman Sarokin Date: Mon, 23 Mar 2020 16:35:47 -0700 Subject: [PATCH] Reorganized structure of Elementwise operations. Support of linking for two input elementwise. Added broadcast parameters. Removed ApplyMask(mul + broadcast). PiperOrigin-RevId: 302545362 Change-Id: Icb9cb94aaad448a205dc7160f4b44820081d69ca --- tensorflow/lite/delegates/gpu/metal/api.cc | 65 ++++- .../delegates/gpu/metal/compiled_model.cc | 14 +- .../gpu/metal/compute_task_descriptor.h | 4 + .../lite/delegates/gpu/metal/kernels/add.cc | 1 + .../gpu/metal/kernels/elementwise.cc | 236 +++++++++--------- .../delegates/gpu/metal/kernels/elementwise.h | 15 +- .../gpu/metal/kernels/elementwise_test.mm | 32 ++- .../lite/delegates/gpu/metal/kernels/mul.cc | 96 ------- .../lite/delegates/gpu/metal/kernels/mul.h | 5 - .../delegates/gpu/metal/kernels/mul_test.mm | 51 ---- 10 files changed, 240 insertions(+), 279 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/metal/api.cc b/tensorflow/lite/delegates/gpu/metal/api.cc index b2887e523a5..dedb2aa8df1 100644 --- a/tensorflow/lite/delegates/gpu/metal/api.cc +++ b/tensorflow/lite/delegates/gpu/metal/api.cc @@ -51,6 +51,25 @@ namespace tflite { namespace gpu { namespace metal { namespace { +bool IsWidthBroadcastedForSecondInput( + const std::vector>*>& inputs) { + return inputs.size() == 2 && + inputs[0]->tensor.shape.w != inputs[1]->tensor.shape.w && + inputs[1]->tensor.shape.w == 1; +} +bool IsHeightBroadcastedForSecondInput( + const std::vector>*>& inputs) { + return inputs.size() == 2 && + inputs[0]->tensor.shape.h != inputs[1]->tensor.shape.h && + inputs[1]->tensor.shape.h == 1; +} +bool IsChannelsBroadcastedForSecondInput( + const std::vector>*>& inputs) { + return inputs.size() == 2 && + inputs[0]->tensor.shape.c != inputs[1]->tensor.shape.c && + inputs[1]->tensor.shape.c == 1; +} + std::vector SelectDepthWiseConv( int id, ValueId input_id, ValueId output_id, const DepthwiseConvolution2DAttributes& attr, @@ -134,11 +153,22 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node, int node_id = static_cast(node->id); auto op_type = OperationTypeFromString(node->operation.type); switch (op_type) { - case OperationType::ADD: - *tasks = Add(node_id, inputs, outputs[0], - absl::any_cast(node->operation.attributes), - options); + case OperationType::ADD: { + const auto srcs = graph.FindInputs(node_id); + ElementwiseBroadcastSettings broadcast; + broadcast.width = IsWidthBroadcastedForSecondInput(srcs); + broadcast.height = IsHeightBroadcastedForSecondInput(srcs); + broadcast.channels = IsChannelsBroadcastedForSecondInput(srcs); + if (broadcast.width || broadcast.height || broadcast.channels) { + *tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0], op_type, + broadcast); + } else { + *tasks = Add(node_id, inputs, outputs[0], + absl::any_cast(node->operation.attributes), + options); + } break; + } case OperationType::CONCAT: { std::vector input_shapes; for (auto& input : graph.FindInputs(node->id)) { @@ -194,7 +224,18 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node, absl::any_cast(node->operation.attributes), options); } else { - *tasks = ApplyMask(node_id, inputs[0], inputs[1], outputs[0], options); + if (inputs.size() == 2) { + const auto srcs = graph.FindInputs(node_id); + ElementwiseBroadcastSettings broadcast; + broadcast.width = IsWidthBroadcastedForSecondInput(srcs); + broadcast.height = IsHeightBroadcastedForSecondInput(srcs); + broadcast.channels = IsChannelsBroadcastedForSecondInput(srcs); + *tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0], + op_type, broadcast); + } else { + return absl::UnimplementedError( + "No support of multiply with more than 2 inputs"); + } } break; case OperationType::PAD: { @@ -269,8 +310,18 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node, case OperationType::SUB: { const ElementwiseAttributes* attr = absl::any_cast(&node->operation.attributes); - *tasks = - ElementwiseWithTwoInputs(node_id, inputs, outputs[0], op_type, attr); + if (attr) { + *tasks = ElementwiseWithOneInputAndConstantArguent( + node_id, inputs[0], outputs[0], options, op_type, *attr); + } else { + const auto srcs = graph.FindInputs(node_id); + ElementwiseBroadcastSettings broadcast; + broadcast.width = IsWidthBroadcastedForSecondInput(srcs); + broadcast.height = IsHeightBroadcastedForSecondInput(srcs); + broadcast.channels = IsChannelsBroadcastedForSecondInput(srcs); + *tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0], op_type, + broadcast); + } } break; case OperationType::BATCH_NORMALIZATION: case OperationType::BATCH_TO_SPACE: diff --git a/tensorflow/lite/delegates/gpu/metal/compiled_model.cc b/tensorflow/lite/delegates/gpu/metal/compiled_model.cc index 711ed9fed88..06cc10a0520 100644 --- a/tensorflow/lite/delegates/gpu/metal/compiled_model.cc +++ b/tensorflow/lite/delegates/gpu/metal/compiled_model.cc @@ -180,10 +180,16 @@ void BuildFusableChains(const std::vector& input_ids, bool fused = false; for (auto& chain : *chains) { // We can fuse only single output for now. - if (Contains(task_descriptor->input_buffers, - chain.back()->output_buffer.id) && - CanFuseOperations(chain.back(), task_descriptor, output_ids, - *descriptors, chains)) { + bool can_link = false; + if (task_descriptor->is_associative_op) { + can_link = Contains(task_descriptor->input_buffers, + chain.back()->output_buffer.id); + } else { + can_link = task_descriptor->input_buffers[0].id == + chain.back()->output_buffer.id; + } + if (can_link && CanFuseOperations(chain.back(), task_descriptor, + output_ids, *descriptors, chains)) { chain.push_back(task_descriptor); fused = true; break; diff --git a/tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h b/tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h index 35bad273c50..923f4dcc245 100644 --- a/tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h +++ b/tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h @@ -99,6 +99,10 @@ struct ComputeTaskDescriptor { // $2 // output_buffer[linear_index] = value; // } + + // when operation associative, we can rearrange input tensors + // for example add is associative + bool is_associative_op = false; std::string shader_source; std::vector input_buffers; // A single per-operation output is supported now. diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/add.cc b/tensorflow/lite/delegates/gpu/metal/kernels/add.cc index c857a092a53..b4a8e781c72 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/add.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/add.cc @@ -86,6 +86,7 @@ std::vector Add(int id, } desc->is_linkable = true; + desc->is_associative_op = true; desc->shader_source = GetAddTableCodeFused(input_ids.size() - 1); for (int i = 0; i < input_ids.size(); ++i) { diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/elementwise.cc b/tensorflow/lite/delegates/gpu/metal/kernels/elementwise.cc index 7fdfd3257ea..9d9e054f40a 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/elementwise.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/elementwise.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/strings/substitute.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/util.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" @@ -29,115 +30,93 @@ namespace metal { namespace { -std::string GetElementwiseWithTwoInputsCode(int src_count, - OperationType op_type, - const float* scalar) { - std::string code = R"( - #include - using namespace metal; +std::string OneInputFunctor(OperationType op_type, const std::string& value) { + const std::unordered_map functors{ + {OperationType::ABS, "abs($0)"}, + {OperationType::SIN, "sin($0)"}, + {OperationType::HARD_SWISH, + "$0 * clamp($0 / 6.0f + FLT4(0.5f), FLT4(0.0f), FLT4(1.0f))"}, + {OperationType::COS, "cos($0)"}, + {OperationType::EXP, "exp($0)"}, + {OperationType::LOG, "log($0)"}, + {OperationType::SQRT, "sqrt($0)"}, + {OperationType::RSQRT, "1.0 / sqrt($0)"}, + {OperationType::SQUARE, "$0 * $0"}, + {OperationType::SIGMOID, "1.0 / (1.0 + exp(-1.0 * $0))"}, + {OperationType::TANH, "tanh($0)"}, + }; - struct uniforms { - int4 src_size; - }; - - $0 - kernel void ComputeFunction( - $1 - uint3 gid[[thread_position_in_grid]]) { - if (static_cast(gid.x) >= params.src_size.x || - static_cast(gid.y) >= params.src_size.y) { - return; - } - - int linear_index = (int(gid.z) * params.src_size.y + int(gid.y)) * - params.src_size.x + int(gid.x); - FLT4 src_0 = src_buffer0[linear_index]; - )"; - - if (scalar == nullptr) { - code += " FLT4 src_1 = src_buffer1[linear_index];"; - } else { - code += " FLT4 src_1 = FLT4(" + std::to_string(*scalar) + ");"; + if (functors.find(op_type) == functors.end()) { + return "Error, unknown op"; } - switch (op_type) { - case OperationType::DIV: { - code += " FLT4 value = src_0 / src_1;"; - break; - } - case OperationType::MAXIMUM: { - code += " FLT4 value = max(src_0, src_1);"; - break; - } - case OperationType::MINIMUM: { - code += " FLT4 value = min(src_0, src_1);"; - break; - } - case OperationType::POW: { - code += " FLT4 value = pow(src_0, src_1);"; - break; - } - case OperationType::SQUARED_DIFF: { - code += " FLT4 value = (src_0 - src_1) * (src_0 - src_1);"; - break; - } - case OperationType::SUB: { - code += " FLT4 value = src_0 - src_1;"; - break; - } - default: { - return ""; - } - } - code += R"( - $2 - dst_buffer[linear_index] = value; - })"; - return code; + + return absl::Substitute(functors.at(op_type), value); } + +std::string TwoInputFunctor(OperationType op_type, const std::string& value0, + const std::string& value1) { + const std::unordered_map functors{ + {OperationType::ADD, "$0 + $1"}, + {OperationType::DIV, "$0 / $1"}, + {OperationType::MAXIMUM, "max($0, $1)"}, + {OperationType::MINIMUM, "min($0, $1)"}, + {OperationType::MUL, "$0 * $1"}, + {OperationType::POW, "pow($0, $1)"}, + {OperationType::SQUARED_DIFF, "($0 - $1) * ($0 - $1)"}, + {OperationType::SUB, "$0 - $1"}, + }; + + if (functors.find(op_type) == functors.end()) { + return "Error, unknown op"; + } + + return absl::Substitute(functors.at(op_type), value0, value1); +} + } // namespace std::vector ElementwiseWithTwoInputs( int id, std::vector input_ids, ValueId output_id, - OperationType op_type, const ElementwiseAttributes* attr) { - const float* scalar = nullptr; - if (attr) { - scalar = absl::get_if(&attr->param); - } + OperationType op_type, const ElementwiseBroadcastSettings& settings) { auto desc = std::make_shared(); desc->id = id; - desc->is_linkable = false; - desc->shader_source = - GetElementwiseWithTwoInputsCode(input_ids.size(), op_type, scalar); - - for (int i = 0; i < input_ids.size(); ++i) { - const std::string buffer_name = - "device FLT4* const src_buffer" + std::to_string(i); - desc->input_buffers.push_back({input_ids[i], buffer_name}); + desc->is_linkable = true; + const std::string x_coord = settings.width ? "0" : "int(gid.x)"; + const std::string y_coord = settings.height ? "0" : "int(gid.y)"; + const std::string s_coord = settings.channels ? "0" : "int(gid.z)"; + std::string code = + "FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid, device FLT4* " + "const second_tensor, int2 second_size) {\n"; + code += " int second_index = (" + s_coord + " * second_size.y + " + y_coord + + ") * second_size.x + " + x_coord + ";\n"; + code += " FLT4 src_1 = second_tensor[second_index];\n"; + if (settings.channels) { + code += " src_1.y = src_1.x;\n"; + code += " src_1.z = src_1.x;\n"; + code += " src_1.w = src_1.x;\n"; } + code += " return " + TwoInputFunctor(op_type, "value", "src_1") + ";\n"; + code += "}\n"; - desc->output_buffer = {output_id, "device FLT4* dst_buffer", - [input_ids](const std::map& buffers) { - return buffers.find(input_ids[0])->second; - }}; + desc->shader_source = code; + + desc->input_buffers = { + {input_ids[0], "device FLT4* const"}, + {input_ids[1], "device FLT4* const"}, + }; + desc->output_buffer = {output_id}; desc->uniform_buffers = { - {"constant uniforms& params", - [input_ids](const std::map& buffers) { - const auto& dimension = buffers.find(input_ids[0])->second; - std::vector uniform_params = {dimension.w, dimension.h, 0, 0}; + {"constant int2&", + [input_ids, output_id](const std::map& buffers) { + const auto& input_dim_1 = buffers.find(input_ids[1])->second; + std::vector uniform_params{ + input_dim_1.w, + input_dim_1.h, + }; return GetByteBuffer(uniform_params); }}, }; - - desc->resize_function = [input_ids](const std::map& buffers) { - const auto& src_dim = buffers.find(input_ids[0])->second; - const uint3 groups_size{16, 16, 1}; - int groups_x = IntegralDivideRoundUp(src_dim.w, groups_size.x); - int groups_y = IntegralDivideRoundUp(src_dim.h, groups_size.y); - const int dst_layers = IntegralDivideRoundUp(src_dim.c, 4); - int groups_z = IntegralDivideRoundUp(dst_layers, groups_size.z); - return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z}); - }; return {desc}; } @@ -146,29 +125,10 @@ std::vector ElementwiseWithOneInput( auto desc = std::make_shared(); desc->id = id; desc->is_linkable = true; - - const std::unordered_map functors{ - {OperationType::ABS, "abs(value)"}, - {OperationType::SIN, "sin(value)"}, - {OperationType::HARD_SWISH, - "value * clamp(value / 6.0f + FLT4(0.5f), FLT4(0.0f), FLT4(1.0f))"}, - {OperationType::COS, "cos(value)"}, - {OperationType::EXP, "exp(value)"}, - {OperationType::LOG, "log(value)"}, - {OperationType::SQRT, "sqrt(value)"}, - {OperationType::RSQRT, "1.0 / sqrt(value)"}, - {OperationType::SQUARE, "value * value"}, - {OperationType::SIGMOID, "1.0 / (1.0 + exp(-1.0 * value))"}, - {OperationType::TANH, "tanh(value)"}, - }; - - if (functors.count(op_type) == 0) { - return {}; - } - desc->shader_source = "FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid) {\n"; - desc->shader_source += " return " + functors.at(op_type) + ";\n"; + desc->shader_source += + " return " + OneInputFunctor(op_type, "value") + ";\n"; desc->shader_source += " }"; desc->input_buffers = {{input_id}}; @@ -176,6 +136,54 @@ std::vector ElementwiseWithOneInput( return {desc}; } +std::vector ElementwiseWithOneInputAndConstantArguent( + int id, ValueId input_id, ValueId output_id, const RuntimeOptions& options, + OperationType op_type, const ElementwiseAttributes& attr) { + auto desc = std::make_shared(); + desc->id = id; + desc->is_linkable = true; + auto scalar = absl::get_if(&attr.param); + auto linear_buf = + absl::get_if>(&attr.param); + std::string param_desc; + if (scalar) { + param_desc += ", float scalar_val"; + } + if (linear_buf) { + param_desc += ", device FLT4* const linear_buf"; + } + desc->shader_source = + "FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid" + param_desc + + ") {\n"; + if (scalar) { + desc->shader_source += " FLT4 second_arg = FLT4(scalar_val);\n"; + } else if (linear_buf) { + desc->shader_source += " FLT4 second_arg = linear_buf[gid.z];\n"; + } + desc->shader_source += + " return " + TwoInputFunctor(op_type, "value", "second_arg") + ";\n"; + desc->shader_source += " }"; + + desc->input_buffers = {{input_id}}; + desc->output_buffer = {output_id}; + if (scalar) { + std::vector scalar_bits = + GetByteBuffer(std::vector{*scalar}); + desc->uniform_buffers = { + {"constant float&", + [scalar_bits](const std::map& buffers) { + return scalar_bits; + }}, + }; + } else if (linear_buf) { + desc->immutable_buffers = { + {"device FLT4* const", + GetByteBufferConverted(linear_buf->data, options.storage_precision)}, + }; + } + return {desc}; +} + } // namespace metal } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/elementwise.h b/tensorflow/lite/delegates/gpu/metal/kernels/elementwise.h index af70e433e79..2520c2f2df4 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/elementwise.h +++ b/tensorflow/lite/delegates/gpu/metal/kernels/elementwise.h @@ -25,13 +25,26 @@ namespace tflite { namespace gpu { namespace metal { +struct ElementwiseBroadcastSettings { + bool width = false; + bool height = false; + bool channels = false; +}; + +// Two inputs are two runtime tensors std::vector ElementwiseWithTwoInputs( int id, std::vector input_ids, ValueId output_id, - OperationType op_type, const ElementwiseAttributes* attr); + OperationType op_type, const ElementwiseBroadcastSettings& settings); +// One input is one runtime tensor std::vector ElementwiseWithOneInput( int id, ValueId input_id, ValueId output_id, OperationType op_type); +// First input is one runtime tensor and second input is constant argument +std::vector ElementwiseWithOneInputAndConstantArguent( + int id, ValueId input_id, ValueId output_id, const RuntimeOptions& options, + OperationType op_type, const ElementwiseAttributes& attr); + } // namespace metal } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/elementwise_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/elementwise_test.mm index d8521ba76b1..6b30bc5c703 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/elementwise_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/elementwise_test.mm @@ -94,7 +94,7 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { - (void)testExp { OperationType op_type = OperationType::EXP; - const BHWC shape(1, 1, 1, 5); + const BHWC shape(1, 1, 1, 7); SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, /*inputs=*/{GetTensorRef(0, shape)}, /*outputs=*/{GetTensorRef(1, shape)}); @@ -312,4 +312,34 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } +- (void)testMulBroadcastChannels { + OperationType op_type = OperationType::MUL; + const BHWC shape(1, 1, 2, 2); + const BHWC shape_2(1, 1, 2, 1); + SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, + /*inputs=*/{GetTensorRef(0, shape), GetTensorRef(1, shape_2)}, + /*outputs=*/{GetTensorRef(2, shape)}); + XCTAssertTrue(model.PopulateTensor(0, {1.0, 2.0, 3.0, 4.0})); + XCTAssertTrue(model.PopulateTensor(1, {2.0, 3.0})); + auto status = model.Invoke(); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + status = CompareVectors({2.0, 4.0, 9.0, 12.0}, model.GetOutput(0), 1e-6f); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); +} + +- (void)testMulBroadcastWidthAndHeight { + OperationType op_type = OperationType::MUL; + const BHWC shape(1, 1, 2, 2); + const BHWC shape_2(1, 1, 1, 2); + SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, + /*inputs=*/{GetTensorRef(0, shape), GetTensorRef(1, shape_2)}, + /*outputs=*/{GetTensorRef(2, shape)}); + XCTAssertTrue(model.PopulateTensor(0, {1.0, 2.0, 3.0, 4.0})); + XCTAssertTrue(model.PopulateTensor(1, {2.0, 3.0})); + auto status = model.Invoke(); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); + status = CompareVectors({2.0, 6.0, 6.0, 12.0}, model.GetOutput(0), 1e-6f); + XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); +} + @end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/mul.cc b/tensorflow/lite/delegates/gpu/metal/kernels/mul.cc index 21a04f2fc35..e90ab6b4f12 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/mul.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/mul.cc @@ -35,102 +35,6 @@ limitations under the License. namespace tflite { namespace gpu { namespace metal { -namespace { - -std::string GetApplyMaskCode() { - std::string shader_source = R"( - #include - using namespace metal; - struct uniforms { - int4 src_0_size; - int4 src_1_size; - int4 dst_size; - }; - - $0 - kernel void ComputeFunction( - $1 - uint3 gid[[thread_position_in_grid]]) { - int X = static_cast(gid.x); - int Y = static_cast(gid.y); - if (X >= params.dst_size.x || Y >= params.dst_size.y) { - return; - } - int src_0_index = (gid.z * params.src_0_size.y + static_cast(gid.y)) * - params.src_0_size.x + static_cast(gid.x); - int src_1_index = 0; - if (params.dst_size.z == 1) { - // [H, W, C] x [H, W, 0][0] - src_1_index = static_cast(gid.y) * params.src_1_size.x + - static_cast(gid.x); - } else if (params.src_0_size.y == params.src_1_size.y && - params.src_0_size.x == params.src_1_size.x) { - // [H, W, C] x [H, W, C] - src_1_index = src_0_index; - } else { - // [H, W, C] x [0, 0, C] - src_1_index = gid.z * params.src_1_size.y * params.src_1_size.x ; - } - FLT4 value = src_buffer_0[src_0_index] * src_buffer_1[src_1_index]; - int linear_index = (gid.z * params.dst_size.y + static_cast(gid.y)) * - params.dst_size.x + static_cast(gid.x); - $2 - dst_buffer[linear_index] = value; - } - )"; - return shader_source; -} -} // namespace - -std::vector ApplyMask(int id, ValueId input_id_0, - ValueId input_id_1, - ValueId output_id, - const RuntimeOptions& options) { - auto desc = std::make_shared(); - desc->id = id; - desc->is_linkable = false; - desc->shader_source = GetApplyMaskCode(); - - desc->input_buffers = { - {input_id_0, "device FLT4* const src_buffer_0"}, // data - {input_id_1, "device FLT4* const src_buffer_1"}, // mask - }; - - desc->output_buffer = { - output_id, "device FLT4* dst_buffer", - [input_id_0, input_id_1](const std::map& buffers) { - return buffers.find(input_id_0)->second; - }}; - - desc->uniform_buffers = { - {"constant uniforms& params", - [input_id_0, input_id_1, - output_id](const std::map& buffers) { - const auto& input_dim_0 = buffers.find(input_id_0)->second; - const auto& input_dim_1 = buffers.find(input_id_1)->second; - const auto& output_dim = buffers.find(output_id)->second; - std::vector uniform_params{ - input_dim_0.w, input_dim_0.h, input_dim_0.c, 0, - input_dim_1.w, input_dim_1.h, input_dim_1.c, 0, - output_dim.w, output_dim.h, output_dim.c, 0, - }; - return GetByteBuffer(uniform_params); - }}, - }; - - desc->resize_function = [input_id_0, - input_id_1](const std::map& buffers) { - const auto& src_shape = buffers.find(input_id_0)->second; - const uint3 groups_size{16, 16, 1}; - int groups_x = IntegralDivideRoundUp(src_shape.w, groups_size.x); - int groups_y = IntegralDivideRoundUp(src_shape.h, groups_size.y); - int groups_z = IntegralDivideRoundUp(src_shape.c, 4); - return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z}); - }; - - return {desc}; -} - std::vector Multiply(int id, ValueId input_id, ValueId output_id, const MultiplyAttributes& attr, diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/mul.h b/tensorflow/lite/delegates/gpu/metal/kernels/mul.h index bc83b149e78..b5ff37cf560 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/mul.h +++ b/tensorflow/lite/delegates/gpu/metal/kernels/mul.h @@ -30,11 +30,6 @@ std::vector Multiply(int id, ValueId input_id, ValueId output_id, const MultiplyAttributes& attr, const RuntimeOptions& options); - -std::vector ApplyMask(int id, ValueId input_id_0, - ValueId input_id_1, - ValueId output_id, - const RuntimeOptions& options); } // namespace metal } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/mul_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/mul_test.mm index f69598bad5b..d881950c831 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/mul_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/kernels/mul_test.mm @@ -95,55 +95,4 @@ using ::tflite::gpu::metal::SingleOpModel; XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); } - -- (void)testApplyMaskChannel1 { - TensorRef input; - input.type = DataType::FLOAT32; - input.ref = 0; - input.shape = BHWC(1, 1, 2, 2); - - TensorRef mask; - mask.type = DataType::FLOAT32; - mask.ref = 1; - mask.shape = BHWC(1, 1, 2, 1); - - TensorRef output; - output.type = DataType::FLOAT32; - output.ref = 2; - output.shape = BHWC(1, 1, 2, 2); - - SingleOpModel model({ToString(OperationType::MUL), {}}, {input, mask}, {output}); - XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); - XCTAssertTrue(model.PopulateTensor(1, {2, 3})); - auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); - status = CompareVectors({2, 4, 9, 12}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); -} - -- (void)testApplyMaskEqualsToInputChannel { - TensorRef input; - input.type = DataType::FLOAT32; - input.ref = 0; - input.shape = BHWC(1, 1, 2, 2); - - TensorRef mask; - mask.type = DataType::FLOAT32; - mask.ref = 1; - mask.shape = BHWC(1, 1, 2, 2); - - TensorRef output; - output.type = DataType::FLOAT32; - output.ref = 2; - output.shape = BHWC(1, 1, 2, 2); - - SingleOpModel model({ToString(OperationType::MUL), {}}, {input, mask}, {output}); - XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); - XCTAssertTrue(model.PopulateTensor(1, {1, 2, 3, 4})); - auto status = model.Invoke(); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); - status = CompareVectors({1, 4, 9, 16}, model.GetOutput(0), 1e-6f); - XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str()); -} - @end