Allows DEQUANTIZE with constant input in the GPU delegate.
PiperOrigin-RevId: 361183783 Change-Id: If6bb70a60e30d101a910998e905e6e5d298b7901
This commit is contained in:
		
							parent
							
								
									caaea7ca1a
								
							
						
					
					
						commit
						c50f3f1bf8
					
				| @ -539,8 +539,13 @@ class DequantizeOperationParser : public TFLiteOperationParser { | ||||
|                            const TfLiteNode* tflite_node, | ||||
|                            const TfLiteRegistration* registration) final { | ||||
|     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); | ||||
|     RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, | ||||
|                                        /*runtime_inputs=*/1, /*outputs=*/1)); | ||||
|     const int num_inputs = NumInputs(tflite_node); | ||||
|     const int num_outputs = NumOutputs(tflite_node); | ||||
|     if (num_inputs != 1 || num_outputs != 1) { | ||||
|       return absl::InternalError(absl::StrCat( | ||||
|           "Expected 1 input & output each from Dequantize, got: %d, %d", | ||||
|           num_inputs, num_outputs)); | ||||
|     } | ||||
|     return absl::OkStatus(); | ||||
|   } | ||||
| 
 | ||||
| @ -551,7 +556,24 @@ class DequantizeOperationParser : public TFLiteOperationParser { | ||||
|     // with floating-point versions of the original tensors.
 | ||||
|     Node* node = graph->NewNode(); | ||||
|     node->operation.type = ToString(OperationType::QUANTIZE_AND_DEQUANTIZE); | ||||
|     RETURN_IF_ERROR(reader->AddInput(node, 0)); | ||||
|     const int runtime_inputs = reader->GetNumberOfRuntimeInputs(); | ||||
|     if (runtime_inputs == 1) { | ||||
|       // Non-constant dequantization.
 | ||||
|       RETURN_IF_ERROR(reader->AddInput(node, 0)); | ||||
|     } else { | ||||
|       // TODO(b/181274192): Optimize out this constant dequantization from the
 | ||||
|       // graph later.
 | ||||
|       TensorFloat32 tensor; | ||||
|       RETURN_IF_ERROR(reader->ReadTensor(0, &tensor)); | ||||
|       Value* value; | ||||
|       RETURN_IF_ERROR(NewConstNode(std::move(tensor), graph, &value)); | ||||
|       // Need to retain the quant params from the original constant input.
 | ||||
|       const TfLiteTensor* tflite_input = reader->GetInputTensor(0); | ||||
|       value->quant_params.emplace(); | ||||
|       RETURN_IF_ERROR( | ||||
|           PopulateQuantParams(*tflite_input, &value->quant_params.value())); | ||||
|       RETURN_IF_ERROR(graph->AddConsumer(node->id, value->id)); | ||||
|     } | ||||
|     RETURN_IF_ERROR(reader->AddOutputs(node)); | ||||
| 
 | ||||
|     // Quantization attributes should already be present in the input tensor.
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user