From f38c4d866b7a1a1ac0c81334f2c4b421c4122731 Mon Sep 17 00:00:00 2001 From: Taehee Jeong <taeheej@google.com> Date: Mon, 28 Dec 2020 18:25:05 -0800 Subject: [PATCH] Add reference_tflite_model option to DiffOptions This makes comparison between two tflite models possible. PiperOrigin-RevId: 349357549 Change-Id: Ibeb8e206be2a997be8974b4ec51eee21bcc45a24 --- tensorflow/lite/testing/tflite_diff_flags.h | 8 +++++++- tensorflow/lite/testing/tflite_diff_util.cc | 5 ++++- tensorflow/lite/testing/tflite_diff_util.h | 2 ++ 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/tensorflow/lite/testing/tflite_diff_flags.h b/tensorflow/lite/testing/tflite_diff_flags.h index 7022cb03ad1..e94f2aae45f 100644 --- a/tensorflow/lite/testing/tflite_diff_flags.h +++ b/tensorflow/lite/testing/tflite_diff_flags.h @@ -36,6 +36,7 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) { string output_layer; int32_t num_runs_per_pass = 100; string delegate_name; + string reference_tflite_model; } values; std::string delegate_name; @@ -61,6 +62,10 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) { tensorflow::Flag("delegate", &values.delegate_name, "[optional] Delegate to use for executing ops. Must be " "`{\"\", NNAPI, GPU, FLEX}`"), + tensorflow::Flag("reference_tflite_model", &values.reference_tflite_model, + "[optional] Path of the TensorFlow Lite model to " + "compare inference results against the model given in " + "`tflite_model`."), }; bool no_inputs = *argc == 1; @@ -96,7 +101,8 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) { Split<string>(values.input_layer_shape, ":"), Split<string>(values.output_layer, ","), values.num_runs_per_pass, - delegate}; + delegate, + values.reference_tflite_model}; } } // namespace testing diff --git a/tensorflow/lite/testing/tflite_diff_util.cc b/tensorflow/lite/testing/tflite_diff_util.cc index 2e628fd710d..628233aeb1d 100644 --- a/tensorflow/lite/testing/tflite_diff_util.cc +++ b/tensorflow/lite/testing/tflite_diff_util.cc @@ -30,8 +30,11 @@ bool SingleRunDiffTestWithProvidedRunner(::tflite::testing::DiffOptions options, int num_invocations, TestRunner* (*runner_factory)()) { std::stringstream tflite_stream; + std::string reference_tflite_model = options.reference_tflite_model.empty() + ? options.tflite_model + : options.reference_tflite_model; if (!GenerateTestSpecFromTFLiteModel( - tflite_stream, options.tflite_model, num_invocations, + tflite_stream, reference_tflite_model, num_invocations, options.input_layer, options.input_layer_type, options.input_layer_shape, options.output_layer)) { return false; diff --git a/tensorflow/lite/testing/tflite_diff_util.h b/tensorflow/lite/testing/tflite_diff_util.h index 3cf4342b810..16ca24a7539 100644 --- a/tensorflow/lite/testing/tflite_diff_util.h +++ b/tensorflow/lite/testing/tflite_diff_util.h @@ -47,6 +47,8 @@ struct DiffOptions { int num_runs_per_pass; // The type of delegate to apply during inference. TfLiteDriver::DelegateType delegate; + // Path of tflite model used to generate golden values. + std::string reference_tflite_model = ""; }; // Run a single TensorFLow Lite diff test with a given options.