Support more delegate options in kernel tests by utilizing TFLite tooling's delegate registrar.
PiperOrigin-RevId: 320527595 Change-Id: I9caa7b4cf0a961c4aef90fd73840069b8ed1a32e
This commit is contained in:
parent
6c292dc103
commit
eddb295677
@ -171,9 +171,9 @@ cc_library(
|
||||
deps = [
|
||||
":acceleration_test_util",
|
||||
":builtin_ops",
|
||||
":test_util_delegate_providers",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite:minimal_logging",
|
||||
"//tensorflow/lite:schema_fbs_version",
|
||||
"//tensorflow/lite:string",
|
||||
"//tensorflow/lite:string_util",
|
||||
@ -186,6 +186,10 @@ cc_library(
|
||||
"//tensorflow/lite/nnapi:nnapi_implementation",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"//tensorflow/lite/testing:util",
|
||||
"//tensorflow/lite/tools:command_line_flags",
|
||||
"//tensorflow/lite/tools:logging",
|
||||
"//tensorflow/lite/tools:tool_params",
|
||||
"//tensorflow/lite/tools/delegates:delegate_provider_hdr",
|
||||
"//tensorflow/lite/tools/optimize:quantization_utils",
|
||||
"//tensorflow/lite/tools/optimize/sparsity:format_converter",
|
||||
"//tensorflow/lite/tools/versioning",
|
||||
@ -194,6 +198,28 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
# A convenient library for tflite delegate execution providers.
|
||||
cc_library(
|
||||
name = "test_util_delegate_providers",
|
||||
copts = tflite_copts(),
|
||||
deps = [
|
||||
"//tensorflow/lite/tools/delegates:coreml_delegate_provider",
|
||||
"//tensorflow/lite/tools/delegates:default_execution_provider",
|
||||
"//tensorflow/lite/tools/delegates:external_delegate_provider",
|
||||
"//tensorflow/lite/tools/delegates:hexagon_delegate_provider",
|
||||
"//tensorflow/lite/tools/delegates:nnapi_delegate_provider",
|
||||
"//tensorflow/lite/tools/delegates:xnnpack_delegate_provider",
|
||||
] + select({
|
||||
# Metal GPU delegate for iOS has its own setups for kernel tests, so
|
||||
# skipping linking w/ the gpu_delegate_provider.
|
||||
"//tensorflow:ios": [],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/lite/tools/delegates:gpu_delegate_provider",
|
||||
],
|
||||
}),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# TODO(b/132204084): Create tflite_cc_test rule to automate test_main inclusion.
|
||||
cc_library(
|
||||
name = "test_main",
|
||||
|
@ -22,22 +22,20 @@ limitations under the License.
|
||||
namespace {
|
||||
|
||||
void InitKernelTest(int* argc, char** argv) {
|
||||
bool use_nnapi = false;
|
||||
std::vector<tflite::Flag> flags = {
|
||||
tflite::Flag::CreateFlag("use_nnapi", &use_nnapi, "Use NNAPI"),
|
||||
};
|
||||
tflite::Flags::Parse(argc, const_cast<const char**>(argv), flags);
|
||||
tflite::KernelTestDelegateProviders* const delegate_providers =
|
||||
tflite::KernelTestDelegateProviders::Get();
|
||||
delegate_providers->InitFromCmdlineArgs(argc, const_cast<const char**>(argv));
|
||||
|
||||
if (use_nnapi) {
|
||||
tflite::SingleOpModel::SetForceUseNnapi(true);
|
||||
}
|
||||
// TODO(b/160764491): remove the special handling of NNAPI delegate test.
|
||||
tflite::SingleOpModel::SetForceUseNnapi(
|
||||
delegate_providers->ConstParams().Get<bool>("use_nnapi"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
::tflite::LogToStderr();
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
InitKernelTest(&argc, argv);
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||
#include <gtest/gtest.h>
|
||||
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
#include "tensorflow/lite/core/subgraph.h"
|
||||
#include "tensorflow/lite/delegates/nnapi/acceleration_test_util.h"
|
||||
@ -39,12 +40,13 @@ limitations under the License.
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/acceleration_test_util.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/minimal_logging.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/nnapi/nnapi_implementation.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/string_type.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
#include "tensorflow/lite/tools/command_line_flags.h"
|
||||
#include "tensorflow/lite/tools/logging.h"
|
||||
#include "tensorflow/lite/tools/versioning/op_version.h"
|
||||
#include "tensorflow/lite/version.h"
|
||||
|
||||
@ -219,14 +221,27 @@ void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
||||
}
|
||||
|
||||
TfLiteStatus SingleOpModel::ApplyDelegate() {
|
||||
auto* delegate_providers = tflite::KernelTestDelegateProviders::Get();
|
||||
|
||||
if (force_use_nnapi) {
|
||||
delegate_ = TestNnApiDelegate();
|
||||
|
||||
// As we currently have special handling of nnapi delegate in kernel tests,
|
||||
// we turn off the nnapi delegate provider to avoid re-applying it later.
|
||||
// TODO(b/160764491): remove this special handling for NNAPI delegate test.
|
||||
delegate_providers->MutableParams()->Set<bool>("use_nnapi", false);
|
||||
}
|
||||
|
||||
if (delegate_) {
|
||||
TFLITE_LOG(WARN) << "Having a manually-set TfLite delegate, and bypassing "
|
||||
"KernelTestDelegateProviders";
|
||||
return interpreter_->ModifyGraphWithDelegate(delegate_);
|
||||
}
|
||||
|
||||
for (auto& one : delegate_providers->CreateAllDelegates()) {
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
interpreter_->ModifyGraphWithDelegate(std::move(one)));
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
@ -327,7 +342,7 @@ void SingleOpModel::ExpectOpAcceleratedWithNnapi(const std::string& test_id) {
|
||||
return;
|
||||
}
|
||||
|
||||
TFLITE_LOG_PROD(TFLITE_LOG_INFO, "Validating acceleration");
|
||||
TFLITE_LOG(INFO) << "Validating acceleration";
|
||||
const NnApi* nnapi = NnApiImplementation();
|
||||
if (nnapi && nnapi->nnapi_exists &&
|
||||
nnapi->android_sdk_version >=
|
||||
@ -379,4 +394,39 @@ void MultiOpModel::AddCustomOp(
|
||||
CustomOptionsFormat_FLEXBUFFERS));
|
||||
}
|
||||
|
||||
/*static*/ KernelTestDelegateProviders* KernelTestDelegateProviders::Get() {
|
||||
static KernelTestDelegateProviders* const providers =
|
||||
new KernelTestDelegateProviders();
|
||||
return providers;
|
||||
}
|
||||
|
||||
KernelTestDelegateProviders::KernelTestDelegateProviders() {
|
||||
for (const auto& one : tools::GetRegisteredDelegateProviders()) {
|
||||
params_.Merge(one->DefaultParams());
|
||||
}
|
||||
}
|
||||
|
||||
bool KernelTestDelegateProviders::InitFromCmdlineArgs(int* argc,
|
||||
const char** argv) {
|
||||
std::vector<tflite::Flag> flags;
|
||||
for (const auto& one : tools::GetRegisteredDelegateProviders()) {
|
||||
auto one_flags = one->CreateFlags(¶ms_);
|
||||
flags.insert(flags.end(), one_flags.begin(), one_flags.end());
|
||||
}
|
||||
return tflite::Flags::Parse(argc, argv, flags);
|
||||
}
|
||||
|
||||
std::vector<tools::TfLiteDelegatePtr>
|
||||
KernelTestDelegateProviders::CreateAllDelegates() const {
|
||||
std::vector<tools::TfLiteDelegatePtr> delegates;
|
||||
for (const auto& one : tools::GetRegisteredDelegateProviders()) {
|
||||
auto ptr = one->CreateTfLiteDelegate(params_);
|
||||
// It's possible that a delegate of certain type won't be created as
|
||||
// user-specified benchmark params tells not to.
|
||||
if (ptr == nullptr) continue;
|
||||
delegates.emplace_back(std::move(ptr));
|
||||
TFLITE_LOG(INFO) << one->GetName() << " delegate is created.";
|
||||
}
|
||||
return delegates;
|
||||
}
|
||||
} // namespace tflite
|
||||
|
@ -45,8 +45,10 @@ limitations under the License.
|
||||
#include "tensorflow/lite/string_type.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
#include "tensorflow/lite/testing/util.h" // IWYU pragma: keep
|
||||
#include "tensorflow/lite/tools/delegates/delegate_provider.h"
|
||||
#include "tensorflow/lite/tools/optimize/quantization_utils.h"
|
||||
#include "tensorflow/lite/tools/optimize/sparsity/format_converter.h"
|
||||
#include "tensorflow/lite/tools/tool_params.h"
|
||||
#include "tensorflow/lite/type_to_tflitetype.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -898,6 +900,36 @@ class MultiOpModel : public SingleOpModel {
|
||||
}
|
||||
};
|
||||
|
||||
// A utility class to provide TfLite delegate creations for kernel tests. The
|
||||
// options of a particular delegate could be specified from commandline flags by
|
||||
// using the delegate provider registrar as implemented in lite/tools/delegates
|
||||
// directory.
|
||||
class KernelTestDelegateProviders {
|
||||
public:
|
||||
// Returns a global KernelTestDelegateProviders instance.
|
||||
static KernelTestDelegateProviders* Get();
|
||||
|
||||
KernelTestDelegateProviders();
|
||||
|
||||
// Initialize delegate-related parameters from commandline arguments and
|
||||
// returns true if successful.
|
||||
bool InitFromCmdlineArgs(int* argc, const char** argv);
|
||||
|
||||
// This provides a way to overwrite parameter values programmatically before
|
||||
// creating TfLite delegates.
|
||||
tools::ToolParams* MutableParams() { return ¶ms_; }
|
||||
const tools::ToolParams& ConstParams() const { return params_; }
|
||||
|
||||
// Create a list of TfLite delegates based on what have been initialized (i.e.
|
||||
// 'params_').
|
||||
std::vector<tools::TfLiteDelegatePtr> CreateAllDelegates() const;
|
||||
|
||||
private:
|
||||
// Contain delegate-related parameters that are initialized from command-line
|
||||
// flags.
|
||||
tools::ToolParams params_;
|
||||
};
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_KERNELS_TEST_UTIL_H_
|
||||
|
@ -47,6 +47,28 @@ TEST(TestUtilTest, QuantizeVectorScalingUp) {
|
||||
EXPECT_THAT(q_data, ElementsAreArray(expected));
|
||||
}
|
||||
|
||||
TEST(KernelTestDelegateProvidersTest, DelegateProvidersParams) {
|
||||
KernelTestDelegateProviders providers;
|
||||
const auto& params = providers.ConstParams();
|
||||
EXPECT_TRUE(params.HasParam("use_xnnpack"));
|
||||
EXPECT_TRUE(params.HasParam("use_nnapi"));
|
||||
|
||||
int argc = 3;
|
||||
const char* argv[] = {"program_name", "--use_nnapi=true",
|
||||
"--other_undefined_flag=1"};
|
||||
EXPECT_TRUE(providers.InitFromCmdlineArgs(&argc, argv));
|
||||
EXPECT_TRUE(params.Get<bool>("use_nnapi"));
|
||||
EXPECT_EQ(2, argc);
|
||||
EXPECT_EQ("--other_undefined_flag=1", argv[1]);
|
||||
}
|
||||
|
||||
TEST(KernelTestDelegateProvidersTest, CreateTfLiteDelegates) {
|
||||
#if !defined(__Fuchsia__) && !defined(TFLITE_WITHOUT_XNNPACK)
|
||||
KernelTestDelegateProviders providers;
|
||||
providers.MutableParams()->Set<bool>("use_xnnpack", true);
|
||||
EXPECT_GE(providers.CreateAllDelegates().size(), 1);
|
||||
#endif
|
||||
}
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user