Allows DEQUANTIZE with constant input in the GPU delegate.
PiperOrigin-RevId: 361183783 Change-Id: If6bb70a60e30d101a910998e905e6e5d298b7901
This commit is contained in:
		
							parent
							
								
									caaea7ca1a
								
							
						
					
					
						commit
						c50f3f1bf8
					
				@ -539,8 +539,13 @@ class DequantizeOperationParser : public TFLiteOperationParser {
 | 
			
		||||
                           const TfLiteNode* tflite_node,
 | 
			
		||||
                           const TfLiteRegistration* registration) final {
 | 
			
		||||
    RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
 | 
			
		||||
    RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
 | 
			
		||||
                                       /*runtime_inputs=*/1, /*outputs=*/1));
 | 
			
		||||
    const int num_inputs = NumInputs(tflite_node);
 | 
			
		||||
    const int num_outputs = NumOutputs(tflite_node);
 | 
			
		||||
    if (num_inputs != 1 || num_outputs != 1) {
 | 
			
		||||
      return absl::InternalError(absl::StrCat(
 | 
			
		||||
          "Expected 1 input & output each from Dequantize, got: %d, %d",
 | 
			
		||||
          num_inputs, num_outputs));
 | 
			
		||||
    }
 | 
			
		||||
    return absl::OkStatus();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -551,7 +556,24 @@ class DequantizeOperationParser : public TFLiteOperationParser {
 | 
			
		||||
    // with floating-point versions of the original tensors.
 | 
			
		||||
    Node* node = graph->NewNode();
 | 
			
		||||
    node->operation.type = ToString(OperationType::QUANTIZE_AND_DEQUANTIZE);
 | 
			
		||||
    const int runtime_inputs = reader->GetNumberOfRuntimeInputs();
 | 
			
		||||
    if (runtime_inputs == 1) {
 | 
			
		||||
      // Non-constant dequantization.
 | 
			
		||||
      RETURN_IF_ERROR(reader->AddInput(node, 0));
 | 
			
		||||
    } else {
 | 
			
		||||
      // TODO(b/181274192): Optimize out this constant dequantization from the
 | 
			
		||||
      // graph later.
 | 
			
		||||
      TensorFloat32 tensor;
 | 
			
		||||
      RETURN_IF_ERROR(reader->ReadTensor(0, &tensor));
 | 
			
		||||
      Value* value;
 | 
			
		||||
      RETURN_IF_ERROR(NewConstNode(std::move(tensor), graph, &value));
 | 
			
		||||
      // Need to retain the quant params from the original constant input.
 | 
			
		||||
      const TfLiteTensor* tflite_input = reader->GetInputTensor(0);
 | 
			
		||||
      value->quant_params.emplace();
 | 
			
		||||
      RETURN_IF_ERROR(
 | 
			
		||||
          PopulateQuantParams(*tflite_input, &value->quant_params.value()));
 | 
			
		||||
      RETURN_IF_ERROR(graph->AddConsumer(node->id, value->id));
 | 
			
		||||
    }
 | 
			
		||||
    RETURN_IF_ERROR(reader->AddOutputs(node));
 | 
			
		||||
 | 
			
		||||
    // Quantization attributes should already be present in the input tensor.
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user