From 59145d293a442cefb47f92a564e472ee748429ab Mon Sep 17 00:00:00 2001 From: Raman Sarokin Date: Mon, 3 Feb 2020 09:05:30 -0800 Subject: [PATCH] Support of broadcasting for ElementwiseTwoInput. MUL and ADD added to ElementwiseTwoInput(for some cases). Removed ApplyMask(done as MUL in ElementwiseTwoInput). PiperOrigin-RevId: 292931827 Change-Id: I32ef9b5ece3f47011d6da3f361a5a6b25a8b9127 --- .../delegates/gpu/cl/kernels/elementwise.cc | 51 ++++++-- .../delegates/gpu/cl/kernels/elementwise.h | 18 ++- .../gpu/cl/kernels/elementwise_test.cc | 122 ++++++++++++++++++ .../delegates/gpu/cl/kernels/multiply_add.cc | 70 ---------- .../delegates/gpu/cl/kernels/multiply_add.h | 30 ----- .../gpu/cl/kernels/multiply_add_test.cc | 90 ------------- .../gpu/cl/selectors/operation_selector.cc | 64 +++++++-- .../gpu/cl/selectors/simple_selectors.cc | 7 - .../gpu/cl/selectors/simple_selectors.h | 4 - 9 files changed, 234 insertions(+), 222 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.cc b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.cc index d6d4acc8858..b6c6b1409f8 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.cc @@ -105,13 +105,15 @@ ElementwiseOneInput CreateElementwiseOneInput(const OperationDef& definition, ElementwiseTwoInput::ElementwiseTwoInput(ElementwiseTwoInput&& operation) : ElementwiseOperation(std::move(operation)), link_index_(operation.link_index_), - op_type_(operation.op_type_) {} + op_type_(operation.op_type_), + broadcast_(operation.broadcast_) {} ElementwiseTwoInput& ElementwiseTwoInput::operator=( ElementwiseTwoInput&& operation) { if (this != &operation) { link_index_ = operation.link_index_; op_type_ = operation.op_type_; + broadcast_ = operation.broadcast_; ElementwiseOperation::operator=(std::move(operation)); } return *this; @@ -121,31 +123,46 @@ void ElementwiseTwoInput::SetLinkIndex(int index) { link_index_ = index; } std::string ElementwiseTwoInput::GetCoreCode( const LinkingContext& context) const { + const std::string size_name = "src_size_" + std::to_string(link_index_); TensorCodeGenerator src_tensor( absl::StrCat("src_data_", link_index_), - WHSPoint{"src_size.x", "src_size.y", "src_size.z"}, + WHSPoint{size_name + ".x", size_name + ".y", size_name + ".z"}, definition_.src_tensors[1]); - std::string result; + const std::string x_coord = broadcast_.width ? "0" : context.x_coord; + const std::string y_coord = broadcast_.height ? "0" : context.y_coord; + const std::string s_coord = broadcast_.channels ? "0" : context.s_coord; + const std::string second_var = "second_var_" + std::to_string(link_index_); + std::string result = " FLT4 " + second_var + " = " + + src_tensor.ReadWHS(x_coord, y_coord, s_coord) + ";\n"; + if (broadcast_.channels) { + result += " " + second_var + ".y = " + second_var + ".x;\n"; + result += " " + second_var + ".z = " + second_var + ".x;\n"; + result += " " + second_var + ".w = " + second_var + ".x;\n"; + } switch (op_type_) { + case OperationType::ADD: + result += "$0 += $1;\n"; + break; case OperationType::DIV: - result = "$0 /= $1;\n"; + result += "$0 /= $1;\n"; + break; + case OperationType::MUL: + result += "$0 *= $1;\n"; break; case OperationType::POW: - result = "$0 = pow($0, $1);\n"; + result += "$0 = pow($0, $1);\n"; break; case OperationType::SQUARED_DIFF: - result = "$0 -= $1;\n"; + result += "$0 -= $1;\n"; result += "$0 *= $0;\n"; break; case OperationType::SUB: - result = "$0 -= $1;\n"; + result += "$0 -= $1;\n"; break; default: return "Unknown operation type;\n"; } - return absl::Substitute( - result, context.var_name, - src_tensor.ReadWHS(context.x_coord, context.y_coord, context.s_coord)); + return absl::Substitute(result, context.var_name, second_var); } std::string ElementwiseTwoInput::GetArgsDeclaration() const { @@ -164,9 +181,21 @@ Status ElementwiseTwoInput::BindArguments(CLKernel* kernel) { return OkStatus(); } +ElementwiseTwoInput CreateElementwiseTwoInput( + const OperationDef& definition, const OperationType& op_type, + const BroadcastSettings& broadcast) { + ElementwiseTwoInput operation(definition, op_type, broadcast); + operation.SetLinkIndex(0); + return operation; +} + ElementwiseTwoInput CreateElementwiseTwoInput(const OperationDef& definition, const OperationType& op_type) { - ElementwiseTwoInput operation(definition, op_type); + BroadcastSettings broadcast; + broadcast.width = false; + broadcast.height = false; + broadcast.channels = false; + ElementwiseTwoInput operation(definition, op_type, broadcast); operation.SetLinkIndex(0); return operation; } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.h b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.h index 7ace3ee131c..a09ddd1b7db 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.h @@ -48,13 +48,22 @@ class ElementwiseOneInput : public ElementwiseOperation { ElementwiseOneInput CreateElementwiseOneInput(const OperationDef& definition, const OperationType& op_type); +struct BroadcastSettings { + bool width; + bool height; + bool channels; +}; + // Class for simple two input operations without any parameters, for example // sub, div and etc. class ElementwiseTwoInput : public ElementwiseOperation { public: explicit ElementwiseTwoInput(const OperationDef& definition, - const OperationType& op_type) - : ElementwiseOperation(definition), op_type_(op_type) {} + const OperationType& op_type, + const BroadcastSettings& broadcast) + : ElementwiseOperation(definition), + op_type_(op_type), + broadcast_(broadcast) {} // Move only ElementwiseTwoInput(ElementwiseTwoInput&& operation); @@ -70,8 +79,13 @@ class ElementwiseTwoInput : public ElementwiseOperation { private: int link_index_; OperationType op_type_; + BroadcastSettings broadcast_; }; +ElementwiseTwoInput CreateElementwiseTwoInput( + const OperationDef& definition, const OperationType& op_type, + const BroadcastSettings& broadcast); + ElementwiseTwoInput CreateElementwiseTwoInput(const OperationDef& definition, const OperationType& op_type); diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise_test.cc index 81b29bfab82..24d30eecf25 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise_test.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise_test.cc @@ -397,6 +397,128 @@ TEST_F(OpenCLOperationTest, Pow) { } } +TEST_F(OpenCLOperationTest, Add) { + TensorFloat32 src_tensor_0, src_tensor_1; + src_tensor_0.shape = BHWC(1, 2, 1, 2); + src_tensor_1.shape = BHWC(1, 2, 1, 2); + src_tensor_0.data = {1.0f, 2.0f, 3.0f, 4.5f}; + src_tensor_1.data = {0.5f, 1.0f, 3.0f, 1.5f}; + + for (auto storage : env_.GetSupportedStorages()) { + for (auto precision : env_.GetSupportedPrecisions()) { + const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f; + OperationDef op_def; + op_def.precision = precision; + auto data_type = DeduceDataTypeFromPrecision(precision); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); + TensorFloat32 dst_tensor; + ElementwiseTwoInput operation = + CreateElementwiseTwoInput(op_def, OperationType::ADD); + ASSERT_OK(ExecuteGPUOperation({src_tensor_0, src_tensor_1}, + creation_context_, &operation, + BHWC(1, 2, 1, 2), &dst_tensor)); + EXPECT_THAT(dst_tensor.data, + Pointwise(FloatNear(eps), {1.5f, 3.0f, 6.0f, 6.0f})); + } + } +} + +TEST_F(OpenCLOperationTest, Mul) { + TensorFloat32 src_tensor_0, src_tensor_1; + src_tensor_0.shape = BHWC(1, 2, 1, 2); + src_tensor_1.shape = BHWC(1, 2, 1, 2); + src_tensor_0.data = {1.0f, 2.0f, 3.0f, 4.5f}; + src_tensor_1.data = {0.5f, 1.0f, 3.0f, 1.5f}; + + for (auto storage : env_.GetSupportedStorages()) { + for (auto precision : env_.GetSupportedPrecisions()) { + const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f; + OperationDef op_def; + op_def.precision = precision; + auto data_type = DeduceDataTypeFromPrecision(precision); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); + TensorFloat32 dst_tensor; + ElementwiseTwoInput operation = + CreateElementwiseTwoInput(op_def, OperationType::MUL); + ASSERT_OK(ExecuteGPUOperation({src_tensor_0, src_tensor_1}, + creation_context_, &operation, + BHWC(1, 2, 1, 2), &dst_tensor)); + EXPECT_THAT(dst_tensor.data, + Pointwise(FloatNear(eps), {0.5f, 2.0f, 9.0f, 6.75f})); + } + } +} + +TEST_F(OpenCLOperationTest, MulBroadcastHW) { + TensorFloat32 src_tensor_0, src_tensor_1; + src_tensor_0.shape = BHWC(1, 2, 1, 2); + src_tensor_1.shape = BHWC(1, 1, 1, 2); + src_tensor_0.data = {1.0f, 2.0f, 3.0f, 4.5f}; + src_tensor_1.data = {0.5f, 3.0f}; + + BroadcastSettings broadcast; + broadcast.width = true; + broadcast.height = true; + broadcast.channels = false; + + for (auto storage : env_.GetSupportedStorages()) { + for (auto precision : env_.GetSupportedPrecisions()) { + const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f; + OperationDef op_def; + op_def.precision = precision; + auto data_type = DeduceDataTypeFromPrecision(precision); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); + TensorFloat32 dst_tensor; + ElementwiseTwoInput operation = + CreateElementwiseTwoInput(op_def, OperationType::MUL, broadcast); + ASSERT_OK(ExecuteGPUOperation({src_tensor_0, src_tensor_1}, + creation_context_, &operation, + BHWC(1, 2, 1, 2), &dst_tensor)); + EXPECT_THAT(dst_tensor.data, + Pointwise(FloatNear(eps), {0.5f, 6.0f, 1.5f, 13.5f})); + } + } +} + +TEST_F(OpenCLOperationTest, MulBroadcastChannels) { + TensorFloat32 src_tensor_0, src_tensor_1; + src_tensor_0.shape = BHWC(1, 2, 1, 2); + src_tensor_1.shape = BHWC(1, 2, 1, 1); + src_tensor_0.data = {1.0f, 2.0f, 3.0f, 4.5f}; + src_tensor_1.data = {0.5f, 3.0f}; + + BroadcastSettings broadcast; + broadcast.width = false; + broadcast.height = false; + broadcast.channels = true; + + for (auto storage : env_.GetSupportedStorages()) { + for (auto precision : env_.GetSupportedPrecisions()) { + const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f; + OperationDef op_def; + op_def.precision = precision; + auto data_type = DeduceDataTypeFromPrecision(precision); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); + TensorFloat32 dst_tensor; + ElementwiseTwoInput operation = + CreateElementwiseTwoInput(op_def, OperationType::MUL, broadcast); + ASSERT_OK(ExecuteGPUOperation({src_tensor_0, src_tensor_1}, + creation_context_, &operation, + BHWC(1, 2, 1, 2), &dst_tensor)); + EXPECT_THAT(dst_tensor.data, + Pointwise(FloatNear(eps), {0.5f, 1.0f, 9.0f, 13.5f})); + } + } +} + } // namespace } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/multiply_add.cc b/tensorflow/lite/delegates/gpu/cl/kernels/multiply_add.cc index b8fd56a4752..45f48246078 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/multiply_add.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/multiply_add.cc @@ -175,76 +175,6 @@ Status CreateMultiplyAdd(const CreationContext& creation_context, return OkStatus(); } -ApplyMask::ApplyMask(ApplyMask&& operation) - : ElementwiseOperation(std::move(operation)), - mask_type_(operation.mask_type_), - link_index_(operation.link_index_) {} - -ApplyMask& ApplyMask::operator=(ApplyMask&& operation) { - if (this != &operation) { - mask_type_ = operation.mask_type_; - link_index_ = operation.link_index_; - ElementwiseOperation::operator=(std::move(operation)); - } - return *this; -} - -void ApplyMask::SetLinkIndex(int index) { link_index_ = index; } - -std::string ApplyMask::GetCoreCode(const LinkingContext& context) const { - const std::string size_name = "mask_size_op" + std::to_string(link_index_); - const std::string tensor_name = absl::StrCat("mask_data_op", link_index_); - TensorCodeGenerator mask( - tensor_name, - WHSPoint{size_name + ".x", size_name + ".y", size_name + ".z"}, - definition_.src_tensors[1]); - switch (mask_type_) { - case MaskType::TENSOR: - return context.var_name + " *= " + - mask.ReadWHS(context.x_coord, context.y_coord, context.s_coord) + - ";\n"; - case MaskType::CHANNELS: - return context.var_name + - " *= " + mask.ReadWHS("0", "0", context.s_coord) + ";\n"; - case MaskType::LAYER: - return context.var_name + - " *= " + mask.ReadWHS(context.x_coord, context.y_coord, "0") + - ".x;\n"; - } -} - -std::string ApplyMask::GetArgsDeclaration() const { - std::string args; - const std::string tensor_name = absl::StrCat("mask_data_op", link_index_); - absl::StrAppend(&args, ",\n", - GetTensorDeclaration(AccessType::READ, tensor_name, - definition_.src_tensors[1])); - const std::string size_name = "mask_size_op" + std::to_string(link_index_); - absl::StrAppend(&args, ",\n int4 ", size_name); - return args; -} - -Status ApplyMask::BindArguments(CLKernel* kernel) { - RETURN_IF_ERROR(kernel->SetMemoryAuto(src_[1]->GetMemoryPtr())); - RETURN_IF_ERROR(kernel->SetBytesAuto(src_[1]->GetWBatchedHSB())); - return OkStatus(); -} - -ApplyMask CreateApplyMask(const OperationDef& definition, const BHWC& src_shape, - const BHWC& mask_shape) { - ApplyMask::MaskType mask_type; - if (mask_shape == src_shape) { - mask_type = ApplyMask::MaskType::TENSOR; - } else if (mask_shape.c == 1) { - mask_type = ApplyMask::MaskType::LAYER; - } else { - mask_type = ApplyMask::MaskType::CHANNELS; - } - ApplyMask operation(definition, mask_type); - operation.SetLinkIndex(0); - return operation; -} - } // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/multiply_add.h b/tensorflow/lite/delegates/gpu/cl/kernels/multiply_add.h index 650a20ef49b..83bb6e11216 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/multiply_add.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/multiply_add.h @@ -126,36 +126,6 @@ Status MultiplyAdd::UploadAdd(const ::tflite::gpu::Tensor& add, return OkStatus(); } -class ApplyMask : public ElementwiseOperation { - public: - // Move only - ApplyMask(ApplyMask&& operation); - ApplyMask& operator=(ApplyMask&& operation); - ApplyMask(const ApplyMask&) = delete; - ApplyMask& operator=(const ApplyMask&) = delete; - - void SetLinkIndex(int index) override; - std::string GetCoreCode(const LinkingContext& context) const override; - std::string GetArgsDeclaration() const override; - Status BindArguments(CLKernel* kernel) override; - - private: - friend ApplyMask CreateApplyMask(const OperationDef& definition, - const BHWC& src_shape, - const BHWC& mask_shape); - - enum class MaskType { LAYER, CHANNELS, TENSOR }; - - explicit ApplyMask(const OperationDef& definition, MaskType mask_type) - : ElementwiseOperation(definition), mask_type_(mask_type) {} - - MaskType mask_type_; - int link_index_; -}; - -ApplyMask CreateApplyMask(const OperationDef& definition, const BHWC& src_shape, - const BHWC& mask_shape); - } // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/multiply_add_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/multiply_add_test.cc index c3cb97106b1..444a380c2e9 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/multiply_add_test.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/multiply_add_test.cc @@ -181,96 +181,6 @@ TEST_F(OpenCLOperationTest, MultiplyAddVectorMad) { } } -TEST_F(OpenCLOperationTest, ApplyMaskOneChannel) { - TensorFloat32 src_tensor; - src_tensor.shape = BHWC(1, 2, 2, 2); - src_tensor.data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 3.0f, 4.0f, 6.0f}; - TensorFloat32 mask_tensor; - mask_tensor.shape = BHWC(1, 2, 2, 1); - mask_tensor.data = {2.0f, 0.5f, 1.0f, 0.0f}; - - for (auto storage : env_.GetSupportedStorages()) { - for (auto precision : env_.GetSupportedPrecisions()) { - const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f; - OperationDef op_def; - op_def.precision = precision; - auto data_type = DeduceDataTypeFromPrecision(precision); - op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); - op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); - op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); - TensorFloat32 dst_tensor; - ApplyMask operation = - CreateApplyMask(op_def, src_tensor.shape, mask_tensor.shape); - ASSERT_OK(ExecuteGPUOperation({src_tensor, mask_tensor}, - creation_context_, &operation, - BHWC(1, 2, 2, 2), &dst_tensor)); - EXPECT_THAT(dst_tensor.data, - Pointwise(FloatNear(eps), {-8.0f, -6.0f, -0.5f, 0.0f, 1.0f, - 3.0f, 0.0f, 0.0f})); - } - } -} - -TEST_F(OpenCLOperationTest, ApplyMaskEqualSizes) { - TensorFloat32 src_tensor; - src_tensor.shape = BHWC(1, 2, 2, 2); - src_tensor.data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 3.0f, 4.0f, 6.0f}; - TensorFloat32 mask_tensor; - mask_tensor.shape = BHWC(1, 2, 2, 2); - mask_tensor.data = {2.0f, 0.5f, 1.0f, 0.0f, 2.0f, 0.5f, 1.0f, 0.0f}; - - for (auto storage : env_.GetSupportedStorages()) { - for (auto precision : env_.GetSupportedPrecisions()) { - const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f; - OperationDef op_def; - op_def.precision = precision; - auto data_type = DeduceDataTypeFromPrecision(precision); - op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); - op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); - op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); - TensorFloat32 dst_tensor; - ApplyMask operation = - CreateApplyMask(op_def, src_tensor.shape, mask_tensor.shape); - ASSERT_OK(ExecuteGPUOperation({src_tensor, mask_tensor}, - creation_context_, &operation, - BHWC(1, 2, 2, 2), &dst_tensor)); - EXPECT_THAT(dst_tensor.data, - Pointwise(FloatNear(eps), {-8.0f, -1.5f, -1.0f, 0.0f, 2.0f, - 1.5f, 4.0f, 0.0f})); - } - } -} - -TEST_F(OpenCLOperationTest, ApplyMaskVector) { - TensorFloat32 src_tensor; - src_tensor.shape = BHWC(1, 2, 2, 2); - src_tensor.data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 3.0f, 4.0f, 6.0f}; - TensorFloat32 mask_tensor; - mask_tensor.shape = BHWC(1, 1, 1, 2); - mask_tensor.data = {2.0f, 0.5f}; - - for (auto storage : env_.GetSupportedStorages()) { - for (auto precision : env_.GetSupportedPrecisions()) { - const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f; - OperationDef op_def; - op_def.precision = precision; - auto data_type = DeduceDataTypeFromPrecision(precision); - op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); - op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); - op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); - TensorFloat32 dst_tensor; - ApplyMask operation = - CreateApplyMask(op_def, src_tensor.shape, mask_tensor.shape); - ASSERT_OK(ExecuteGPUOperation({src_tensor, mask_tensor}, - creation_context_, &operation, - BHWC(1, 2, 2, 2), &dst_tensor)); - EXPECT_THAT(dst_tensor.data, - Pointwise(FloatNear(eps), {-8.0f, -1.5f, -2.0f, 0.0f, 2.0f, - 1.5f, 8.0f, 3.0f})); - } - } -} - } // namespace } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc index c93c2b49bd1..e45a750b2fd 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc @@ -26,11 +26,32 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" namespace tflite { namespace gpu { namespace cl { +namespace { +bool IsWidthBroadcastedForSecondInput( + const std::vector>*>& inputs) { + return inputs.size() == 2 && + inputs[0]->tensor.shape.w != inputs[1]->tensor.shape.w && + inputs[1]->tensor.shape.w == 1; +} +bool IsHeightBroadcastedForSecondInput( + const std::vector>*>& inputs) { + return inputs.size() == 2 && + inputs[0]->tensor.shape.h != inputs[1]->tensor.shape.h && + inputs[1]->tensor.shape.h == 1; +} +bool IsChannelsBroadcastedForSecondInput( + const std::vector>*>& inputs) { + return inputs.size() == 2 && + inputs[0]->tensor.shape.c != inputs[1]->tensor.shape.c && + inputs[1]->tensor.shape.c == 1; +} +} // namespace Status GPUOperationFromNode(const CreationContext& creation_context, const OperationDef& op_def, ModelHints hints, @@ -59,12 +80,23 @@ Status GPUOperationFromNode(const CreationContext& creation_context, if (adds || adds_scalar) { return SelectBroadcastAdd(attr, creation_context, op_def, gpu_op); } else { - auto output = outputs[0]; - std::vector channels(inputs.size()); - for (int i = 0; i < inputs.size(); ++i) { - channels[i] = inputs[i]->tensor.shape.c; + BroadcastSettings broadcast; + broadcast.width = IsWidthBroadcastedForSecondInput(inputs); + broadcast.height = IsHeightBroadcastedForSecondInput(inputs); + broadcast.channels = IsChannelsBroadcastedForSecondInput(inputs); + if (broadcast.width || broadcast.height || broadcast.channels) { + ElementwiseTwoInput operation = + CreateElementwiseTwoInput(op_def, op_type, broadcast); + *gpu_op = + absl::make_unique(std::move(operation)); + } else { + auto output = outputs[0]; + std::vector channels(inputs.size()); + for (int i = 0; i < inputs.size(); ++i) { + channels[i] = inputs[i]->tensor.shape.c; + } + SelectAdd(op_def, channels, output->tensor.shape.c, gpu_op); } - SelectAdd(op_def, channels, output->tensor.shape.c, gpu_op); return OkStatus(); } } @@ -121,8 +153,20 @@ Status GPUOperationFromNode(const CreationContext& creation_context, return SelectMultiplyScalar(attr, creation_context, op_def, gpu_op); } else { - SelectApplyMask(op_def, inputs[0]->tensor.shape, - inputs[1]->tensor.shape, gpu_op); + if (inputs.size() == 2) { + BroadcastSettings broadcast; + broadcast.width = IsWidthBroadcastedForSecondInput(inputs); + broadcast.height = IsHeightBroadcastedForSecondInput(inputs); + broadcast.channels = IsChannelsBroadcastedForSecondInput(inputs); + ElementwiseTwoInput operation = + CreateElementwiseTwoInput(op_def, op_type, broadcast); + *gpu_op = + absl::make_unique(std::move(operation)); + return OkStatus(); + } else { + return UnimplementedError( + "No support of multiply with more than 2 inputs"); + } return OkStatus(); } } @@ -190,8 +234,12 @@ Status GPUOperationFromNode(const CreationContext& creation_context, case OperationType::POW: case OperationType::SQUARED_DIFF: case OperationType::SUB: { + BroadcastSettings broadcast; + broadcast.width = IsWidthBroadcastedForSecondInput(inputs); + broadcast.height = IsHeightBroadcastedForSecondInput(inputs); + broadcast.channels = IsChannelsBroadcastedForSecondInput(inputs); ElementwiseTwoInput operation = - CreateElementwiseTwoInput(op_def, op_type); + CreateElementwiseTwoInput(op_def, op_type, broadcast); *gpu_op = absl::make_unique(std::move(operation)); return OkStatus(); } diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc index 98af5951b95..1dde6e514a8 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc @@ -43,13 +43,6 @@ namespace tflite { namespace gpu { namespace cl { -void SelectApplyMask(const OperationDef& op_def, const BHWC& src_shape, - const BHWC& mask_shape, - std::unique_ptr* ptr) { - ApplyMask operation = CreateApplyMask(op_def, src_shape, mask_shape); - *ptr = absl::make_unique(std::move(operation)); -} - void SelectLSTM(const OperationDef& op_def, std::unique_ptr* ptr) { LSTM operation = CreateLSTM(op_def); diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h index d1b68d1ce28..a9cc7c2fe7b 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h +++ b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h @@ -27,10 +27,6 @@ namespace tflite { namespace gpu { namespace cl { -void SelectApplyMask(const OperationDef& op_def, const BHWC& src_shape, - const BHWC& mask_shape, - std::unique_ptr* ptr); - void SelectLSTM(const OperationDef& op_def, std::unique_ptr* ptr); void SelectReLU(const CreationContext& creation_context,