Add new methods for diff tests that compares TFLite vs any custom TestRunner

PiperOrigin-RevId: 285891574
Change-Id: I67da00e81f6ca60e939b5f77502be1249d59a930
This commit is contained in:
Karim Nosir 2019-12-16 18:17:30 -08:00 committed by TensorFlower Gardener
parent 10c882dfbd
commit 8abcbc77a0
2 changed files with 60 additions and 2 deletions

View File

@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/testing/tflite_diff_util.h"
#include <cstdarg>
#include <cstdio>
#include <cstdlib>
@ -19,11 +21,27 @@ limitations under the License.
#include "tensorflow/lite/testing/generate_testspec.h"
#include "tensorflow/lite/testing/parse_testdata.h"
#include "tensorflow/lite/testing/tflite_diff_util.h"
#include "tensorflow/lite/testing/tflite_driver.h"
namespace tflite {
namespace testing {
namespace {
bool SingleRunDiffTestWithProvidedRunner(::tflite::testing::DiffOptions options,
int num_invocations,
TestRunner* (*runner_factory)()) {
std::stringstream tflite_stream;
if (!GenerateTestSpecFromTFLiteModel(
tflite_stream, options.tflite_model, num_invocations,
options.input_layer, options.input_layer_type,
options.input_layer_shape, options.output_layer)) {
return false;
}
std::unique_ptr<TestRunner> runner(runner_factory());
runner->LoadModel(options.tflite_model);
return ParseAndRunTests(&tflite_stream, runner.get());
}
} // namespace
bool RunDiffTest(const DiffOptions& options, int num_invocations) {
std::stringstream tflite_stream;
@ -35,7 +53,39 @@ bool RunDiffTest(const DiffOptions& options, int num_invocations) {
}
TfLiteDriver tflite_driver(options.delegate);
tflite_driver.LoadModel(options.tflite_model);
return tflite::testing::ParseAndRunTests(&tflite_stream, &tflite_driver);
return ParseAndRunTests(&tflite_stream, &tflite_driver);
}
bool RunDiffTestWithProvidedRunner(const tflite::testing::DiffOptions& options,
TestRunner* (*runner_factory)()) {
int failure_count = 0;
for (int i = 0; i < options.num_runs_per_pass; i++) {
if (!SingleRunDiffTestWithProvidedRunner(options,
/*num_invocations=*/1,
runner_factory)) {
++failure_count;
}
}
int failures_in_first_pass = failure_count;
if (failure_count == 0) {
// Let's try again with num_invocations > 1 to make sure we can do multiple
// invocations without resetting the interpreter.
for (int i = 0; i < options.num_runs_per_pass; i++) {
if (!SingleRunDiffTestWithProvidedRunner(options,
/*num_invocations=*/2,
runner_factory)) {
++failure_count;
}
}
}
fprintf(stderr, "Num errors in single-inference pass: %d\n",
failures_in_first_pass);
fprintf(stderr, "Num errors in multi-inference pass : %d\n",
failure_count - failures_in_first_pass);
return failure_count == 0;
}
} // namespace testing

View File

@ -52,6 +52,14 @@ struct DiffOptions {
// Run a single TensorFLow Lite diff test with a given options.
bool RunDiffTest(const DiffOptions& options, int num_invocations);
// Runs diff test for custom TestRunner identified by the factory methiodd
// 'runner_factory' against TFLite CPU given 'options' 'runner_factory' should
// return instance of TestRunner, caller will take ownership of the returned
// object.
// Function returns True if test pass, false otherwise.
bool RunDiffTestWithProvidedRunner(const tflite::testing::DiffOptions& options,
TestRunner* (*runner_factory)());
} // namespace testing
} // namespace tflite