Support running TFLite kernel unit tests with NNAPI

PiperOrigin-RevId: 248757895
This commit is contained in:
Jared Duke 2019-05-17 11:34:19 -07:00 committed by TensorFlower Gardener
parent 406f7ff0dc
commit 5195395077
15 changed files with 116 additions and 48 deletions

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/lite/context_util.h"
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
#include "tensorflow/lite/graph_info.h"
#include "tensorflow/lite/minimal_logging.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace tflite {
@ -283,6 +284,16 @@ TfLiteDelegateParams* CreateDelegateParams(TfLiteDelegate* delegate,
TfLiteStatus Subgraph::ReplaceNodeSubsetsWithDelegateKernels(
TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace,
TfLiteDelegate* delegate) {
// Ignore empty node replacement sets.
if (!nodes_to_replace->size) {
return kTfLiteOk;
}
TFLITE_LOG(tflite::TFLITE_LOG_INFO,
"Replacing %d node(s) with delegate (%s) node.",
nodes_to_replace->size,
registration.custom_name ? registration.custom_name : "unknown");
// Annotate the registration as DELEGATE op.
registration.builtin_code = BuiltinOperator_DELEGATE;

View File

@ -22,6 +22,7 @@ cc_library(
hdrs = ["nnapi_delegate.h"],
deps = [
"//tensorflow/lite:kernel_api",
"//tensorflow/lite:minimal_logging",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/kernels:kernel_util",
"//tensorflow/lite/nnapi:nnapi_implementation",

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/context_util.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/minimal_logging.h"
#include "tensorflow/lite/nnapi/nnapi_implementation.h"
#ifdef __ANDROID__
@ -1584,6 +1585,8 @@ StatefulNnApiDelegate::StatefulNnApiDelegate(Options options)
if (options.accelerator_name) {
delegate_data_.accelerator_name = options.accelerator_name;
}
TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO,
"Created TensorFlow Lite delegate for NNAPI.");
Prepare = DoPrepare;
data_ = &delegate_data_;
}
@ -1657,6 +1660,11 @@ TfLiteStatus StatefulNnApiDelegate::DoPrepare(TfLiteContext* context,
// First element in vector must be the number of actual nodes.
supported_nodes[0] = supported_nodes.size() - 1;
// If there are no delegated nodes, short-circuit node replacement.
if (!supported_nodes[0]) {
return kTfLiteOk;
}
// NN API Delegate Registration (the pseudo kernel that will invoke NN
// API node sub sets)
static const TfLiteRegistration nnapi_delegate_kernel = {

View File

@ -51,6 +51,7 @@ cc_library(
"//tensorflow/lite:framework",
"//tensorflow/lite:schema_fbs_version",
"//tensorflow/lite:string_util",
"//tensorflow/lite/delegates/nnapi:nnapi_delegate",
"//tensorflow/lite/kernels/internal:tensor_utils",
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/testing:util",
@ -59,6 +60,19 @@ cc_library(
],
)
# TODO(b/132204084): Update all kernel tests to use this main lib.
cc_library(
name = "test_main",
testonly = 1,
srcs = ["test_main.cc"],
deps = [
":test_util",
"//tensorflow/lite/testing:util",
"//tensorflow/lite/tools:command_line_flags",
"@com_google_googletest//:gtest",
],
)
cc_library(
name = "eigen_support",
srcs = [
@ -518,8 +532,10 @@ cc_test(
name = "div_test",
size = "small",
srcs = ["div_test.cc"],
tags = ["tflite_nnapi"],
deps = [
":builtin_ops",
":test_main",
"//tensorflow/lite:framework",
"//tensorflow/lite/kernels:test_util",
"@com_google_googletest//:gtest",
@ -530,8 +546,10 @@ cc_test(
name = "sub_test",
size = "small",
srcs = ["sub_test.cc"],
tags = ["tflite_nnapi"],
deps = [
":builtin_ops",
":test_main",
"//tensorflow/lite:framework",
"//tensorflow/lite/kernels:test_util",
"@com_google_googletest//:gtest",
@ -542,8 +560,10 @@ cc_test(
name = "transpose_test",
size = "small",
srcs = ["transpose_test.cc"],
tags = ["tflite_nnapi"],
deps = [
":builtin_ops",
":test_main",
"//tensorflow/lite:framework",
"//tensorflow/lite/kernels:test_util",
"//tensorflow/lite/kernels/internal:reference",
@ -630,8 +650,10 @@ cc_test(
name = "dequantize_test",
size = "small",
srcs = ["dequantize_test.cc"],
tags = ["tflite_nnapi"],
deps = [
":builtin_ops",
":test_main",
"//tensorflow/lite:framework",
"//tensorflow/lite/kernels:test_util",
"//tensorflow/lite/kernels/internal:types",
@ -669,8 +691,10 @@ cc_test(
name = "floor_test",
size = "small",
srcs = ["floor_test.cc"],
tags = ["tflite_nnapi"],
deps = [
":builtin_ops",
":test_main",
"//tensorflow/lite:framework",
"//tensorflow/lite/kernels:test_util",
"@com_google_googletest//:gtest",
@ -969,8 +993,10 @@ cc_test(
name = "local_response_norm_test",
size = "small",
srcs = ["local_response_norm_test.cc"],
tags = ["tflite_nnapi"],
deps = [
":builtin_ops",
":test_main",
"//tensorflow/lite:framework",
"//tensorflow/lite/kernels:test_util",
"@com_google_googletest//:gtest",
@ -993,8 +1019,10 @@ cc_test(
name = "softmax_test",
size = "small",
srcs = ["softmax_test.cc"],
tags = ["tflite_nnapi"],
deps = [
":builtin_ops",
":test_main",
"//tensorflow/lite:framework",
"//tensorflow/lite/kernels:test_util",
"//tensorflow/lite/kernels/internal:reference_base",
@ -1019,8 +1047,10 @@ cc_test(
name = "lsh_projection_test",
size = "small",
srcs = ["lsh_projection_test.cc"],
tags = ["tflite_nnapi"],
deps = [
":builtin_ops",
":test_main",
"//tensorflow/lite:framework",
"//tensorflow/lite/kernels:test_util",
"@com_google_googletest//:gtest",

View File

@ -75,9 +75,3 @@ TEST(DequantizeOpTest, INT8) {
} // namespace
} // namespace tflite
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -167,9 +167,3 @@ TEST(IntegerDivOpTest, WithBroadcast) {
} // namespace
} // namespace tflite
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -75,9 +75,3 @@ TEST(FloorOpTest, MultiDims) {
} // namespace
} // namespace tflite
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -93,9 +93,3 @@ TEST(LocalResponseNormOpTest, SmallRadius) {
} // namespace
} // namespace tflite
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -115,9 +115,3 @@ TEST(LSHProjectionOpTest2, Sparse3DInputs) {
} // namespace
} // namespace tflite
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -136,9 +136,3 @@ TEST(SoftmaxOpTest, CompareWithTFminiBetaNotEq1) {
} // namespace
} // namespace tflite
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -417,8 +417,3 @@ TEST(QuantizedSubOpModel, QuantizedTestsReluActivationBroadcastInt16) {
} // namespace
} // namespace tflite
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -0,0 +1,43 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <vector>
#include <gtest/gtest.h>
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/testing/util.h"
#include "tensorflow/lite/tools/command_line_flags.h"
namespace {
void InitKernelTest(int* argc, char** argv) {
bool use_nnapi = false;
std::vector<tflite::Flag> flags = {
tflite::Flag::CreateFlag("--use_nnapi", &use_nnapi, "Use NNAPI"),
};
tflite::Flags::Parse(argc, const_cast<const char**>(argv), flags);
if (use_nnapi) {
tflite::SingleOpModel::SetForceUseNnapi(true);
}
}
} // namespace
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
InitKernelTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -14,14 +14,23 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/version.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
#include "tensorflow/lite/version.h"
namespace tflite {
using ::testing::FloatNear;
using ::testing::Matcher;
namespace {
// Whether to enable (global) use of NNAPI. Note that this will typically
// be set via a command-line flag.
static bool force_use_nnapi = false;
} // namespace
std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
float max_abs_error) {
std::vector<Matcher<float>> matchers;
@ -138,6 +147,11 @@ void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
<< "Cannot allocate tensors";
interpreter_->ResetVariableTensors();
if (force_use_nnapi) {
// TODO(b/124505407): Check the result and fail accordingly.
interpreter_->ModifyGraphWithDelegate(NnApiDelegate());
}
// Modify delegate with function.
if (apply_delegate_fn_) {
apply_delegate_fn_(interpreter_.get());
@ -146,6 +160,11 @@ void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
void SingleOpModel::Invoke() { CHECK(interpreter_->Invoke() == kTfLiteOk); }
// static
void SingleOpModel::SetForceUseNnapi(bool use_nnapi) {
force_use_nnapi = use_nnapi;
}
int32_t SingleOpModel::GetTensorSize(int index) const {
TfLiteTensor* t = interpreter_->tensor(index);
CHECK(t);

View File

@ -335,6 +335,9 @@ class SingleOpModel {
resolver_ = std::move(resolver);
}
// Enables NNAPI delegate application during interpreter creation.
static void SetForceUseNnapi(bool use_nnapi);
protected:
int32_t GetTensorSize(int index) const;

View File

@ -354,9 +354,3 @@ TEST(TransposeTest, ComplexTestWithReorderDynamicTensor) {
} // namespace
} // namespace tflite
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}