Add reference_tflite_model option to DiffOptions
This makes comparison between two tflite models possible. PiperOrigin-RevId: 349357549 Change-Id: Ibeb8e206be2a997be8974b4ec51eee21bcc45a24
This commit is contained in:
parent
5688e7ef42
commit
f38c4d866b
tensorflow/lite/testing
@ -36,6 +36,7 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
|
||||
string output_layer;
|
||||
int32_t num_runs_per_pass = 100;
|
||||
string delegate_name;
|
||||
string reference_tflite_model;
|
||||
} values;
|
||||
|
||||
std::string delegate_name;
|
||||
@ -61,6 +62,10 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
|
||||
tensorflow::Flag("delegate", &values.delegate_name,
|
||||
"[optional] Delegate to use for executing ops. Must be "
|
||||
"`{\"\", NNAPI, GPU, FLEX}`"),
|
||||
tensorflow::Flag("reference_tflite_model", &values.reference_tflite_model,
|
||||
"[optional] Path of the TensorFlow Lite model to "
|
||||
"compare inference results against the model given in "
|
||||
"`tflite_model`."),
|
||||
};
|
||||
|
||||
bool no_inputs = *argc == 1;
|
||||
@ -96,7 +101,8 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
|
||||
Split<string>(values.input_layer_shape, ":"),
|
||||
Split<string>(values.output_layer, ","),
|
||||
values.num_runs_per_pass,
|
||||
delegate};
|
||||
delegate,
|
||||
values.reference_tflite_model};
|
||||
}
|
||||
|
||||
} // namespace testing
|
||||
|
@ -30,8 +30,11 @@ bool SingleRunDiffTestWithProvidedRunner(::tflite::testing::DiffOptions options,
|
||||
int num_invocations,
|
||||
TestRunner* (*runner_factory)()) {
|
||||
std::stringstream tflite_stream;
|
||||
std::string reference_tflite_model = options.reference_tflite_model.empty()
|
||||
? options.tflite_model
|
||||
: options.reference_tflite_model;
|
||||
if (!GenerateTestSpecFromTFLiteModel(
|
||||
tflite_stream, options.tflite_model, num_invocations,
|
||||
tflite_stream, reference_tflite_model, num_invocations,
|
||||
options.input_layer, options.input_layer_type,
|
||||
options.input_layer_shape, options.output_layer)) {
|
||||
return false;
|
||||
|
@ -47,6 +47,8 @@ struct DiffOptions {
|
||||
int num_runs_per_pass;
|
||||
// The type of delegate to apply during inference.
|
||||
TfLiteDriver::DelegateType delegate;
|
||||
// Path of tflite model used to generate golden values.
|
||||
std::string reference_tflite_model = "";
|
||||
};
|
||||
|
||||
// Run a single TensorFLow Lite diff test with a given options.
|
||||
|
Loading…
Reference in New Issue
Block a user