Support running TFLite kernel unit tests with NNAPI
PiperOrigin-RevId: 248757895
This commit is contained in:
parent
406f7ff0dc
commit
5195395077
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/context_util.h"
|
||||
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
|
||||
#include "tensorflow/lite/graph_info.h"
|
||||
#include "tensorflow/lite/minimal_logging.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -283,6 +284,16 @@ TfLiteDelegateParams* CreateDelegateParams(TfLiteDelegate* delegate,
|
||||
TfLiteStatus Subgraph::ReplaceNodeSubsetsWithDelegateKernels(
|
||||
TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace,
|
||||
TfLiteDelegate* delegate) {
|
||||
// Ignore empty node replacement sets.
|
||||
if (!nodes_to_replace->size) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TFLITE_LOG(tflite::TFLITE_LOG_INFO,
|
||||
"Replacing %d node(s) with delegate (%s) node.",
|
||||
nodes_to_replace->size,
|
||||
registration.custom_name ? registration.custom_name : "unknown");
|
||||
|
||||
// Annotate the registration as DELEGATE op.
|
||||
registration.builtin_code = BuiltinOperator_DELEGATE;
|
||||
|
||||
|
@ -22,6 +22,7 @@ cc_library(
|
||||
hdrs = ["nnapi_delegate.h"],
|
||||
deps = [
|
||||
"//tensorflow/lite:kernel_api",
|
||||
"//tensorflow/lite:minimal_logging",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/kernels:kernel_util",
|
||||
"//tensorflow/lite/nnapi:nnapi_implementation",
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/context_util.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/minimal_logging.h"
|
||||
#include "tensorflow/lite/nnapi/nnapi_implementation.h"
|
||||
|
||||
#ifdef __ANDROID__
|
||||
@ -1584,6 +1585,8 @@ StatefulNnApiDelegate::StatefulNnApiDelegate(Options options)
|
||||
if (options.accelerator_name) {
|
||||
delegate_data_.accelerator_name = options.accelerator_name;
|
||||
}
|
||||
TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO,
|
||||
"Created TensorFlow Lite delegate for NNAPI.");
|
||||
Prepare = DoPrepare;
|
||||
data_ = &delegate_data_;
|
||||
}
|
||||
@ -1657,6 +1660,11 @@ TfLiteStatus StatefulNnApiDelegate::DoPrepare(TfLiteContext* context,
|
||||
// First element in vector must be the number of actual nodes.
|
||||
supported_nodes[0] = supported_nodes.size() - 1;
|
||||
|
||||
// If there are no delegated nodes, short-circuit node replacement.
|
||||
if (!supported_nodes[0]) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// NN API Delegate Registration (the pseudo kernel that will invoke NN
|
||||
// API node sub sets)
|
||||
static const TfLiteRegistration nnapi_delegate_kernel = {
|
||||
|
@ -51,6 +51,7 @@ cc_library(
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite:schema_fbs_version",
|
||||
"//tensorflow/lite:string_util",
|
||||
"//tensorflow/lite/delegates/nnapi:nnapi_delegate",
|
||||
"//tensorflow/lite/kernels/internal:tensor_utils",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"//tensorflow/lite/testing:util",
|
||||
@ -59,6 +60,19 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
# TODO(b/132204084): Update all kernel tests to use this main lib.
|
||||
cc_library(
|
||||
name = "test_main",
|
||||
testonly = 1,
|
||||
srcs = ["test_main.cc"],
|
||||
deps = [
|
||||
":test_util",
|
||||
"//tensorflow/lite/testing:util",
|
||||
"//tensorflow/lite/tools:command_line_flags",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "eigen_support",
|
||||
srcs = [
|
||||
@ -518,8 +532,10 @@ cc_test(
|
||||
name = "div_test",
|
||||
size = "small",
|
||||
srcs = ["div_test.cc"],
|
||||
tags = ["tflite_nnapi"],
|
||||
deps = [
|
||||
":builtin_ops",
|
||||
":test_main",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/kernels:test_util",
|
||||
"@com_google_googletest//:gtest",
|
||||
@ -530,8 +546,10 @@ cc_test(
|
||||
name = "sub_test",
|
||||
size = "small",
|
||||
srcs = ["sub_test.cc"],
|
||||
tags = ["tflite_nnapi"],
|
||||
deps = [
|
||||
":builtin_ops",
|
||||
":test_main",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/kernels:test_util",
|
||||
"@com_google_googletest//:gtest",
|
||||
@ -542,8 +560,10 @@ cc_test(
|
||||
name = "transpose_test",
|
||||
size = "small",
|
||||
srcs = ["transpose_test.cc"],
|
||||
tags = ["tflite_nnapi"],
|
||||
deps = [
|
||||
":builtin_ops",
|
||||
":test_main",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/kernels:test_util",
|
||||
"//tensorflow/lite/kernels/internal:reference",
|
||||
@ -630,8 +650,10 @@ cc_test(
|
||||
name = "dequantize_test",
|
||||
size = "small",
|
||||
srcs = ["dequantize_test.cc"],
|
||||
tags = ["tflite_nnapi"],
|
||||
deps = [
|
||||
":builtin_ops",
|
||||
":test_main",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/kernels:test_util",
|
||||
"//tensorflow/lite/kernels/internal:types",
|
||||
@ -669,8 +691,10 @@ cc_test(
|
||||
name = "floor_test",
|
||||
size = "small",
|
||||
srcs = ["floor_test.cc"],
|
||||
tags = ["tflite_nnapi"],
|
||||
deps = [
|
||||
":builtin_ops",
|
||||
":test_main",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/kernels:test_util",
|
||||
"@com_google_googletest//:gtest",
|
||||
@ -969,8 +993,10 @@ cc_test(
|
||||
name = "local_response_norm_test",
|
||||
size = "small",
|
||||
srcs = ["local_response_norm_test.cc"],
|
||||
tags = ["tflite_nnapi"],
|
||||
deps = [
|
||||
":builtin_ops",
|
||||
":test_main",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/kernels:test_util",
|
||||
"@com_google_googletest//:gtest",
|
||||
@ -993,8 +1019,10 @@ cc_test(
|
||||
name = "softmax_test",
|
||||
size = "small",
|
||||
srcs = ["softmax_test.cc"],
|
||||
tags = ["tflite_nnapi"],
|
||||
deps = [
|
||||
":builtin_ops",
|
||||
":test_main",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/kernels:test_util",
|
||||
"//tensorflow/lite/kernels/internal:reference_base",
|
||||
@ -1019,8 +1047,10 @@ cc_test(
|
||||
name = "lsh_projection_test",
|
||||
size = "small",
|
||||
srcs = ["lsh_projection_test.cc"],
|
||||
tags = ["tflite_nnapi"],
|
||||
deps = [
|
||||
":builtin_ops",
|
||||
":test_main",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/kernels:test_util",
|
||||
"@com_google_googletest//:gtest",
|
||||
|
@ -75,9 +75,3 @@ TEST(DequantizeOpTest, INT8) {
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
::tflite::LogToStderr();
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
|
@ -167,9 +167,3 @@ TEST(IntegerDivOpTest, WithBroadcast) {
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
::tflite::LogToStderr();
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
|
@ -75,9 +75,3 @@ TEST(FloorOpTest, MultiDims) {
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
::tflite::LogToStderr();
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
|
@ -93,9 +93,3 @@ TEST(LocalResponseNormOpTest, SmallRadius) {
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
::tflite::LogToStderr();
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
|
@ -115,9 +115,3 @@ TEST(LSHProjectionOpTest2, Sparse3DInputs) {
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
::tflite::LogToStderr();
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
|
@ -136,9 +136,3 @@ TEST(SoftmaxOpTest, CompareWithTFminiBetaNotEq1) {
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
::tflite::LogToStderr();
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
|
@ -417,8 +417,3 @@ TEST(QuantizedSubOpModel, QuantizedTestsReluActivationBroadcastInt16) {
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
int main(int argc, char** argv) {
|
||||
::tflite::LogToStderr();
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
|
43
tensorflow/lite/kernels/test_main.cc
Normal file
43
tensorflow/lite/kernels/test_main.cc
Normal file
@ -0,0 +1,43 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <vector>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
#include "tensorflow/lite/testing/util.h"
|
||||
#include "tensorflow/lite/tools/command_line_flags.h"
|
||||
|
||||
namespace {
|
||||
|
||||
void InitKernelTest(int* argc, char** argv) {
|
||||
bool use_nnapi = false;
|
||||
std::vector<tflite::Flag> flags = {
|
||||
tflite::Flag::CreateFlag("--use_nnapi", &use_nnapi, "Use NNAPI"),
|
||||
};
|
||||
tflite::Flags::Parse(argc, const_cast<const char**>(argv), flags);
|
||||
|
||||
if (use_nnapi) {
|
||||
tflite::SingleOpModel::SetForceUseNnapi(true);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
::tflite::LogToStderr();
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
InitKernelTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
@ -14,14 +14,23 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
|
||||
#include "tensorflow/lite/version.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
|
||||
#include "tensorflow/lite/version.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
using ::testing::FloatNear;
|
||||
using ::testing::Matcher;
|
||||
|
||||
namespace {
|
||||
|
||||
// Whether to enable (global) use of NNAPI. Note that this will typically
|
||||
// be set via a command-line flag.
|
||||
static bool force_use_nnapi = false;
|
||||
|
||||
} // namespace
|
||||
|
||||
std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
|
||||
float max_abs_error) {
|
||||
std::vector<Matcher<float>> matchers;
|
||||
@ -138,6 +147,11 @@ void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
||||
<< "Cannot allocate tensors";
|
||||
interpreter_->ResetVariableTensors();
|
||||
|
||||
if (force_use_nnapi) {
|
||||
// TODO(b/124505407): Check the result and fail accordingly.
|
||||
interpreter_->ModifyGraphWithDelegate(NnApiDelegate());
|
||||
}
|
||||
|
||||
// Modify delegate with function.
|
||||
if (apply_delegate_fn_) {
|
||||
apply_delegate_fn_(interpreter_.get());
|
||||
@ -146,6 +160,11 @@ void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
||||
|
||||
void SingleOpModel::Invoke() { CHECK(interpreter_->Invoke() == kTfLiteOk); }
|
||||
|
||||
// static
|
||||
void SingleOpModel::SetForceUseNnapi(bool use_nnapi) {
|
||||
force_use_nnapi = use_nnapi;
|
||||
}
|
||||
|
||||
int32_t SingleOpModel::GetTensorSize(int index) const {
|
||||
TfLiteTensor* t = interpreter_->tensor(index);
|
||||
CHECK(t);
|
||||
|
@ -335,6 +335,9 @@ class SingleOpModel {
|
||||
resolver_ = std::move(resolver);
|
||||
}
|
||||
|
||||
// Enables NNAPI delegate application during interpreter creation.
|
||||
static void SetForceUseNnapi(bool use_nnapi);
|
||||
|
||||
protected:
|
||||
int32_t GetTensorSize(int index) const;
|
||||
|
||||
|
@ -354,9 +354,3 @@ TEST(TransposeTest, ComplexTestWithReorderDynamicTensor) {
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
::tflite::LogToStderr();
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user