From 33096914a7e1afa84007451faa25ad66ab78f15d Mon Sep 17 00:00:00 2001 From: Yunlu Li Date: Wed, 19 Dec 2018 16:00:11 -0800 Subject: [PATCH] Make tflite_driver able to run single op model with reference kernels. PiperOrigin-RevId: 226248707 --- tensorflow/lite/kernels/BUILD | 15 + tensorflow/lite/kernels/register_ref.cc | 297 ++++++++++++++++++ tensorflow/lite/kernels/register_ref.h | 39 +++ tensorflow/lite/testing/BUILD | 1 + tensorflow/lite/testing/tflite_driver.cc | 14 +- tensorflow/lite/testing/tflite_driver.h | 6 +- tensorflow/lite/testing/tflite_driver_test.cc | 34 ++ 7 files changed, 402 insertions(+), 4 deletions(-) create mode 100644 tensorflow/lite/kernels/register_ref.cc create mode 100644 tensorflow/lite/kernels/register_ref.h diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index bad1c4aebf1..5cc06c7a633 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -285,6 +285,21 @@ cc_library( ], ) +# The builtin_ops target will resolve to optimized kernels when available. This +# target uses reference kernels only, and is useful for validation and testing. +# It should *not* generally be used in production. +cc_library( + name = "reference_ops", + srcs = ["register_ref.cc"], + hdrs = ["register_ref.h"], + deps = [ + ":builtin_op_kernels", + "//tensorflow/lite:framework", + "//tensorflow/lite:util", + "//tensorflow/lite/c:c_api_internal", + ], +) + tf_cc_test( name = "audio_spectrogram_test", size = "small", diff --git a/tensorflow/lite/kernels/register_ref.cc b/tensorflow/lite/kernels/register_ref.cc new file mode 100644 index 00000000000..584e044b98b --- /dev/null +++ b/tensorflow/lite/kernels/register_ref.cc @@ -0,0 +1,297 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/kernels/register_ref.h" +#include "tensorflow/lite/util.h" + +namespace tflite { +namespace ops { + +namespace custom { + +TfLiteRegistration* Register_AUDIO_SPECTROGRAM(); +TfLiteRegistration* Register_LAYER_NORM_LSTM(); +TfLiteRegistration* Register_MFCC(); +TfLiteRegistration* Register_DETECTION_POSTPROCESS(); +TfLiteRegistration* Register_RELU_1(); + +} // namespace custom + +namespace builtin { + +// TODO(yunluli): Some of the registries, e.g. Tanh(), could only invoke +// optimized kernels. Add a _REF() variant for them. +TfLiteRegistration* Register_ABS(); +TfLiteRegistration* Register_RELU(); +TfLiteRegistration* Register_RELU_N1_TO_1(); +TfLiteRegistration* Register_RELU6(); +TfLiteRegistration* Register_TANH(); +TfLiteRegistration* Register_LOGISTIC(); +TfLiteRegistration* Register_AVERAGE_POOL_REF(); +TfLiteRegistration* Register_MAX_POOL_REF(); +TfLiteRegistration* Register_L2_POOL_REF(); +TfLiteRegistration* Register_CONVOLUTION_REF(); +TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_REF(); +TfLiteRegistration* Register_SVDF(); +TfLiteRegistration* Register_RNN(); +TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_RNN(); +TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_RNN(); +TfLiteRegistration* Register_EMBEDDING_LOOKUP(); +TfLiteRegistration* Register_EMBEDDING_LOOKUP_SPARSE(); +TfLiteRegistration* Register_FULLY_CONNECTED_REF(); +TfLiteRegistration* Register_LSH_PROJECTION(); +TfLiteRegistration* Register_HASHTABLE_LOOKUP(); +TfLiteRegistration* Register_SOFTMAX(); +TfLiteRegistration* Register_CONCATENATION_REF(); +TfLiteRegistration* Register_ADD_REF(); +TfLiteRegistration* Register_SPACE_TO_BATCH_ND_REF(); +TfLiteRegistration* Register_DIV_REF(); +TfLiteRegistration* Register_SUB_REF(); +TfLiteRegistration* Register_BATCH_TO_SPACE_ND_REF(); +TfLiteRegistration* Register_MUL_REF(); +TfLiteRegistration* Register_L2NORM_REF(); +TfLiteRegistration* Register_LOCAL_RESPONSE_NORM_REF(); +TfLiteRegistration* Register_LSTM(); +TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_LSTM(); +TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_LSTM(); +TfLiteRegistration* Register_PAD_REF(); +TfLiteRegistration* Register_PADV2_REF(); +TfLiteRegistration* Register_RESHAPE(); +TfLiteRegistration* Register_RESIZE_BILINEAR_REF(); +TfLiteRegistration* Register_RESIZE_NEAREST_NEIGHBOR_REF(); +TfLiteRegistration* Register_SKIP_GRAM(); +TfLiteRegistration* Register_SPACE_TO_DEPTH_REF(); +TfLiteRegistration* Register_GATHER(); +TfLiteRegistration* Register_TRANSPOSE_REF(); +TfLiteRegistration* Register_MEAN_REF(); +TfLiteRegistration* Register_SPLIT(); +TfLiteRegistration* Register_SPLIT_V(); +TfLiteRegistration* Register_SQUEEZE(); +TfLiteRegistration* Register_STRIDED_SLICE_REF(); +TfLiteRegistration* Register_EXP(); +TfLiteRegistration* Register_TOPK_V2(); +TfLiteRegistration* Register_LOG(); +TfLiteRegistration* Register_LOG_SOFTMAX(); +TfLiteRegistration* Register_CAST(); +TfLiteRegistration* Register_DEQUANTIZE(); +TfLiteRegistration* Register_PRELU(); +TfLiteRegistration* Register_MAXIMUM(); +TfLiteRegistration* Register_MINIMUM(); +TfLiteRegistration* Register_ARG_MAX(); +TfLiteRegistration* Register_ARG_MIN(); +TfLiteRegistration* Register_GREATER(); +TfLiteRegistration* Register_GREATER_EQUAL(); +TfLiteRegistration* Register_LESS(); +TfLiteRegistration* Register_LESS_EQUAL(); +TfLiteRegistration* Register_FLOOR(); +TfLiteRegistration* Register_TILE(); +TfLiteRegistration* Register_NEG(); +TfLiteRegistration* Register_SUM(); +TfLiteRegistration* Register_REDUCE_PROD(); +TfLiteRegistration* Register_REDUCE_MAX(); +TfLiteRegistration* Register_REDUCE_MIN(); +TfLiteRegistration* Register_REDUCE_ANY(); +TfLiteRegistration* Register_SELECT(); +TfLiteRegistration* Register_SLICE(); +TfLiteRegistration* Register_SIN(); +TfLiteRegistration* Register_TRANSPOSECONV_REF(); +TfLiteRegistration* Register_EXPAND_DIMS(); +TfLiteRegistration* Register_SPARSE_TO_DENSE(); +TfLiteRegistration* Register_EQUAL(); +TfLiteRegistration* Register_NOT_EQUAL(); +TfLiteRegistration* Register_SQRT(); +TfLiteRegistration* Register_RSQRT(); +TfLiteRegistration* Register_SHAPE(); +TfLiteRegistration* Register_POW(); +TfLiteRegistration* Register_FAKE_QUANT(); +TfLiteRegistration* Register_PACK(); +TfLiteRegistration* Register_ONE_HOT(); +TfLiteRegistration* Register_LOGICAL_OR(); +TfLiteRegistration* Register_LOGICAL_AND(); +TfLiteRegistration* Register_LOGICAL_NOT(); +TfLiteRegistration* Register_UNPACK(); +TfLiteRegistration* Register_FLOOR_DIV(); +TfLiteRegistration* Register_SQUARE(); +TfLiteRegistration* Register_ZEROS_LIKE(); +TfLiteRegistration* Register_FLOOR_MOD(); +TfLiteRegistration* Register_RANGE(); +TfLiteRegistration* Register_LEAKY_RELU(); +TfLiteRegistration* Register_SQUARED_DIFFERENCE(); +TfLiteRegistration* Register_FILL(); +TfLiteRegistration* Register_MIRROR_PAD(); + +namespace { + +TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) { + context->ReportError( + context, + "Regular TensorFlow ops are not supported by this interpreter. Make sure " + "you invoke the Flex delegate before inference."); + return kTfLiteError; +} + +} // namespace + +const TfLiteRegistration* BuiltinRefOpResolver::FindOp( + tflite::BuiltinOperator op, int version) const { + return MutableOpResolver::FindOp(op, version); +} + +const TfLiteRegistration* BuiltinRefOpResolver::FindOp(const char* op, + int version) const { + // Return the NULL Op for all ops whose name start with "Flex", allowing + // the interpreter to delegate their execution. + if (IsFlexOp(op)) { + static TfLiteRegistration null_op{ + nullptr, nullptr, &UnsupportedTensorFlowOp, + nullptr, nullptr, BuiltinOperator_CUSTOM, + "Flex", 1}; + return &null_op; + } + return MutableOpResolver::FindOp(op, version); +} + +BuiltinRefOpResolver::BuiltinRefOpResolver() { + AddBuiltin(BuiltinOperator_ABS, Register_ABS()); + AddBuiltin(BuiltinOperator_RELU, Register_RELU()); + AddBuiltin(BuiltinOperator_RELU_N1_TO_1, Register_RELU_N1_TO_1()); + AddBuiltin(BuiltinOperator_RELU6, Register_RELU6()); + AddBuiltin(BuiltinOperator_TANH, Register_TANH()); + AddBuiltin(BuiltinOperator_LOGISTIC, Register_LOGISTIC()); + AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D, Register_AVERAGE_POOL_REF()); + AddBuiltin(BuiltinOperator_MAX_POOL_2D, Register_MAX_POOL_REF()); + AddBuiltin(BuiltinOperator_L2_POOL_2D, Register_L2_POOL_REF()); + AddBuiltin(BuiltinOperator_CONV_2D, Register_CONVOLUTION_REF()); + AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, + Register_DEPTHWISE_CONVOLUTION_REF(), + /* min_version */ 1, + /* max_version */ 2); + AddBuiltin(BuiltinOperator_SVDF, Register_SVDF()); + AddBuiltin(BuiltinOperator_RNN, Register_RNN()); + AddBuiltin(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN, + Register_BIDIRECTIONAL_SEQUENCE_RNN()); + AddBuiltin(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, + Register_UNIDIRECTIONAL_SEQUENCE_RNN()); + AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP, Register_EMBEDDING_LOOKUP()); + AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP_SPARSE, + Register_EMBEDDING_LOOKUP_SPARSE()); + AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Register_FULLY_CONNECTED_REF(), + /* min_version */ 1, + /* max_version */ 2); + AddBuiltin(BuiltinOperator_LSH_PROJECTION, Register_LSH_PROJECTION()); + AddBuiltin(BuiltinOperator_HASHTABLE_LOOKUP, Register_HASHTABLE_LOOKUP()); + AddBuiltin(BuiltinOperator_SOFTMAX, Register_SOFTMAX()); + AddBuiltin(BuiltinOperator_CONCATENATION, Register_CONCATENATION_REF()); + AddBuiltin(BuiltinOperator_ADD, Register_ADD_REF()); + AddBuiltin(BuiltinOperator_SPACE_TO_BATCH_ND, + Register_SPACE_TO_BATCH_ND_REF()); + AddBuiltin(BuiltinOperator_BATCH_TO_SPACE_ND, + Register_BATCH_TO_SPACE_ND_REF()); + AddBuiltin(BuiltinOperator_MUL, Register_MUL_REF()); + AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2NORM_REF()); + AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, + Register_LOCAL_RESPONSE_NORM_REF()); + AddBuiltin(BuiltinOperator_LSTM, Register_LSTM(), /* min_version */ 1, + /* max_version */ 2); + AddBuiltin(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, + Register_BIDIRECTIONAL_SEQUENCE_LSTM(), /* min_version */ 1, + /* max_version */ 2); + AddBuiltin(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, + Register_UNIDIRECTIONAL_SEQUENCE_LSTM()); + AddBuiltin(BuiltinOperator_PAD, Register_PAD_REF()); + AddBuiltin(BuiltinOperator_PADV2, Register_PADV2_REF()); + AddBuiltin(BuiltinOperator_RESHAPE, Register_RESHAPE()); + AddBuiltin(BuiltinOperator_RESIZE_BILINEAR, Register_RESIZE_BILINEAR_REF()); + AddBuiltin(BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, + Register_RESIZE_NEAREST_NEIGHBOR_REF()); + AddBuiltin(BuiltinOperator_SKIP_GRAM, Register_SKIP_GRAM()); + AddBuiltin(BuiltinOperator_SPACE_TO_DEPTH, Register_SPACE_TO_DEPTH_REF()); + AddBuiltin(BuiltinOperator_GATHER, Register_GATHER()); + AddBuiltin(BuiltinOperator_TRANSPOSE, Register_TRANSPOSE_REF()); + AddBuiltin(BuiltinOperator_MEAN, Register_MEAN_REF()); + AddBuiltin(BuiltinOperator_DIV, Register_DIV_REF()); + AddBuiltin(BuiltinOperator_SUB, Register_SUB_REF()); + AddBuiltin(BuiltinOperator_SPLIT, Register_SPLIT()); + AddBuiltin(BuiltinOperator_SPLIT_V, Register_SPLIT_V()); + AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE()); + AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE_REF()); + AddBuiltin(BuiltinOperator_EXP, Register_EXP()); + AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2()); + AddBuiltin(BuiltinOperator_LOG, Register_LOG()); + AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX()); + AddBuiltin(BuiltinOperator_CAST, Register_CAST()); + AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE(), + /* min_version */ 1, + /* max_version */ 2); + AddBuiltin(BuiltinOperator_PRELU, Register_PRELU()); + AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM()); + AddBuiltin(BuiltinOperator_MINIMUM, Register_MINIMUM()); + AddBuiltin(BuiltinOperator_ARG_MAX, Register_ARG_MAX()); + AddBuiltin(BuiltinOperator_ARG_MIN, Register_ARG_MIN()); + AddBuiltin(BuiltinOperator_GREATER, Register_GREATER()); + AddBuiltin(BuiltinOperator_GREATER_EQUAL, Register_GREATER_EQUAL()); + AddBuiltin(BuiltinOperator_LESS, Register_LESS()); + AddBuiltin(BuiltinOperator_LESS_EQUAL, Register_LESS_EQUAL()); + AddBuiltin(BuiltinOperator_FLOOR, Register_FLOOR()); + AddBuiltin(BuiltinOperator_NEG, Register_NEG()); + AddBuiltin(BuiltinOperator_SELECT, Register_SELECT()); + AddBuiltin(BuiltinOperator_SLICE, Register_SLICE()); + AddBuiltin(BuiltinOperator_SIN, Register_SIN()); + AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSECONV_REF()); + AddBuiltin(BuiltinOperator_TILE, Register_TILE()); + AddBuiltin(BuiltinOperator_SUM, Register_SUM()); + AddBuiltin(BuiltinOperator_REDUCE_PROD, Register_REDUCE_PROD()); + AddBuiltin(BuiltinOperator_REDUCE_MAX, Register_REDUCE_MAX()); + AddBuiltin(BuiltinOperator_REDUCE_MIN, Register_REDUCE_MIN()); + AddBuiltin(BuiltinOperator_REDUCE_ANY, Register_REDUCE_ANY()); + AddBuiltin(BuiltinOperator_EXPAND_DIMS, Register_EXPAND_DIMS()); + AddBuiltin(BuiltinOperator_SPARSE_TO_DENSE, Register_SPARSE_TO_DENSE()); + AddBuiltin(BuiltinOperator_EQUAL, Register_EQUAL()); + AddBuiltin(BuiltinOperator_NOT_EQUAL, Register_NOT_EQUAL()); + AddBuiltin(BuiltinOperator_SQRT, Register_SQRT()); + AddBuiltin(BuiltinOperator_RSQRT, Register_RSQRT()); + AddBuiltin(BuiltinOperator_SHAPE, Register_SHAPE()); + AddBuiltin(BuiltinOperator_POW, Register_POW()); + AddBuiltin(BuiltinOperator_FAKE_QUANT, Register_FAKE_QUANT(), 1, 2); + AddBuiltin(BuiltinOperator_PACK, Register_PACK()); + AddBuiltin(BuiltinOperator_ONE_HOT, Register_ONE_HOT()); + AddBuiltin(BuiltinOperator_LOGICAL_OR, Register_LOGICAL_OR()); + AddBuiltin(BuiltinOperator_LOGICAL_AND, Register_LOGICAL_AND()); + AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT()); + AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK()); + AddBuiltin(BuiltinOperator_FLOOR_DIV, Register_FLOOR_DIV()); + AddBuiltin(BuiltinOperator_SQUARE, Register_SQUARE()); + AddBuiltin(BuiltinOperator_ZEROS_LIKE, Register_ZEROS_LIKE()); + AddBuiltin(BuiltinOperator_FLOOR_MOD, Register_FLOOR_MOD()); + AddBuiltin(BuiltinOperator_RANGE, Register_RANGE()); + AddBuiltin(BuiltinOperator_LEAKY_RELU, Register_LEAKY_RELU()); + AddBuiltin(BuiltinOperator_SQUARED_DIFFERENCE, Register_SQUARED_DIFFERENCE()); + AddBuiltin(BuiltinOperator_FILL, Register_FILL()); + AddBuiltin(BuiltinOperator_MIRROR_PAD, Register_MIRROR_PAD()); + + // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that + // custom ops aren't always included by default. + AddCustom("Mfcc", tflite::ops::custom::Register_MFCC()); + AddCustom("AudioSpectrogram", + tflite::ops::custom::Register_AUDIO_SPECTROGRAM()); + AddCustom("LayerNormLstm", tflite::ops::custom::Register_LAYER_NORM_LSTM()); + AddCustom("Relu1", tflite::ops::custom::Register_RELU_1()); + AddCustom("TFLite_Detection_PostProcess", + tflite::ops::custom::Register_DETECTION_POSTPROCESS()); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/kernels/register_ref.h b/tensorflow/lite/kernels/register_ref.h new file mode 100644 index 00000000000..c66d4a25bc4 --- /dev/null +++ b/tensorflow/lite/kernels/register_ref.h @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_REGISTER_REF_H_ +#define TENSORFLOW_LITE_KERNELS_REGISTER_REF_H_ + +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/mutable_op_resolver.h" + +namespace tflite { +namespace ops { +namespace builtin { + +class BuiltinRefOpResolver : public MutableOpResolver { + public: + BuiltinRefOpResolver(); + + const TfLiteRegistration* FindOp(tflite::BuiltinOperator op, + int version) const override; + const TfLiteRegistration* FindOp(const char* op, int version) const override; +}; + +} // namespace builtin +} // namespace ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_KERNELS_REGISTER_REF_H_ diff --git a/tensorflow/lite/testing/BUILD b/tensorflow/lite/testing/BUILD index 22ffed43cc0..fa25cfaa69e 100644 --- a/tensorflow/lite/testing/BUILD +++ b/tensorflow/lite/testing/BUILD @@ -165,6 +165,7 @@ cc_library( "//tensorflow/lite:string_util", "//tensorflow/lite/delegates/flex:delegate", "//tensorflow/lite/kernels:builtin_ops", + "//tensorflow/lite/kernels:reference_ops", "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/lite/testing/tflite_driver.cc b/tensorflow/lite/testing/tflite_driver.cc index 4e11d49f252..ffe296432a4 100644 --- a/tensorflow/lite/testing/tflite_driver.cc +++ b/tensorflow/lite/testing/tflite_driver.cc @@ -19,6 +19,8 @@ limitations under the License. #include "absl/strings/escaping.h" #include "tensorflow/lite/builtin_op_data.h" #include "tensorflow/lite/delegates/flex/delegate.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/register_ref.h" #include "tensorflow/lite/string_util.h" #include "tensorflow/lite/testing/split.h" @@ -188,8 +190,15 @@ class TfLiteDriver::Expectation { size_t num_elements_; }; -TfLiteDriver::TfLiteDriver(bool use_nnapi, const string& delegate_name) +TfLiteDriver::TfLiteDriver(bool use_nnapi, const string& delegate_name, + bool reference_kernel) : use_nnapi_(use_nnapi) { + if (reference_kernel) { + resolver_.reset(new ops::builtin::BuiltinRefOpResolver); + } else { + resolver_.reset(new ops::builtin::BuiltinOpResolver); + } + if (delegate_name == "FLEX") { delegate_ = FlexDelegate::Create(); } @@ -221,8 +230,7 @@ void TfLiteDriver::LoadModel(const string& bin_file_path) { Invalidate("Failed to mmap model " + bin_file_path); return; } - ops::builtin::BuiltinOpResolver builtins; - InterpreterBuilder(*model_, builtins)(&interpreter_); + InterpreterBuilder(*model_, *resolver_)(&interpreter_); if (!interpreter_) { Invalidate("Failed build interpreter"); return; diff --git a/tensorflow/lite/testing/tflite_driver.h b/tensorflow/lite/testing/tflite_driver.h index 1da0533c57c..537f20dfbfd 100644 --- a/tensorflow/lite/testing/tflite_driver.h +++ b/tensorflow/lite/testing/tflite_driver.h @@ -16,10 +16,12 @@ limitations under the License. #define TENSORFLOW_LITE_TESTING_TFLITE_DRIVER_H_ #include +#include #include "tensorflow/lite/delegates/flex/delegate.h" #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/register_ref.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/testing/test_runner.h" @@ -29,7 +31,8 @@ namespace testing { // A test runner that feeds inputs into TF Lite and verifies its outputs. class TfLiteDriver : public TestRunner { public: - explicit TfLiteDriver(bool use_nnapi, const string& delegate = ""); + explicit TfLiteDriver(bool use_nnapi, const string& delegate = "", + bool reference_kernel = false); ~TfLiteDriver() override; void LoadModel(const string& bin_file_path) override; @@ -65,6 +68,7 @@ class TfLiteDriver : public TestRunner { class Expectation; + std::unique_ptr resolver_; std::unique_ptr delegate_; bool use_nnapi_ = false; std::unique_ptr model_; diff --git a/tensorflow/lite/testing/tflite_driver_test.cc b/tensorflow/lite/testing/tflite_driver_test.cc index 6e953e5e19b..81bf6700cb8 100644 --- a/tensorflow/lite/testing/tflite_driver_test.cc +++ b/tensorflow/lite/testing/tflite_driver_test.cc @@ -56,6 +56,40 @@ TEST(TfliteDriverTest, SimpleTest) { ASSERT_TRUE(runner->CheckResults()); } +TEST(TfliteDriverTest, SingleAddOpTest) { + std::unique_ptr runner(new TfLiteDriver( + /*use_nnapi*/ false, /*delegate*/ "", /*reference_kernel*/ true)); + + runner->SetModelBaseDir("tensorflow/lite"); + runner->LoadModel("testdata/multi_add.bin"); + ASSERT_TRUE(runner->IsValid()); + + ASSERT_THAT(runner->GetInputs(), ElementsAre(0, 1, 2, 3)); + ASSERT_THAT(runner->GetOutputs(), ElementsAre(5, 6)); + + for (int i : {0, 1, 2, 3}) { + runner->ReshapeTensor(i, "1,2,2,1"); + } + ASSERT_TRUE(runner->IsValid()); + + runner->AllocateTensors(); + + runner->SetInput(0, "0.1,0.2,0.3,0.4"); + runner->SetInput(1, "0.001,0.002,0.003,0.004"); + runner->SetInput(2, "0.001,0.002,0.003,0.004"); + runner->SetInput(3, "0.01,0.02,0.03,0.04"); + + runner->ResetTensor(2); + + runner->SetExpectation(5, "0.101,0.202,0.303,0.404"); + runner->SetExpectation(6, "0.011,0.022,0.033,0.044"); + + runner->Invoke(); + ASSERT_TRUE(runner->IsValid()); + + ASSERT_TRUE(runner->CheckResults()); +} + } // namespace } // namespace testing } // namespace tflite