From f38c4d866b7a1a1ac0c81334f2c4b421c4122731 Mon Sep 17 00:00:00 2001
From: Taehee Jeong <taeheej@google.com>
Date: Mon, 28 Dec 2020 18:25:05 -0800
Subject: [PATCH] Add reference_tflite_model option to DiffOptions

This makes comparison between two tflite models possible.

PiperOrigin-RevId: 349357549
Change-Id: Ibeb8e206be2a997be8974b4ec51eee21bcc45a24
---
 tensorflow/lite/testing/tflite_diff_flags.h | 8 +++++++-
 tensorflow/lite/testing/tflite_diff_util.cc | 5 ++++-
 tensorflow/lite/testing/tflite_diff_util.h  | 2 ++
 3 files changed, 13 insertions(+), 2 deletions(-)

diff --git a/tensorflow/lite/testing/tflite_diff_flags.h b/tensorflow/lite/testing/tflite_diff_flags.h
index 7022cb03ad1..e94f2aae45f 100644
--- a/tensorflow/lite/testing/tflite_diff_flags.h
+++ b/tensorflow/lite/testing/tflite_diff_flags.h
@@ -36,6 +36,7 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
     string output_layer;
     int32_t num_runs_per_pass = 100;
     string delegate_name;
+    string reference_tflite_model;
   } values;
 
   std::string delegate_name;
@@ -61,6 +62,10 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
       tensorflow::Flag("delegate", &values.delegate_name,
                        "[optional] Delegate to use for executing ops. Must be "
                        "`{\"\", NNAPI, GPU, FLEX}`"),
+      tensorflow::Flag("reference_tflite_model", &values.reference_tflite_model,
+                       "[optional] Path of the TensorFlow Lite model to "
+                       "compare inference results against the model given in "
+                       "`tflite_model`."),
   };
 
   bool no_inputs = *argc == 1;
@@ -96,7 +101,8 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
           Split<string>(values.input_layer_shape, ":"),
           Split<string>(values.output_layer, ","),
           values.num_runs_per_pass,
-          delegate};
+          delegate,
+          values.reference_tflite_model};
 }
 
 }  // namespace testing
diff --git a/tensorflow/lite/testing/tflite_diff_util.cc b/tensorflow/lite/testing/tflite_diff_util.cc
index 2e628fd710d..628233aeb1d 100644
--- a/tensorflow/lite/testing/tflite_diff_util.cc
+++ b/tensorflow/lite/testing/tflite_diff_util.cc
@@ -30,8 +30,11 @@ bool SingleRunDiffTestWithProvidedRunner(::tflite::testing::DiffOptions options,
                                          int num_invocations,
                                          TestRunner* (*runner_factory)()) {
   std::stringstream tflite_stream;
+  std::string reference_tflite_model = options.reference_tflite_model.empty()
+                                           ? options.tflite_model
+                                           : options.reference_tflite_model;
   if (!GenerateTestSpecFromTFLiteModel(
-          tflite_stream, options.tflite_model, num_invocations,
+          tflite_stream, reference_tflite_model, num_invocations,
           options.input_layer, options.input_layer_type,
           options.input_layer_shape, options.output_layer)) {
     return false;
diff --git a/tensorflow/lite/testing/tflite_diff_util.h b/tensorflow/lite/testing/tflite_diff_util.h
index 3cf4342b810..16ca24a7539 100644
--- a/tensorflow/lite/testing/tflite_diff_util.h
+++ b/tensorflow/lite/testing/tflite_diff_util.h
@@ -47,6 +47,8 @@ struct DiffOptions {
   int num_runs_per_pass;
   // The type of delegate to apply during inference.
   TfLiteDriver::DelegateType delegate;
+  // Path of tflite model used to generate golden values.
+  std::string reference_tflite_model = "";
 };
 
 // Run a single TensorFLow Lite diff test with a given options.