diff --git a/tensorflow/lite/testing/BUILD b/tensorflow/lite/testing/BUILD index 2e8d56d6a2f..a42228313f9 100644 --- a/tensorflow/lite/testing/BUILD +++ b/tensorflow/lite/testing/BUILD @@ -219,16 +219,18 @@ cc_library( ":join", ":split", ":test_runner", + "@com_google_absl//absl/strings", "//tensorflow/lite:builtin_op_data", "//tensorflow/lite:framework", "//tensorflow/lite:string_util", - "//tensorflow/lite/delegates/flex:delegate", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/kernels:custom_ops", "//tensorflow/lite/kernels:reference_ops", "//tensorflow/lite/tools/evaluation:utils", - "@com_google_absl//absl/strings", - ], + ] + select({ + "//tensorflow:ios": [], + "//conditions:default": ["//tensorflow/lite/delegates/flex:delegate"], + }), ) tf_cc_test( @@ -355,6 +357,7 @@ cc_library( ":join", ":split", ":tf_driver", + ":tflite_driver", "//tensorflow/lite:string", ] + select({ "//conditions:default": [ @@ -403,6 +406,9 @@ cc_library( "//tensorflow:android": [ "//tensorflow/core:android_tensorflow_lib", ], + "//tensorflow:ios": [ + "//tensorflow/core:ios_tensorflow_lib", + ], }), ) @@ -435,6 +441,9 @@ cc_library( "//tensorflow:android": [ "//tensorflow/core:android_tensorflow_lib", ], + "//tensorflow:ios": [ + "//tensorflow/core:ios_tensorflow_lib", + ], }), ) diff --git a/tensorflow/lite/testing/generate_testspec.cc b/tensorflow/lite/testing/generate_testspec.cc index 99021c9f317..e7435e19f49 100644 --- a/tensorflow/lite/testing/generate_testspec.cc +++ b/tensorflow/lite/testing/generate_testspec.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/lite/testing/join.h" #include "tensorflow/lite/testing/split.h" #include "tensorflow/lite/testing/tf_driver.h" +#include "tensorflow/lite/testing/tflite_driver.h" namespace tflite { namespace testing { @@ -83,6 +84,68 @@ std::vector GenerateInputValues( return input_values; } +bool GenerateTestSpecFromRunner(std::iostream& stream, int num_invocations, + const std::vector& input_layer, + const std::vector& input_layer_type, + const std::vector& input_layer_shape, + const std::vector& output_layer, + TestRunner* runner) { + stream << "reshape {\n"; + for (const auto& shape : input_layer_shape) { + stream << " input: \"" << shape << "\"\n"; + } + stream << "}\n"; + + // Generate inputs. + std::mt19937 random_engine; + for (int i = 0; i < num_invocations; ++i) { + // Note that the input values are random, so each invocation will have a + // different set. + std::vector input_values = GenerateInputValues( + &random_engine, input_layer, input_layer_type, input_layer_shape); + if (input_values.empty()) { + std::cerr << "Unable to generate input values for the TensorFlow model. " + "Make sure the correct values are defined for " + "input_layer, input_layer_type, and input_layer_shape." + << std::endl; + return false; + } + + // Run TensorFlow. + auto inputs = runner->GetInputs(); + for (int j = 0; j < input_values.size(); j++) { + runner->SetInput(inputs[j], input_values[j]); + if (!runner->IsValid()) { + std::cerr << runner->GetErrorMessage() << std::endl; + return false; + } + } + + runner->Invoke(); + if (!runner->IsValid()) { + std::cerr << runner->GetErrorMessage() << std::endl; + return false; + } + + // Write second part of test spec, with inputs and outputs. + stream << "invoke {\n"; + for (const auto& value : input_values) { + stream << " input: \"" << value << "\"\n"; + } + auto outputs = runner->GetOutputs(); + for (int j = 0; j < output_layer.size(); j++) { + stream << " output: \"" << runner->ReadOutput(outputs[j]) << "\"\n"; + if (!runner->IsValid()) { + std::cerr << runner->GetErrorMessage() << std::endl; + return false; + } + } + stream << "}\n"; + } + + return true; +} + } // namespace bool GenerateTestSpecFromTensorflowModel( @@ -108,61 +171,29 @@ bool GenerateTestSpecFromTensorflowModel( std::cerr << runner.GetErrorMessage() << std::endl; return false; } - // Write first part of test spec, defining model and input shapes. stream << "load_model: " << tflite_model_path << "\n"; - stream << "reshape {\n"; - for (const auto& shape : input_layer_shape) { - stream << " input: \"" << shape << "\"\n"; + return GenerateTestSpecFromRunner(stream, num_invocations, input_layer, + input_layer_type, input_layer_shape, + output_layer, &runner); +} + +bool GenerateTestSpecFromTFLiteModel( + std::iostream& stream, const string& tflite_model_path, int num_invocations, + const std::vector& input_layer, + const std::vector& input_layer_type, + const std::vector& input_layer_shape, + const std::vector& output_layer) { + TfLiteDriver runner; + runner.LoadModel(tflite_model_path); + if (!runner.IsValid()) { + std::cerr << runner.GetErrorMessage() << std::endl; + return false; } - stream << "}\n"; - - // Generate inputs. - std::mt19937 random_engine; - for (int i = 0; i < num_invocations; ++i) { - // Note that the input values are random, so each invocation will have a - // different set. - std::vector input_values = GenerateInputValues( - &random_engine, input_layer, input_layer_type, input_layer_shape); - if (input_values.empty()) { - std::cerr << "Unable to generate input values for the TensorFlow model. " - "Make sure the correct values are defined for " - "input_layer, input_layer_type, and input_layer_shape." - << std::endl; - return false; - } - - // Run TensorFlow. - for (int j = 0; j < input_values.size(); j++) { - runner.SetInput(j, input_values[j]); - if (!runner.IsValid()) { - std::cerr << runner.GetErrorMessage() << std::endl; - return false; - } - } - - runner.Invoke(); - if (!runner.IsValid()) { - std::cerr << runner.GetErrorMessage() << std::endl; - return false; - } - - // Write second part of test spec, with inputs and outputs. - stream << "invoke {\n"; - for (const auto& value : input_values) { - stream << " input: \"" << value << "\"\n"; - } - for (int j = 0; j < output_layer.size(); j++) { - stream << " output: \"" << runner.ReadOutput(j) << "\"\n"; - if (!runner.IsValid()) { - std::cerr << runner.GetErrorMessage() << std::endl; - return false; - } - } - stream << "}\n"; - } - - return true; + runner.AllocateTensors(); + return GenerateTestSpecFromRunner(stream, num_invocations, input_layer, + input_layer_type, input_layer_shape, + output_layer, &runner); } } // namespace testing diff --git a/tensorflow/lite/testing/generate_testspec.h b/tensorflow/lite/testing/generate_testspec.h index 58f8065972b..79d0114ce8e 100644 --- a/tensorflow/lite/testing/generate_testspec.h +++ b/tensorflow/lite/testing/generate_testspec.h @@ -46,6 +46,14 @@ bool GenerateTestSpecFromTensorflowModel( const std::vector& input_layer_shape, const std::vector& output_layer); +// Generate test spec by executing TFLite model on random inputs. +bool GenerateTestSpecFromTFLiteModel( + std::iostream& stream, const string& tflite_model_path, int num_invocations, + const std::vector& input_layer, + const std::vector& input_layer_type, + const std::vector& input_layer_shape, + const std::vector& output_layer); + // Generates random values that are filled into the tensor. template std::vector GenerateRandomTensor(const std::vector& shape, diff --git a/tensorflow/lite/testing/tflite_driver.cc b/tensorflow/lite/testing/tflite_driver.cc index 3d988eb624a..9aeba87bbea 100644 --- a/tensorflow/lite/testing/tflite_driver.cc +++ b/tensorflow/lite/testing/tflite_driver.cc @@ -18,9 +18,12 @@ limitations under the License. #include #include #include + #include "absl/strings/escaping.h" #include "tensorflow/lite/builtin_op_data.h" +#if !defined(__APPLE__) #include "tensorflow/lite/delegates/flex/delegate.h" +#endif #include "tensorflow/lite/kernels/custom_ops_register.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/kernels/register_ref.h" @@ -331,10 +334,12 @@ TfLiteDriver::TfLiteDriver(DelegateType delegate_type, bool reference_kernel) delegate_ = evaluation::CreateGPUDelegate(/*model=*/nullptr); break; case DelegateType::kFlex: +#if !defined(__APPLE__) delegate_ = Interpreter::TfLiteDelegatePtr( FlexDelegate::Create().release(), [](TfLiteDelegate* delegate) { delete static_cast(delegate); }); +#endif break; } } diff --git a/tensorflow/lite/testing/tflite_driver.h b/tensorflow/lite/testing/tflite_driver.h index 258902606a5..bce3e9c4c01 100644 --- a/tensorflow/lite/testing/tflite_driver.h +++ b/tensorflow/lite/testing/tflite_driver.h @@ -18,7 +18,9 @@ limitations under the License. #include #include +#if !defined(__APPLE__) #include "tensorflow/lite/delegates/flex/delegate.h" +#endif #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/kernels/register_ref.h"