From 8abcbc77a0bc2af38b1f2b0b270504cf1bc9a225 Mon Sep 17 00:00:00 2001
From: Karim Nosir <karimnosseir@google.com>
Date: Mon, 16 Dec 2019 18:17:30 -0800
Subject: [PATCH] Add new methods for diff tests that compares TFLite vs any
 custom TestRunner

PiperOrigin-RevId: 285891574
Change-Id: I67da00e81f6ca60e939b5f77502be1249d59a930
---
 tensorflow/lite/testing/tflite_diff_util.cc | 54 ++++++++++++++++++++-
 tensorflow/lite/testing/tflite_diff_util.h  |  8 +++
 2 files changed, 60 insertions(+), 2 deletions(-)

diff --git a/tensorflow/lite/testing/tflite_diff_util.cc b/tensorflow/lite/testing/tflite_diff_util.cc
index 721830adc4d..2e628fd710d 100644
--- a/tensorflow/lite/testing/tflite_diff_util.cc
+++ b/tensorflow/lite/testing/tflite_diff_util.cc
@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
+#include "tensorflow/lite/testing/tflite_diff_util.h"
+
 #include <cstdarg>
 #include <cstdio>
 #include <cstdlib>
@@ -19,11 +21,27 @@ limitations under the License.
 
 #include "tensorflow/lite/testing/generate_testspec.h"
 #include "tensorflow/lite/testing/parse_testdata.h"
-#include "tensorflow/lite/testing/tflite_diff_util.h"
 #include "tensorflow/lite/testing/tflite_driver.h"
 
 namespace tflite {
 namespace testing {
+namespace {
+bool SingleRunDiffTestWithProvidedRunner(::tflite::testing::DiffOptions options,
+                                         int num_invocations,
+                                         TestRunner* (*runner_factory)()) {
+  std::stringstream tflite_stream;
+  if (!GenerateTestSpecFromTFLiteModel(
+          tflite_stream, options.tflite_model, num_invocations,
+          options.input_layer, options.input_layer_type,
+          options.input_layer_shape, options.output_layer)) {
+    return false;
+  }
+
+  std::unique_ptr<TestRunner> runner(runner_factory());
+  runner->LoadModel(options.tflite_model);
+  return ParseAndRunTests(&tflite_stream, runner.get());
+}
+}  // namespace
 
 bool RunDiffTest(const DiffOptions& options, int num_invocations) {
   std::stringstream tflite_stream;
@@ -35,7 +53,39 @@ bool RunDiffTest(const DiffOptions& options, int num_invocations) {
   }
   TfLiteDriver tflite_driver(options.delegate);
   tflite_driver.LoadModel(options.tflite_model);
-  return tflite::testing::ParseAndRunTests(&tflite_stream, &tflite_driver);
+  return ParseAndRunTests(&tflite_stream, &tflite_driver);
+}
+
+bool RunDiffTestWithProvidedRunner(const tflite::testing::DiffOptions& options,
+                                   TestRunner* (*runner_factory)()) {
+  int failure_count = 0;
+  for (int i = 0; i < options.num_runs_per_pass; i++) {
+    if (!SingleRunDiffTestWithProvidedRunner(options,
+                                             /*num_invocations=*/1,
+                                             runner_factory)) {
+      ++failure_count;
+    }
+  }
+  int failures_in_first_pass = failure_count;
+
+  if (failure_count == 0) {
+    // Let's try again with num_invocations > 1 to make sure we can do multiple
+    // invocations without resetting the interpreter.
+    for (int i = 0; i < options.num_runs_per_pass; i++) {
+      if (!SingleRunDiffTestWithProvidedRunner(options,
+                                               /*num_invocations=*/2,
+                                               runner_factory)) {
+        ++failure_count;
+      }
+    }
+  }
+
+  fprintf(stderr, "Num errors in single-inference pass: %d\n",
+          failures_in_first_pass);
+  fprintf(stderr, "Num errors in multi-inference pass : %d\n",
+          failure_count - failures_in_first_pass);
+
+  return failure_count == 0;
 }
 }  // namespace testing
 
diff --git a/tensorflow/lite/testing/tflite_diff_util.h b/tensorflow/lite/testing/tflite_diff_util.h
index 362bc64a6bc..3cf4342b810 100644
--- a/tensorflow/lite/testing/tflite_diff_util.h
+++ b/tensorflow/lite/testing/tflite_diff_util.h
@@ -52,6 +52,14 @@ struct DiffOptions {
 // Run a single TensorFLow Lite diff test with a given options.
 bool RunDiffTest(const DiffOptions& options, int num_invocations);
 
+// Runs diff test for custom TestRunner identified by the factory methiodd
+// 'runner_factory' against TFLite CPU given 'options' 'runner_factory' should
+// return instance of TestRunner, caller will take ownership of the returned
+// object.
+// Function returns True if test pass, false otherwise.
+bool RunDiffTestWithProvidedRunner(const tflite::testing::DiffOptions& options,
+                                   TestRunner* (*runner_factory)());
+
 }  // namespace testing
 }  // namespace tflite