From 8abcbc77a0bc2af38b1f2b0b270504cf1bc9a225 Mon Sep 17 00:00:00 2001 From: Karim Nosir <karimnosseir@google.com> Date: Mon, 16 Dec 2019 18:17:30 -0800 Subject: [PATCH] Add new methods for diff tests that compares TFLite vs any custom TestRunner PiperOrigin-RevId: 285891574 Change-Id: I67da00e81f6ca60e939b5f77502be1249d59a930 --- tensorflow/lite/testing/tflite_diff_util.cc | 54 ++++++++++++++++++++- tensorflow/lite/testing/tflite_diff_util.h | 8 +++ 2 files changed, 60 insertions(+), 2 deletions(-) diff --git a/tensorflow/lite/testing/tflite_diff_util.cc b/tensorflow/lite/testing/tflite_diff_util.cc index 721830adc4d..2e628fd710d 100644 --- a/tensorflow/lite/testing/tflite_diff_util.cc +++ b/tensorflow/lite/testing/tflite_diff_util.cc @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/lite/testing/tflite_diff_util.h" + #include <cstdarg> #include <cstdio> #include <cstdlib> @@ -19,11 +21,27 @@ limitations under the License. #include "tensorflow/lite/testing/generate_testspec.h" #include "tensorflow/lite/testing/parse_testdata.h" -#include "tensorflow/lite/testing/tflite_diff_util.h" #include "tensorflow/lite/testing/tflite_driver.h" namespace tflite { namespace testing { +namespace { +bool SingleRunDiffTestWithProvidedRunner(::tflite::testing::DiffOptions options, + int num_invocations, + TestRunner* (*runner_factory)()) { + std::stringstream tflite_stream; + if (!GenerateTestSpecFromTFLiteModel( + tflite_stream, options.tflite_model, num_invocations, + options.input_layer, options.input_layer_type, + options.input_layer_shape, options.output_layer)) { + return false; + } + + std::unique_ptr<TestRunner> runner(runner_factory()); + runner->LoadModel(options.tflite_model); + return ParseAndRunTests(&tflite_stream, runner.get()); +} +} // namespace bool RunDiffTest(const DiffOptions& options, int num_invocations) { std::stringstream tflite_stream; @@ -35,7 +53,39 @@ bool RunDiffTest(const DiffOptions& options, int num_invocations) { } TfLiteDriver tflite_driver(options.delegate); tflite_driver.LoadModel(options.tflite_model); - return tflite::testing::ParseAndRunTests(&tflite_stream, &tflite_driver); + return ParseAndRunTests(&tflite_stream, &tflite_driver); +} + +bool RunDiffTestWithProvidedRunner(const tflite::testing::DiffOptions& options, + TestRunner* (*runner_factory)()) { + int failure_count = 0; + for (int i = 0; i < options.num_runs_per_pass; i++) { + if (!SingleRunDiffTestWithProvidedRunner(options, + /*num_invocations=*/1, + runner_factory)) { + ++failure_count; + } + } + int failures_in_first_pass = failure_count; + + if (failure_count == 0) { + // Let's try again with num_invocations > 1 to make sure we can do multiple + // invocations without resetting the interpreter. + for (int i = 0; i < options.num_runs_per_pass; i++) { + if (!SingleRunDiffTestWithProvidedRunner(options, + /*num_invocations=*/2, + runner_factory)) { + ++failure_count; + } + } + } + + fprintf(stderr, "Num errors in single-inference pass: %d\n", + failures_in_first_pass); + fprintf(stderr, "Num errors in multi-inference pass : %d\n", + failure_count - failures_in_first_pass); + + return failure_count == 0; } } // namespace testing diff --git a/tensorflow/lite/testing/tflite_diff_util.h b/tensorflow/lite/testing/tflite_diff_util.h index 362bc64a6bc..3cf4342b810 100644 --- a/tensorflow/lite/testing/tflite_diff_util.h +++ b/tensorflow/lite/testing/tflite_diff_util.h @@ -52,6 +52,14 @@ struct DiffOptions { // Run a single TensorFLow Lite diff test with a given options. bool RunDiffTest(const DiffOptions& options, int num_invocations); +// Runs diff test for custom TestRunner identified by the factory methiodd +// 'runner_factory' against TFLite CPU given 'options' 'runner_factory' should +// return instance of TestRunner, caller will take ownership of the returned +// object. +// Function returns True if test pass, false otherwise. +bool RunDiffTestWithProvidedRunner(const tflite::testing::DiffOptions& options, + TestRunner* (*runner_factory)()); + } // namespace testing } // namespace tflite