Support more delegate options in kernel tests by utilizing TFLite tooling's delegate registrar.

PiperOrigin-RevId: 320527595
Change-Id: I9caa7b4cf0a961c4aef90fd73840069b8ed1a32e
This commit is contained in:
Chao Mei 2020-07-09 19:17:34 -07:00 committed by TensorFlower Gardener
parent 6c292dc103
commit eddb295677
5 changed files with 140 additions and 12 deletions

View File

@ -171,9 +171,9 @@ cc_library(
deps = [
":acceleration_test_util",
":builtin_ops",
":test_util_delegate_providers",
"//tensorflow/core/platform:logging",
"//tensorflow/lite:framework",
"//tensorflow/lite:minimal_logging",
"//tensorflow/lite:schema_fbs_version",
"//tensorflow/lite:string",
"//tensorflow/lite:string_util",
@ -186,6 +186,10 @@ cc_library(
"//tensorflow/lite/nnapi:nnapi_implementation",
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/testing:util",
"//tensorflow/lite/tools:command_line_flags",
"//tensorflow/lite/tools:logging",
"//tensorflow/lite/tools:tool_params",
"//tensorflow/lite/tools/delegates:delegate_provider_hdr",
"//tensorflow/lite/tools/optimize:quantization_utils",
"//tensorflow/lite/tools/optimize/sparsity:format_converter",
"//tensorflow/lite/tools/versioning",
@ -194,6 +198,28 @@ cc_library(
],
)
# A convenient library for tflite delegate execution providers.
cc_library(
name = "test_util_delegate_providers",
copts = tflite_copts(),
deps = [
"//tensorflow/lite/tools/delegates:coreml_delegate_provider",
"//tensorflow/lite/tools/delegates:default_execution_provider",
"//tensorflow/lite/tools/delegates:external_delegate_provider",
"//tensorflow/lite/tools/delegates:hexagon_delegate_provider",
"//tensorflow/lite/tools/delegates:nnapi_delegate_provider",
"//tensorflow/lite/tools/delegates:xnnpack_delegate_provider",
] + select({
# Metal GPU delegate for iOS has its own setups for kernel tests, so
# skipping linking w/ the gpu_delegate_provider.
"//tensorflow:ios": [],
"//conditions:default": [
"//tensorflow/lite/tools/delegates:gpu_delegate_provider",
],
}),
alwayslink = 1,
)
# TODO(b/132204084): Create tflite_cc_test rule to automate test_main inclusion.
cc_library(
name = "test_main",

View File

@ -22,22 +22,20 @@ limitations under the License.
namespace {
void InitKernelTest(int* argc, char** argv) {
bool use_nnapi = false;
std::vector<tflite::Flag> flags = {
tflite::Flag::CreateFlag("use_nnapi", &use_nnapi, "Use NNAPI"),
};
tflite::Flags::Parse(argc, const_cast<const char**>(argv), flags);
tflite::KernelTestDelegateProviders* const delegate_providers =
tflite::KernelTestDelegateProviders::Get();
delegate_providers->InitFromCmdlineArgs(argc, const_cast<const char**>(argv));
if (use_nnapi) {
tflite::SingleOpModel::SetForceUseNnapi(true);
}
// TODO(b/160764491): remove the special handling of NNAPI delegate test.
tflite::SingleOpModel::SetForceUseNnapi(
delegate_providers->ConstParams().Get<bool>("use_nnapi"));
}
} // namespace
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
InitKernelTest(&argc, argv);
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -32,6 +32,7 @@ limitations under the License.
#include <gtest/gtest.h>
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/core/subgraph.h"
#include "tensorflow/lite/delegates/nnapi/acceleration_test_util.h"
@ -39,12 +40,13 @@ limitations under the License.
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/acceleration_test_util.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/minimal_logging.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/nnapi/nnapi_implementation.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/string_type.h"
#include "tensorflow/lite/string_util.h"
#include "tensorflow/lite/tools/command_line_flags.h"
#include "tensorflow/lite/tools/logging.h"
#include "tensorflow/lite/tools/versioning/op_version.h"
#include "tensorflow/lite/version.h"
@ -219,14 +221,27 @@ void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
}
TfLiteStatus SingleOpModel::ApplyDelegate() {
auto* delegate_providers = tflite::KernelTestDelegateProviders::Get();
if (force_use_nnapi) {
delegate_ = TestNnApiDelegate();
// As we currently have special handling of nnapi delegate in kernel tests,
// we turn off the nnapi delegate provider to avoid re-applying it later.
// TODO(b/160764491): remove this special handling for NNAPI delegate test.
delegate_providers->MutableParams()->Set<bool>("use_nnapi", false);
}
if (delegate_) {
TFLITE_LOG(WARN) << "Having a manually-set TfLite delegate, and bypassing "
"KernelTestDelegateProviders";
return interpreter_->ModifyGraphWithDelegate(delegate_);
}
for (auto& one : delegate_providers->CreateAllDelegates()) {
TF_LITE_ENSURE_STATUS(
interpreter_->ModifyGraphWithDelegate(std::move(one)));
}
return kTfLiteOk;
}
@ -327,7 +342,7 @@ void SingleOpModel::ExpectOpAcceleratedWithNnapi(const std::string& test_id) {
return;
}
TFLITE_LOG_PROD(TFLITE_LOG_INFO, "Validating acceleration");
TFLITE_LOG(INFO) << "Validating acceleration";
const NnApi* nnapi = NnApiImplementation();
if (nnapi && nnapi->nnapi_exists &&
nnapi->android_sdk_version >=
@ -379,4 +394,39 @@ void MultiOpModel::AddCustomOp(
CustomOptionsFormat_FLEXBUFFERS));
}
/*static*/ KernelTestDelegateProviders* KernelTestDelegateProviders::Get() {
static KernelTestDelegateProviders* const providers =
new KernelTestDelegateProviders();
return providers;
}
KernelTestDelegateProviders::KernelTestDelegateProviders() {
for (const auto& one : tools::GetRegisteredDelegateProviders()) {
params_.Merge(one->DefaultParams());
}
}
bool KernelTestDelegateProviders::InitFromCmdlineArgs(int* argc,
const char** argv) {
std::vector<tflite::Flag> flags;
for (const auto& one : tools::GetRegisteredDelegateProviders()) {
auto one_flags = one->CreateFlags(&params_);
flags.insert(flags.end(), one_flags.begin(), one_flags.end());
}
return tflite::Flags::Parse(argc, argv, flags);
}
std::vector<tools::TfLiteDelegatePtr>
KernelTestDelegateProviders::CreateAllDelegates() const {
std::vector<tools::TfLiteDelegatePtr> delegates;
for (const auto& one : tools::GetRegisteredDelegateProviders()) {
auto ptr = one->CreateTfLiteDelegate(params_);
// It's possible that a delegate of certain type won't be created as
// user-specified benchmark params tells not to.
if (ptr == nullptr) continue;
delegates.emplace_back(std::move(ptr));
TFLITE_LOG(INFO) << one->GetName() << " delegate is created.";
}
return delegates;
}
} // namespace tflite

