Reusing common function for constant argument parsing in Multiply operation.

PiperOrigin-RevId: 329539781
Change-Id: I145dce25607f3a3725d930f2f8ed9d9a9193f76d
This commit is contained in:
Raman Sarokin 2020-09-01 10:57:14 -07:00 committed by TensorFlower Gardener
parent 7a9ca9b344
commit 2c5a33590d

View File

@ -1106,7 +1106,6 @@ class MulOperationParser : public TFLiteOperationParser {
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
// Determine runtime/constant tensors.
const TfLiteTensor* input0 = reader->GetInputTensor(0);
if (!input0) {
return absl::InvalidArgumentError(
@ -1127,9 +1126,9 @@ class MulOperationParser : public TFLiteOperationParser {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::MUL);
RETURN_IF_ERROR(reader->AddOutputs(node));
// The "larger" input tensor must be bound to 1st input and the "smaller"
// input tensor ("mask") must be bound to 2nd input.
// Determine runtime/constant tensors.
if (runtime_tensor0 && runtime_tensor1) {
if (input0 == input1) {
// replace MUL(A, A) with POW(A, 2.0)
@ -1138,10 +1137,11 @@ class MulOperationParser : public TFLiteOperationParser {
ElementwiseAttributes attr;
attr.param = 2.0f;
node->operation.attributes = std::move(attr);
RETURN_IF_ERROR(reader->AddInput(node, 0));
return reader->AddOutputs(node);
return reader->AddInput(node, 0);
}
// The "larger" input tensor must be bound to 1st input and the "smaller"
// input tensor must be bound to 2nd input.
BHWC shape0;
RETURN_IF_ERROR(ExtractTensorShape(*input0, &shape0));
BHWC shape1;
@ -1153,58 +1153,18 @@ class MulOperationParser : public TFLiteOperationParser {
input_tensor0 = 1;
input_tensor1 = 0;
}
RETURN_IF_ERROR(
ParseApplyMask(node, input_tensor0, input_tensor1, graph, reader));
RETURN_IF_ERROR(reader->AddInput(node, input_tensor0));
RETURN_IF_ERROR(reader->AddInput(node, input_tensor1));
} else {
// The runtime input tensor must be bound to 1st input and the constant
// input tensor must be bound to 2nd input.
int runtime_tensor = 0;
int constant_tensor = 1;
TfLiteIntArray* constant_dims = input1->dims;
if (constant_tensor0 && runtime_tensor1) {
runtime_tensor = 1;
constant_tensor = 0;
constant_dims = input0->dims;
}
RETURN_IF_ERROR(ParseMultiplyScalar(node, runtime_tensor, constant_tensor,
constant_dims, graph, reader));
ElementwiseAttributes attr;
RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param));
node->operation.attributes = std::move(attr);
}
const TfLiteMulParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
return MaybeFuseActivation(tf_options->activation, graph, node);
}
private:
absl::Status ParseApplyMask(Node* node, int input_tensor0, int input_tensor1,
GraphFloat32* graph, ObjectReader* reader) {
RETURN_IF_ERROR(reader->AddInput(node, input_tensor0));
RETURN_IF_ERROR(reader->AddInput(node, input_tensor1));
return reader->AddOutputs(node);
}
absl::Status ParseMultiplyScalar(Node* node, int runtime_tensor,
int constant_tensor,
const TfLiteIntArray* constant_dims,
GraphFloat32* graph, ObjectReader* reader) {
RETURN_IF_ERROR(reader->AddInput(node, runtime_tensor));
ElementwiseAttributes attr;
if (constant_dims->size <= 0 || NumElements(constant_dims) == 1) {
Tensor<Scalar, DataType::FLOAT32> tensor;
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
attr.param = tensor.data[0];
} else if (constant_dims->size == 3) {
Tensor<HWC, DataType::FLOAT32> tensor;
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
attr.param = std::move(tensor);
} else {
Tensor<Linear, DataType::FLOAT32> tensor;
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
attr.param = std::move(tensor);
}
node->operation.attributes = std::move(attr);
return reader->AddOutputs(node);
}
};
class PReLUOperationParser : public TFLiteOperationParser {