Support running TFLite kernel unit tests with NNAPI
PiperOrigin-RevId: 248757895
This commit is contained in:
parent
406f7ff0dc
commit
5195395077
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/context_util.h"
|
#include "tensorflow/lite/context_util.h"
|
||||||
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
|
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
|
||||||
#include "tensorflow/lite/graph_info.h"
|
#include "tensorflow/lite/graph_info.h"
|
||||||
|
#include "tensorflow/lite/minimal_logging.h"
|
||||||
#include "tensorflow/lite/schema/schema_generated.h"
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
@ -283,6 +284,16 @@ TfLiteDelegateParams* CreateDelegateParams(TfLiteDelegate* delegate,
|
|||||||
TfLiteStatus Subgraph::ReplaceNodeSubsetsWithDelegateKernels(
|
TfLiteStatus Subgraph::ReplaceNodeSubsetsWithDelegateKernels(
|
||||||
TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace,
|
TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace,
|
||||||
TfLiteDelegate* delegate) {
|
TfLiteDelegate* delegate) {
|
||||||
|
// Ignore empty node replacement sets.
|
||||||
|
if (!nodes_to_replace->size) {
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
TFLITE_LOG(tflite::TFLITE_LOG_INFO,
|
||||||
|
"Replacing %d node(s) with delegate (%s) node.",
|
||||||
|
nodes_to_replace->size,
|
||||||
|
registration.custom_name ? registration.custom_name : "unknown");
|
||||||
|
|
||||||
// Annotate the registration as DELEGATE op.
|
// Annotate the registration as DELEGATE op.
|
||||||
registration.builtin_code = BuiltinOperator_DELEGATE;
|
registration.builtin_code = BuiltinOperator_DELEGATE;
|
||||||
|
|
||||||
|
@ -22,6 +22,7 @@ cc_library(
|
|||||||
hdrs = ["nnapi_delegate.h"],
|
hdrs = ["nnapi_delegate.h"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/lite:kernel_api",
|
"//tensorflow/lite:kernel_api",
|
||||||
|
"//tensorflow/lite:minimal_logging",
|
||||||
"//tensorflow/lite/c:c_api_internal",
|
"//tensorflow/lite/c:c_api_internal",
|
||||||
"//tensorflow/lite/kernels:kernel_util",
|
"//tensorflow/lite/kernels:kernel_util",
|
||||||
"//tensorflow/lite/nnapi:nnapi_implementation",
|
"//tensorflow/lite/nnapi:nnapi_implementation",
|
||||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/c/c_api_internal.h"
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
#include "tensorflow/lite/context_util.h"
|
#include "tensorflow/lite/context_util.h"
|
||||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
|
#include "tensorflow/lite/minimal_logging.h"
|
||||||
#include "tensorflow/lite/nnapi/nnapi_implementation.h"
|
#include "tensorflow/lite/nnapi/nnapi_implementation.h"
|
||||||
|
|
||||||
#ifdef __ANDROID__
|
#ifdef __ANDROID__
|
||||||
@ -1584,6 +1585,8 @@ StatefulNnApiDelegate::StatefulNnApiDelegate(Options options)
|
|||||||
if (options.accelerator_name) {
|
if (options.accelerator_name) {
|
||||||
delegate_data_.accelerator_name = options.accelerator_name;
|
delegate_data_.accelerator_name = options.accelerator_name;
|
||||||
}
|
}
|
||||||
|
TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO,
|
||||||
|
"Created TensorFlow Lite delegate for NNAPI.");
|
||||||
Prepare = DoPrepare;
|
Prepare = DoPrepare;
|
||||||
data_ = &delegate_data_;
|
data_ = &delegate_data_;
|
||||||
}
|
}
|
||||||
@ -1657,6 +1660,11 @@ TfLiteStatus StatefulNnApiDelegate::DoPrepare(TfLiteContext* context,
|
|||||||
// First element in vector must be the number of actual nodes.
|
// First element in vector must be the number of actual nodes.
|
||||||
supported_nodes[0] = supported_nodes.size() - 1;
|
supported_nodes[0] = supported_nodes.size() - 1;
|
||||||
|
|
||||||
|
// If there are no delegated nodes, short-circuit node replacement.
|
||||||
|
if (!supported_nodes[0]) {
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
// NN API Delegate Registration (the pseudo kernel that will invoke NN
|
// NN API Delegate Registration (the pseudo kernel that will invoke NN
|
||||||
// API node sub sets)
|
// API node sub sets)
|
||||||
static const TfLiteRegistration nnapi_delegate_kernel = {
|
static const TfLiteRegistration nnapi_delegate_kernel = {
|
||||||
|
@ -51,6 +51,7 @@ cc_library(
|
|||||||
"//tensorflow/lite:framework",
|
"//tensorflow/lite:framework",
|
||||||
"//tensorflow/lite:schema_fbs_version",
|
"//tensorflow/lite:schema_fbs_version",
|
||||||
"//tensorflow/lite:string_util",
|
"//tensorflow/lite:string_util",
|
||||||
|
"//tensorflow/lite/delegates/nnapi:nnapi_delegate",
|
||||||
"//tensorflow/lite/kernels/internal:tensor_utils",
|
"//tensorflow/lite/kernels/internal:tensor_utils",
|
||||||
"//tensorflow/lite/schema:schema_fbs",
|
"//tensorflow/lite/schema:schema_fbs",
|
||||||
"//tensorflow/lite/testing:util",
|
"//tensorflow/lite/testing:util",
|
||||||
@ -59,6 +60,19 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO(b/132204084): Update all kernel tests to use this main lib.
|
||||||
|
cc_library(
|
||||||
|
name = "test_main",
|
||||||
|
testonly = 1,
|
||||||
|
srcs = ["test_main.cc"],
|
||||||
|
deps = [
|
||||||
|
":test_util",
|
||||||
|
"//tensorflow/lite/testing:util",
|
||||||
|
"//tensorflow/lite/tools:command_line_flags",
|
||||||
|
"@com_google_googletest//:gtest",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "eigen_support",
|
name = "eigen_support",
|
||||||
srcs = [
|
srcs = [
|
||||||
@ -518,8 +532,10 @@ cc_test(
|
|||||||
name = "div_test",
|
name = "div_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["div_test.cc"],
|
srcs = ["div_test.cc"],
|
||||||
|
tags = ["tflite_nnapi"],
|
||||||
deps = [
|
deps = [
|
||||||
":builtin_ops",
|
":builtin_ops",
|
||||||
|
":test_main",
|
||||||
"//tensorflow/lite:framework",
|
"//tensorflow/lite:framework",
|
||||||
"//tensorflow/lite/kernels:test_util",
|
"//tensorflow/lite/kernels:test_util",
|
||||||
"@com_google_googletest//:gtest",
|
"@com_google_googletest//:gtest",
|
||||||
@ -530,8 +546,10 @@ cc_test(
|
|||||||
name = "sub_test",
|
name = "sub_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["sub_test.cc"],
|
srcs = ["sub_test.cc"],
|
||||||
|
tags = ["tflite_nnapi"],
|
||||||
deps = [
|
deps = [
|
||||||
":builtin_ops",
|
":builtin_ops",
|
||||||
|
":test_main",
|
||||||
"//tensorflow/lite:framework",
|
"//tensorflow/lite:framework",
|
||||||
"//tensorflow/lite/kernels:test_util",
|
"//tensorflow/lite/kernels:test_util",
|
||||||
"@com_google_googletest//:gtest",
|
"@com_google_googletest//:gtest",
|
||||||
@ -542,8 +560,10 @@ cc_test(
|
|||||||
name = "transpose_test",
|
name = "transpose_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["transpose_test.cc"],
|
srcs = ["transpose_test.cc"],
|
||||||
|
tags = ["tflite_nnapi"],
|
||||||
deps = [
|
deps = [
|
||||||
":builtin_ops",
|
":builtin_ops",
|
||||||
|
":test_main",
|
||||||
"//tensorflow/lite:framework",
|
"//tensorflow/lite:framework",
|
||||||
"//tensorflow/lite/kernels:test_util",
|
"//tensorflow/lite/kernels:test_util",
|
||||||
"//tensorflow/lite/kernels/internal:reference",
|
"//tensorflow/lite/kernels/internal:reference",
|
||||||
@ -630,8 +650,10 @@ cc_test(
|
|||||||
name = "dequantize_test",
|
name = "dequantize_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["dequantize_test.cc"],
|
srcs = ["dequantize_test.cc"],
|
||||||
|
tags = ["tflite_nnapi"],
|
||||||
deps = [
|
deps = [
|
||||||
":builtin_ops",
|
":builtin_ops",
|
||||||
|
":test_main",
|
||||||
"//tensorflow/lite:framework",
|
"//tensorflow/lite:framework",
|
||||||
"//tensorflow/lite/kernels:test_util",
|
"//tensorflow/lite/kernels:test_util",
|
||||||
"//tensorflow/lite/kernels/internal:types",
|
"//tensorflow/lite/kernels/internal:types",
|
||||||
@ -669,8 +691,10 @@ cc_test(
|
|||||||
name = "floor_test",
|
name = "floor_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["floor_test.cc"],
|
srcs = ["floor_test.cc"],
|
||||||
|
tags = ["tflite_nnapi"],
|
||||||
deps = [
|
deps = [
|
||||||
":builtin_ops",
|
":builtin_ops",
|
||||||
|
":test_main",
|
||||||
"//tensorflow/lite:framework",
|
"//tensorflow/lite:framework",
|
||||||
"//tensorflow/lite/kernels:test_util",
|
"//tensorflow/lite/kernels:test_util",
|
||||||
"@com_google_googletest//:gtest",
|
"@com_google_googletest//:gtest",
|
||||||
@ -969,8 +993,10 @@ cc_test(
|
|||||||
name = "local_response_norm_test",
|
name = "local_response_norm_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["local_response_norm_test.cc"],
|
srcs = ["local_response_norm_test.cc"],
|
||||||
|
tags = ["tflite_nnapi"],
|
||||||
deps = [
|
deps = [
|
||||||
":builtin_ops",
|
":builtin_ops",
|
||||||
|
":test_main",
|
||||||
"//tensorflow/lite:framework",
|
"//tensorflow/lite:framework",
|
||||||
"//tensorflow/lite/kernels:test_util",
|
"//tensorflow/lite/kernels:test_util",
|
||||||
"@com_google_googletest//:gtest",
|
"@com_google_googletest//:gtest",
|
||||||
@ -993,8 +1019,10 @@ cc_test(
|
|||||||
name = "softmax_test",
|
name = "softmax_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["softmax_test.cc"],
|
srcs = ["softmax_test.cc"],
|
||||||
|
tags = ["tflite_nnapi"],
|
||||||
deps = [
|
deps = [
|
||||||
":builtin_ops",
|
":builtin_ops",
|
||||||
|
":test_main",
|
||||||
"//tensorflow/lite:framework",
|
"//tensorflow/lite:framework",
|
||||||
"//tensorflow/lite/kernels:test_util",
|
"//tensorflow/lite/kernels:test_util",
|
||||||
"//tensorflow/lite/kernels/internal:reference_base",
|
"//tensorflow/lite/kernels/internal:reference_base",
|
||||||
@ -1019,8 +1047,10 @@ cc_test(
|
|||||||
name = "lsh_projection_test",
|
name = "lsh_projection_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["lsh_projection_test.cc"],
|
srcs = ["lsh_projection_test.cc"],
|
||||||
|
tags = ["tflite_nnapi"],
|
||||||
deps = [
|
deps = [
|
||||||
":builtin_ops",
|
":builtin_ops",
|
||||||
|
":test_main",
|
||||||
"//tensorflow/lite:framework",
|
"//tensorflow/lite:framework",
|
||||||
"//tensorflow/lite/kernels:test_util",
|
"//tensorflow/lite/kernels:test_util",
|
||||||
"@com_google_googletest//:gtest",
|
"@com_google_googletest//:gtest",
|
||||||
|
@ -75,9 +75,3 @@ TEST(DequantizeOpTest, INT8) {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
|
||||||
::tflite::LogToStderr();
|
|
||||||
::testing::InitGoogleTest(&argc, argv);
|
|
||||||
return RUN_ALL_TESTS();
|
|
||||||
}
|
|
||||||
|
@ -167,9 +167,3 @@ TEST(IntegerDivOpTest, WithBroadcast) {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
|
||||||
::tflite::LogToStderr();
|
|
||||||
::testing::InitGoogleTest(&argc, argv);
|
|
||||||
return RUN_ALL_TESTS();
|
|
||||||
}
|
|
||||||
|
@ -75,9 +75,3 @@ TEST(FloorOpTest, MultiDims) {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
|
||||||
::tflite::LogToStderr();
|
|
||||||
::testing::InitGoogleTest(&argc, argv);
|
|
||||||
return RUN_ALL_TESTS();
|
|
||||||
}
|
|
||||||
|
@ -93,9 +93,3 @@ TEST(LocalResponseNormOpTest, SmallRadius) {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
|
||||||
::tflite::LogToStderr();
|
|
||||||
::testing::InitGoogleTest(&argc, argv);
|
|
||||||
return RUN_ALL_TESTS();
|
|
||||||
}
|
|
||||||
|
@ -115,9 +115,3 @@ TEST(LSHProjectionOpTest2, Sparse3DInputs) {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
|
||||||
::tflite::LogToStderr();
|
|
||||||
::testing::InitGoogleTest(&argc, argv);
|
|
||||||
return RUN_ALL_TESTS();
|
|
||||||
}
|
|
||||||
|
@ -136,9 +136,3 @@ TEST(SoftmaxOpTest, CompareWithTFminiBetaNotEq1) {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
|
||||||
::tflite::LogToStderr();
|
|
||||||
::testing::InitGoogleTest(&argc, argv);
|
|
||||||
return RUN_ALL_TESTS();
|
|
||||||
}
|
|
||||||
|
@ -417,8 +417,3 @@ TEST(QuantizedSubOpModel, QuantizedTestsReluActivationBroadcastInt16) {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
int main(int argc, char** argv) {
|
|
||||||
::tflite::LogToStderr();
|
|
||||||
::testing::InitGoogleTest(&argc, argv);
|
|
||||||
return RUN_ALL_TESTS();
|
|
||||||
}
|
|
||||||
|
43
tensorflow/lite/kernels/test_main.cc
Normal file
43
tensorflow/lite/kernels/test_main.cc
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "tensorflow/lite/kernels/test_util.h"
|
||||||
|
#include "tensorflow/lite/testing/util.h"
|
||||||
|
#include "tensorflow/lite/tools/command_line_flags.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
void InitKernelTest(int* argc, char** argv) {
|
||||||
|
bool use_nnapi = false;
|
||||||
|
std::vector<tflite::Flag> flags = {
|
||||||
|
tflite::Flag::CreateFlag("--use_nnapi", &use_nnapi, "Use NNAPI"),
|
||||||
|
};
|
||||||
|
tflite::Flags::Parse(argc, const_cast<const char**>(argv), flags);
|
||||||
|
|
||||||
|
if (use_nnapi) {
|
||||||
|
tflite::SingleOpModel::SetForceUseNnapi(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
int main(int argc, char** argv) {
|
||||||
|
::tflite::LogToStderr();
|
||||||
|
::testing::InitGoogleTest(&argc, argv);
|
||||||
|
InitKernelTest(&argc, argv);
|
||||||
|
return RUN_ALL_TESTS();
|
||||||
|
}
|
@ -14,14 +14,23 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include "tensorflow/lite/kernels/test_util.h"
|
#include "tensorflow/lite/kernels/test_util.h"
|
||||||
|
|
||||||
#include "tensorflow/lite/version.h"
|
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
|
||||||
|
#include "tensorflow/lite/version.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
|
|
||||||
using ::testing::FloatNear;
|
using ::testing::FloatNear;
|
||||||
using ::testing::Matcher;
|
using ::testing::Matcher;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Whether to enable (global) use of NNAPI. Note that this will typically
|
||||||
|
// be set via a command-line flag.
|
||||||
|
static bool force_use_nnapi = false;
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
|
std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
|
||||||
float max_abs_error) {
|
float max_abs_error) {
|
||||||
std::vector<Matcher<float>> matchers;
|
std::vector<Matcher<float>> matchers;
|
||||||
@ -138,6 +147,11 @@ void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
|||||||
<< "Cannot allocate tensors";
|
<< "Cannot allocate tensors";
|
||||||
interpreter_->ResetVariableTensors();
|
interpreter_->ResetVariableTensors();
|
||||||
|
|
||||||
|
if (force_use_nnapi) {
|
||||||
|
// TODO(b/124505407): Check the result and fail accordingly.
|
||||||
|
interpreter_->ModifyGraphWithDelegate(NnApiDelegate());
|
||||||
|
}
|
||||||
|
|
||||||
// Modify delegate with function.
|
// Modify delegate with function.
|
||||||
if (apply_delegate_fn_) {
|
if (apply_delegate_fn_) {
|
||||||
apply_delegate_fn_(interpreter_.get());
|
apply_delegate_fn_(interpreter_.get());
|
||||||
@ -146,6 +160,11 @@ void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
|||||||
|
|
||||||
void SingleOpModel::Invoke() { CHECK(interpreter_->Invoke() == kTfLiteOk); }
|
void SingleOpModel::Invoke() { CHECK(interpreter_->Invoke() == kTfLiteOk); }
|
||||||
|
|
||||||
|
// static
|
||||||
|
void SingleOpModel::SetForceUseNnapi(bool use_nnapi) {
|
||||||
|
force_use_nnapi = use_nnapi;
|
||||||
|
}
|
||||||
|
|
||||||
int32_t SingleOpModel::GetTensorSize(int index) const {
|
int32_t SingleOpModel::GetTensorSize(int index) const {
|
||||||
TfLiteTensor* t = interpreter_->tensor(index);
|
TfLiteTensor* t = interpreter_->tensor(index);
|
||||||
CHECK(t);
|
CHECK(t);
|
||||||
|
@ -335,6 +335,9 @@ class SingleOpModel {
|
|||||||
resolver_ = std::move(resolver);
|
resolver_ = std::move(resolver);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Enables NNAPI delegate application during interpreter creation.
|
||||||
|
static void SetForceUseNnapi(bool use_nnapi);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
int32_t GetTensorSize(int index) const;
|
int32_t GetTensorSize(int index) const;
|
||||||
|
|
||||||
|
@ -354,9 +354,3 @@ TEST(TransposeTest, ComplexTestWithReorderDynamicTensor) {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
|
||||||
::tflite::LogToStderr();
|
|
||||||
::testing::InitGoogleTest(&argc, argv);
|
|
||||||
return RUN_ALL_TESTS();
|
|
||||||
}
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user