From eddb295677f8bad9b9e28d2b1ff8cd97d2846a3e Mon Sep 17 00:00:00 2001 From: Chao Mei Date: Thu, 9 Jul 2020 19:17:34 -0700 Subject: [PATCH] Support more delegate options in kernel tests by utilizing TFLite tooling's delegate registrar. PiperOrigin-RevId: 320527595 Change-Id: I9caa7b4cf0a961c4aef90fd73840069b8ed1a32e --- tensorflow/lite/kernels/BUILD | 28 +++++++++++- tensorflow/lite/kernels/test_main.cc | 16 +++---- tensorflow/lite/kernels/test_util.cc | 54 ++++++++++++++++++++++- tensorflow/lite/kernels/test_util.h | 32 ++++++++++++++ tensorflow/lite/kernels/test_util_test.cc | 22 +++++++++ 5 files changed, 140 insertions(+), 12 deletions(-) diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index cb00d73adac..898d153c527 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -171,9 +171,9 @@ cc_library( deps = [ ":acceleration_test_util", ":builtin_ops", + ":test_util_delegate_providers", "//tensorflow/core/platform:logging", "//tensorflow/lite:framework", - "//tensorflow/lite:minimal_logging", "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite:string", "//tensorflow/lite:string_util", @@ -186,6 +186,10 @@ cc_library( "//tensorflow/lite/nnapi:nnapi_implementation", "//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/testing:util", + "//tensorflow/lite/tools:command_line_flags", + "//tensorflow/lite/tools:logging", + "//tensorflow/lite/tools:tool_params", + "//tensorflow/lite/tools/delegates:delegate_provider_hdr", "//tensorflow/lite/tools/optimize:quantization_utils", "//tensorflow/lite/tools/optimize/sparsity:format_converter", "//tensorflow/lite/tools/versioning", @@ -194,6 +198,28 @@ cc_library( ], ) +# A convenient library for tflite delegate execution providers. +cc_library( + name = "test_util_delegate_providers", + copts = tflite_copts(), + deps = [ + "//tensorflow/lite/tools/delegates:coreml_delegate_provider", + "//tensorflow/lite/tools/delegates:default_execution_provider", + "//tensorflow/lite/tools/delegates:external_delegate_provider", + "//tensorflow/lite/tools/delegates:hexagon_delegate_provider", + "//tensorflow/lite/tools/delegates:nnapi_delegate_provider", + "//tensorflow/lite/tools/delegates:xnnpack_delegate_provider", + ] + select({ + # Metal GPU delegate for iOS has its own setups for kernel tests, so + # skipping linking w/ the gpu_delegate_provider. + "//tensorflow:ios": [], + "//conditions:default": [ + "//tensorflow/lite/tools/delegates:gpu_delegate_provider", + ], + }), + alwayslink = 1, +) + # TODO(b/132204084): Create tflite_cc_test rule to automate test_main inclusion. cc_library( name = "test_main", diff --git a/tensorflow/lite/kernels/test_main.cc b/tensorflow/lite/kernels/test_main.cc index d8e12297fea..a99109080fa 100644 --- a/tensorflow/lite/kernels/test_main.cc +++ b/tensorflow/lite/kernels/test_main.cc @@ -22,22 +22,20 @@ limitations under the License. namespace { void InitKernelTest(int* argc, char** argv) { - bool use_nnapi = false; - std::vector flags = { - tflite::Flag::CreateFlag("use_nnapi", &use_nnapi, "Use NNAPI"), - }; - tflite::Flags::Parse(argc, const_cast(argv), flags); + tflite::KernelTestDelegateProviders* const delegate_providers = + tflite::KernelTestDelegateProviders::Get(); + delegate_providers->InitFromCmdlineArgs(argc, const_cast(argv)); - if (use_nnapi) { - tflite::SingleOpModel::SetForceUseNnapi(true); - } + // TODO(b/160764491): remove the special handling of NNAPI delegate test. + tflite::SingleOpModel::SetForceUseNnapi( + delegate_providers->ConstParams().Get("use_nnapi")); } } // namespace int main(int argc, char** argv) { ::tflite::LogToStderr(); - ::testing::InitGoogleTest(&argc, argv); InitKernelTest(&argc, argv); + ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/lite/kernels/test_util.cc b/tensorflow/lite/kernels/test_util.cc index 24f6e4f11ca..d77d3367afe 100644 --- a/tensorflow/lite/kernels/test_util.cc +++ b/tensorflow/lite/kernels/test_util.cc @@ -32,6 +32,7 @@ limitations under the License. #include #include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/core/platform/logging.h" +#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/subgraph.h" #include "tensorflow/lite/delegates/nnapi/acceleration_test_util.h" @@ -39,12 +40,13 @@ limitations under the License. #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/acceleration_test_util.h" #include "tensorflow/lite/kernels/register.h" -#include "tensorflow/lite/minimal_logging.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/nnapi/nnapi_implementation.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/string_type.h" #include "tensorflow/lite/string_util.h" +#include "tensorflow/lite/tools/command_line_flags.h" +#include "tensorflow/lite/tools/logging.h" #include "tensorflow/lite/tools/versioning/op_version.h" #include "tensorflow/lite/version.h" @@ -219,14 +221,27 @@ void SingleOpModel::BuildInterpreter(std::vector> input_shapes, } TfLiteStatus SingleOpModel::ApplyDelegate() { + auto* delegate_providers = tflite::KernelTestDelegateProviders::Get(); + if (force_use_nnapi) { delegate_ = TestNnApiDelegate(); + + // As we currently have special handling of nnapi delegate in kernel tests, + // we turn off the nnapi delegate provider to avoid re-applying it later. + // TODO(b/160764491): remove this special handling for NNAPI delegate test. + delegate_providers->MutableParams()->Set("use_nnapi", false); } if (delegate_) { + TFLITE_LOG(WARN) << "Having a manually-set TfLite delegate, and bypassing " + "KernelTestDelegateProviders"; return interpreter_->ModifyGraphWithDelegate(delegate_); } + for (auto& one : delegate_providers->CreateAllDelegates()) { + TF_LITE_ENSURE_STATUS( + interpreter_->ModifyGraphWithDelegate(std::move(one))); + } return kTfLiteOk; } @@ -327,7 +342,7 @@ void SingleOpModel::ExpectOpAcceleratedWithNnapi(const std::string& test_id) { return; } - TFLITE_LOG_PROD(TFLITE_LOG_INFO, "Validating acceleration"); + TFLITE_LOG(INFO) << "Validating acceleration"; const NnApi* nnapi = NnApiImplementation(); if (nnapi && nnapi->nnapi_exists && nnapi->android_sdk_version >= @@ -379,4 +394,39 @@ void MultiOpModel::AddCustomOp( CustomOptionsFormat_FLEXBUFFERS)); } +/*static*/ KernelTestDelegateProviders* KernelTestDelegateProviders::Get() { + static KernelTestDelegateProviders* const providers = + new KernelTestDelegateProviders(); + return providers; +} + +KernelTestDelegateProviders::KernelTestDelegateProviders() { + for (const auto& one : tools::GetRegisteredDelegateProviders()) { + params_.Merge(one->DefaultParams()); + } +} + +bool KernelTestDelegateProviders::InitFromCmdlineArgs(int* argc, + const char** argv) { + std::vector flags; + for (const auto& one : tools::GetRegisteredDelegateProviders()) { + auto one_flags = one->CreateFlags(¶ms_); + flags.insert(flags.end(), one_flags.begin(), one_flags.end()); + } + return tflite::Flags::Parse(argc, argv, flags); +} + +std::vector +KernelTestDelegateProviders::CreateAllDelegates() const { + std::vector delegates; + for (const auto& one : tools::GetRegisteredDelegateProviders()) { + auto ptr = one->CreateTfLiteDelegate(params_); + // It's possible that a delegate of certain type won't be created as + // user-specified benchmark params tells not to. + if (ptr == nullptr) continue; + delegates.emplace_back(std::move(ptr)); + TFLITE_LOG(INFO) << one->GetName() << " delegate is created."; + } + return delegates; +} } // namespace tflite diff --git a/tensorflow/lite/kernels/test_util.h b/tensorflow/lite/kernels/test_util.h index bc93bdae58a..f58867a5120 100644 --- a/tensorflow/lite/kernels/test_util.h +++ b/tensorflow/lite/kernels/test_util.h @@ -45,8 +45,10 @@ limitations under the License. #include "tensorflow/lite/string_type.h" #include "tensorflow/lite/string_util.h" #include "tensorflow/lite/testing/util.h" // IWYU pragma: keep +#include "tensorflow/lite/tools/delegates/delegate_provider.h" #include "tensorflow/lite/tools/optimize/quantization_utils.h" #include "tensorflow/lite/tools/optimize/sparsity/format_converter.h" +#include "tensorflow/lite/tools/tool_params.h" #include "tensorflow/lite/type_to_tflitetype.h" namespace tflite { @@ -898,6 +900,36 @@ class MultiOpModel : public SingleOpModel { } }; +// A utility class to provide TfLite delegate creations for kernel tests. The +// options of a particular delegate could be specified from commandline flags by +// using the delegate provider registrar as implemented in lite/tools/delegates +// directory. +class KernelTestDelegateProviders { + public: + // Returns a global KernelTestDelegateProviders instance. + static KernelTestDelegateProviders* Get(); + + KernelTestDelegateProviders(); + + // Initialize delegate-related parameters from commandline arguments and + // returns true if successful. + bool InitFromCmdlineArgs(int* argc, const char** argv); + + // This provides a way to overwrite parameter values programmatically before + // creating TfLite delegates. + tools::ToolParams* MutableParams() { return ¶ms_; } + const tools::ToolParams& ConstParams() const { return params_; } + + // Create a list of TfLite delegates based on what have been initialized (i.e. + // 'params_'). + std::vector CreateAllDelegates() const; + + private: + // Contain delegate-related parameters that are initialized from command-line + // flags. + tools::ToolParams params_; +}; + } // namespace tflite #endif // TENSORFLOW_LITE_KERNELS_TEST_UTIL_H_ diff --git a/tensorflow/lite/kernels/test_util_test.cc b/tensorflow/lite/kernels/test_util_test.cc index e6f865f6cd6..1ac08631079 100644 --- a/tensorflow/lite/kernels/test_util_test.cc +++ b/tensorflow/lite/kernels/test_util_test.cc @@ -47,6 +47,28 @@ TEST(TestUtilTest, QuantizeVectorScalingUp) { EXPECT_THAT(q_data, ElementsAreArray(expected)); } +TEST(KernelTestDelegateProvidersTest, DelegateProvidersParams) { + KernelTestDelegateProviders providers; + const auto& params = providers.ConstParams(); + EXPECT_TRUE(params.HasParam("use_xnnpack")); + EXPECT_TRUE(params.HasParam("use_nnapi")); + + int argc = 3; + const char* argv[] = {"program_name", "--use_nnapi=true", + "--other_undefined_flag=1"}; + EXPECT_TRUE(providers.InitFromCmdlineArgs(&argc, argv)); + EXPECT_TRUE(params.Get("use_nnapi")); + EXPECT_EQ(2, argc); + EXPECT_EQ("--other_undefined_flag=1", argv[1]); +} + +TEST(KernelTestDelegateProvidersTest, CreateTfLiteDelegates) { +#if !defined(__Fuchsia__) && !defined(TFLITE_WITHOUT_XNNPACK) + KernelTestDelegateProviders providers; + providers.MutableParams()->Set("use_xnnpack", true); + EXPECT_GE(providers.CreateAllDelegates().size(), 1); +#endif +} } // namespace } // namespace tflite