View File

@ -45,8 +45,10 @@ limitations under the License.
#include "tensorflow/lite/string_type.h"
#include "tensorflow/lite/string_util.h"
#include "tensorflow/lite/testing/util.h" // IWYU pragma: keep
#include "tensorflow/lite/tools/delegates/delegate_provider.h"
#include "tensorflow/lite/tools/optimize/quantization_utils.h"
#include "tensorflow/lite/tools/optimize/sparsity/format_converter.h"
#include "tensorflow/lite/tools/tool_params.h"
#include "tensorflow/lite/type_to_tflitetype.h"
namespace tflite {
@ -898,6 +900,36 @@ class MultiOpModel : public SingleOpModel {
}
};
// A utility class to provide TfLite delegate creations for kernel tests. The
// options of a particular delegate could be specified from commandline flags by
// using the delegate provider registrar as implemented in lite/tools/delegates
// directory.
class KernelTestDelegateProviders {
public:
// Returns a global KernelTestDelegateProviders instance.
static KernelTestDelegateProviders* Get();
KernelTestDelegateProviders();
// Initialize delegate-related parameters from commandline arguments and
// returns true if successful.
bool InitFromCmdlineArgs(int* argc, const char** argv);
// This provides a way to overwrite parameter values programmatically before
// creating TfLite delegates.
tools::ToolParams* MutableParams() { return &params_; }
const tools::ToolParams& ConstParams() const { return params_; }
// Create a list of TfLite delegates based on what have been initialized (i.e.
// 'params_').
std::vector<tools::TfLiteDelegatePtr> CreateAllDelegates() const;
private:
// Contain delegate-related parameters that are initialized from command-line
// flags.
tools::ToolParams params_;
};
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_TEST_UTIL_H_

View File

@ -47,6 +47,28 @@ TEST(TestUtilTest, QuantizeVectorScalingUp) {
EXPECT_THAT(q_data, ElementsAreArray(expected));
}
TEST(KernelTestDelegateProvidersTest, DelegateProvidersParams) {
KernelTestDelegateProviders providers;
const auto& params = providers.ConstParams();
EXPECT_TRUE(params.HasParam("use_xnnpack"));
EXPECT_TRUE(params.HasParam("use_nnapi"));
int argc = 3;
const char* argv[] = {"program_name", "--use_nnapi=true",
"--other_undefined_flag=1"};
EXPECT_TRUE(providers.InitFromCmdlineArgs(&argc, argv));
EXPECT_TRUE(params.Get<bool>("use_nnapi"));
EXPECT_EQ(2, argc);
EXPECT_EQ("--other_undefined_flag=1", argv[1]);
}
TEST(KernelTestDelegateProvidersTest, CreateTfLiteDelegates) {
#if !defined(__Fuchsia__) && !defined(TFLITE_WITHOUT_XNNPACK)
KernelTestDelegateProviders providers;
providers.MutableParams()->Set<bool>("use_xnnpack", true);
EXPECT_GE(providers.CreateAllDelegates().size(), 1);
#endif
}
} // namespace
} // namespace tflite