Reusing common function for constant argument parsing in Multiply operation.

PiperOrigin-RevId: 329539781
Change-Id: I145dce25607f3a3725d930f2f8ed9d9a9193f76d
This commit is contained in:
Raman Sarokin 2020-09-01 10:57:14 -07:00 committed by TensorFlower Gardener
parent 7a9ca9b344
commit 2c5a33590d

View File

@ -1106,7 +1106,6 @@ class MulOperationParser : public TFLiteOperationParser {
absl::Status Parse(const TfLiteNode* tflite_node, absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration, const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final { GraphFloat32* graph, ObjectReader* reader) final {
// Determine runtime/constant tensors.
const TfLiteTensor* input0 = reader->GetInputTensor(0); const TfLiteTensor* input0 = reader->GetInputTensor(0);
if (!input0) { if (!input0) {
return absl::InvalidArgumentError( return absl::InvalidArgumentError(
@ -1127,9 +1126,9 @@ class MulOperationParser : public TFLiteOperationParser {
Node* node = graph->NewNode(); Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::MUL); node->operation.type = ToString(OperationType::MUL);
RETURN_IF_ERROR(reader->AddOutputs(node));
// The "larger" input tensor must be bound to 1st input and the "smaller" // Determine runtime/constant tensors.
// input tensor ("mask") must be bound to 2nd input.
if (runtime_tensor0 && runtime_tensor1) { if (runtime_tensor0 && runtime_tensor1) {
if (input0 == input1) { if (input0 == input1) {
// replace MUL(A, A) with POW(A, 2.0) // replace MUL(A, A) with POW(A, 2.0)
@ -1138,10 +1137,11 @@ class MulOperationParser : public TFLiteOperationParser {
ElementwiseAttributes attr; ElementwiseAttributes attr;
attr.param = 2.0f; attr.param = 2.0f;
node->operation.attributes = std::move(attr); node->operation.attributes = std::move(attr);
RETURN_IF_ERROR(reader->AddInput(node, 0)); return reader->AddInput(node, 0);
return reader->AddOutputs(node);
} }
// The "larger" input tensor must be bound to 1st input and the "smaller"
// input tensor must be bound to 2nd input.
BHWC shape0; BHWC shape0;
RETURN_IF_ERROR(ExtractTensorShape(*input0, &shape0)); RETURN_IF_ERROR(ExtractTensorShape(*input0, &shape0));
BHWC shape1; BHWC shape1;
@ -1153,58 +1153,18 @@ class MulOperationParser : public TFLiteOperationParser {
input_tensor0 = 1; input_tensor0 = 1;
input_tensor1 = 0; input_tensor1 = 0;
} }
RETURN_IF_ERROR( RETURN_IF_ERROR(reader->AddInput(node, input_tensor0));
ParseApplyMask(node, input_tensor0, input_tensor1, graph, reader)); RETURN_IF_ERROR(reader->AddInput(node, input_tensor1));
} else { } else {
// The runtime input tensor must be bound to 1st input and the constant ElementwiseAttributes attr;
// input tensor must be bound to 2nd input. RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param));
int runtime_tensor = 0; node->operation.attributes = std::move(attr);
int constant_tensor = 1;
TfLiteIntArray* constant_dims = input1->dims;
if (constant_tensor0 && runtime_tensor1) {
runtime_tensor = 1;
constant_tensor = 0;
constant_dims = input0->dims;
}
RETURN_IF_ERROR(ParseMultiplyScalar(node, runtime_tensor, constant_tensor,
constant_dims, graph, reader));
} }
const TfLiteMulParams* tf_options; const TfLiteMulParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
return MaybeFuseActivation(tf_options->activation, graph, node); return MaybeFuseActivation(tf_options->activation, graph, node);
} }
private:
absl::Status ParseApplyMask(Node* node, int input_tensor0, int input_tensor1,
GraphFloat32* graph, ObjectReader* reader) {
RETURN_IF_ERROR(reader->AddInput(node, input_tensor0));
RETURN_IF_ERROR(reader->AddInput(node, input_tensor1));
return reader->AddOutputs(node);
}
absl::Status ParseMultiplyScalar(Node* node, int runtime_tensor,
int constant_tensor,
const TfLiteIntArray* constant_dims,
GraphFloat32* graph, ObjectReader* reader) {
RETURN_IF_ERROR(reader->AddInput(node, runtime_tensor));
ElementwiseAttributes attr;
if (constant_dims->size <= 0 || NumElements(constant_dims) == 1) {
Tensor<Scalar, DataType::FLOAT32> tensor;
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
attr.param = tensor.data[0];
} else if (constant_dims->size == 3) {
Tensor<HWC, DataType::FLOAT32> tensor;
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
attr.param = std::move(tensor);
} else {
Tensor<Linear, DataType::FLOAT32> tensor;
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
attr.param = std::move(tensor);
}
node->operation.attributes = std::move(attr);
return reader->AddOutputs(node);
}
}; };
class PReLUOperationParser : public TFLiteOperationParser { class PReLUOperationParser : public TFLiteOperationParser {