Extract the KernelTestDelegateProviders into a separate library from test_util, and apply it in lite/testing/tflite_driver.
PiperOrigin-RevId: 323208799 Change-Id: I047f9aa54e32263c0b21aa673bea8cc7de751ba7
This commit is contained in:
parent
53ae4101be
commit
03300ba696
@ -171,7 +171,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":acceleration_test_util",
|
":acceleration_test_util",
|
||||||
":builtin_ops",
|
":builtin_ops",
|
||||||
":test_util_delegate_providers",
|
":test_delegate_providers_lib",
|
||||||
"//tensorflow/core/platform:logging",
|
"//tensorflow/core/platform:logging",
|
||||||
"//tensorflow/lite:framework",
|
"//tensorflow/lite:framework",
|
||||||
"//tensorflow/lite:schema_fbs_version",
|
"//tensorflow/lite:schema_fbs_version",
|
||||||
@ -189,7 +189,6 @@ cc_library(
|
|||||||
"//tensorflow/lite/tools:command_line_flags",
|
"//tensorflow/lite/tools:command_line_flags",
|
||||||
"//tensorflow/lite/tools:logging",
|
"//tensorflow/lite/tools:logging",
|
||||||
"//tensorflow/lite/tools:tool_params",
|
"//tensorflow/lite/tools:tool_params",
|
||||||
"//tensorflow/lite/tools/delegates:delegate_provider_hdr",
|
|
||||||
"//tensorflow/lite/tools/optimize:quantization_utils",
|
"//tensorflow/lite/tools/optimize:quantization_utils",
|
||||||
"//tensorflow/lite/tools/optimize/sparsity:format_converter",
|
"//tensorflow/lite/tools/optimize/sparsity:format_converter",
|
||||||
"//tensorflow/lite/tools/versioning",
|
"//tensorflow/lite/tools/versioning",
|
||||||
@ -198,7 +197,8 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# A convenient library for tflite delegate execution providers.
|
# A convenient library of tflite delegate execution providers for kernel tests
|
||||||
|
# based on SingleOpModel or its derivatives defined in test_util.h/cc.
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "test_util_delegate_providers",
|
name = "test_util_delegate_providers",
|
||||||
copts = tflite_copts(),
|
copts = tflite_copts(),
|
||||||
@ -220,13 +220,28 @@ cc_library(
|
|||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "test_delegate_providers_lib",
|
||||||
|
srcs = ["test_delegate_providers.cc"],
|
||||||
|
hdrs = ["test_delegate_providers.h"],
|
||||||
|
copts = tflite_copts(),
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/lite/tools:command_line_flags",
|
||||||
|
"//tensorflow/lite/tools:logging",
|
||||||
|
"//tensorflow/lite/tools:tool_params",
|
||||||
|
"//tensorflow/lite/tools/delegates:delegate_provider_hdr",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# TODO(b/132204084): Create tflite_cc_test rule to automate test_main inclusion.
|
# TODO(b/132204084): Create tflite_cc_test rule to automate test_main inclusion.
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "test_main",
|
name = "test_main",
|
||||||
testonly = 1,
|
testonly = 1,
|
||||||
srcs = ["test_main.cc"],
|
srcs = ["test_main.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":test_delegate_providers_lib",
|
||||||
":test_util",
|
":test_util",
|
||||||
|
":test_util_delegate_providers",
|
||||||
"//tensorflow/lite/testing:util",
|
"//tensorflow/lite/testing:util",
|
||||||
"//tensorflow/lite/tools:command_line_flags",
|
"//tensorflow/lite/tools:command_line_flags",
|
||||||
"@com_google_googletest//:gtest",
|
"@com_google_googletest//:gtest",
|
||||||
@ -456,6 +471,17 @@ cc_test(
|
|||||||
name = "test_util_test",
|
name = "test_util_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["test_util_test.cc"],
|
srcs = ["test_util_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":test_util",
|
||||||
|
"//tensorflow/lite/testing:util",
|
||||||
|
"@com_google_googletest//:gtest",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "test_delegate_providers_lib_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["test_delegate_providers_test.cc"],
|
||||||
# See details in https://github.com/bazelbuild/bazel/issues/11552 to avoid
|
# See details in https://github.com/bazelbuild/bazel/issues/11552 to avoid
|
||||||
# lazy symbol binding failure on macOS.
|
# lazy symbol binding failure on macOS.
|
||||||
linkstatic = select({
|
linkstatic = select({
|
||||||
@ -463,9 +489,11 @@ cc_test(
|
|||||||
"//conditions:default": False,
|
"//conditions:default": False,
|
||||||
}),
|
}),
|
||||||
deps = [
|
deps = [
|
||||||
":test_util",
|
":test_delegate_providers_lib",
|
||||||
"//tensorflow/lite/testing:util",
|
"//tensorflow/lite/tools/delegates:default_execution_provider",
|
||||||
"@com_google_googletest//:gtest",
|
"//tensorflow/lite/tools/delegates:nnapi_delegate_provider",
|
||||||
|
"//tensorflow/lite/tools/delegates:xnnpack_delegate_provider",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
57
tensorflow/lite/kernels/test_delegate_providers.cc
Normal file
57
tensorflow/lite/kernels/test_delegate_providers.cc
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/lite/kernels/test_delegate_providers.h"
|
||||||
|
|
||||||
|
#include "tensorflow/lite/tools/command_line_flags.h"
|
||||||
|
#include "tensorflow/lite/tools/logging.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
/*static*/ KernelTestDelegateProviders* KernelTestDelegateProviders::Get() {
|
||||||
|
static KernelTestDelegateProviders* const providers =
|
||||||
|
new KernelTestDelegateProviders();
|
||||||
|
return providers;
|
||||||
|
}
|
||||||
|
|
||||||
|
KernelTestDelegateProviders::KernelTestDelegateProviders() {
|
||||||
|
for (const auto& one : tools::GetRegisteredDelegateProviders()) {
|
||||||
|
params_.Merge(one->DefaultParams());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool KernelTestDelegateProviders::InitFromCmdlineArgs(int* argc,
|
||||||
|
const char** argv) {
|
||||||
|
std::vector<tflite::Flag> flags;
|
||||||
|
for (const auto& one : tools::GetRegisteredDelegateProviders()) {
|
||||||
|
auto one_flags = one->CreateFlags(¶ms_);
|
||||||
|
flags.insert(flags.end(), one_flags.begin(), one_flags.end());
|
||||||
|
}
|
||||||
|
return tflite::Flags::Parse(argc, argv, flags);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<tools::TfLiteDelegatePtr>
|
||||||
|
KernelTestDelegateProviders::CreateAllDelegates(
|
||||||
|
const tools::ToolParams& params) const {
|
||||||
|
std::vector<tools::TfLiteDelegatePtr> delegates;
|
||||||
|
for (const auto& one : tools::GetRegisteredDelegateProviders()) {
|
||||||
|
auto ptr = one->CreateTfLiteDelegate(params);
|
||||||
|
// It's possible that a delegate of certain type won't be created as
|
||||||
|
// user-specified benchmark params tells not to.
|
||||||
|
if (ptr == nullptr) continue;
|
||||||
|
delegates.emplace_back(std::move(ptr));
|
||||||
|
TFLITE_LOG(INFO) << one->GetName() << " delegate is created.";
|
||||||
|
}
|
||||||
|
return delegates;
|
||||||
|
}
|
||||||
|
} // namespace tflite
|
71
tensorflow/lite/kernels/test_delegate_providers.h
Normal file
71
tensorflow/lite/kernels/test_delegate_providers.h
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_KERNELS_TEST_DELEGATE_PROVIDERS_H_
|
||||||
|
#define TENSORFLOW_LITE_KERNELS_TEST_DELEGATE_PROVIDERS_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/tools/delegates/delegate_provider.h"
|
||||||
|
#include "tensorflow/lite/tools/tool_params.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
// A utility class to provide TfLite delegate creations for kernel tests. The
|
||||||
|
// options of a particular delegate could be specified from commandline flags by
|
||||||
|
// using the delegate provider registrar as implemented in lite/tools/delegates
|
||||||
|
// directory.
|
||||||
|
class KernelTestDelegateProviders {
|
||||||
|
public:
|
||||||
|
// Returns a global KernelTestDelegateProviders instance.
|
||||||
|
static KernelTestDelegateProviders* Get();
|
||||||
|
|
||||||
|
KernelTestDelegateProviders();
|
||||||
|
|
||||||
|
// Initialize delegate-related parameters from commandline arguments and
|
||||||
|
// returns true if successful.
|
||||||
|
bool InitFromCmdlineArgs(int* argc, const char** argv);
|
||||||
|
|
||||||
|
// This provides a way to overwrite parameter values programmatically before
|
||||||
|
// creating TfLite delegates. Note, changes to the returned ToolParams will
|
||||||
|
// have a global impact on creating TfLite delegates.
|
||||||
|
// If a local-only change is preferred, recommend using the following workflow
|
||||||
|
// create TfLite delegates via delegate providers:
|
||||||
|
// tools::ToolParams local_params;
|
||||||
|
// local_params.Merge(KernelTestDelegateProviders::Get()->ConstParams());
|
||||||
|
// Overwrite params in local_params by calling local_params.Set<...>(...);
|
||||||
|
// Get TfLite delegates via
|
||||||
|
// KernelTestDelegateProviders::Get()->CreateAllDelegates(local_params);
|
||||||
|
tools::ToolParams* MutableParams() { return ¶ms_; }
|
||||||
|
const tools::ToolParams& ConstParams() const { return params_; }
|
||||||
|
|
||||||
|
// Create a list of TfLite delegates based on the provided parameters
|
||||||
|
// `params`.
|
||||||
|
std::vector<tools::TfLiteDelegatePtr> CreateAllDelegates(
|
||||||
|
const tools::ToolParams& params) const;
|
||||||
|
|
||||||
|
// Similar to the above, but creating a list of TfLite delegates based on what
|
||||||
|
// have been initialized (i.e. 'params_').
|
||||||
|
std::vector<tools::TfLiteDelegatePtr> CreateAllDelegates() const {
|
||||||
|
return CreateAllDelegates(params_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Contain delegate-related parameters that are initialized from command-line
|
||||||
|
// flags.
|
||||||
|
tools::ToolParams params_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_KERNELS_TEST_DELEGATE_PROVIDERS_H_
|
50
tensorflow/lite/kernels/test_delegate_providers_test.cc
Normal file
50
tensorflow/lite/kernels/test_delegate_providers_test.cc
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/lite/kernels/test_delegate_providers.h"
|
||||||
|
|
||||||
|
#include <gmock/gmock.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace {
|
||||||
|
TEST(KernelTestDelegateProvidersTest, DelegateProvidersParams) {
|
||||||
|
KernelTestDelegateProviders providers;
|
||||||
|
const auto& params = providers.ConstParams();
|
||||||
|
EXPECT_TRUE(params.HasParam("use_xnnpack"));
|
||||||
|
EXPECT_TRUE(params.HasParam("use_nnapi"));
|
||||||
|
|
||||||
|
int argc = 3;
|
||||||
|
const char* argv[] = {"program_name", "--use_nnapi=true",
|
||||||
|
"--other_undefined_flag=1"};
|
||||||
|
EXPECT_TRUE(providers.InitFromCmdlineArgs(&argc, argv));
|
||||||
|
EXPECT_TRUE(params.Get<bool>("use_nnapi"));
|
||||||
|
EXPECT_EQ(2, argc);
|
||||||
|
EXPECT_EQ("--other_undefined_flag=1", argv[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(KernelTestDelegateProvidersTest, CreateTfLiteDelegates) {
|
||||||
|
#if !defined(__Fuchsia__) && !defined(TFLITE_WITHOUT_XNNPACK)
|
||||||
|
KernelTestDelegateProviders providers;
|
||||||
|
providers.MutableParams()->Set<bool>("use_xnnpack", true);
|
||||||
|
EXPECT_GE(providers.CreateAllDelegates().size(), 1);
|
||||||
|
|
||||||
|
tools::ToolParams local_params;
|
||||||
|
local_params.Merge(providers.ConstParams());
|
||||||
|
local_params.Set<bool>("use_xnnpack", false);
|
||||||
|
EXPECT_TRUE(providers.CreateAllDelegates(local_params).empty());
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
} // namespace tflite
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
#include "tensorflow/lite/kernels/test_delegate_providers.h"
|
||||||
#include "tensorflow/lite/kernels/test_util.h"
|
#include "tensorflow/lite/kernels/test_util.h"
|
||||||
#include "tensorflow/lite/testing/util.h"
|
#include "tensorflow/lite/testing/util.h"
|
||||||
#include "tensorflow/lite/tools/command_line_flags.h"
|
#include "tensorflow/lite/tools/command_line_flags.h"
|
||||||
|
@ -40,12 +40,12 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/interpreter.h"
|
#include "tensorflow/lite/interpreter.h"
|
||||||
#include "tensorflow/lite/kernels/acceleration_test_util.h"
|
#include "tensorflow/lite/kernels/acceleration_test_util.h"
|
||||||
#include "tensorflow/lite/kernels/register.h"
|
#include "tensorflow/lite/kernels/register.h"
|
||||||
|
#include "tensorflow/lite/kernels/test_delegate_providers.h"
|
||||||
#include "tensorflow/lite/model.h"
|
#include "tensorflow/lite/model.h"
|
||||||
#include "tensorflow/lite/nnapi/nnapi_implementation.h"
|
#include "tensorflow/lite/nnapi/nnapi_implementation.h"
|
||||||
#include "tensorflow/lite/schema/schema_generated.h"
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
#include "tensorflow/lite/string_type.h"
|
#include "tensorflow/lite/string_type.h"
|
||||||
#include "tensorflow/lite/string_util.h"
|
#include "tensorflow/lite/string_util.h"
|
||||||
#include "tensorflow/lite/tools/command_line_flags.h"
|
|
||||||
#include "tensorflow/lite/tools/logging.h"
|
#include "tensorflow/lite/tools/logging.h"
|
||||||
#include "tensorflow/lite/tools/versioning/op_version.h"
|
#include "tensorflow/lite/tools/versioning/op_version.h"
|
||||||
#include "tensorflow/lite/version.h"
|
#include "tensorflow/lite/version.h"
|
||||||
@ -234,8 +234,12 @@ void SingleOpModel::BuildInterpreter(
|
|||||||
|
|
||||||
// static
|
// static
|
||||||
bool SingleOpModel::GetForceUseNnapi() {
|
bool SingleOpModel::GetForceUseNnapi() {
|
||||||
return tflite::KernelTestDelegateProviders::Get()->ConstParams().Get<bool>(
|
const auto& delegate_params =
|
||||||
"use_nnapi");
|
tflite::KernelTestDelegateProviders::Get()->ConstParams();
|
||||||
|
// It's possible this library isn't linked with the nnapi delegate provider
|
||||||
|
// lib.
|
||||||
|
return delegate_params.HasParam("use_nnapi") &&
|
||||||
|
delegate_params.Get<bool>("use_nnapi");
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t SingleOpModel::GetTensorSize(int index) const {
|
int32_t SingleOpModel::GetTensorSize(int index) const {
|
||||||
@ -374,41 +378,4 @@ void MultiOpModel::AddCustomOp(
|
|||||||
builder_.CreateVector<uint8_t>(custom_option),
|
builder_.CreateVector<uint8_t>(custom_option),
|
||||||
CustomOptionsFormat_FLEXBUFFERS));
|
CustomOptionsFormat_FLEXBUFFERS));
|
||||||
}
|
}
|
||||||
|
|
||||||
/*static*/ KernelTestDelegateProviders* KernelTestDelegateProviders::Get() {
|
|
||||||
static KernelTestDelegateProviders* const providers =
|
|
||||||
new KernelTestDelegateProviders();
|
|
||||||
return providers;
|
|
||||||
}
|
|
||||||
|
|
||||||
KernelTestDelegateProviders::KernelTestDelegateProviders() {
|
|
||||||
for (const auto& one : tools::GetRegisteredDelegateProviders()) {
|
|
||||||
params_.Merge(one->DefaultParams());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool KernelTestDelegateProviders::InitFromCmdlineArgs(int* argc,
|
|
||||||
const char** argv) {
|
|
||||||
std::vector<tflite::Flag> flags;
|
|
||||||
for (const auto& one : tools::GetRegisteredDelegateProviders()) {
|
|
||||||
auto one_flags = one->CreateFlags(¶ms_);
|
|
||||||
flags.insert(flags.end(), one_flags.begin(), one_flags.end());
|
|
||||||
}
|
|
||||||
return tflite::Flags::Parse(argc, argv, flags);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<tools::TfLiteDelegatePtr>
|
|
||||||
KernelTestDelegateProviders::CreateAllDelegates(
|
|
||||||
const tools::ToolParams& params) const {
|
|
||||||
std::vector<tools::TfLiteDelegatePtr> delegates;
|
|
||||||
for (const auto& one : tools::GetRegisteredDelegateProviders()) {
|
|
||||||
auto ptr = one->CreateTfLiteDelegate(params);
|
|
||||||
// It's possible that a delegate of certain type won't be created as
|
|
||||||
// user-specified benchmark params tells not to.
|
|
||||||
if (ptr == nullptr) continue;
|
|
||||||
delegates.emplace_back(std::move(ptr));
|
|
||||||
TFLITE_LOG(INFO) << one->GetName() << " delegate is created.";
|
|
||||||
}
|
|
||||||
return delegates;
|
|
||||||
}
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -45,10 +45,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/string_type.h"
|
#include "tensorflow/lite/string_type.h"
|
||||||
#include "tensorflow/lite/string_util.h"
|
#include "tensorflow/lite/string_util.h"
|
||||||
#include "tensorflow/lite/testing/util.h" // IWYU pragma: keep
|
#include "tensorflow/lite/testing/util.h" // IWYU pragma: keep
|
||||||
#include "tensorflow/lite/tools/delegates/delegate_provider.h"
|
|
||||||
#include "tensorflow/lite/tools/optimize/quantization_utils.h"
|
#include "tensorflow/lite/tools/optimize/quantization_utils.h"
|
||||||
#include "tensorflow/lite/tools/optimize/sparsity/format_converter.h"
|
#include "tensorflow/lite/tools/optimize/sparsity/format_converter.h"
|
||||||
#include "tensorflow/lite/tools/tool_params.h"
|
|
||||||
#include "tensorflow/lite/type_to_tflitetype.h"
|
#include "tensorflow/lite/type_to_tflitetype.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
@ -899,52 +897,6 @@ class MultiOpModel : public SingleOpModel {
|
|||||||
return AddTensor<T>(t, {}, false);
|
return AddTensor<T>(t, {}, false);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// A utility class to provide TfLite delegate creations for kernel tests. The
|
|
||||||
// options of a particular delegate could be specified from commandline flags by
|
|
||||||
// using the delegate provider registrar as implemented in lite/tools/delegates
|
|
||||||
// directory.
|
|
||||||
class KernelTestDelegateProviders {
|
|
||||||
public:
|
|
||||||
// Returns a global KernelTestDelegateProviders instance.
|
|
||||||
static KernelTestDelegateProviders* Get();
|
|
||||||
|
|
||||||
KernelTestDelegateProviders();
|
|
||||||
|
|
||||||
// Initialize delegate-related parameters from commandline arguments and
|
|
||||||
// returns true if successful.
|
|
||||||
bool InitFromCmdlineArgs(int* argc, const char** argv);
|
|
||||||
|
|
||||||
// This provides a way to overwrite parameter values programmatically before
|
|
||||||
// creating TfLite delegates. Note, changes to the returned ToolParams will
|
|
||||||
// have a global impact on creating TfLite delegates.
|
|
||||||
// If a local-only change is preferred, recommend using the following workflow
|
|
||||||
// create TfLite delegates via delegate providers:
|
|
||||||
// tools::ToolParams local_params;
|
|
||||||
// local_params.Merge(KernelTestDelegateProviders::Get()->ConstParams());
|
|
||||||
// Overwrite params in local_params by calling local_params.Set<...>(...);
|
|
||||||
// Get TfLite delegates via
|
|
||||||
// KernelTestDelegateProviders::Get()->CreateAllDelegates(local_params);
|
|
||||||
tools::ToolParams* MutableParams() { return ¶ms_; }
|
|
||||||
const tools::ToolParams& ConstParams() const { return params_; }
|
|
||||||
|
|
||||||
// Create a list of TfLite delegates based on the provided parameters
|
|
||||||
// `params`.
|
|
||||||
std::vector<tools::TfLiteDelegatePtr> CreateAllDelegates(
|
|
||||||
const tools::ToolParams& params) const;
|
|
||||||
|
|
||||||
// Similar to the above, but creating a list of TfLite delegates based on what
|
|
||||||
// have been initialized (i.e. 'params_').
|
|
||||||
std::vector<tools::TfLiteDelegatePtr> CreateAllDelegates() const {
|
|
||||||
return CreateAllDelegates(params_);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
// Contain delegate-related parameters that are initialized from command-line
|
|
||||||
// flags.
|
|
||||||
tools::ToolParams params_;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_KERNELS_TEST_UTIL_H_
|
#endif // TENSORFLOW_LITE_KERNELS_TEST_UTIL_H_
|
||||||
|
@ -47,33 +47,6 @@ TEST(TestUtilTest, QuantizeVectorScalingUp) {
|
|||||||
EXPECT_THAT(q_data, ElementsAreArray(expected));
|
EXPECT_THAT(q_data, ElementsAreArray(expected));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(KernelTestDelegateProvidersTest, DelegateProvidersParams) {
|
|
||||||
KernelTestDelegateProviders providers;
|
|
||||||
const auto& params = providers.ConstParams();
|
|
||||||
EXPECT_TRUE(params.HasParam("use_xnnpack"));
|
|
||||||
EXPECT_TRUE(params.HasParam("use_nnapi"));
|
|
||||||
|
|
||||||
int argc = 3;
|
|
||||||
const char* argv[] = {"program_name", "--use_nnapi=true",
|
|
||||||
"--other_undefined_flag=1"};
|
|
||||||
EXPECT_TRUE(providers.InitFromCmdlineArgs(&argc, argv));
|
|
||||||
EXPECT_TRUE(params.Get<bool>("use_nnapi"));
|
|
||||||
EXPECT_EQ(2, argc);
|
|
||||||
EXPECT_EQ("--other_undefined_flag=1", argv[1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(KernelTestDelegateProvidersTest, CreateTfLiteDelegates) {
|
|
||||||
#if !defined(__Fuchsia__) && !defined(TFLITE_WITHOUT_XNNPACK)
|
|
||||||
KernelTestDelegateProviders providers;
|
|
||||||
providers.MutableParams()->Set<bool>("use_xnnpack", true);
|
|
||||||
EXPECT_GE(providers.CreateAllDelegates().size(), 1);
|
|
||||||
|
|
||||||
tools::ToolParams local_params;
|
|
||||||
local_params.Merge(providers.ConstParams());
|
|
||||||
local_params.Set<bool>("use_xnnpack", false);
|
|
||||||
EXPECT_TRUE(providers.CreateAllDelegates(local_params).empty());
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
|
@ -59,12 +59,14 @@ exports_files([
|
|||||||
deps = [
|
deps = [
|
||||||
":parse_testdata_lib",
|
":parse_testdata_lib",
|
||||||
":tflite_driver",
|
":tflite_driver",
|
||||||
|
":tflite_driver_delegate_providers",
|
||||||
":util",
|
":util",
|
||||||
"@com_google_googletest//:gtest",
|
"@com_google_googletest//:gtest",
|
||||||
"@com_googlesource_code_re2//:re2",
|
"@com_googlesource_code_re2//:re2",
|
||||||
"//tensorflow/lite:builtin_op_data",
|
"//tensorflow/lite:builtin_op_data",
|
||||||
"//tensorflow/lite:framework",
|
"//tensorflow/lite:framework",
|
||||||
"//tensorflow/lite/kernels:builtin_ops",
|
"//tensorflow/lite/kernels:builtin_ops",
|
||||||
|
"//tensorflow/lite/kernels:test_delegate_providers_lib",
|
||||||
] + select({
|
] + select({
|
||||||
"//conditions:default": [
|
"//conditions:default": [
|
||||||
"//tensorflow/core:framework_internal",
|
"//tensorflow/core:framework_internal",
|
||||||
@ -229,8 +231,9 @@ cc_library(
|
|||||||
"//tensorflow/lite:string_util",
|
"//tensorflow/lite:string_util",
|
||||||
"//tensorflow/lite/kernels:builtin_ops",
|
"//tensorflow/lite/kernels:builtin_ops",
|
||||||
"//tensorflow/lite/kernels:custom_ops",
|
"//tensorflow/lite/kernels:custom_ops",
|
||||||
"//tensorflow/lite/kernels/hashtable:hashtable_op_kernels",
|
|
||||||
"//tensorflow/lite/kernels:reference_ops",
|
"//tensorflow/lite/kernels:reference_ops",
|
||||||
|
"//tensorflow/lite/kernels:test_delegate_providers_lib",
|
||||||
|
"//tensorflow/lite/kernels/hashtable:hashtable_op_kernels",
|
||||||
"//tensorflow/lite/tools/evaluation:utils",
|
"//tensorflow/lite/tools/evaluation:utils",
|
||||||
] + select({
|
] + select({
|
||||||
"//tensorflow:ios": [],
|
"//tensorflow:ios": [],
|
||||||
@ -238,6 +241,22 @@ cc_library(
|
|||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# A convenient library of tflite delegate execution providers for tests based
|
||||||
|
# on the `tflite_driver` library.
|
||||||
|
cc_library(
|
||||||
|
name = "tflite_driver_delegate_providers",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/lite/tools/delegates:coreml_delegate_provider",
|
||||||
|
"//tensorflow/lite/tools/delegates:default_execution_provider",
|
||||||
|
"//tensorflow/lite/tools/delegates:external_delegate_provider",
|
||||||
|
"//tensorflow/lite/tools/delegates:gpu_delegate_provider",
|
||||||
|
"//tensorflow/lite/tools/delegates:hexagon_delegate_provider",
|
||||||
|
"//tensorflow/lite/tools/delegates:nnapi_delegate_provider",
|
||||||
|
"//tensorflow/lite/tools/delegates:xnnpack_delegate_provider",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
tf_cc_test(
|
tf_cc_test(
|
||||||
name = "tflite_driver_test",
|
name = "tflite_driver_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/status.h"
|
#include "tensorflow/core/platform/status.h"
|
||||||
#include "tensorflow/core/platform/subprocess.h"
|
#include "tensorflow/core/platform/subprocess.h"
|
||||||
#include "tensorflow/core/util/command_line_flags.h"
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
|
#include "tensorflow/lite/kernels/test_delegate_providers.h"
|
||||||
#include "tensorflow/lite/testing/parse_testdata.h"
|
#include "tensorflow/lite/testing/parse_testdata.h"
|
||||||
#include "tensorflow/lite/testing/tflite_driver.h"
|
#include "tensorflow/lite/testing/tflite_driver.h"
|
||||||
#include "tensorflow/lite/testing/util.h"
|
#include "tensorflow/lite/testing/util.h"
|
||||||
@ -47,7 +48,6 @@ string* FLAGS_tar_binary_path = new string("/bin/tar");
|
|||||||
string* FLAGS_unzip_binary_path = new string("/system/bin/unzip");
|
string* FLAGS_unzip_binary_path = new string("/system/bin/unzip");
|
||||||
string* FLAGS_tar_binary_path = new string("/system/bin/tar");
|
string* FLAGS_tar_binary_path = new string("/system/bin/tar");
|
||||||
#endif
|
#endif
|
||||||
bool FLAGS_use_nnapi = false;
|
|
||||||
bool FLAGS_ignore_unsupported_nnapi = false;
|
bool FLAGS_ignore_unsupported_nnapi = false;
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
@ -298,9 +298,10 @@ TEST_P(OpsTest, RunZipTests) {
|
|||||||
|
|
||||||
std::ifstream tflite_stream(tflite_test_case);
|
std::ifstream tflite_stream(tflite_test_case);
|
||||||
ASSERT_TRUE(tflite_stream.is_open()) << tflite_test_case;
|
ASSERT_TRUE(tflite_stream.is_open()) << tflite_test_case;
|
||||||
tflite::testing::TfLiteDriver test_driver(
|
tflite::testing::TfLiteDriver test_driver;
|
||||||
FLAGS_use_nnapi ? TfLiteDriver::DelegateType::kNnapi
|
const bool use_nnapi =
|
||||||
: TfLiteDriver::DelegateType::kNone);
|
tflite::KernelTestDelegateProviders::Get()->ConstParams().Get<bool>(
|
||||||
|
"use_nnapi");
|
||||||
|
|
||||||
auto quantized_tests_error = GetQuantizeTestsError();
|
auto quantized_tests_error = GetQuantizeTestsError();
|
||||||
bool fully_quantize = false;
|
bool fully_quantize = false;
|
||||||
@ -317,7 +318,7 @@ TEST_P(OpsTest, RunZipTests) {
|
|||||||
test_driver.SetModelBaseDir(tflite_dir);
|
test_driver.SetModelBaseDir(tflite_dir);
|
||||||
|
|
||||||
auto broken_tests = GetKnownBrokenTests();
|
auto broken_tests = GetKnownBrokenTests();
|
||||||
if (FLAGS_use_nnapi) {
|
if (use_nnapi) {
|
||||||
auto kBrokenNnapiTests = GetKnownBrokenNnapiTests();
|
auto kBrokenNnapiTests = GetKnownBrokenNnapiTests();
|
||||||
broken_tests.insert(kBrokenNnapiTests.begin(), kBrokenNnapiTests.end());
|
broken_tests.insert(kBrokenNnapiTests.begin(), kBrokenNnapiTests.end());
|
||||||
}
|
}
|
||||||
@ -334,7 +335,7 @@ TEST_P(OpsTest, RunZipTests) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (bug_number.empty()) {
|
if (bug_number.empty()) {
|
||||||
if (FLAGS_use_nnapi && FLAGS_ignore_unsupported_nnapi && !result) {
|
if (use_nnapi && FLAGS_ignore_unsupported_nnapi && !result) {
|
||||||
EXPECT_EQ(message, string("Failed to invoke interpreter")) << message;
|
EXPECT_EQ(message, string("Failed to invoke interpreter")) << message;
|
||||||
} else {
|
} else {
|
||||||
EXPECT_TRUE(result) << message;
|
EXPECT_TRUE(result) << message;
|
||||||
@ -408,8 +409,6 @@ int main(int argc, char** argv) {
|
|||||||
tensorflow::Flag("tar_binary_path",
|
tensorflow::Flag("tar_binary_path",
|
||||||
tflite::testing::FLAGS_tar_binary_path,
|
tflite::testing::FLAGS_tar_binary_path,
|
||||||
"Location of a suitable tar binary."),
|
"Location of a suitable tar binary."),
|
||||||
tensorflow::Flag("use_nnapi", &tflite::testing::FLAGS_use_nnapi,
|
|
||||||
"Whether to enable the NNAPI delegate"),
|
|
||||||
tensorflow::Flag("ignore_unsupported_nnapi",
|
tensorflow::Flag("ignore_unsupported_nnapi",
|
||||||
&tflite::testing::FLAGS_ignore_unsupported_nnapi,
|
&tflite::testing::FLAGS_ignore_unsupported_nnapi,
|
||||||
"Don't fail tests just because delegation to NNAPI "
|
"Don't fail tests just because delegation to NNAPI "
|
||||||
@ -417,7 +416,12 @@ int main(int argc, char** argv) {
|
|||||||
bool success = tensorflow::Flags::Parse(&argc, argv, flags);
|
bool success = tensorflow::Flags::Parse(&argc, argv, flags);
|
||||||
if (!success || (argc == 2 && !strcmp(argv[1], "--helpfull"))) {
|
if (!success || (argc == 2 && !strcmp(argv[1], "--helpfull"))) {
|
||||||
fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str());
|
fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str());
|
||||||
return 1;
|
return EXIT_FAILURE;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!tflite::testing::TfLiteDriver::InitTestDelegateProviders(
|
||||||
|
&argc, const_cast<const char**>(argv))) {
|
||||||
|
return EXIT_FAILURE;
|
||||||
}
|
}
|
||||||
|
|
||||||
::tflite::LogToStderr();
|
::tflite::LogToStderr();
|
||||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/kernels/hashtable/hashtable_ops.h"
|
#include "tensorflow/lite/kernels/hashtable/hashtable_ops.h"
|
||||||
#include "tensorflow/lite/kernels/register.h"
|
#include "tensorflow/lite/kernels/register.h"
|
||||||
#include "tensorflow/lite/kernels/register_ref.h"
|
#include "tensorflow/lite/kernels/register_ref.h"
|
||||||
|
#include "tensorflow/lite/kernels/test_delegate_providers.h"
|
||||||
#include "tensorflow/lite/string_util.h"
|
#include "tensorflow/lite/string_util.h"
|
||||||
#include "tensorflow/lite/testing/join.h"
|
#include "tensorflow/lite/testing/join.h"
|
||||||
#include "tensorflow/lite/testing/split.h"
|
#include "tensorflow/lite/testing/split.h"
|
||||||
@ -346,6 +347,12 @@ bool TfLiteDriver::DataExpectation::Check(bool verbose,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* static */
|
||||||
|
bool TfLiteDriver::InitTestDelegateProviders(int* argc, const char** argv) {
|
||||||
|
return tflite::KernelTestDelegateProviders::Get()->InitFromCmdlineArgs(argc,
|
||||||
|
argv);
|
||||||
|
}
|
||||||
|
|
||||||
TfLiteDriver::TfLiteDriver(DelegateType delegate_type, bool reference_kernel)
|
TfLiteDriver::TfLiteDriver(DelegateType delegate_type, bool reference_kernel)
|
||||||
: delegate_(nullptr, nullptr),
|
: delegate_(nullptr, nullptr),
|
||||||
relative_threshold_(kRelativeThreshold),
|
relative_threshold_(kRelativeThreshold),
|
||||||
@ -414,6 +421,16 @@ void TfLiteDriver::LoadModel(const string& bin_file_path) {
|
|||||||
Invalidate("Unable to the build graph using the delegate");
|
Invalidate("Unable to the build graph using the delegate");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
auto* delegate_providers = tflite::KernelTestDelegateProviders::Get();
|
||||||
|
for (auto& one : delegate_providers->CreateAllDelegates()) {
|
||||||
|
if (interpreter_->ModifyGraphWithDelegate(std::move(one)) != kTfLiteOk) {
|
||||||
|
Invalidate(
|
||||||
|
"Unable to the build graph using the delegate initialized from "
|
||||||
|
"tflite::KernelTestDelegateProviders");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
must_allocate_tensors_ = true;
|
must_allocate_tensors_ = true;
|
||||||
|
@ -40,10 +40,15 @@ class TfLiteDriver : public TestRunner {
|
|||||||
kFlex,
|
kFlex,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Initialize the global test delegate providers from commandline arguments
|
||||||
|
// and returns true if successful.
|
||||||
|
static bool InitTestDelegateProviders(int* argc, const char** argv);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a new TfLiteDriver
|
* Creates a new TfLiteDriver
|
||||||
* @param delegate The (optional) delegate to use.
|
* @param delegate The (optional) delegate to use.
|
||||||
* @param reference_kernel Whether to use the builtin reference kernel ops.
|
* @param reference_kernel Whether to use the builtin reference kernel
|
||||||
|
* ops.
|
||||||
*/
|
*/
|
||||||
explicit TfLiteDriver(DelegateType delegate_type = DelegateType::kNone,
|
explicit TfLiteDriver(DelegateType delegate_type = DelegateType::kNone,
|
||||||
bool reference_kernel = false);
|
bool reference_kernel = false);
|
||||||
|
Loading…
Reference in New Issue
Block a user