Simplify how used delegate is supplied to SingleOpModel-based tests.
PiperOrigin-RevId: 314761693 Change-Id: I19a160aad575e832c92a72e442e5fc49e8e2f45a
This commit is contained in:
parent
171d688aaa
commit
45463fbd19
@ -39,45 +39,24 @@ limitations under the License.
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
class SingleOpModelWithNNAPI : public SingleOpModel {
|
||||
public:
|
||||
SingleOpModelWithNNAPI() = default;
|
||||
void Init(const NnApi* nnapi,
|
||||
tflite::StatefulNnApiDelegate::Options options) {
|
||||
stateful_delegate_.reset(new StatefulNnApiDelegate(nnapi, options));
|
||||
auto* delegate = stateful_delegate_.get();
|
||||
this->SetApplyDelegate([delegate, this](Interpreter* interpreter) {
|
||||
compilation_status_ = interpreter->ModifyGraphWithDelegate(delegate);
|
||||
});
|
||||
}
|
||||
|
||||
StatefulNnApiDelegate* GetDelegate() { return stateful_delegate_.get(); }
|
||||
|
||||
void SetBufferHandle(int index, TfLiteBufferHandle handle) {
|
||||
interpreter_->SetBufferHandle(index, handle, stateful_delegate_.get());
|
||||
}
|
||||
TfLiteStatus GetCompilationStatus() { return compilation_status_; }
|
||||
|
||||
private:
|
||||
std::unique_ptr<StatefulNnApiDelegate> stateful_delegate_;
|
||||
TfLiteStatus compilation_status_;
|
||||
};
|
||||
|
||||
class FloatAddOpModel : public SingleOpModelWithNNAPI {
|
||||
class FloatAddOpModel : public SingleOpModel {
|
||||
public:
|
||||
FloatAddOpModel() = default;
|
||||
void Init(const NnApi* nnapi, tflite::StatefulNnApiDelegate::Options options,
|
||||
const TensorData& input1, const TensorData& input2,
|
||||
const TensorData& output, ActivationFunctionType activation_type,
|
||||
bool allow_fp32_relax_to_fp16 = false) {
|
||||
SingleOpModelWithNNAPI::Init(nnapi, options);
|
||||
stateful_delegate_.reset(new StatefulNnApiDelegate(nnapi, options));
|
||||
SetDelegate(stateful_delegate_.get());
|
||||
|
||||
input1_ = AddInput(input1);
|
||||
input2_ = AddInput(input2);
|
||||
output_ = AddOutput(output);
|
||||
SetBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions,
|
||||
CreateAddOptions(builder_, activation_type).Union());
|
||||
BuildInterpreter({GetShape(input1_), GetShape(input2_)}, /*num_threads=*/-1,
|
||||
allow_fp32_relax_to_fp16, /*apply_delegate=*/true);
|
||||
allow_fp32_relax_to_fp16, /*apply_delegate=*/false);
|
||||
compilation_status_ = ApplyDelegate();
|
||||
}
|
||||
|
||||
int input1() { return input1_; }
|
||||
@ -85,12 +64,16 @@ class FloatAddOpModel : public SingleOpModelWithNNAPI {
|
||||
|
||||
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
|
||||
|
||||
TfLiteStatus GetCompilationStatus() { return compilation_status_; }
|
||||
|
||||
protected:
|
||||
int input1_;
|
||||
int input2_;
|
||||
int output_;
|
||||
|
||||
private:
|
||||
std::unique_ptr<StatefulNnApiDelegate> stateful_delegate_;
|
||||
TfLiteStatus compilation_status_;
|
||||
};
|
||||
|
||||
struct NnApiDeviceSelectionTest
|
||||
@ -281,10 +264,7 @@ class ArgMaxOpModel : public SingleOpModel, public AcceleratedModel {
|
||||
|
||||
void Init(std::initializer_list<int> input_shape, TensorType input_type,
|
||||
int axis_value, TensorType output_type) {
|
||||
auto* delegate = GetDelegate();
|
||||
this->SetApplyDelegate([delegate](Interpreter* interpreter) {
|
||||
interpreter->ModifyGraphWithDelegate(delegate);
|
||||
});
|
||||
SetDelegate(GetDelegate());
|
||||
input_ = AddInput(input_type);
|
||||
axis_ = AddConstInput(TensorType_INT32, {axis_value}, {1});
|
||||
output_ = AddOutput(output_type);
|
||||
@ -395,10 +375,7 @@ class AddSubOpsAcceleratedModel : public MultiOpModel, public AcceleratedModel {
|
||||
const std::string& accelerator_name,
|
||||
bool allow_fp32_relax_to_fp16 = false)
|
||||
: MultiOpModel(), AcceleratedModel(nnapi, accelerator_name) {
|
||||
auto* delegate = GetDelegate();
|
||||
this->SetApplyDelegate([delegate](Interpreter* interpreter) {
|
||||
interpreter->ModifyGraphWithDelegate(delegate);
|
||||
});
|
||||
SetDelegate(GetDelegate());
|
||||
Init(input1, input2, input3, output, activation_type,
|
||||
allow_fp32_relax_to_fp16);
|
||||
}
|
||||
@ -585,10 +562,7 @@ class HardSwishAddOpsAcceleratedModel : public MultiOpModel,
|
||||
const std::string& accelerator_name,
|
||||
bool allow_fp32_relax_to_fp16 = false)
|
||||
: MultiOpModel(), AcceleratedModel(nnapi, accelerator_name) {
|
||||
auto* delegate = GetDelegate();
|
||||
this->SetApplyDelegate([delegate](Interpreter* interpreter) {
|
||||
interpreter->ModifyGraphWithDelegate(delegate);
|
||||
});
|
||||
SetDelegate(GetDelegate());
|
||||
Init(input1, input2, output, activation_type, allow_fp32_relax_to_fp16);
|
||||
}
|
||||
|
||||
@ -724,10 +698,7 @@ class QuantizedWeightsConvolutionOpModel : public SingleOpModel,
|
||||
int dilation_width_factor = 1, int dilation_height_factor = 1,
|
||||
int num_threads = -1, std::initializer_list<uint8_t> filter_data = {})
|
||||
: SingleOpModel(), AcceleratedModel(nnapi, accelerator_name) {
|
||||
auto* delegate = GetDelegate();
|
||||
this->SetApplyDelegate([delegate](Interpreter* interpreter) {
|
||||
interpreter->ModifyGraphWithDelegate(delegate);
|
||||
});
|
||||
SetDelegate(GetDelegate());
|
||||
|
||||
input_ = AddInput(input);
|
||||
|
||||
@ -861,10 +832,7 @@ class LongIdentityModel : public MultiOpModel, public AcceleratedModel {
|
||||
private:
|
||||
void Init(const std::vector<int>& input_shape, int graph_size,
|
||||
const std::unordered_set<int>& custom_nodes_indexes) {
|
||||
auto* delegate = GetDelegate();
|
||||
this->SetApplyDelegate([delegate](Interpreter* interpreter) {
|
||||
interpreter->ModifyGraphWithDelegate(delegate);
|
||||
});
|
||||
SetDelegate(GetDelegate());
|
||||
|
||||
const TensorData tensor_data{TensorType_FLOAT32, input_shape};
|
||||
|
||||
|
@ -31,10 +31,7 @@ class SingleOpModelWithNNAPI : public SingleOpModel {
|
||||
public:
|
||||
explicit SingleOpModelWithNNAPI(const NnApi* nnapi) {
|
||||
stateful_delegate_.reset(new StatefulNnApiDelegate(nnapi));
|
||||
auto* delegate = stateful_delegate_.get();
|
||||
this->SetApplyDelegate([delegate](Interpreter* interpreter) {
|
||||
interpreter->ModifyGraphWithDelegate(delegate);
|
||||
});
|
||||
this->SetDelegate(stateful_delegate_.get());
|
||||
}
|
||||
|
||||
StatefulNnApiDelegate* GetDelegate() { return stateful_delegate_.get(); }
|
||||
|
@ -46,19 +46,12 @@ MATCHER(QuantizedNear, "") {
|
||||
|
||||
class SingleOpModelWithNNAPI : public SingleOpModel {
|
||||
public:
|
||||
SingleOpModelWithNNAPI() {
|
||||
this->SetApplyDelegate([](Interpreter* interpreter) {
|
||||
interpreter->ModifyGraphWithDelegate(NnApiDelegate());
|
||||
});
|
||||
}
|
||||
SingleOpModelWithNNAPI() { SetDelegate(NnApiDelegate()); }
|
||||
|
||||
explicit SingleOpModelWithNNAPI(
|
||||
const StatefulNnApiDelegate::Options& options) {
|
||||
stateful_delegate_.reset(new StatefulNnApiDelegate(options));
|
||||
auto* delegate = stateful_delegate_.get();
|
||||
this->SetApplyDelegate([delegate](Interpreter* interpreter) {
|
||||
interpreter->ModifyGraphWithDelegate(delegate);
|
||||
});
|
||||
SetDelegate(stateful_delegate_.get());
|
||||
}
|
||||
|
||||
TfLiteStatus ResizeInputTensor(int tensor_index,
|
||||
|
@ -199,15 +199,16 @@ void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
||||
if (apply_delegate) ApplyDelegate();
|
||||
}
|
||||
|
||||
void SingleOpModel::ApplyDelegate() {
|
||||
TfLiteStatus SingleOpModel::ApplyDelegate() {
|
||||
if (force_use_nnapi) {
|
||||
interpreter_->ModifyGraphWithDelegate(TestNnApiDelegate());
|
||||
delegate_ = TestNnApiDelegate();
|
||||
}
|
||||
|
||||
// Modify delegate with function.
|
||||
if (apply_delegate_fn_) {
|
||||
apply_delegate_fn_(interpreter_.get());
|
||||
if (delegate_) {
|
||||
return interpreter_->ModifyGraphWithDelegate(delegate_);
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
void SingleOpModel::Invoke() { ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); }
|
||||
|
@ -160,14 +160,11 @@ class SingleOpModel {
|
||||
SingleOpModel() {}
|
||||
~SingleOpModel();
|
||||
|
||||
// Set a function callback that is run right after graph is prepared
|
||||
// that allows applying external delegates. This is useful for testing
|
||||
// other runtimes like NN API or GPU.
|
||||
void SetApplyDelegate(std::function<void(Interpreter*)> apply_delegate_fn) {
|
||||
apply_delegate_fn_ = apply_delegate_fn;
|
||||
}
|
||||
// Set a delegate that is applied right after graph is prepared. This is
|
||||
// useful for testing other runtimes like NN API or GPU.
|
||||
void SetDelegate(TfLiteDelegate* delegate) { delegate_ = delegate; }
|
||||
|
||||
void ApplyDelegate();
|
||||
TfLiteStatus ApplyDelegate();
|
||||
|
||||
// Copying or assignment is disallowed to simplify ownership semantics.
|
||||
SingleOpModel(const SingleOpModel&) = delete;
|
||||
@ -755,9 +752,7 @@ class SingleOpModel {
|
||||
std::vector<int32_t> outputs_;
|
||||
std::vector<flatbuffers::Offset<Tensor>> tensors_;
|
||||
std::vector<flatbuffers::Offset<Buffer>> buffers_;
|
||||
// A function pointer that gets called after the interpreter is created but
|
||||
// before evaluation happens. This is useful for applying a delegate.
|
||||
std::function<void(Interpreter*)> apply_delegate_fn_;
|
||||
TfLiteDelegate* delegate_ = nullptr;
|
||||
};
|
||||
|
||||
// Populate string tensors.
|
||||
|
Loading…
Reference in New Issue
Block a user