Allows DEQUANTIZE with constant input in the GPU delegate.
PiperOrigin-RevId: 361183783 Change-Id: If6bb70a60e30d101a910998e905e6e5d298b7901
This commit is contained in:
parent
caaea7ca1a
commit
c50f3f1bf8
@ -539,8 +539,13 @@ class DequantizeOperationParser : public TFLiteOperationParser {
|
||||
const TfLiteNode* tflite_node,
|
||||
const TfLiteRegistration* registration) final {
|
||||
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
|
||||
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
|
||||
/*runtime_inputs=*/1, /*outputs=*/1));
|
||||
const int num_inputs = NumInputs(tflite_node);
|
||||
const int num_outputs = NumOutputs(tflite_node);
|
||||
if (num_inputs != 1 || num_outputs != 1) {
|
||||
return absl::InternalError(absl::StrCat(
|
||||
"Expected 1 input & output each from Dequantize, got: %d, %d",
|
||||
num_inputs, num_outputs));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
@ -551,7 +556,24 @@ class DequantizeOperationParser : public TFLiteOperationParser {
|
||||
// with floating-point versions of the original tensors.
|
||||
Node* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::QUANTIZE_AND_DEQUANTIZE);
|
||||
RETURN_IF_ERROR(reader->AddInput(node, 0));
|
||||
const int runtime_inputs = reader->GetNumberOfRuntimeInputs();
|
||||
if (runtime_inputs == 1) {
|
||||
// Non-constant dequantization.
|
||||
RETURN_IF_ERROR(reader->AddInput(node, 0));
|
||||
} else {
|
||||
// TODO(b/181274192): Optimize out this constant dequantization from the
|
||||
// graph later.
|
||||
TensorFloat32 tensor;
|
||||
RETURN_IF_ERROR(reader->ReadTensor(0, &tensor));
|
||||
Value* value;
|
||||
RETURN_IF_ERROR(NewConstNode(std::move(tensor), graph, &value));
|
||||
// Need to retain the quant params from the original constant input.
|
||||
const TfLiteTensor* tflite_input = reader->GetInputTensor(0);
|
||||
value->quant_params.emplace();
|
||||
RETURN_IF_ERROR(
|
||||
PopulateQuantParams(*tflite_input, &value->quant_params.value()));
|
||||
RETURN_IF_ERROR(graph->AddConsumer(node->id, value->id));
|
||||
}
|
||||
RETURN_IF_ERROR(reader->AddOutputs(node));
|
||||
|
||||
// Quantization attributes should already be present in the input tensor.
|
||||
|
Loading…
Reference in New Issue
Block a user