Refactor TfLiteDriver delegate support
Unify delegate/NNAPI arguments, and disable use of NNAPI by default in the tflite_diff tool. Also add support for testing the GPU delegate on Android. PiperOrigin-RevId: 260576456
This commit is contained in:
parent
366ddc8948
commit
d5e4b9b00e
@ -108,7 +108,7 @@ TEST_P(SpeechTest, DISABLED_HotwordOkGoogleRank1Test) {
|
|||||||
"speech_hotword_model_out_rank1.csv", /*input_tensor=*/"0",
|
"speech_hotword_model_out_rank1.csv", /*input_tensor=*/"0",
|
||||||
/*output_tensor=*/"18", /*persistent_tensors=*/"4",
|
/*output_tensor=*/"18", /*persistent_tensors=*/"4",
|
||||||
/*sequence_size=*/40, &os));
|
/*sequence_size=*/40, &os));
|
||||||
testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
|
testing::TfLiteDriver test_driver;
|
||||||
ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
|
ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
|
||||||
<< test_driver.GetErrorMessage();
|
<< test_driver.GetErrorMessage();
|
||||||
}
|
}
|
||||||
@ -120,7 +120,7 @@ TEST_P(SpeechTest, DISABLED_HotwordOkGoogleRank2Test) {
|
|||||||
"speech_hotword_model_out_rank2.csv", /*input_tensor=*/"17",
|
"speech_hotword_model_out_rank2.csv", /*input_tensor=*/"17",
|
||||||
/*output_tensor=*/"18", /*persistent_tensors=*/"1",
|
/*output_tensor=*/"18", /*persistent_tensors=*/"1",
|
||||||
/*sequence_size=*/40, &os));
|
/*sequence_size=*/40, &os));
|
||||||
testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
|
testing::TfLiteDriver test_driver;
|
||||||
ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
|
ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
|
||||||
<< test_driver.GetErrorMessage();
|
<< test_driver.GetErrorMessage();
|
||||||
}
|
}
|
||||||
@ -133,7 +133,7 @@ TEST_P(SpeechTest, DISABLED_SpeakerIdOkGoogleTest) {
|
|||||||
/*output_tensor=*/"63",
|
/*output_tensor=*/"63",
|
||||||
/*persistent_tensors=*/"18,19,38,39,58,59",
|
/*persistent_tensors=*/"18,19,38,39,58,59",
|
||||||
/*sequence_size=*/80, &os));
|
/*sequence_size=*/80, &os));
|
||||||
testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
|
testing::TfLiteDriver test_driver;
|
||||||
ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
|
ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
|
||||||
<< test_driver.GetErrorMessage();
|
<< test_driver.GetErrorMessage();
|
||||||
}
|
}
|
||||||
@ -146,7 +146,7 @@ TEST_P(SpeechTest, AsrAmTest) {
|
|||||||
/*output_tensor=*/"104",
|
/*output_tensor=*/"104",
|
||||||
/*persistent_tensors=*/"18,19,38,39,58,59,78,79,98,99",
|
/*persistent_tensors=*/"18,19,38,39,58,59,78,79,98,99",
|
||||||
/*sequence_size=*/320, &os));
|
/*sequence_size=*/320, &os));
|
||||||
testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
|
testing::TfLiteDriver test_driver;
|
||||||
ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
|
ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
|
||||||
<< test_driver.GetErrorMessage();
|
<< test_driver.GetErrorMessage();
|
||||||
}
|
}
|
||||||
@ -159,7 +159,7 @@ TEST_P(SpeechTest, AsrAmQuantizedTest) {
|
|||||||
/*output_tensor=*/"104",
|
/*output_tensor=*/"104",
|
||||||
/*persistent_tensors=*/"18,19,38,39,58,59,78,79,98,99",
|
/*persistent_tensors=*/"18,19,38,39,58,59,78,79,98,99",
|
||||||
/*sequence_size=*/320, &os));
|
/*sequence_size=*/320, &os));
|
||||||
testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
|
testing::TfLiteDriver test_driver;
|
||||||
ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
|
ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
|
||||||
<< test_driver.GetErrorMessage();
|
<< test_driver.GetErrorMessage();
|
||||||
}
|
}
|
||||||
@ -170,7 +170,7 @@ TEST_P(SpeechTest, AsrAmQuantizedTest) {
|
|||||||
// results.
|
// results.
|
||||||
TEST_P(SpeechTest, DISABLED_AsrLmTest) {
|
TEST_P(SpeechTest, DISABLED_AsrLmTest) {
|
||||||
std::ifstream in_file;
|
std::ifstream in_file;
|
||||||
testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
|
testing::TfLiteDriver test_driver;
|
||||||
ASSERT_TRUE(Init("speech_asr_lm_model.test_spec", &test_driver, &in_file));
|
ASSERT_TRUE(Init("speech_asr_lm_model.test_spec", &test_driver, &in_file));
|
||||||
ASSERT_TRUE(
|
ASSERT_TRUE(
|
||||||
testing::ParseAndRunTests(&in_file, &test_driver, GetMaxInvocations()))
|
testing::ParseAndRunTests(&in_file, &test_driver, GetMaxInvocations()))
|
||||||
@ -185,7 +185,7 @@ TEST_P(SpeechTest, DISABLED_EndpointerTest) {
|
|||||||
/*output_tensor=*/"56",
|
/*output_tensor=*/"56",
|
||||||
/*persistent_tensors=*/"27,28,47,48",
|
/*persistent_tensors=*/"27,28,47,48",
|
||||||
/*sequence_size=*/320, &os));
|
/*sequence_size=*/320, &os));
|
||||||
testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
|
testing::TfLiteDriver test_driver;
|
||||||
ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
|
ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
|
||||||
<< test_driver.GetErrorMessage();
|
<< test_driver.GetErrorMessage();
|
||||||
}
|
}
|
||||||
@ -198,7 +198,7 @@ TEST_P(SpeechTest, DISABLED_TtsTest) {
|
|||||||
/*output_tensor=*/"71",
|
/*output_tensor=*/"71",
|
||||||
/*persistent_tensors=*/"24,25,44,45,64,65,70",
|
/*persistent_tensors=*/"24,25,44,45,64,65,70",
|
||||||
/*sequence_size=*/334, &os));
|
/*sequence_size=*/334, &os));
|
||||||
testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
|
testing::TfLiteDriver test_driver;
|
||||||
ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
|
ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
|
||||||
<< test_driver.GetErrorMessage();
|
<< test_driver.GetErrorMessage();
|
||||||
}
|
}
|
||||||
|
@ -190,6 +190,7 @@ cc_library(
|
|||||||
"//tensorflow/lite/kernels:builtin_ops",
|
"//tensorflow/lite/kernels:builtin_ops",
|
||||||
"//tensorflow/lite/kernels:custom_ops",
|
"//tensorflow/lite/kernels:custom_ops",
|
||||||
"//tensorflow/lite/kernels:reference_ops",
|
"//tensorflow/lite/kernels:reference_ops",
|
||||||
|
"//tensorflow/lite/tools/evaluation:utils",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -368,6 +369,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":split",
|
":split",
|
||||||
":tflite_diff_util",
|
":tflite_diff_util",
|
||||||
|
":tflite_driver",
|
||||||
] + select({
|
] + select({
|
||||||
"//conditions:default": [
|
"//conditions:default": [
|
||||||
"//tensorflow/core:framework_internal",
|
"//tensorflow/core:framework_internal",
|
||||||
|
@ -262,7 +262,9 @@ TEST_P(OpsTest, RunZipTests) {
|
|||||||
|
|
||||||
std::ifstream tflite_stream(tflite_test_case);
|
std::ifstream tflite_stream(tflite_test_case);
|
||||||
ASSERT_TRUE(tflite_stream.is_open()) << tflite_test_case;
|
ASSERT_TRUE(tflite_stream.is_open()) << tflite_test_case;
|
||||||
tflite::testing::TfLiteDriver test_driver(FLAGS_use_nnapi);
|
tflite::testing::TfLiteDriver test_driver(
|
||||||
|
FLAGS_use_nnapi ? TfLiteDriver::DelegateType::kNnapi
|
||||||
|
: TfLiteDriver::DelegateType::kNone);
|
||||||
|
|
||||||
if (test_path.find("fully_quantize=True") != std::string::npos) {
|
if (test_path.find("fully_quantize=True") != std::string::npos) {
|
||||||
// TODO(b/134594898): Tighten this constraint.
|
// TODO(b/134594898): Tighten this constraint.
|
||||||
|
@ -19,10 +19,13 @@ int main(int argc, char** argv) {
|
|||||||
tflite::testing::kernel_test::TestOptions options =
|
tflite::testing::kernel_test::TestOptions options =
|
||||||
tflite::testing::kernel_test::ParseTfliteKernelTestFlags(&argc, argv);
|
tflite::testing::kernel_test::ParseTfliteKernelTestFlags(&argc, argv);
|
||||||
const bool run_reference_kernel = options.kernel_type == "REFERENCE";
|
const bool run_reference_kernel = options.kernel_type == "REFERENCE";
|
||||||
const bool use_nnapi = options.kernel_type == "NNAPI";
|
const tflite::testing::TfLiteDriver::DelegateType delegate_type =
|
||||||
|
options.kernel_type == "NNAPI"
|
||||||
|
? tflite::testing::TfLiteDriver::DelegateType::kNnapi
|
||||||
|
: tflite::testing::TfLiteDriver::DelegateType::kNone;
|
||||||
|
|
||||||
auto runner = absl::make_unique<tflite::testing::TfLiteDriver>(
|
auto runner = absl::make_unique<tflite::testing::TfLiteDriver>(
|
||||||
use_nnapi, "", run_reference_kernel);
|
delegate_type, run_reference_kernel);
|
||||||
if (tflite::testing::kernel_test::RunKernelTest(options, runner.get()) ==
|
if (tflite::testing::kernel_test::RunKernelTest(options, runner.get()) ==
|
||||||
kTfLiteOk) {
|
kTfLiteOk) {
|
||||||
return 0;
|
return 0;
|
||||||
|
@ -34,7 +34,8 @@ TEST(UtilTest, SimpleE2ETest) {
|
|||||||
"tensorflow/lite/testdata/test_input.csv";
|
"tensorflow/lite/testdata/test_input.csv";
|
||||||
options.dump_output_to_file = FLAGS_test_tmpdir + "/test_out.csv";
|
options.dump_output_to_file = FLAGS_test_tmpdir + "/test_out.csv";
|
||||||
options.kernel_type = "REFERENCE";
|
options.kernel_type = "REFERENCE";
|
||||||
std::unique_ptr<TestRunner> runner(new TfLiteDriver(false, "", true));
|
std::unique_ptr<TestRunner> runner(new TfLiteDriver(
|
||||||
|
TfLiteDriver::DelegateType::kNone, /*reference_kernel=*/true));
|
||||||
RunKernelTest(options, runner.get());
|
RunKernelTest(options, runner.get());
|
||||||
std::string expected = "3";
|
std::string expected = "3";
|
||||||
for (int i = 0; i < 1 * 8 * 8 * 3 - 1; i++) {
|
for (int i = 0; i < 1 * 8 * 8 * 3 - 1; i++) {
|
||||||
|
@ -42,7 +42,9 @@ bool Interpret(const char* examples_filename, bool use_nnapi) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
printf("Use nnapi is set to: %d\n", use_nnapi);
|
printf("Use nnapi is set to: %d\n", use_nnapi);
|
||||||
tflite::testing::TfLiteDriver test_driver(use_nnapi);
|
tflite::testing::TfLiteDriver test_driver(
|
||||||
|
use_nnapi ? tflite::testing::TfLiteDriver::DelegateType::kNnapi
|
||||||
|
: tflite::testing::TfLiteDriver::DelegateType::kNone);
|
||||||
|
|
||||||
test_driver.SetModelBaseDir(dirname(examples_filename));
|
test_driver.SetModelBaseDir(dirname(examples_filename));
|
||||||
if (!tflite::testing::ParseAndRunTests(&tflite_stream, &test_driver)) {
|
if (!tflite::testing::ParseAndRunTests(&tflite_stream, &test_driver)) {
|
||||||
|
@ -17,9 +17,10 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
|
||||||
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
#include "tensorflow/lite/testing/split.h"
|
#include "tensorflow/lite/testing/split.h"
|
||||||
#include "tensorflow/lite/testing/tflite_diff_util.h"
|
#include "tensorflow/lite/testing/tflite_diff_util.h"
|
||||||
#include "tensorflow/core/util/command_line_flags.h"
|
#include "tensorflow/lite/testing/tflite_driver.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace testing {
|
namespace testing {
|
||||||
@ -33,9 +34,10 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
|
|||||||
string input_layer_shape;
|
string input_layer_shape;
|
||||||
string output_layer;
|
string output_layer;
|
||||||
int32_t num_runs_per_pass = 100;
|
int32_t num_runs_per_pass = 100;
|
||||||
string delegate;
|
string delegate_name;
|
||||||
} values;
|
} values;
|
||||||
|
|
||||||
|
std::string delegate_name;
|
||||||
std::vector<tensorflow::Flag> flags = {
|
std::vector<tensorflow::Flag> flags = {
|
||||||
tensorflow::Flag("tensorflow_model", &values.tensorflow_model,
|
tensorflow::Flag("tensorflow_model", &values.tensorflow_model,
|
||||||
"Path of tensorflow model."),
|
"Path of tensorflow model."),
|
||||||
@ -55,9 +57,9 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
|
|||||||
"output_1,output_2."),
|
"output_1,output_2."),
|
||||||
tensorflow::Flag("num_runs_per_pass", &values.num_runs_per_pass,
|
tensorflow::Flag("num_runs_per_pass", &values.num_runs_per_pass,
|
||||||
"[optional] Number of full runs in each pass."),
|
"[optional] Number of full runs in each pass."),
|
||||||
tensorflow::Flag("delegate", &values.delegate,
|
tensorflow::Flag("delegate", &values.delegate_name,
|
||||||
"[optional] Delegate to use for executing ops. Must be "
|
"[optional] Delegate to use for executing ops. Must be "
|
||||||
"`{\"\", FLEX}`"),
|
"`{\"\", NNAPI, GPU, FLEX}`"),
|
||||||
};
|
};
|
||||||
|
|
||||||
bool no_inputs = *argc == 1;
|
bool no_inputs = *argc == 1;
|
||||||
@ -70,9 +72,20 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
|
|||||||
values.input_layer_shape.empty() || values.output_layer.empty()) {
|
values.input_layer_shape.empty() || values.output_layer.empty()) {
|
||||||
fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str());
|
fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str());
|
||||||
return {};
|
return {};
|
||||||
} else if (!(values.delegate == "" || values.delegate == "FLEX")) {
|
}
|
||||||
fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str());
|
|
||||||
return {};
|
TfLiteDriver::DelegateType delegate = TfLiteDriver::DelegateType::kNone;
|
||||||
|
if (!values.delegate_name.empty()) {
|
||||||
|
if (delegate_name == "NNAPI") {
|
||||||
|
delegate = TfLiteDriver::DelegateType::kNnapi;
|
||||||
|
} else if (values.delegate_name == "GPU") {
|
||||||
|
delegate = TfLiteDriver::DelegateType::kGpu;
|
||||||
|
} else if (values.delegate_name == "FLEX") {
|
||||||
|
delegate = TfLiteDriver::DelegateType::kFlex;
|
||||||
|
} else {
|
||||||
|
fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str());
|
||||||
|
return {};
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return {values.tensorflow_model,
|
return {values.tensorflow_model,
|
||||||
@ -82,7 +95,7 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
|
|||||||
Split<string>(values.input_layer_shape, ":"),
|
Split<string>(values.input_layer_shape, ":"),
|
||||||
Split<string>(values.output_layer, ","),
|
Split<string>(values.output_layer, ","),
|
||||||
values.num_runs_per_pass,
|
values.num_runs_per_pass,
|
||||||
values.delegate};
|
delegate};
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace testing
|
} // namespace testing
|
||||||
|
@ -33,7 +33,7 @@ bool RunDiffTest(const DiffOptions& options, int num_invocations) {
|
|||||||
options.input_layer_shape, options.output_layer)) {
|
options.input_layer_shape, options.output_layer)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
TfLiteDriver tflite_driver(/*use_nnapi=*/true, options.delegate);
|
TfLiteDriver tflite_driver(options.delegate);
|
||||||
tflite_driver.LoadModel(options.tflite_model);
|
tflite_driver.LoadModel(options.tflite_model);
|
||||||
return tflite::testing::ParseAndRunTests(&tflite_stream, &tflite_driver);
|
return tflite::testing::ParseAndRunTests(&tflite_stream, &tflite_driver);
|
||||||
}
|
}
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/lite/string.h"
|
#include "tensorflow/lite/string.h"
|
||||||
|
#include "tensorflow/lite/testing/tflite_driver.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace testing {
|
namespace testing {
|
||||||
@ -44,9 +45,8 @@ struct DiffOptions {
|
|||||||
// each of the passes. The first pass has a single inference, while the
|
// each of the passes. The first pass has a single inference, while the
|
||||||
// second pass does multiple inferences back to back.
|
// second pass does multiple inferences back to back.
|
||||||
int num_runs_per_pass;
|
int num_runs_per_pass;
|
||||||
// Path to the delegate library to be loaded in order to execute ops. Must be
|
// The type of delegate to apply during inference.
|
||||||
// `{"", FLEX}`.
|
TfLiteDriver::DelegateType delegate;
|
||||||
string delegate;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Run a single TensorFLow Lite diff test with a given options.
|
// Run a single TensorFLow Lite diff test with a given options.
|
||||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/string_util.h"
|
#include "tensorflow/lite/string_util.h"
|
||||||
#include "tensorflow/lite/testing/join.h"
|
#include "tensorflow/lite/testing/join.h"
|
||||||
#include "tensorflow/lite/testing/split.h"
|
#include "tensorflow/lite/testing/split.h"
|
||||||
|
#include "tensorflow/lite/tools/evaluation/utils.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace testing {
|
namespace testing {
|
||||||
@ -259,9 +260,8 @@ bool TfLiteDriver::Expectation::Check(bool verbose,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteDriver::TfLiteDriver(bool use_nnapi, const string& delegate_name,
|
TfLiteDriver::TfLiteDriver(DelegateType delegate_type, bool reference_kernel)
|
||||||
bool reference_kernel)
|
: delegate_(nullptr, nullptr),
|
||||||
: use_nnapi_(use_nnapi),
|
|
||||||
relative_threshold_(kRelativeThreshold),
|
relative_threshold_(kRelativeThreshold),
|
||||||
absolute_threshold_(kAbsoluteThreshold) {
|
absolute_threshold_(kAbsoluteThreshold) {
|
||||||
if (reference_kernel) {
|
if (reference_kernel) {
|
||||||
@ -274,8 +274,21 @@ TfLiteDriver::TfLiteDriver(bool use_nnapi, const string& delegate_name,
|
|||||||
tflite::ops::custom::Register_RFFT2D());
|
tflite::ops::custom::Register_RFFT2D());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (delegate_name == "FLEX") {
|
switch (delegate_type) {
|
||||||
delegate_ = FlexDelegate::Create();
|
case DelegateType::kNone:
|
||||||
|
break;
|
||||||
|
case DelegateType::kNnapi:
|
||||||
|
delegate_ = evaluation::CreateNNAPIDelegate();
|
||||||
|
break;
|
||||||
|
case DelegateType::kGpu:
|
||||||
|
delegate_ = evaluation::CreateGPUDelegate(/*model=*/nullptr);
|
||||||
|
break;
|
||||||
|
case DelegateType::kFlex:
|
||||||
|
delegate_ = Interpreter::TfLiteDelegatePtr(
|
||||||
|
FlexDelegate::Create().release(), [](TfLiteDelegate* delegate) {
|
||||||
|
delete static_cast<tflite::FlexDelegate*>(delegate);
|
||||||
|
});
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -310,8 +323,6 @@ void TfLiteDriver::LoadModel(const string& bin_file_path) {
|
|||||||
Invalidate("Failed build interpreter");
|
Invalidate("Failed build interpreter");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
interpreter_->UseNNAPI(use_nnapi_);
|
|
||||||
|
|
||||||
if (delegate_) {
|
if (delegate_) {
|
||||||
if (interpreter_->ModifyGraphWithDelegate(delegate_.get()) != kTfLiteOk) {
|
if (interpreter_->ModifyGraphWithDelegate(delegate_.get()) != kTfLiteOk) {
|
||||||
Invalidate("Unable to the build graph using the delegate");
|
Invalidate("Unable to the build graph using the delegate");
|
||||||
|
@ -31,7 +31,19 @@ namespace testing {
|
|||||||
// A test runner that feeds inputs into TF Lite and verifies its outputs.
|
// A test runner that feeds inputs into TF Lite and verifies its outputs.
|
||||||
class TfLiteDriver : public TestRunner {
|
class TfLiteDriver : public TestRunner {
|
||||||
public:
|
public:
|
||||||
explicit TfLiteDriver(bool use_nnapi, const string& delegate = "",
|
enum class DelegateType {
|
||||||
|
kNone,
|
||||||
|
kNnapi,
|
||||||
|
kGpu,
|
||||||
|
kFlex,
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a new TfLiteDriver
|
||||||
|
* @param delegate The (optional) delegate to use.
|
||||||
|
* @param reference_kernel Whether to use the builtin reference kernel ops.
|
||||||
|
*/
|
||||||
|
explicit TfLiteDriver(DelegateType delegate_type = DelegateType::kNone,
|
||||||
bool reference_kernel = false);
|
bool reference_kernel = false);
|
||||||
~TfLiteDriver() override;
|
~TfLiteDriver() override;
|
||||||
|
|
||||||
@ -71,8 +83,7 @@ class TfLiteDriver : public TestRunner {
|
|||||||
class Expectation;
|
class Expectation;
|
||||||
|
|
||||||
std::unique_ptr<OpResolver> resolver_;
|
std::unique_ptr<OpResolver> resolver_;
|
||||||
std::unique_ptr<FlexDelegate> delegate_;
|
Interpreter::TfLiteDelegatePtr delegate_;
|
||||||
bool use_nnapi_ = false;
|
|
||||||
std::unique_ptr<FlatBufferModel> model_;
|
std::unique_ptr<FlatBufferModel> model_;
|
||||||
std::unique_ptr<Interpreter> interpreter_;
|
std::unique_ptr<Interpreter> interpreter_;
|
||||||
std::map<int, std::unique_ptr<Expectation>> expected_output_;
|
std::map<int, std::unique_ptr<Expectation>> expected_output_;
|
||||||
|
@ -24,7 +24,7 @@ namespace {
|
|||||||
using ::testing::ElementsAre;
|
using ::testing::ElementsAre;
|
||||||
|
|
||||||
TEST(TfliteDriverTest, SimpleTest) {
|
TEST(TfliteDriverTest, SimpleTest) {
|
||||||
std::unique_ptr<TestRunner> runner(new TfLiteDriver(/*use_nnapi=*/false));
|
std::unique_ptr<TestRunner> runner(new TfLiteDriver());
|
||||||
|
|
||||||
runner->SetModelBaseDir("tensorflow/lite");
|
runner->SetModelBaseDir("tensorflow/lite");
|
||||||
runner->LoadModel("testdata/multi_add.bin");
|
runner->LoadModel("testdata/multi_add.bin");
|
||||||
@ -60,7 +60,8 @@ TEST(TfliteDriverTest, SimpleTest) {
|
|||||||
|
|
||||||
TEST(TfliteDriverTest, SingleAddOpTest) {
|
TEST(TfliteDriverTest, SingleAddOpTest) {
|
||||||
std::unique_ptr<TestRunner> runner(new TfLiteDriver(
|
std::unique_ptr<TestRunner> runner(new TfLiteDriver(
|
||||||
/*use_nnapi*/ false, /*delegate*/ "", /*reference_kernel*/ true));
|
/*delegate_type=*/TfLiteDriver::DelegateType::kNone,
|
||||||
|
/*reference_kernel=*/true));
|
||||||
|
|
||||||
runner->SetModelBaseDir("tensorflow/lite");
|
runner->SetModelBaseDir("tensorflow/lite");
|
||||||
runner->LoadModel("testdata/multi_add.bin");
|
runner->LoadModel("testdata/multi_add.bin");
|
||||||
@ -95,7 +96,7 @@ TEST(TfliteDriverTest, SingleAddOpTest) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(TfliteDriverTest, AddQuantizedInt8Test) {
|
TEST(TfliteDriverTest, AddQuantizedInt8Test) {
|
||||||
std::unique_ptr<TestRunner> runner(new TfLiteDriver(/*use_nnapi=*/false));
|
std::unique_ptr<TestRunner> runner(new TfLiteDriver());
|
||||||
|
|
||||||
runner->SetModelBaseDir("tensorflow/lite");
|
runner->SetModelBaseDir("tensorflow/lite");
|
||||||
runner->LoadModel("testdata/add_quantized_int8.bin");
|
runner->LoadModel("testdata/add_quantized_int8.bin");
|
||||||
|
@ -73,6 +73,7 @@ TfLiteStatus GetSortedFileNames(const std::string& directory,
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(b/138448769): Migrate delegate helper APIs to lite/testing.
|
||||||
Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate() {
|
Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate() {
|
||||||
#if defined(__ANDROID__)
|
#if defined(__ANDROID__)
|
||||||
return Interpreter::TfLiteDelegatePtr(
|
return Interpreter::TfLiteDelegatePtr(
|
||||||
@ -108,7 +109,8 @@ Interpreter::TfLiteDelegatePtr CreateGPUDelegate(
|
|||||||
tflite::FlatBufferModel* model) {
|
tflite::FlatBufferModel* model) {
|
||||||
#if defined(__ANDROID__)
|
#if defined(__ANDROID__)
|
||||||
TfLiteGpuDelegateOptions options;
|
TfLiteGpuDelegateOptions options;
|
||||||
options.metadata = TfLiteGpuDelegateGetModelMetadata(model->GetModel());
|
options.metadata =
|
||||||
|
model ? TfLiteGpuDelegateGetModelMetadata(model->GetModel()) : nullptr;
|
||||||
options.compile_options.precision_loss_allowed = 1;
|
options.compile_options.precision_loss_allowed = 1;
|
||||||
options.compile_options.preferred_gl_object_type =
|
options.compile_options.preferred_gl_object_type =
|
||||||
TFLITE_GL_OBJECT_TYPE_FASTEST;
|
TFLITE_GL_OBJECT_TYPE_FASTEST;
|
||||||
|
Loading…
Reference in New Issue
Block a user