Simplify how used delegate is supplied to SingleOpModel-based tests.

PiperOrigin-RevId: 314761693
Change-Id: I19a160aad575e832c92a72e442e5fc49e8e2f45a
This commit is contained in:
Robert David 2020-06-04 10:57:16 -07:00 committed by TensorFlower Gardener
parent 171d688aaa
commit 45463fbd19
5 changed files with 29 additions and 75 deletions

View File

@ -39,45 +39,24 @@ limitations under the License.
namespace tflite {
namespace {
class SingleOpModelWithNNAPI : public SingleOpModel {
public:
SingleOpModelWithNNAPI() = default;
void Init(const NnApi* nnapi,
tflite::StatefulNnApiDelegate::Options options) {
stateful_delegate_.reset(new StatefulNnApiDelegate(nnapi, options));
auto* delegate = stateful_delegate_.get();
this->SetApplyDelegate([delegate, this](Interpreter* interpreter) {
compilation_status_ = interpreter->ModifyGraphWithDelegate(delegate);
});
}
StatefulNnApiDelegate* GetDelegate() { return stateful_delegate_.get(); }
void SetBufferHandle(int index, TfLiteBufferHandle handle) {
interpreter_->SetBufferHandle(index, handle, stateful_delegate_.get());
}
TfLiteStatus GetCompilationStatus() { return compilation_status_; }
private:
std::unique_ptr<StatefulNnApiDelegate> stateful_delegate_;
TfLiteStatus compilation_status_;
};
class FloatAddOpModel : public SingleOpModelWithNNAPI {
class FloatAddOpModel : public SingleOpModel {
public:
FloatAddOpModel() = default;
void Init(const NnApi* nnapi, tflite::StatefulNnApiDelegate::Options options,
const TensorData& input1, const TensorData& input2,
const TensorData& output, ActivationFunctionType activation_type,
bool allow_fp32_relax_to_fp16 = false) {
SingleOpModelWithNNAPI::Init(nnapi, options);
stateful_delegate_.reset(new StatefulNnApiDelegate(nnapi, options));
SetDelegate(stateful_delegate_.get());
input1_ = AddInput(input1);
input2_ = AddInput(input2);
output_ = AddOutput(output);
SetBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions,
CreateAddOptions(builder_, activation_type).Union());
BuildInterpreter({GetShape(input1_), GetShape(input2_)}, /*num_threads=*/-1,
allow_fp32_relax_to_fp16, /*apply_delegate=*/true);
allow_fp32_relax_to_fp16, /*apply_delegate=*/false);
compilation_status_ = ApplyDelegate();
}
int input1() { return input1_; }
@ -85,12 +64,16 @@ class FloatAddOpModel : public SingleOpModelWithNNAPI {
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
TfLiteStatus GetCompilationStatus() { return compilation_status_; }
protected:
int input1_;
int input2_;
int output_;
private:
std::unique_ptr<StatefulNnApiDelegate> stateful_delegate_;
TfLiteStatus compilation_status_;
};
struct NnApiDeviceSelectionTest
@ -281,10 +264,7 @@ class ArgMaxOpModel : public SingleOpModel, public AcceleratedModel {
void Init(std::initializer_list<int> input_shape, TensorType input_type,
int axis_value, TensorType output_type) {
auto* delegate = GetDelegate();
this->SetApplyDelegate([delegate](Interpreter* interpreter) {
interpreter->ModifyGraphWithDelegate(delegate);
});
SetDelegate(GetDelegate());
input_ = AddInput(input_type);
axis_ = AddConstInput(TensorType_INT32, {axis_value}, {1});
output_ = AddOutput(output_type);
@ -395,10 +375,7 @@ class AddSubOpsAcceleratedModel : public MultiOpModel, public AcceleratedModel {
const std::string& accelerator_name,
bool allow_fp32_relax_to_fp16 = false)
: MultiOpModel(), AcceleratedModel(nnapi, accelerator_name) {
auto* delegate = GetDelegate();
this->SetApplyDelegate([delegate](Interpreter* interpreter) {
interpreter->ModifyGraphWithDelegate(delegate);
});
SetDelegate(GetDelegate());
Init(input1, input2, input3, output, activation_type,
allow_fp32_relax_to_fp16);
}
@ -585,10 +562,7 @@ class HardSwishAddOpsAcceleratedModel : public MultiOpModel,
const std::string& accelerator_name,
bool allow_fp32_relax_to_fp16 = false)
: MultiOpModel(), AcceleratedModel(nnapi, accelerator_name) {
auto* delegate = GetDelegate();
this->SetApplyDelegate([delegate](Interpreter* interpreter) {
interpreter->ModifyGraphWithDelegate(delegate);
});
SetDelegate(GetDelegate());
Init(input1, input2, output, activation_type, allow_fp32_relax_to_fp16);
}
@ -724,10 +698,7 @@ class QuantizedWeightsConvolutionOpModel : public SingleOpModel,
int dilation_width_factor = 1, int dilation_height_factor = 1,
int num_threads = -1, std::initializer_list<uint8_t> filter_data = {})
: SingleOpModel(), AcceleratedModel(nnapi, accelerator_name) {
auto* delegate = GetDelegate();
this->SetApplyDelegate([delegate](Interpreter* interpreter) {
interpreter->ModifyGraphWithDelegate(delegate);
});
SetDelegate(GetDelegate());
input_ = AddInput(input);
@ -861,10 +832,7 @@ class LongIdentityModel : public MultiOpModel, public AcceleratedModel {
private:
void Init(const std::vector<int>& input_shape, int graph_size,
const std::unordered_set<int>& custom_nodes_indexes) {
auto* delegate = GetDelegate();
this->SetApplyDelegate([delegate](Interpreter* interpreter) {
interpreter->ModifyGraphWithDelegate(delegate);
});
SetDelegate(GetDelegate());
const TensorData tensor_data{TensorType_FLOAT32, input_shape};

View File

@ -31,10 +31,7 @@ class SingleOpModelWithNNAPI : public SingleOpModel {
public:
explicit SingleOpModelWithNNAPI(const NnApi* nnapi) {
stateful_delegate_.reset(new StatefulNnApiDelegate(nnapi));
auto* delegate = stateful_delegate_.get();
this->SetApplyDelegate([delegate](Interpreter* interpreter) {
interpreter->ModifyGraphWithDelegate(delegate);
});
this->SetDelegate(stateful_delegate_.get());
}
StatefulNnApiDelegate* GetDelegate() { return stateful_delegate_.get(); }

View File

@ -46,19 +46,12 @@ MATCHER(QuantizedNear, "") {
class SingleOpModelWithNNAPI : public SingleOpModel {
public:
SingleOpModelWithNNAPI() {
this->SetApplyDelegate([](Interpreter* interpreter) {
interpreter->ModifyGraphWithDelegate(NnApiDelegate());
});
}
SingleOpModelWithNNAPI() { SetDelegate(NnApiDelegate()); }
explicit SingleOpModelWithNNAPI(
const StatefulNnApiDelegate::Options& options) {
stateful_delegate_.reset(new StatefulNnApiDelegate(options));
auto* delegate = stateful_delegate_.get();
this->SetApplyDelegate([delegate](Interpreter* interpreter) {
interpreter->ModifyGraphWithDelegate(delegate);
});
SetDelegate(stateful_delegate_.get());
}
TfLiteStatus ResizeInputTensor(int tensor_index,

View File

@ -199,15 +199,16 @@ void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
if (apply_delegate) ApplyDelegate();
}
void SingleOpModel::ApplyDelegate() {
TfLiteStatus SingleOpModel::ApplyDelegate() {
if (force_use_nnapi) {
interpreter_->ModifyGraphWithDelegate(TestNnApiDelegate());
delegate_ = TestNnApiDelegate();
}
// Modify delegate with function.
if (apply_delegate_fn_) {
apply_delegate_fn_(interpreter_.get());
if (delegate_) {
return interpreter_->ModifyGraphWithDelegate(delegate_);
}
return kTfLiteOk;
}
void SingleOpModel::Invoke() { ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); }

View File

@ -160,14 +160,11 @@ class SingleOpModel {
SingleOpModel() {}
~SingleOpModel();
// Set a function callback that is run right after graph is prepared
// that allows applying external delegates. This is useful for testing
// other runtimes like NN API or GPU.
void SetApplyDelegate(std::function<void(Interpreter*)> apply_delegate_fn) {
apply_delegate_fn_ = apply_delegate_fn;
}
// Set a delegate that is applied right after graph is prepared. This is
// useful for testing other runtimes like NN API or GPU.
void SetDelegate(TfLiteDelegate* delegate) { delegate_ = delegate; }
void ApplyDelegate();
TfLiteStatus ApplyDelegate();
// Copying or assignment is disallowed to simplify ownership semantics.
SingleOpModel(const SingleOpModel&) = delete;
@ -755,9 +752,7 @@ class SingleOpModel {
std::vector<int32_t> outputs_;
std::vector<flatbuffers::Offset<Tensor>> tensors_;
std::vector<flatbuffers::Offset<Buffer>> buffers_;
// A function pointer that gets called after the interpreter is created but
// before evaluation happens. This is useful for applying a delegate.
std::function<void(Interpreter*)> apply_delegate_fn_;
TfLiteDelegate* delegate_ = nullptr;
};
// Populate string tensors.