Extract the KernelTestDelegateProviders into a separate library from test_util, and apply it in lite/testing/tflite_driver.
PiperOrigin-RevId: 323208799 Change-Id: I047f9aa54e32263c0b21aa673bea8cc7de751ba7
This commit is contained in:
parent
53ae4101be
commit
03300ba696
@ -171,7 +171,7 @@ cc_library(
|
||||
deps = [
|
||||
":acceleration_test_util",
|
||||
":builtin_ops",
|
||||
":test_util_delegate_providers",
|
||||
":test_delegate_providers_lib",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite:schema_fbs_version",
|
||||
@ -189,7 +189,6 @@ cc_library(
|
||||
"//tensorflow/lite/tools:command_line_flags",
|
||||
"//tensorflow/lite/tools:logging",
|
||||
"//tensorflow/lite/tools:tool_params",
|
||||
"//tensorflow/lite/tools/delegates:delegate_provider_hdr",
|
||||
"//tensorflow/lite/tools/optimize:quantization_utils",
|
||||
"//tensorflow/lite/tools/optimize/sparsity:format_converter",
|
||||
"//tensorflow/lite/tools/versioning",
|
||||
@ -198,7 +197,8 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
# A convenient library for tflite delegate execution providers.
|
||||
# A convenient library of tflite delegate execution providers for kernel tests
|
||||
# based on SingleOpModel or its derivatives defined in test_util.h/cc.
|
||||
cc_library(
|
||||
name = "test_util_delegate_providers",
|
||||
copts = tflite_copts(),
|
||||
@ -220,13 +220,28 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "test_delegate_providers_lib",
|
||||
srcs = ["test_delegate_providers.cc"],
|
||||
hdrs = ["test_delegate_providers.h"],
|
||||
copts = tflite_copts(),
|
||||
deps = [
|
||||
"//tensorflow/lite/tools:command_line_flags",
|
||||
"//tensorflow/lite/tools:logging",
|
||||
"//tensorflow/lite/tools:tool_params",
|
||||
"//tensorflow/lite/tools/delegates:delegate_provider_hdr",
|
||||
],
|
||||
)
|
||||
|
||||
# TODO(b/132204084): Create tflite_cc_test rule to automate test_main inclusion.
|
||||
cc_library(
|
||||
name = "test_main",
|
||||
testonly = 1,
|
||||
srcs = ["test_main.cc"],
|
||||
deps = [
|
||||
":test_delegate_providers_lib",
|
||||
":test_util",
|
||||
":test_util_delegate_providers",
|
||||
"//tensorflow/lite/testing:util",
|
||||
"//tensorflow/lite/tools:command_line_flags",
|
||||
"@com_google_googletest//:gtest",
|
||||
@ -456,6 +471,17 @@ cc_test(
|
||||
name = "test_util_test",
|
||||
size = "small",
|
||||
srcs = ["test_util_test.cc"],
|
||||
deps = [
|
||||
":test_util",
|
||||
"//tensorflow/lite/testing:util",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "test_delegate_providers_lib_test",
|
||||
size = "small",
|
||||
srcs = ["test_delegate_providers_test.cc"],
|
||||
# See details in https://github.com/bazelbuild/bazel/issues/11552 to avoid
|
||||
# lazy symbol binding failure on macOS.
|
||||
linkstatic = select({
|
||||
@ -463,9 +489,11 @@ cc_test(
|
||||
"//conditions:default": False,
|
||||
}),
|
||||
deps = [
|
||||
":test_util",
|
||||
"//tensorflow/lite/testing:util",
|
||||
"@com_google_googletest//:gtest",
|
||||
":test_delegate_providers_lib",
|
||||
"//tensorflow/lite/tools/delegates:default_execution_provider",
|
||||
"//tensorflow/lite/tools/delegates:nnapi_delegate_provider",
|
||||
"//tensorflow/lite/tools/delegates:xnnpack_delegate_provider",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
|
57
tensorflow/lite/kernels/test_delegate_providers.cc
Normal file
57
tensorflow/lite/kernels/test_delegate_providers.cc
Normal file
@ -0,0 +1,57 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/kernels/test_delegate_providers.h"
|
||||
|
||||
#include "tensorflow/lite/tools/command_line_flags.h"
|
||||
#include "tensorflow/lite/tools/logging.h"
|
||||
|
||||
namespace tflite {
|
||||
/*static*/ KernelTestDelegateProviders* KernelTestDelegateProviders::Get() {
|
||||
static KernelTestDelegateProviders* const providers =
|
||||
new KernelTestDelegateProviders();
|
||||
return providers;
|
||||
}
|
||||
|
||||
KernelTestDelegateProviders::KernelTestDelegateProviders() {
|
||||
for (const auto& one : tools::GetRegisteredDelegateProviders()) {
|
||||
params_.Merge(one->DefaultParams());
|
||||
}
|
||||
}
|
||||
|
||||
bool KernelTestDelegateProviders::InitFromCmdlineArgs(int* argc,
|
||||
const char** argv) {
|
||||
std::vector<tflite::Flag> flags;
|
||||
for (const auto& one : tools::GetRegisteredDelegateProviders()) {
|
||||
auto one_flags = one->CreateFlags(¶ms_);
|
||||
flags.insert(flags.end(), one_flags.begin(), one_flags.end());
|
||||
}
|
||||
return tflite::Flags::Parse(argc, argv, flags);
|
||||
}
|
||||
|
||||
std::vector<tools::TfLiteDelegatePtr>
|
||||
KernelTestDelegateProviders::CreateAllDelegates(
|
||||
const tools::ToolParams& params) const {
|
||||
std::vector<tools::TfLiteDelegatePtr> delegates;
|
||||
for (const auto& one : tools::GetRegisteredDelegateProviders()) {
|
||||
auto ptr = one->CreateTfLiteDelegate(params);
|
||||
// It's possible that a delegate of certain type won't be created as
|
||||
// user-specified benchmark params tells not to.
|
||||
if (ptr == nullptr) continue;
|
||||
delegates.emplace_back(std::move(ptr));
|
||||
TFLITE_LOG(INFO) << one->GetName() << " delegate is created.";
|
||||
}
|
||||
return delegates;
|
||||
}
|
||||
} // namespace tflite
|
71
tensorflow/lite/kernels/test_delegate_providers.h
Normal file
71
tensorflow/lite/kernels/test_delegate_providers.h
Normal file
@ -0,0 +1,71 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_TEST_DELEGATE_PROVIDERS_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_TEST_DELEGATE_PROVIDERS_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/tools/delegates/delegate_provider.h"
|
||||
#include "tensorflow/lite/tools/tool_params.h"
|
||||
|
||||
namespace tflite {
|
||||
// A utility class to provide TfLite delegate creations for kernel tests. The
|
||||
// options of a particular delegate could be specified from commandline flags by
|
||||
// using the delegate provider registrar as implemented in lite/tools/delegates
|
||||
// directory.
|
||||
class KernelTestDelegateProviders {
|
||||
public:
|
||||
// Returns a global KernelTestDelegateProviders instance.
|
||||
static KernelTestDelegateProviders* Get();
|
||||
|
||||
KernelTestDelegateProviders();
|
||||
|
||||
// Initialize delegate-related parameters from commandline arguments and
|
||||
// returns true if successful.
|
||||
bool InitFromCmdlineArgs(int* argc, const char** argv);
|
||||
|
||||
// This provides a way to overwrite parameter values programmatically before
|
||||
// creating TfLite delegates. Note, changes to the returned ToolParams will
|
||||
// have a global impact on creating TfLite delegates.
|
||||
// If a local-only change is preferred, recommend using the following workflow
|
||||
// create TfLite delegates via delegate providers:
|
||||
// tools::ToolParams local_params;
|
||||
// local_params.Merge(KernelTestDelegateProviders::Get()->ConstParams());
|
||||
// Overwrite params in local_params by calling local_params.Set<...>(...);
|
||||
// Get TfLite delegates via
|
||||
// KernelTestDelegateProviders::Get()->CreateAllDelegates(local_params);
|
||||
tools::ToolParams* MutableParams() { return ¶ms_; }
|
||||
const tools::ToolParams& ConstParams() const { return params_; }
|
||||
|
||||
// Create a list of TfLite delegates based on the provided parameters
|
||||
// `params`.
|
||||
std::vector<tools::TfLiteDelegatePtr> CreateAllDelegates(
|
||||
const tools::ToolParams& params) const;
|
||||
|
||||
// Similar to the above, but creating a list of TfLite delegates based on what
|
||||
// have been initialized (i.e. 'params_').
|
||||
std::vector<tools::TfLiteDelegatePtr> CreateAllDelegates() const {
|
||||
return CreateAllDelegates(params_);
|
||||
}
|
||||
|
||||
private:
|
||||
// Contain delegate-related parameters that are initialized from command-line
|
||||
// flags.
|
||||
tools::ToolParams params_;
|
||||
};
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_KERNELS_TEST_DELEGATE_PROVIDERS_H_
|
50
tensorflow/lite/kernels/test_delegate_providers_test.cc
Normal file
50
tensorflow/lite/kernels/test_delegate_providers_test.cc
Normal file
@ -0,0 +1,50 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/kernels/test_delegate_providers.h"
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
TEST(KernelTestDelegateProvidersTest, DelegateProvidersParams) {
|
||||
KernelTestDelegateProviders providers;
|
||||
const auto& params = providers.ConstParams();
|
||||
EXPECT_TRUE(params.HasParam("use_xnnpack"));
|
||||
EXPECT_TRUE(params.HasParam("use_nnapi"));
|
||||
|
||||
int argc = 3;
|
||||
const char* argv[] = {"program_name", "--use_nnapi=true",
|
||||
"--other_undefined_flag=1"};
|
||||
EXPECT_TRUE(providers.InitFromCmdlineArgs(&argc, argv));
|
||||
EXPECT_TRUE(params.Get<bool>("use_nnapi"));
|
||||
EXPECT_EQ(2, argc);
|
||||
EXPECT_EQ("--other_undefined_flag=1", argv[1]);
|
||||
}
|
||||
|
||||
TEST(KernelTestDelegateProvidersTest, CreateTfLiteDelegates) {
|
||||
#if !defined(__Fuchsia__) && !defined(TFLITE_WITHOUT_XNNPACK)
|
||||
KernelTestDelegateProviders providers;
|
||||
providers.MutableParams()->Set<bool>("use_xnnpack", true);
|
||||
EXPECT_GE(providers.CreateAllDelegates().size(), 1);
|
||||
|
||||
tools::ToolParams local_params;
|
||||
local_params.Merge(providers.ConstParams());
|
||||
local_params.Set<bool>("use_xnnpack", false);
|
||||
EXPECT_TRUE(providers.CreateAllDelegates(local_params).empty());
|
||||
#endif
|
||||
}
|
||||
} // namespace
|
||||
} // namespace tflite
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/kernels/test_delegate_providers.h"
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
#include "tensorflow/lite/testing/util.h"
|
||||
#include "tensorflow/lite/tools/command_line_flags.h"
|
||||
|
@ -40,12 +40,12 @@ limitations under the License.
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/acceleration_test_util.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/kernels/test_delegate_providers.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/nnapi/nnapi_implementation.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/string_type.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
#include "tensorflow/lite/tools/command_line_flags.h"
|
||||
#include "tensorflow/lite/tools/logging.h"
|
||||
#include "tensorflow/lite/tools/versioning/op_version.h"
|
||||
#include "tensorflow/lite/version.h"
|
||||
@ -234,8 +234,12 @@ void SingleOpModel::BuildInterpreter(
|
||||
|
||||
// static
|
||||
bool SingleOpModel::GetForceUseNnapi() {
|
||||
return tflite::KernelTestDelegateProviders::Get()->ConstParams().Get<bool>(
|
||||
"use_nnapi");
|
||||
const auto& delegate_params =
|
||||
tflite::KernelTestDelegateProviders::Get()->ConstParams();
|
||||
// It's possible this library isn't linked with the nnapi delegate provider
|
||||
// lib.
|
||||
return delegate_params.HasParam("use_nnapi") &&
|
||||
delegate_params.Get<bool>("use_nnapi");
|
||||
}
|
||||
|
||||
int32_t SingleOpModel::GetTensorSize(int index) const {
|
||||
@ -374,41 +378,4 @@ void MultiOpModel::AddCustomOp(
|
||||
builder_.CreateVector<uint8_t>(custom_option),
|
||||
CustomOptionsFormat_FLEXBUFFERS));
|
||||
}
|
||||
|
||||
/*static*/ KernelTestDelegateProviders* KernelTestDelegateProviders::Get() {
|
||||
static KernelTestDelegateProviders* const providers =
|
||||
new KernelTestDelegateProviders();
|
||||
return providers;
|
||||
}
|
||||
|
||||
KernelTestDelegateProviders::KernelTestDelegateProviders() {
|
||||
for (const auto& one : tools::GetRegisteredDelegateProviders()) {
|
||||
params_.Merge(one->DefaultParams());
|
||||
}
|
||||
}
|
||||
|
||||
bool KernelTestDelegateProviders::InitFromCmdlineArgs(int* argc,
|
||||
const char** argv) {
|
||||
std::vector<tflite::Flag> flags;
|
||||
for (const auto& one : tools::GetRegisteredDelegateProviders()) {
|
||||
auto one_flags = one->CreateFlags(¶ms_);
|
||||
flags.insert(flags.end(), one_flags.begin(), one_flags.end());
|
||||
}
|
||||
return tflite::Flags::Parse(argc, argv, flags);
|
||||
}
|
||||
|
||||
std::vector<tools::TfLiteDelegatePtr>
|
||||
KernelTestDelegateProviders::CreateAllDelegates(
|
||||
const tools::ToolParams& params) const {
|
||||
std::vector<tools::TfLiteDelegatePtr> delegates;
|
||||
for (const auto& one : tools::GetRegisteredDelegateProviders()) {
|
||||
auto ptr = one->CreateTfLiteDelegate(params);
|
||||
// It's possible that a delegate of certain type won't be created as
|
||||
// user-specified benchmark params tells not to.
|
||||
if (ptr == nullptr) continue;
|
||||
delegates.emplace_back(std::move(ptr));
|
||||
TFLITE_LOG(INFO) << one->GetName() << " delegate is created.";
|
||||
}
|
||||
return delegates;
|
||||
}
|
||||
} // namespace tflite
|
||||
|
@ -45,10 +45,8 @@ limitations under the License.
|
||||
#include "tensorflow/lite/string_type.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
#include "tensorflow/lite/testing/util.h" // IWYU pragma: keep
|
||||
#include "tensorflow/lite/tools/delegates/delegate_provider.h"
|
||||
#include "tensorflow/lite/tools/optimize/quantization_utils.h"
|
||||
#include "tensorflow/lite/tools/optimize/sparsity/format_converter.h"
|
||||
#include "tensorflow/lite/tools/tool_params.h"
|
||||
#include "tensorflow/lite/type_to_tflitetype.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -899,52 +897,6 @@ class MultiOpModel : public SingleOpModel {
|
||||
return AddTensor<T>(t, {}, false);
|
||||
}
|
||||
};
|
||||
|
||||
// A utility class to provide TfLite delegate creations for kernel tests. The
|
||||
// options of a particular delegate could be specified from commandline flags by
|
||||
// using the delegate provider registrar as implemented in lite/tools/delegates
|
||||
// directory.
|
||||
class KernelTestDelegateProviders {
|
||||
public:
|
||||
// Returns a global KernelTestDelegateProviders instance.
|
||||
static KernelTestDelegateProviders* Get();
|
||||
|
||||
KernelTestDelegateProviders();
|
||||
|
||||
// Initialize delegate-related parameters from commandline arguments and
|
||||
// returns true if successful.
|
||||
bool InitFromCmdlineArgs(int* argc, const char** argv);
|
||||
|
||||
// This provides a way to overwrite parameter values programmatically before
|
||||
// creating TfLite delegates. Note, changes to the returned ToolParams will
|
||||
// have a global impact on creating TfLite delegates.
|
||||
// If a local-only change is preferred, recommend using the following workflow
|
||||
// create TfLite delegates via delegate providers:
|
||||
// tools::ToolParams local_params;
|
||||
// local_params.Merge(KernelTestDelegateProviders::Get()->ConstParams());
|
||||
// Overwrite params in local_params by calling local_params.Set<...>(...);
|
||||
// Get TfLite delegates via
|
||||
// KernelTestDelegateProviders::Get()->CreateAllDelegates(local_params);
|
||||
tools::ToolParams* MutableParams() { return ¶ms_; }
|
||||
const tools::ToolParams& ConstParams() const { return params_; }
|
||||
|
||||
// Create a list of TfLite delegates based on the provided parameters
|
||||
// `params`.
|
||||
std::vector<tools::TfLiteDelegatePtr> CreateAllDelegates(
|
||||
const tools::ToolParams& params) const;
|
||||
|
||||
// Similar to the above, but creating a list of TfLite delegates based on what
|
||||
// have been initialized (i.e. 'params_').
|
||||
std::vector<tools::TfLiteDelegatePtr> CreateAllDelegates() const {
|
||||
return CreateAllDelegates(params_);
|
||||
}
|
||||
|
||||
private:
|
||||
// Contain delegate-related parameters that are initialized from command-line
|
||||
// flags.
|
||||
tools::ToolParams params_;
|
||||
};
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_KERNELS_TEST_UTIL_H_
|
||||
|
@ -47,33 +47,6 @@ TEST(TestUtilTest, QuantizeVectorScalingUp) {
|
||||
EXPECT_THAT(q_data, ElementsAreArray(expected));
|
||||
}
|
||||
|
||||
TEST(KernelTestDelegateProvidersTest, DelegateProvidersParams) {
|
||||
KernelTestDelegateProviders providers;
|
||||
const auto& params = providers.ConstParams();
|
||||
EXPECT_TRUE(params.HasParam("use_xnnpack"));
|
||||
EXPECT_TRUE(params.HasParam("use_nnapi"));
|
||||
|
||||
int argc = 3;
|
||||
const char* argv[] = {"program_name", "--use_nnapi=true",
|
||||
"--other_undefined_flag=1"};
|
||||
EXPECT_TRUE(providers.InitFromCmdlineArgs(&argc, argv));
|
||||
EXPECT_TRUE(params.Get<bool>("use_nnapi"));
|
||||
EXPECT_EQ(2, argc);
|
||||
EXPECT_EQ("--other_undefined_flag=1", argv[1]);
|
||||
}
|
||||
|
||||
TEST(KernelTestDelegateProvidersTest, CreateTfLiteDelegates) {
|
||||
#if !defined(__Fuchsia__) && !defined(TFLITE_WITHOUT_XNNPACK)
|
||||
KernelTestDelegateProviders providers;
|
||||
providers.MutableParams()->Set<bool>("use_xnnpack", true);
|
||||
EXPECT_GE(providers.CreateAllDelegates().size(), 1);
|
||||
|
||||
tools::ToolParams local_params;
|
||||
local_params.Merge(providers.ConstParams());
|
||||
local_params.Set<bool>("use_xnnpack", false);
|
||||
EXPECT_TRUE(providers.CreateAllDelegates(local_params).empty());
|
||||
#endif
|
||||
}
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
|
@ -59,12 +59,14 @@ exports_files([
|
||||
deps = [
|
||||
":parse_testdata_lib",
|
||||
":tflite_driver",
|
||||
":tflite_driver_delegate_providers",
|
||||
":util",
|
||||
"@com_google_googletest//:gtest",
|
||||
"@com_googlesource_code_re2//:re2",
|
||||
"//tensorflow/lite:builtin_op_data",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"//tensorflow/lite/kernels:test_delegate_providers_lib",
|
||||
] + select({
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:framework_internal",
|
||||
@ -229,8 +231,9 @@ cc_library(
|
||||
"//tensorflow/lite:string_util",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"//tensorflow/lite/kernels:custom_ops",
|
||||
"//tensorflow/lite/kernels/hashtable:hashtable_op_kernels",
|
||||
"//tensorflow/lite/kernels:reference_ops",
|
||||
"//tensorflow/lite/kernels:test_delegate_providers_lib",
|
||||
"//tensorflow/lite/kernels/hashtable:hashtable_op_kernels",
|
||||
"//tensorflow/lite/tools/evaluation:utils",
|
||||
] + select({
|
||||
"//tensorflow:ios": [],
|
||||
@ -238,6 +241,22 @@ cc_library(
|
||||
}),
|
||||
)
|
||||
|
||||
# A convenient library of tflite delegate execution providers for tests based
|
||||
# on the `tflite_driver` library.
|
||||
cc_library(
|
||||
name = "tflite_driver_delegate_providers",
|
||||
deps = [
|
||||
"//tensorflow/lite/tools/delegates:coreml_delegate_provider",
|
||||
"//tensorflow/lite/tools/delegates:default_execution_provider",
|
||||
"//tensorflow/lite/tools/delegates:external_delegate_provider",
|
||||
"//tensorflow/lite/tools/delegates:gpu_delegate_provider",
|
||||
"//tensorflow/lite/tools/delegates:hexagon_delegate_provider",
|
||||
"//tensorflow/lite/tools/delegates:nnapi_delegate_provider",
|
||||
"//tensorflow/lite/tools/delegates:xnnpack_delegate_provider",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "tflite_driver_test",
|
||||
size = "small",
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/subprocess.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
#include "tensorflow/lite/kernels/test_delegate_providers.h"
|
||||
#include "tensorflow/lite/testing/parse_testdata.h"
|
||||
#include "tensorflow/lite/testing/tflite_driver.h"
|
||||
#include "tensorflow/lite/testing/util.h"
|
||||
@ -47,7 +48,6 @@ string* FLAGS_tar_binary_path = new string("/bin/tar");
|
||||
string* FLAGS_unzip_binary_path = new string("/system/bin/unzip");
|
||||
string* FLAGS_tar_binary_path = new string("/system/bin/tar");
|
||||
#endif
|
||||
bool FLAGS_use_nnapi = false;
|
||||
bool FLAGS_ignore_unsupported_nnapi = false;
|
||||
} // namespace
|
||||
|
||||
@ -298,9 +298,10 @@ TEST_P(OpsTest, RunZipTests) {
|
||||
|
||||
std::ifstream tflite_stream(tflite_test_case);
|
||||
ASSERT_TRUE(tflite_stream.is_open()) << tflite_test_case;
|
||||
tflite::testing::TfLiteDriver test_driver(
|
||||
FLAGS_use_nnapi ? TfLiteDriver::DelegateType::kNnapi
|
||||
: TfLiteDriver::DelegateType::kNone);
|
||||
tflite::testing::TfLiteDriver test_driver;
|
||||
const bool use_nnapi =
|
||||
tflite::KernelTestDelegateProviders::Get()->ConstParams().Get<bool>(
|
||||
"use_nnapi");
|
||||
|
||||
auto quantized_tests_error = GetQuantizeTestsError();
|
||||
bool fully_quantize = false;
|
||||
@ -317,7 +318,7 @@ TEST_P(OpsTest, RunZipTests) {
|
||||
test_driver.SetModelBaseDir(tflite_dir);
|
||||
|
||||
auto broken_tests = GetKnownBrokenTests();
|
||||
if (FLAGS_use_nnapi) {
|
||||
if (use_nnapi) {
|
||||
auto kBrokenNnapiTests = GetKnownBrokenNnapiTests();
|
||||
broken_tests.insert(kBrokenNnapiTests.begin(), kBrokenNnapiTests.end());
|
||||
}
|
||||
@ -334,7 +335,7 @@ TEST_P(OpsTest, RunZipTests) {
|
||||
}
|
||||
}
|
||||
if (bug_number.empty()) {
|
||||
if (FLAGS_use_nnapi && FLAGS_ignore_unsupported_nnapi && !result) {
|
||||
if (use_nnapi && FLAGS_ignore_unsupported_nnapi && !result) {
|
||||
EXPECT_EQ(message, string("Failed to invoke interpreter")) << message;
|
||||
} else {
|
||||
EXPECT_TRUE(result) << message;
|
||||
@ -408,8 +409,6 @@ int main(int argc, char** argv) {
|
||||
tensorflow::Flag("tar_binary_path",
|
||||
tflite::testing::FLAGS_tar_binary_path,
|
||||
"Location of a suitable tar binary."),
|
||||
tensorflow::Flag("use_nnapi", &tflite::testing::FLAGS_use_nnapi,
|
||||
"Whether to enable the NNAPI delegate"),
|
||||
tensorflow::Flag("ignore_unsupported_nnapi",
|
||||
&tflite::testing::FLAGS_ignore_unsupported_nnapi,
|
||||
"Don't fail tests just because delegation to NNAPI "
|
||||
@ -417,7 +416,12 @@ int main(int argc, char** argv) {
|
||||
bool success = tensorflow::Flags::Parse(&argc, argv, flags);
|
||||
if (!success || (argc == 2 && !strcmp(argv[1], "--helpfull"))) {
|
||||
fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str());
|
||||
return 1;
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
if (!tflite::testing::TfLiteDriver::InitTestDelegateProviders(
|
||||
&argc, const_cast<const char**>(argv))) {
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
::tflite::LogToStderr();
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/hashtable/hashtable_ops.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/kernels/register_ref.h"
|
||||
#include "tensorflow/lite/kernels/test_delegate_providers.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
#include "tensorflow/lite/testing/join.h"
|
||||
#include "tensorflow/lite/testing/split.h"
|
||||
@ -346,6 +347,12 @@ bool TfLiteDriver::DataExpectation::Check(bool verbose,
|
||||
}
|
||||
}
|
||||
|
||||
/* static */
|
||||
bool TfLiteDriver::InitTestDelegateProviders(int* argc, const char** argv) {
|
||||
return tflite::KernelTestDelegateProviders::Get()->InitFromCmdlineArgs(argc,
|
||||
argv);
|
||||
}
|
||||
|
||||
TfLiteDriver::TfLiteDriver(DelegateType delegate_type, bool reference_kernel)
|
||||
: delegate_(nullptr, nullptr),
|
||||
relative_threshold_(kRelativeThreshold),
|
||||
@ -414,6 +421,16 @@ void TfLiteDriver::LoadModel(const string& bin_file_path) {
|
||||
Invalidate("Unable to the build graph using the delegate");
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
auto* delegate_providers = tflite::KernelTestDelegateProviders::Get();
|
||||
for (auto& one : delegate_providers->CreateAllDelegates()) {
|
||||
if (interpreter_->ModifyGraphWithDelegate(std::move(one)) != kTfLiteOk) {
|
||||
Invalidate(
|
||||
"Unable to the build graph using the delegate initialized from "
|
||||
"tflite::KernelTestDelegateProviders");
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
must_allocate_tensors_ = true;
|
||||
|
@ -40,10 +40,15 @@ class TfLiteDriver : public TestRunner {
|
||||
kFlex,
|
||||
};
|
||||
|
||||
// Initialize the global test delegate providers from commandline arguments
|
||||
// and returns true if successful.
|
||||
static bool InitTestDelegateProviders(int* argc, const char** argv);
|
||||
|
||||
/**
|
||||
* Creates a new TfLiteDriver
|
||||
* @param delegate The (optional) delegate to use.
|
||||
* @param reference_kernel Whether to use the builtin reference kernel ops.
|
||||
* @param reference_kernel Whether to use the builtin reference kernel
|
||||
* ops.
|
||||
*/
|
||||
explicit TfLiteDriver(DelegateType delegate_type = DelegateType::kNone,
|
||||
bool reference_kernel = false);
|
||||
|
Loading…
Reference in New Issue
Block a user