Simplify how used delegate is supplied to SingleOpModel-based tests.
PiperOrigin-RevId: 314761693 Change-Id: I19a160aad575e832c92a72e442e5fc49e8e2f45a
This commit is contained in:
		
							parent
							
								
									171d688aaa
								
							
						
					
					
						commit
						45463fbd19
					
				| @ -39,45 +39,24 @@ limitations under the License. | ||||
| namespace tflite { | ||||
| namespace { | ||||
| 
 | ||||
| class SingleOpModelWithNNAPI : public SingleOpModel { | ||||
|  public: | ||||
|   SingleOpModelWithNNAPI() = default; | ||||
|   void Init(const NnApi* nnapi, | ||||
|             tflite::StatefulNnApiDelegate::Options options) { | ||||
|     stateful_delegate_.reset(new StatefulNnApiDelegate(nnapi, options)); | ||||
|     auto* delegate = stateful_delegate_.get(); | ||||
|     this->SetApplyDelegate([delegate, this](Interpreter* interpreter) { | ||||
|       compilation_status_ = interpreter->ModifyGraphWithDelegate(delegate); | ||||
|     }); | ||||
|   } | ||||
| 
 | ||||
|   StatefulNnApiDelegate* GetDelegate() { return stateful_delegate_.get(); } | ||||
| 
 | ||||
|   void SetBufferHandle(int index, TfLiteBufferHandle handle) { | ||||
|     interpreter_->SetBufferHandle(index, handle, stateful_delegate_.get()); | ||||
|   } | ||||
|   TfLiteStatus GetCompilationStatus() { return compilation_status_; } | ||||
| 
 | ||||
|  private: | ||||
|   std::unique_ptr<StatefulNnApiDelegate> stateful_delegate_; | ||||
|   TfLiteStatus compilation_status_; | ||||
| }; | ||||
| 
 | ||||
| class FloatAddOpModel : public SingleOpModelWithNNAPI { | ||||
| class FloatAddOpModel : public SingleOpModel { | ||||
|  public: | ||||
|   FloatAddOpModel() = default; | ||||
|   void Init(const NnApi* nnapi, tflite::StatefulNnApiDelegate::Options options, | ||||
|             const TensorData& input1, const TensorData& input2, | ||||
|             const TensorData& output, ActivationFunctionType activation_type, | ||||
|             bool allow_fp32_relax_to_fp16 = false) { | ||||
|     SingleOpModelWithNNAPI::Init(nnapi, options); | ||||
|     stateful_delegate_.reset(new StatefulNnApiDelegate(nnapi, options)); | ||||
|     SetDelegate(stateful_delegate_.get()); | ||||
| 
 | ||||
|     input1_ = AddInput(input1); | ||||
|     input2_ = AddInput(input2); | ||||
|     output_ = AddOutput(output); | ||||
|     SetBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions, | ||||
|                  CreateAddOptions(builder_, activation_type).Union()); | ||||
|     BuildInterpreter({GetShape(input1_), GetShape(input2_)}, /*num_threads=*/-1, | ||||
|                      allow_fp32_relax_to_fp16, /*apply_delegate=*/true); | ||||
|                      allow_fp32_relax_to_fp16, /*apply_delegate=*/false); | ||||
|     compilation_status_ = ApplyDelegate(); | ||||
|   } | ||||
| 
 | ||||
|   int input1() { return input1_; } | ||||
| @ -85,12 +64,16 @@ class FloatAddOpModel : public SingleOpModelWithNNAPI { | ||||
| 
 | ||||
|   std::vector<float> GetOutput() { return ExtractVector<float>(output_); } | ||||
| 
 | ||||
|   TfLiteStatus GetCompilationStatus() { return compilation_status_; } | ||||
| 
 | ||||
|  protected: | ||||
|   int input1_; | ||||
|   int input2_; | ||||
|   int output_; | ||||
| 
 | ||||
|  private: | ||||
|   std::unique_ptr<StatefulNnApiDelegate> stateful_delegate_; | ||||
|   TfLiteStatus compilation_status_; | ||||
| }; | ||||
| 
 | ||||
| struct NnApiDeviceSelectionTest | ||||
| @ -281,10 +264,7 @@ class ArgMaxOpModel : public SingleOpModel, public AcceleratedModel { | ||||
| 
 | ||||
|   void Init(std::initializer_list<int> input_shape, TensorType input_type, | ||||
|             int axis_value, TensorType output_type) { | ||||
|     auto* delegate = GetDelegate(); | ||||
|     this->SetApplyDelegate([delegate](Interpreter* interpreter) { | ||||
|       interpreter->ModifyGraphWithDelegate(delegate); | ||||
|     }); | ||||
|     SetDelegate(GetDelegate()); | ||||
|     input_ = AddInput(input_type); | ||||
|     axis_ = AddConstInput(TensorType_INT32, {axis_value}, {1}); | ||||
|     output_ = AddOutput(output_type); | ||||
| @ -395,10 +375,7 @@ class AddSubOpsAcceleratedModel : public MultiOpModel, public AcceleratedModel { | ||||
|                             const std::string& accelerator_name, | ||||
|                             bool allow_fp32_relax_to_fp16 = false) | ||||
|       : MultiOpModel(), AcceleratedModel(nnapi, accelerator_name) { | ||||
|     auto* delegate = GetDelegate(); | ||||
|     this->SetApplyDelegate([delegate](Interpreter* interpreter) { | ||||
|       interpreter->ModifyGraphWithDelegate(delegate); | ||||
|     }); | ||||
|     SetDelegate(GetDelegate()); | ||||
|     Init(input1, input2, input3, output, activation_type, | ||||
|          allow_fp32_relax_to_fp16); | ||||
|   } | ||||
| @ -585,10 +562,7 @@ class HardSwishAddOpsAcceleratedModel : public MultiOpModel, | ||||
|                                   const std::string& accelerator_name, | ||||
|                                   bool allow_fp32_relax_to_fp16 = false) | ||||
|       : MultiOpModel(), AcceleratedModel(nnapi, accelerator_name) { | ||||
|     auto* delegate = GetDelegate(); | ||||
|     this->SetApplyDelegate([delegate](Interpreter* interpreter) { | ||||
|       interpreter->ModifyGraphWithDelegate(delegate); | ||||
|     }); | ||||
|     SetDelegate(GetDelegate()); | ||||
|     Init(input1, input2, output, activation_type, allow_fp32_relax_to_fp16); | ||||
|   } | ||||
| 
 | ||||
| @ -724,10 +698,7 @@ class QuantizedWeightsConvolutionOpModel : public SingleOpModel, | ||||
|       int dilation_width_factor = 1, int dilation_height_factor = 1, | ||||
|       int num_threads = -1, std::initializer_list<uint8_t> filter_data = {}) | ||||
|       : SingleOpModel(), AcceleratedModel(nnapi, accelerator_name) { | ||||
|     auto* delegate = GetDelegate(); | ||||
|     this->SetApplyDelegate([delegate](Interpreter* interpreter) { | ||||
|       interpreter->ModifyGraphWithDelegate(delegate); | ||||
|     }); | ||||
|     SetDelegate(GetDelegate()); | ||||
| 
 | ||||
|     input_ = AddInput(input); | ||||
| 
 | ||||
| @ -861,10 +832,7 @@ class LongIdentityModel : public MultiOpModel, public AcceleratedModel { | ||||
|  private: | ||||
|   void Init(const std::vector<int>& input_shape, int graph_size, | ||||
|             const std::unordered_set<int>& custom_nodes_indexes) { | ||||
|     auto* delegate = GetDelegate(); | ||||
|     this->SetApplyDelegate([delegate](Interpreter* interpreter) { | ||||
|       interpreter->ModifyGraphWithDelegate(delegate); | ||||
|     }); | ||||
|     SetDelegate(GetDelegate()); | ||||
| 
 | ||||
|     const TensorData tensor_data{TensorType_FLOAT32, input_shape}; | ||||
| 
 | ||||
|  | ||||
| @ -31,10 +31,7 @@ class SingleOpModelWithNNAPI : public SingleOpModel { | ||||
|  public: | ||||
|   explicit SingleOpModelWithNNAPI(const NnApi* nnapi) { | ||||
|     stateful_delegate_.reset(new StatefulNnApiDelegate(nnapi)); | ||||
|     auto* delegate = stateful_delegate_.get(); | ||||
|     this->SetApplyDelegate([delegate](Interpreter* interpreter) { | ||||
|       interpreter->ModifyGraphWithDelegate(delegate); | ||||
|     }); | ||||
|     this->SetDelegate(stateful_delegate_.get()); | ||||
|   } | ||||
| 
 | ||||
|   StatefulNnApiDelegate* GetDelegate() { return stateful_delegate_.get(); } | ||||
|  | ||||
| @ -46,19 +46,12 @@ MATCHER(QuantizedNear, "") { | ||||
| 
 | ||||
| class SingleOpModelWithNNAPI : public SingleOpModel { | ||||
|  public: | ||||
|   SingleOpModelWithNNAPI() { | ||||
|     this->SetApplyDelegate([](Interpreter* interpreter) { | ||||
|       interpreter->ModifyGraphWithDelegate(NnApiDelegate()); | ||||
|     }); | ||||
|   } | ||||
|   SingleOpModelWithNNAPI() { SetDelegate(NnApiDelegate()); } | ||||
| 
 | ||||
|   explicit SingleOpModelWithNNAPI( | ||||
|       const StatefulNnApiDelegate::Options& options) { | ||||
|     stateful_delegate_.reset(new StatefulNnApiDelegate(options)); | ||||
|     auto* delegate = stateful_delegate_.get(); | ||||
|     this->SetApplyDelegate([delegate](Interpreter* interpreter) { | ||||
|       interpreter->ModifyGraphWithDelegate(delegate); | ||||
|     }); | ||||
|     SetDelegate(stateful_delegate_.get()); | ||||
|   } | ||||
| 
 | ||||
|   TfLiteStatus ResizeInputTensor(int tensor_index, | ||||
|  | ||||
| @ -199,15 +199,16 @@ void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes, | ||||
|   if (apply_delegate) ApplyDelegate(); | ||||
| } | ||||
| 
 | ||||
| void SingleOpModel::ApplyDelegate() { | ||||
| TfLiteStatus SingleOpModel::ApplyDelegate() { | ||||
|   if (force_use_nnapi) { | ||||
|     interpreter_->ModifyGraphWithDelegate(TestNnApiDelegate()); | ||||
|     delegate_ = TestNnApiDelegate(); | ||||
|   } | ||||
| 
 | ||||
|   // Modify delegate with function.
 | ||||
|   if (apply_delegate_fn_) { | ||||
|     apply_delegate_fn_(interpreter_.get()); | ||||
|   if (delegate_) { | ||||
|     return interpreter_->ModifyGraphWithDelegate(delegate_); | ||||
|   } | ||||
| 
 | ||||
|   return kTfLiteOk; | ||||
| } | ||||
| 
 | ||||
| void SingleOpModel::Invoke() { ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); } | ||||
|  | ||||
| @ -160,14 +160,11 @@ class SingleOpModel { | ||||
|   SingleOpModel() {} | ||||
|   ~SingleOpModel(); | ||||
| 
 | ||||
|   // Set a function callback that is run right after graph is prepared
 | ||||
|   // that allows applying external delegates. This is useful for testing
 | ||||
|   // other runtimes like NN API or GPU.
 | ||||
|   void SetApplyDelegate(std::function<void(Interpreter*)> apply_delegate_fn) { | ||||
|     apply_delegate_fn_ = apply_delegate_fn; | ||||
|   } | ||||
|   // Set a delegate that is applied right after graph is prepared. This is
 | ||||
|   // useful for testing other runtimes like NN API or GPU.
 | ||||
|   void SetDelegate(TfLiteDelegate* delegate) { delegate_ = delegate; } | ||||
| 
 | ||||
|   void ApplyDelegate(); | ||||
|   TfLiteStatus ApplyDelegate(); | ||||
| 
 | ||||
|   // Copying or assignment is disallowed to simplify ownership semantics.
 | ||||
|   SingleOpModel(const SingleOpModel&) = delete; | ||||
| @ -755,9 +752,7 @@ class SingleOpModel { | ||||
|   std::vector<int32_t> outputs_; | ||||
|   std::vector<flatbuffers::Offset<Tensor>> tensors_; | ||||
|   std::vector<flatbuffers::Offset<Buffer>> buffers_; | ||||
|   // A function pointer that gets called after the interpreter is created but
 | ||||
|   // before evaluation happens. This is useful for applying a delegate.
 | ||||
|   std::function<void(Interpreter*)> apply_delegate_fn_; | ||||
|   TfLiteDelegate* delegate_ = nullptr; | ||||
| }; | ||||
| 
 | ||||
| // Populate string tensors.
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user