Add new methods for diff tests that compares TFLite vs any custom TestRunner
PiperOrigin-RevId: 285891574 Change-Id: I67da00e81f6ca60e939b5f77502be1249d59a930
This commit is contained in:
parent
10c882dfbd
commit
8abcbc77a0
tensorflow/lite/testing
@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/testing/tflite_diff_util.h"
|
||||
|
||||
#include <cstdarg>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
@ -19,11 +21,27 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/testing/generate_testspec.h"
|
||||
#include "tensorflow/lite/testing/parse_testdata.h"
|
||||
#include "tensorflow/lite/testing/tflite_diff_util.h"
|
||||
#include "tensorflow/lite/testing/tflite_driver.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace testing {
|
||||
namespace {
|
||||
bool SingleRunDiffTestWithProvidedRunner(::tflite::testing::DiffOptions options,
|
||||
int num_invocations,
|
||||
TestRunner* (*runner_factory)()) {
|
||||
std::stringstream tflite_stream;
|
||||
if (!GenerateTestSpecFromTFLiteModel(
|
||||
tflite_stream, options.tflite_model, num_invocations,
|
||||
options.input_layer, options.input_layer_type,
|
||||
options.input_layer_shape, options.output_layer)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::unique_ptr<TestRunner> runner(runner_factory());
|
||||
runner->LoadModel(options.tflite_model);
|
||||
return ParseAndRunTests(&tflite_stream, runner.get());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool RunDiffTest(const DiffOptions& options, int num_invocations) {
|
||||
std::stringstream tflite_stream;
|
||||
@ -35,7 +53,39 @@ bool RunDiffTest(const DiffOptions& options, int num_invocations) {
|
||||
}
|
||||
TfLiteDriver tflite_driver(options.delegate);
|
||||
tflite_driver.LoadModel(options.tflite_model);
|
||||
return tflite::testing::ParseAndRunTests(&tflite_stream, &tflite_driver);
|
||||
return ParseAndRunTests(&tflite_stream, &tflite_driver);
|
||||
}
|
||||
|
||||
bool RunDiffTestWithProvidedRunner(const tflite::testing::DiffOptions& options,
|
||||
TestRunner* (*runner_factory)()) {
|
||||
int failure_count = 0;
|
||||
for (int i = 0; i < options.num_runs_per_pass; i++) {
|
||||
if (!SingleRunDiffTestWithProvidedRunner(options,
|
||||
/*num_invocations=*/1,
|
||||
runner_factory)) {
|
||||
++failure_count;
|
||||
}
|
||||
}
|
||||
int failures_in_first_pass = failure_count;
|
||||
|
||||
if (failure_count == 0) {
|
||||
// Let's try again with num_invocations > 1 to make sure we can do multiple
|
||||
// invocations without resetting the interpreter.
|
||||
for (int i = 0; i < options.num_runs_per_pass; i++) {
|
||||
if (!SingleRunDiffTestWithProvidedRunner(options,
|
||||
/*num_invocations=*/2,
|
||||
runner_factory)) {
|
||||
++failure_count;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fprintf(stderr, "Num errors in single-inference pass: %d\n",
|
||||
failures_in_first_pass);
|
||||
fprintf(stderr, "Num errors in multi-inference pass : %d\n",
|
||||
failure_count - failures_in_first_pass);
|
||||
|
||||
return failure_count == 0;
|
||||
}
|
||||
} // namespace testing
|
||||
|
||||
|
@ -52,6 +52,14 @@ struct DiffOptions {
|
||||
// Run a single TensorFLow Lite diff test with a given options.
|
||||
bool RunDiffTest(const DiffOptions& options, int num_invocations);
|
||||
|
||||
// Runs diff test for custom TestRunner identified by the factory methiodd
|
||||
// 'runner_factory' against TFLite CPU given 'options' 'runner_factory' should
|
||||
// return instance of TestRunner, caller will take ownership of the returned
|
||||
// object.
|
||||
// Function returns True if test pass, false otherwise.
|
||||
bool RunDiffTestWithProvidedRunner(const tflite::testing::DiffOptions& options,
|
||||
TestRunner* (*runner_factory)());
|
||||
|
||||
} // namespace testing
|
||||
} // namespace tflite
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user