Reusing common function for constant argument parsing in Multiply operation.
PiperOrigin-RevId: 329539781 Change-Id: I145dce25607f3a3725d930f2f8ed9d9a9193f76d
This commit is contained in:
parent
7a9ca9b344
commit
2c5a33590d
@ -1106,7 +1106,6 @@ class MulOperationParser : public TFLiteOperationParser {
|
|||||||
absl::Status Parse(const TfLiteNode* tflite_node,
|
absl::Status Parse(const TfLiteNode* tflite_node,
|
||||||
const TfLiteRegistration* registration,
|
const TfLiteRegistration* registration,
|
||||||
GraphFloat32* graph, ObjectReader* reader) final {
|
GraphFloat32* graph, ObjectReader* reader) final {
|
||||||
// Determine runtime/constant tensors.
|
|
||||||
const TfLiteTensor* input0 = reader->GetInputTensor(0);
|
const TfLiteTensor* input0 = reader->GetInputTensor(0);
|
||||||
if (!input0) {
|
if (!input0) {
|
||||||
return absl::InvalidArgumentError(
|
return absl::InvalidArgumentError(
|
||||||
@ -1127,9 +1126,9 @@ class MulOperationParser : public TFLiteOperationParser {
|
|||||||
|
|
||||||
Node* node = graph->NewNode();
|
Node* node = graph->NewNode();
|
||||||
node->operation.type = ToString(OperationType::MUL);
|
node->operation.type = ToString(OperationType::MUL);
|
||||||
|
RETURN_IF_ERROR(reader->AddOutputs(node));
|
||||||
|
|
||||||
// The "larger" input tensor must be bound to 1st input and the "smaller"
|
// Determine runtime/constant tensors.
|
||||||
// input tensor ("mask") must be bound to 2nd input.
|
|
||||||
if (runtime_tensor0 && runtime_tensor1) {
|
if (runtime_tensor0 && runtime_tensor1) {
|
||||||
if (input0 == input1) {
|
if (input0 == input1) {
|
||||||
// replace MUL(A, A) with POW(A, 2.0)
|
// replace MUL(A, A) with POW(A, 2.0)
|
||||||
@ -1138,10 +1137,11 @@ class MulOperationParser : public TFLiteOperationParser {
|
|||||||
ElementwiseAttributes attr;
|
ElementwiseAttributes attr;
|
||||||
attr.param = 2.0f;
|
attr.param = 2.0f;
|
||||||
node->operation.attributes = std::move(attr);
|
node->operation.attributes = std::move(attr);
|
||||||
RETURN_IF_ERROR(reader->AddInput(node, 0));
|
return reader->AddInput(node, 0);
|
||||||
return reader->AddOutputs(node);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The "larger" input tensor must be bound to 1st input and the "smaller"
|
||||||
|
// input tensor must be bound to 2nd input.
|
||||||
BHWC shape0;
|
BHWC shape0;
|
||||||
RETURN_IF_ERROR(ExtractTensorShape(*input0, &shape0));
|
RETURN_IF_ERROR(ExtractTensorShape(*input0, &shape0));
|
||||||
BHWC shape1;
|
BHWC shape1;
|
||||||
@ -1153,58 +1153,18 @@ class MulOperationParser : public TFLiteOperationParser {
|
|||||||
input_tensor0 = 1;
|
input_tensor0 = 1;
|
||||||
input_tensor1 = 0;
|
input_tensor1 = 0;
|
||||||
}
|
}
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(reader->AddInput(node, input_tensor0));
|
||||||
ParseApplyMask(node, input_tensor0, input_tensor1, graph, reader));
|
RETURN_IF_ERROR(reader->AddInput(node, input_tensor1));
|
||||||
} else {
|
} else {
|
||||||
// The runtime input tensor must be bound to 1st input and the constant
|
ElementwiseAttributes attr;
|
||||||
// input tensor must be bound to 2nd input.
|
RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param));
|
||||||
int runtime_tensor = 0;
|
node->operation.attributes = std::move(attr);
|
||||||
int constant_tensor = 1;
|
|
||||||
TfLiteIntArray* constant_dims = input1->dims;
|
|
||||||
if (constant_tensor0 && runtime_tensor1) {
|
|
||||||
runtime_tensor = 1;
|
|
||||||
constant_tensor = 0;
|
|
||||||
constant_dims = input0->dims;
|
|
||||||
}
|
|
||||||
RETURN_IF_ERROR(ParseMultiplyScalar(node, runtime_tensor, constant_tensor,
|
|
||||||
constant_dims, graph, reader));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const TfLiteMulParams* tf_options;
|
const TfLiteMulParams* tf_options;
|
||||||
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
|
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
|
||||||
return MaybeFuseActivation(tf_options->activation, graph, node);
|
return MaybeFuseActivation(tf_options->activation, graph, node);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
|
||||||
absl::Status ParseApplyMask(Node* node, int input_tensor0, int input_tensor1,
|
|
||||||
GraphFloat32* graph, ObjectReader* reader) {
|
|
||||||
RETURN_IF_ERROR(reader->AddInput(node, input_tensor0));
|
|
||||||
RETURN_IF_ERROR(reader->AddInput(node, input_tensor1));
|
|
||||||
return reader->AddOutputs(node);
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::Status ParseMultiplyScalar(Node* node, int runtime_tensor,
|
|
||||||
int constant_tensor,
|
|
||||||
const TfLiteIntArray* constant_dims,
|
|
||||||
GraphFloat32* graph, ObjectReader* reader) {
|
|
||||||
RETURN_IF_ERROR(reader->AddInput(node, runtime_tensor));
|
|
||||||
ElementwiseAttributes attr;
|
|
||||||
if (constant_dims->size <= 0 || NumElements(constant_dims) == 1) {
|
|
||||||
Tensor<Scalar, DataType::FLOAT32> tensor;
|
|
||||||
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
|
|
||||||
attr.param = tensor.data[0];
|
|
||||||
} else if (constant_dims->size == 3) {
|
|
||||||
Tensor<HWC, DataType::FLOAT32> tensor;
|
|
||||||
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
|
|
||||||
attr.param = std::move(tensor);
|
|
||||||
} else {
|
|
||||||
Tensor<Linear, DataType::FLOAT32> tensor;
|
|
||||||
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
|
|
||||||
attr.param = std::move(tensor);
|
|
||||||
}
|
|
||||||
node->operation.attributes = std::move(attr);
|
|
||||||
return reader->AddOutputs(node);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class PReLUOperationParser : public TFLiteOperationParser {
|
class PReLUOperationParser : public TFLiteOperationParser {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user