Add new methods for diff tests that compares TFLite vs any custom TestRunner

PiperOrigin-RevId: 285891574
Change-Id: I67da00e81f6ca60e939b5f77502be1249d59a930
This commit is contained in:
Karim Nosir 2019-12-16 18:17:30 -08:00 committed by TensorFlower Gardener
parent 10c882dfbd
commit 8abcbc77a0
2 changed files with 60 additions and 2 deletions

View File

@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/lite/testing/tflite_diff_util.h"
#include <cstdarg> #include <cstdarg>
#include <cstdio> #include <cstdio>
#include <cstdlib> #include <cstdlib>
@ -19,11 +21,27 @@ limitations under the License.
#include "tensorflow/lite/testing/generate_testspec.h" #include "tensorflow/lite/testing/generate_testspec.h"
#include "tensorflow/lite/testing/parse_testdata.h" #include "tensorflow/lite/testing/parse_testdata.h"
#include "tensorflow/lite/testing/tflite_diff_util.h"
#include "tensorflow/lite/testing/tflite_driver.h" #include "tensorflow/lite/testing/tflite_driver.h"
namespace tflite { namespace tflite {
namespace testing { namespace testing {
namespace {
bool SingleRunDiffTestWithProvidedRunner(::tflite::testing::DiffOptions options,
int num_invocations,
TestRunner* (*runner_factory)()) {
std::stringstream tflite_stream;
if (!GenerateTestSpecFromTFLiteModel(
tflite_stream, options.tflite_model, num_invocations,
options.input_layer, options.input_layer_type,
options.input_layer_shape, options.output_layer)) {
return false;
}
std::unique_ptr<TestRunner> runner(runner_factory());
runner->LoadModel(options.tflite_model);
return ParseAndRunTests(&tflite_stream, runner.get());
}
} // namespace
bool RunDiffTest(const DiffOptions& options, int num_invocations) { bool RunDiffTest(const DiffOptions& options, int num_invocations) {
std::stringstream tflite_stream; std::stringstream tflite_stream;
@ -35,7 +53,39 @@ bool RunDiffTest(const DiffOptions& options, int num_invocations) {
} }
TfLiteDriver tflite_driver(options.delegate); TfLiteDriver tflite_driver(options.delegate);
tflite_driver.LoadModel(options.tflite_model); tflite_driver.LoadModel(options.tflite_model);
return tflite::testing::ParseAndRunTests(&tflite_stream, &tflite_driver); return ParseAndRunTests(&tflite_stream, &tflite_driver);
}
bool RunDiffTestWithProvidedRunner(const tflite::testing::DiffOptions& options,
TestRunner* (*runner_factory)()) {
int failure_count = 0;
for (int i = 0; i < options.num_runs_per_pass; i++) {
if (!SingleRunDiffTestWithProvidedRunner(options,
/*num_invocations=*/1,
runner_factory)) {
++failure_count;
}
}
int failures_in_first_pass = failure_count;
if (failure_count == 0) {
// Let's try again with num_invocations > 1 to make sure we can do multiple
// invocations without resetting the interpreter.
for (int i = 0; i < options.num_runs_per_pass; i++) {
if (!SingleRunDiffTestWithProvidedRunner(options,
/*num_invocations=*/2,
runner_factory)) {
++failure_count;
}
}
}
fprintf(stderr, "Num errors in single-inference pass: %d\n",
failures_in_first_pass);
fprintf(stderr, "Num errors in multi-inference pass : %d\n",
failure_count - failures_in_first_pass);
return failure_count == 0;
} }
} // namespace testing } // namespace testing

View File

@ -52,6 +52,14 @@ struct DiffOptions {
// Run a single TensorFLow Lite diff test with a given options. // Run a single TensorFLow Lite diff test with a given options.
bool RunDiffTest(const DiffOptions& options, int num_invocations); bool RunDiffTest(const DiffOptions& options, int num_invocations);
// Runs diff test for custom TestRunner identified by the factory methiodd
// 'runner_factory' against TFLite CPU given 'options' 'runner_factory' should
// return instance of TestRunner, caller will take ownership of the returned
// object.
// Function returns True if test pass, false otherwise.
bool RunDiffTestWithProvidedRunner(const tflite::testing::DiffOptions& options,
TestRunner* (*runner_factory)());
} // namespace testing } // namespace testing
} // namespace tflite } // namespace tflite