Add reference_tflite_model option to DiffOptions

This makes comparison between two tflite models possible.

PiperOrigin-RevId: 349357549
Change-Id: Ibeb8e206be2a997be8974b4ec51eee21bcc45a24
This commit is contained in:
Taehee Jeong 2020-12-28 18:25:05 -08:00 committed by TensorFlower Gardener
parent 5688e7ef42
commit f38c4d866b
3 changed files with 13 additions and 2 deletions

View File

@ -36,6 +36,7 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
string output_layer;
int32_t num_runs_per_pass = 100;
string delegate_name;
string reference_tflite_model;
} values;
std::string delegate_name;
@ -61,6 +62,10 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
tensorflow::Flag("delegate", &values.delegate_name,
"[optional] Delegate to use for executing ops. Must be "
"`{\"\", NNAPI, GPU, FLEX}`"),
tensorflow::Flag("reference_tflite_model", &values.reference_tflite_model,
"[optional] Path of the TensorFlow Lite model to "
"compare inference results against the model given in "
"`tflite_model`."),
};
bool no_inputs = *argc == 1;
@ -96,7 +101,8 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
Split<string>(values.input_layer_shape, ":"),
Split<string>(values.output_layer, ","),
values.num_runs_per_pass,
delegate};
delegate,
values.reference_tflite_model};
}
} // namespace testing

View File

@ -30,8 +30,11 @@ bool SingleRunDiffTestWithProvidedRunner(::tflite::testing::DiffOptions options,
int num_invocations,
TestRunner* (*runner_factory)()) {
std::stringstream tflite_stream;
std::string reference_tflite_model = options.reference_tflite_model.empty()
? options.tflite_model
: options.reference_tflite_model;
if (!GenerateTestSpecFromTFLiteModel(
tflite_stream, options.tflite_model, num_invocations,
tflite_stream, reference_tflite_model, num_invocations,
options.input_layer, options.input_layer_type,
options.input_layer_shape, options.output_layer)) {
return false;

View File

@ -47,6 +47,8 @@ struct DiffOptions {
int num_runs_per_pass;
// The type of delegate to apply during inference.
TfLiteDriver::DelegateType delegate;
// Path of tflite model used to generate golden values.
std::string reference_tflite_model = "";
};
// Run a single TensorFLow Lite diff test with a given options.