Allows DEQUANTIZE with constant input in the GPU delegate.

PiperOrigin-RevId: 361183783
Change-Id: If6bb70a60e30d101a910998e905e6e5d298b7901
This commit is contained in:
Sachin Joglekar 2021-03-05 11:08:53 -08:00 committed by TensorFlower Gardener
parent caaea7ca1a
commit c50f3f1bf8

View File

@ -539,8 +539,13 @@ class DequantizeOperationParser : public TFLiteOperationParser {
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1, /*outputs=*/1));
const int num_inputs = NumInputs(tflite_node);
const int num_outputs = NumOutputs(tflite_node);
if (num_inputs != 1 || num_outputs != 1) {
return absl::InternalError(absl::StrCat(
"Expected 1 input & output each from Dequantize, got: %d, %d",
num_inputs, num_outputs));
}
return absl::OkStatus();
}
@ -551,7 +556,24 @@ class DequantizeOperationParser : public TFLiteOperationParser {
// with floating-point versions of the original tensors.
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::QUANTIZE_AND_DEQUANTIZE);
RETURN_IF_ERROR(reader->AddInput(node, 0));
const int runtime_inputs = reader->GetNumberOfRuntimeInputs();
if (runtime_inputs == 1) {
// Non-constant dequantization.
RETURN_IF_ERROR(reader->AddInput(node, 0));
} else {
// TODO(b/181274192): Optimize out this constant dequantization from the
// graph later.
TensorFloat32 tensor;
RETURN_IF_ERROR(reader->ReadTensor(0, &tensor));
Value* value;
RETURN_IF_ERROR(NewConstNode(std::move(tensor), graph, &value));
// Need to retain the quant params from the original constant input.
const TfLiteTensor* tflite_input = reader->GetInputTensor(0);
value->quant_params.emplace();
RETURN_IF_ERROR(
PopulateQuantParams(*tflite_input, &value->quant_params.value()));
RETURN_IF_ERROR(graph->AddConsumer(node->id, value->id));
}
RETURN_IF_ERROR(reader->AddOutputs(node));
// Quantization attributes should already be present in the input tensor.