From 2c5a33590dd147e101ee9bd51df3325517bffd3a Mon Sep 17 00:00:00 2001 From: Raman Sarokin Date: Tue, 1 Sep 2020 10:57:14 -0700 Subject: [PATCH] Reusing common function for constant argument parsing in Multiply operation. PiperOrigin-RevId: 329539781 Change-Id: I145dce25607f3a3725d930f2f8ed9d9a9193f76d --- .../delegates/gpu/common/model_builder.cc | 60 ++++--------------- 1 file changed, 10 insertions(+), 50 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc index 04503c88439..c8f06aad994 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -1106,7 +1106,6 @@ class MulOperationParser : public TFLiteOperationParser { absl::Status Parse(const TfLiteNode* tflite_node, const TfLiteRegistration* registration, GraphFloat32* graph, ObjectReader* reader) final { - // Determine runtime/constant tensors. const TfLiteTensor* input0 = reader->GetInputTensor(0); if (!input0) { return absl::InvalidArgumentError( @@ -1127,9 +1126,9 @@ class MulOperationParser : public TFLiteOperationParser { Node* node = graph->NewNode(); node->operation.type = ToString(OperationType::MUL); + RETURN_IF_ERROR(reader->AddOutputs(node)); - // The "larger" input tensor must be bound to 1st input and the "smaller" - // input tensor ("mask") must be bound to 2nd input. + // Determine runtime/constant tensors. if (runtime_tensor0 && runtime_tensor1) { if (input0 == input1) { // replace MUL(A, A) with POW(A, 2.0) @@ -1138,10 +1137,11 @@ class MulOperationParser : public TFLiteOperationParser { ElementwiseAttributes attr; attr.param = 2.0f; node->operation.attributes = std::move(attr); - RETURN_IF_ERROR(reader->AddInput(node, 0)); - return reader->AddOutputs(node); + return reader->AddInput(node, 0); } + // The "larger" input tensor must be bound to 1st input and the "smaller" + // input tensor must be bound to 2nd input. BHWC shape0; RETURN_IF_ERROR(ExtractTensorShape(*input0, &shape0)); BHWC shape1; @@ -1153,58 +1153,18 @@ class MulOperationParser : public TFLiteOperationParser { input_tensor0 = 1; input_tensor1 = 0; } - RETURN_IF_ERROR( - ParseApplyMask(node, input_tensor0, input_tensor1, graph, reader)); + RETURN_IF_ERROR(reader->AddInput(node, input_tensor0)); + RETURN_IF_ERROR(reader->AddInput(node, input_tensor1)); } else { - // The runtime input tensor must be bound to 1st input and the constant - // input tensor must be bound to 2nd input. - int runtime_tensor = 0; - int constant_tensor = 1; - TfLiteIntArray* constant_dims = input1->dims; - if (constant_tensor0 && runtime_tensor1) { - runtime_tensor = 1; - constant_tensor = 0; - constant_dims = input0->dims; - } - RETURN_IF_ERROR(ParseMultiplyScalar(node, runtime_tensor, constant_tensor, - constant_dims, graph, reader)); + ElementwiseAttributes attr; + RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param)); + node->operation.attributes = std::move(attr); } const TfLiteMulParams* tf_options; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); return MaybeFuseActivation(tf_options->activation, graph, node); } - - private: - absl::Status ParseApplyMask(Node* node, int input_tensor0, int input_tensor1, - GraphFloat32* graph, ObjectReader* reader) { - RETURN_IF_ERROR(reader->AddInput(node, input_tensor0)); - RETURN_IF_ERROR(reader->AddInput(node, input_tensor1)); - return reader->AddOutputs(node); - } - - absl::Status ParseMultiplyScalar(Node* node, int runtime_tensor, - int constant_tensor, - const TfLiteIntArray* constant_dims, - GraphFloat32* graph, ObjectReader* reader) { - RETURN_IF_ERROR(reader->AddInput(node, runtime_tensor)); - ElementwiseAttributes attr; - if (constant_dims->size <= 0 || NumElements(constant_dims) == 1) { - Tensor tensor; - RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor)); - attr.param = tensor.data[0]; - } else if (constant_dims->size == 3) { - Tensor tensor; - RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor)); - attr.param = std::move(tensor); - } else { - Tensor tensor; - RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor)); - attr.param = std::move(tensor); - } - node->operation.attributes = std::move(attr); - return reader->AddOutputs(node); - } }; class PReLUOperationParser : public TFLiteOperationParser {