Reusing common function for constant argument parsing in Multiply operation.
PiperOrigin-RevId: 329539781 Change-Id: I145dce25607f3a3725d930f2f8ed9d9a9193f76d
This commit is contained in:
parent
7a9ca9b344
commit
2c5a33590d
@ -1106,7 +1106,6 @@ class MulOperationParser : public TFLiteOperationParser {
|
||||
absl::Status Parse(const TfLiteNode* tflite_node,
|
||||
const TfLiteRegistration* registration,
|
||||
GraphFloat32* graph, ObjectReader* reader) final {
|
||||
// Determine runtime/constant tensors.
|
||||
const TfLiteTensor* input0 = reader->GetInputTensor(0);
|
||||
if (!input0) {
|
||||
return absl::InvalidArgumentError(
|
||||
@ -1127,9 +1126,9 @@ class MulOperationParser : public TFLiteOperationParser {
|
||||
|
||||
Node* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::MUL);
|
||||
RETURN_IF_ERROR(reader->AddOutputs(node));
|
||||
|
||||
// The "larger" input tensor must be bound to 1st input and the "smaller"
|
||||
// input tensor ("mask") must be bound to 2nd input.
|
||||
// Determine runtime/constant tensors.
|
||||
if (runtime_tensor0 && runtime_tensor1) {
|
||||
if (input0 == input1) {
|
||||
// replace MUL(A, A) with POW(A, 2.0)
|
||||
@ -1138,10 +1137,11 @@ class MulOperationParser : public TFLiteOperationParser {
|
||||
ElementwiseAttributes attr;
|
||||
attr.param = 2.0f;
|
||||
node->operation.attributes = std::move(attr);
|
||||
RETURN_IF_ERROR(reader->AddInput(node, 0));
|
||||
return reader->AddOutputs(node);
|
||||
return reader->AddInput(node, 0);
|
||||
}
|
||||
|
||||
// The "larger" input tensor must be bound to 1st input and the "smaller"
|
||||
// input tensor must be bound to 2nd input.
|
||||
BHWC shape0;
|
||||
RETURN_IF_ERROR(ExtractTensorShape(*input0, &shape0));
|
||||
BHWC shape1;
|
||||
@ -1153,58 +1153,18 @@ class MulOperationParser : public TFLiteOperationParser {
|
||||
input_tensor0 = 1;
|
||||
input_tensor1 = 0;
|
||||
}
|
||||
RETURN_IF_ERROR(
|
||||
ParseApplyMask(node, input_tensor0, input_tensor1, graph, reader));
|
||||
RETURN_IF_ERROR(reader->AddInput(node, input_tensor0));
|
||||
RETURN_IF_ERROR(reader->AddInput(node, input_tensor1));
|
||||
} else {
|
||||
// The runtime input tensor must be bound to 1st input and the constant
|
||||
// input tensor must be bound to 2nd input.
|
||||
int runtime_tensor = 0;
|
||||
int constant_tensor = 1;
|
||||
TfLiteIntArray* constant_dims = input1->dims;
|
||||
if (constant_tensor0 && runtime_tensor1) {
|
||||
runtime_tensor = 1;
|
||||
constant_tensor = 0;
|
||||
constant_dims = input0->dims;
|
||||
}
|
||||
RETURN_IF_ERROR(ParseMultiplyScalar(node, runtime_tensor, constant_tensor,
|
||||
constant_dims, graph, reader));
|
||||
ElementwiseAttributes attr;
|
||||
RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param));
|
||||
node->operation.attributes = std::move(attr);
|
||||
}
|
||||
|
||||
const TfLiteMulParams* tf_options;
|
||||
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
|
||||
return MaybeFuseActivation(tf_options->activation, graph, node);
|
||||
}
|
||||
|
||||
private:
|
||||
absl::Status ParseApplyMask(Node* node, int input_tensor0, int input_tensor1,
|
||||
GraphFloat32* graph, ObjectReader* reader) {
|
||||
RETURN_IF_ERROR(reader->AddInput(node, input_tensor0));
|
||||
RETURN_IF_ERROR(reader->AddInput(node, input_tensor1));
|
||||
return reader->AddOutputs(node);
|
||||
}
|
||||
|
||||
absl::Status ParseMultiplyScalar(Node* node, int runtime_tensor,
|
||||
int constant_tensor,
|
||||
const TfLiteIntArray* constant_dims,
|
||||
GraphFloat32* graph, ObjectReader* reader) {
|
||||
RETURN_IF_ERROR(reader->AddInput(node, runtime_tensor));
|
||||
ElementwiseAttributes attr;
|
||||
if (constant_dims->size <= 0 || NumElements(constant_dims) == 1) {
|
||||
Tensor<Scalar, DataType::FLOAT32> tensor;
|
||||
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
|
||||
attr.param = tensor.data[0];
|
||||
} else if (constant_dims->size == 3) {
|
||||
Tensor<HWC, DataType::FLOAT32> tensor;
|
||||
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
|
||||
attr.param = std::move(tensor);
|
||||
} else {
|
||||
Tensor<Linear, DataType::FLOAT32> tensor;
|
||||
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
|
||||
attr.param = std::move(tensor);
|
||||
}
|
||||
node->operation.attributes = std::move(attr);
|
||||
return reader->AddOutputs(node);
|
||||
}
|
||||
};
|
||||
|
||||
class PReLUOperationParser : public TFLiteOperationParser {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user