Support more delegate options in kernel tests by utilizing TFLite tooling's delegate registrar.
PiperOrigin-RevId: 320527595 Change-Id: I9caa7b4cf0a961c4aef90fd73840069b8ed1a32e
This commit is contained in:
parent
6c292dc103
commit
eddb295677
@ -171,9 +171,9 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":acceleration_test_util",
|
":acceleration_test_util",
|
||||||
":builtin_ops",
|
":builtin_ops",
|
||||||
|
":test_util_delegate_providers",
|
||||||
"//tensorflow/core/platform:logging",
|
"//tensorflow/core/platform:logging",
|
||||||
"//tensorflow/lite:framework",
|
"//tensorflow/lite:framework",
|
||||||
"//tensorflow/lite:minimal_logging",
|
|
||||||
"//tensorflow/lite:schema_fbs_version",
|
"//tensorflow/lite:schema_fbs_version",
|
||||||
"//tensorflow/lite:string",
|
"//tensorflow/lite:string",
|
||||||
"//tensorflow/lite:string_util",
|
"//tensorflow/lite:string_util",
|
||||||
@ -186,6 +186,10 @@ cc_library(
|
|||||||
"//tensorflow/lite/nnapi:nnapi_implementation",
|
"//tensorflow/lite/nnapi:nnapi_implementation",
|
||||||
"//tensorflow/lite/schema:schema_fbs",
|
"//tensorflow/lite/schema:schema_fbs",
|
||||||
"//tensorflow/lite/testing:util",
|
"//tensorflow/lite/testing:util",
|
||||||
|
"//tensorflow/lite/tools:command_line_flags",
|
||||||
|
"//tensorflow/lite/tools:logging",
|
||||||
|
"//tensorflow/lite/tools:tool_params",
|
||||||
|
"//tensorflow/lite/tools/delegates:delegate_provider_hdr",
|
||||||
"//tensorflow/lite/tools/optimize:quantization_utils",
|
"//tensorflow/lite/tools/optimize:quantization_utils",
|
||||||
"//tensorflow/lite/tools/optimize/sparsity:format_converter",
|
"//tensorflow/lite/tools/optimize/sparsity:format_converter",
|
||||||
"//tensorflow/lite/tools/versioning",
|
"//tensorflow/lite/tools/versioning",
|
||||||
@ -194,6 +198,28 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# A convenient library for tflite delegate execution providers.
|
||||||
|
cc_library(
|
||||||
|
name = "test_util_delegate_providers",
|
||||||
|
copts = tflite_copts(),
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/lite/tools/delegates:coreml_delegate_provider",
|
||||||
|
"//tensorflow/lite/tools/delegates:default_execution_provider",
|
||||||
|
"//tensorflow/lite/tools/delegates:external_delegate_provider",
|
||||||
|
"//tensorflow/lite/tools/delegates:hexagon_delegate_provider",
|
||||||
|
"//tensorflow/lite/tools/delegates:nnapi_delegate_provider",
|
||||||
|
"//tensorflow/lite/tools/delegates:xnnpack_delegate_provider",
|
||||||
|
] + select({
|
||||||
|
# Metal GPU delegate for iOS has its own setups for kernel tests, so
|
||||||
|
# skipping linking w/ the gpu_delegate_provider.
|
||||||
|
"//tensorflow:ios": [],
|
||||||
|
"//conditions:default": [
|
||||||
|
"//tensorflow/lite/tools/delegates:gpu_delegate_provider",
|
||||||
|
],
|
||||||
|
}),
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
# TODO(b/132204084): Create tflite_cc_test rule to automate test_main inclusion.
|
# TODO(b/132204084): Create tflite_cc_test rule to automate test_main inclusion.
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "test_main",
|
name = "test_main",
|
||||||
|
@ -22,22 +22,20 @@ limitations under the License.
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
void InitKernelTest(int* argc, char** argv) {
|
void InitKernelTest(int* argc, char** argv) {
|
||||||
bool use_nnapi = false;
|
tflite::KernelTestDelegateProviders* const delegate_providers =
|
||||||
std::vector<tflite::Flag> flags = {
|
tflite::KernelTestDelegateProviders::Get();
|
||||||
tflite::Flag::CreateFlag("use_nnapi", &use_nnapi, "Use NNAPI"),
|
delegate_providers->InitFromCmdlineArgs(argc, const_cast<const char**>(argv));
|
||||||
};
|
|
||||||
tflite::Flags::Parse(argc, const_cast<const char**>(argv), flags);
|
|
||||||
|
|
||||||
if (use_nnapi) {
|
// TODO(b/160764491): remove the special handling of NNAPI delegate test.
|
||||||
tflite::SingleOpModel::SetForceUseNnapi(true);
|
tflite::SingleOpModel::SetForceUseNnapi(
|
||||||
}
|
delegate_providers->ConstParams().Get<bool>("use_nnapi"));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
::tflite::LogToStderr();
|
::tflite::LogToStderr();
|
||||||
::testing::InitGoogleTest(&argc, argv);
|
|
||||||
InitKernelTest(&argc, argv);
|
InitKernelTest(&argc, argv);
|
||||||
|
::testing::InitGoogleTest(&argc, argv);
|
||||||
return RUN_ALL_TESTS();
|
return RUN_ALL_TESTS();
|
||||||
}
|
}
|
||||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
|||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
|
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
#include "tensorflow/lite/core/subgraph.h"
|
#include "tensorflow/lite/core/subgraph.h"
|
||||||
#include "tensorflow/lite/delegates/nnapi/acceleration_test_util.h"
|
#include "tensorflow/lite/delegates/nnapi/acceleration_test_util.h"
|
||||||
@ -39,12 +40,13 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/interpreter.h"
|
#include "tensorflow/lite/interpreter.h"
|
||||||
#include "tensorflow/lite/kernels/acceleration_test_util.h"
|
#include "tensorflow/lite/kernels/acceleration_test_util.h"
|
||||||
#include "tensorflow/lite/kernels/register.h"
|
#include "tensorflow/lite/kernels/register.h"
|
||||||
#include "tensorflow/lite/minimal_logging.h"
|
|
||||||
#include "tensorflow/lite/model.h"
|
#include "tensorflow/lite/model.h"
|
||||||
#include "tensorflow/lite/nnapi/nnapi_implementation.h"
|
#include "tensorflow/lite/nnapi/nnapi_implementation.h"
|
||||||
#include "tensorflow/lite/schema/schema_generated.h"
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
#include "tensorflow/lite/string_type.h"
|
#include "tensorflow/lite/string_type.h"
|
||||||
#include "tensorflow/lite/string_util.h"
|
#include "tensorflow/lite/string_util.h"
|
||||||
|
#include "tensorflow/lite/tools/command_line_flags.h"
|
||||||
|
#include "tensorflow/lite/tools/logging.h"
|
||||||
#include "tensorflow/lite/tools/versioning/op_version.h"
|
#include "tensorflow/lite/tools/versioning/op_version.h"
|
||||||
#include "tensorflow/lite/version.h"
|
#include "tensorflow/lite/version.h"
|
||||||
|
|
||||||
@ -219,14 +221,27 @@ void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
|||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus SingleOpModel::ApplyDelegate() {
|
TfLiteStatus SingleOpModel::ApplyDelegate() {
|
||||||
|
auto* delegate_providers = tflite::KernelTestDelegateProviders::Get();
|
||||||
|
|
||||||
if (force_use_nnapi) {
|
if (force_use_nnapi) {
|
||||||
delegate_ = TestNnApiDelegate();
|
delegate_ = TestNnApiDelegate();
|
||||||
|
|
||||||
|
// As we currently have special handling of nnapi delegate in kernel tests,
|
||||||
|
// we turn off the nnapi delegate provider to avoid re-applying it later.
|
||||||
|
// TODO(b/160764491): remove this special handling for NNAPI delegate test.
|
||||||
|
delegate_providers->MutableParams()->Set<bool>("use_nnapi", false);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (delegate_) {
|
if (delegate_) {
|
||||||
|
TFLITE_LOG(WARN) << "Having a manually-set TfLite delegate, and bypassing "
|
||||||
|
"KernelTestDelegateProviders";
|
||||||
return interpreter_->ModifyGraphWithDelegate(delegate_);
|
return interpreter_->ModifyGraphWithDelegate(delegate_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (auto& one : delegate_providers->CreateAllDelegates()) {
|
||||||
|
TF_LITE_ENSURE_STATUS(
|
||||||
|
interpreter_->ModifyGraphWithDelegate(std::move(one)));
|
||||||
|
}
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -327,7 +342,7 @@ void SingleOpModel::ExpectOpAcceleratedWithNnapi(const std::string& test_id) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
TFLITE_LOG_PROD(TFLITE_LOG_INFO, "Validating acceleration");
|
TFLITE_LOG(INFO) << "Validating acceleration";
|
||||||
const NnApi* nnapi = NnApiImplementation();
|
const NnApi* nnapi = NnApiImplementation();
|
||||||
if (nnapi && nnapi->nnapi_exists &&
|
if (nnapi && nnapi->nnapi_exists &&
|
||||||
nnapi->android_sdk_version >=
|
nnapi->android_sdk_version >=
|
||||||
@ -379,4 +394,39 @@ void MultiOpModel::AddCustomOp(
|
|||||||
CustomOptionsFormat_FLEXBUFFERS));
|
CustomOptionsFormat_FLEXBUFFERS));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*static*/ KernelTestDelegateProviders* KernelTestDelegateProviders::Get() {
|
||||||
|
static KernelTestDelegateProviders* const providers =
|
||||||
|
new KernelTestDelegateProviders();
|
||||||
|
return providers;
|
||||||
|
}
|
||||||
|
|
||||||
|
KernelTestDelegateProviders::KernelTestDelegateProviders() {
|
||||||
|
for (const auto& one : tools::GetRegisteredDelegateProviders()) {
|
||||||
|
params_.Merge(one->DefaultParams());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool KernelTestDelegateProviders::InitFromCmdlineArgs(int* argc,
|
||||||
|
const char** argv) {
|
||||||
|
std::vector<tflite::Flag> flags;
|
||||||
|
for (const auto& one : tools::GetRegisteredDelegateProviders()) {
|
||||||
|
auto one_flags = one->CreateFlags(¶ms_);
|
||||||
|
flags.insert(flags.end(), one_flags.begin(), one_flags.end());
|
||||||
|
}
|
||||||
|
return tflite::Flags::Parse(argc, argv, flags);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<tools::TfLiteDelegatePtr>
|
||||||
|
KernelTestDelegateProviders::CreateAllDelegates() const {
|
||||||
|
std::vector<tools::TfLiteDelegatePtr> delegates;
|
||||||
|
for (const auto& one : tools::GetRegisteredDelegateProviders()) {
|
||||||
|
auto ptr = one->CreateTfLiteDelegate(params_);
|
||||||
|
// It's possible that a delegate of certain type won't be created as
|
||||||
|
// user-specified benchmark params tells not to.
|
||||||
|
if (ptr == nullptr) continue;
|
||||||
|
delegates.emplace_back(std::move(ptr));
|
||||||
|
TFLITE_LOG(INFO) << one->GetName() << " delegate is created.";
|
||||||
|
}
|
||||||
|
return delegates;
|
||||||
|
}
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -45,8 +45,10 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/string_type.h"
|
#include "tensorflow/lite/string_type.h"
|
||||||
#include "tensorflow/lite/string_util.h"
|
#include "tensorflow/lite/string_util.h"
|
||||||
#include "tensorflow/lite/testing/util.h" // IWYU pragma: keep
|
#include "tensorflow/lite/testing/util.h" // IWYU pragma: keep
|
||||||
|
#include "tensorflow/lite/tools/delegates/delegate_provider.h"
|
||||||
#include "tensorflow/lite/tools/optimize/quantization_utils.h"
|
#include "tensorflow/lite/tools/optimize/quantization_utils.h"
|
||||||
#include "tensorflow/lite/tools/optimize/sparsity/format_converter.h"
|
#include "tensorflow/lite/tools/optimize/sparsity/format_converter.h"
|
||||||
|
#include "tensorflow/lite/tools/tool_params.h"
|
||||||
#include "tensorflow/lite/type_to_tflitetype.h"
|
#include "tensorflow/lite/type_to_tflitetype.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
@ -898,6 +900,36 @@ class MultiOpModel : public SingleOpModel {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// A utility class to provide TfLite delegate creations for kernel tests. The
|
||||||
|
// options of a particular delegate could be specified from commandline flags by
|
||||||
|
// using the delegate provider registrar as implemented in lite/tools/delegates
|
||||||
|
// directory.
|
||||||
|
class KernelTestDelegateProviders {
|
||||||
|
public:
|
||||||
|
// Returns a global KernelTestDelegateProviders instance.
|
||||||
|
static KernelTestDelegateProviders* Get();
|
||||||
|
|
||||||
|
KernelTestDelegateProviders();
|
||||||
|
|
||||||
|
// Initialize delegate-related parameters from commandline arguments and
|
||||||
|
// returns true if successful.
|
||||||
|
bool InitFromCmdlineArgs(int* argc, const char** argv);
|
||||||
|
|
||||||
|
// This provides a way to overwrite parameter values programmatically before
|
||||||
|
// creating TfLite delegates.
|
||||||
|
tools::ToolParams* MutableParams() { return ¶ms_; }
|
||||||
|
const tools::ToolParams& ConstParams() const { return params_; }
|
||||||
|
|
||||||
|
// Create a list of TfLite delegates based on what have been initialized (i.e.
|
||||||
|
// 'params_').
|
||||||
|
std::vector<tools::TfLiteDelegatePtr> CreateAllDelegates() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Contain delegate-related parameters that are initialized from command-line
|
||||||
|
// flags.
|
||||||
|
tools::ToolParams params_;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_KERNELS_TEST_UTIL_H_
|
#endif // TENSORFLOW_LITE_KERNELS_TEST_UTIL_H_
|
||||||
|
@ -47,6 +47,28 @@ TEST(TestUtilTest, QuantizeVectorScalingUp) {
|
|||||||
EXPECT_THAT(q_data, ElementsAreArray(expected));
|
EXPECT_THAT(q_data, ElementsAreArray(expected));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(KernelTestDelegateProvidersTest, DelegateProvidersParams) {
|
||||||
|
KernelTestDelegateProviders providers;
|
||||||
|
const auto& params = providers.ConstParams();
|
||||||
|
EXPECT_TRUE(params.HasParam("use_xnnpack"));
|
||||||
|
EXPECT_TRUE(params.HasParam("use_nnapi"));
|
||||||
|
|
||||||
|
int argc = 3;
|
||||||
|
const char* argv[] = {"program_name", "--use_nnapi=true",
|
||||||
|
"--other_undefined_flag=1"};
|
||||||
|
EXPECT_TRUE(providers.InitFromCmdlineArgs(&argc, argv));
|
||||||
|
EXPECT_TRUE(params.Get<bool>("use_nnapi"));
|
||||||
|
EXPECT_EQ(2, argc);
|
||||||
|
EXPECT_EQ("--other_undefined_flag=1", argv[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(KernelTestDelegateProvidersTest, CreateTfLiteDelegates) {
|
||||||
|
#if !defined(__Fuchsia__) && !defined(TFLITE_WITHOUT_XNNPACK)
|
||||||
|
KernelTestDelegateProviders providers;
|
||||||
|
providers.MutableParams()->Set<bool>("use_xnnpack", true);
|
||||||
|
EXPECT_GE(providers.CreateAllDelegates().size(), 1);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user