From 40655bfd6572997c1bbe73ac2ed7e2814e3c48e9 Mon Sep 17 00:00:00 2001 From: Chao Mei Date: Fri, 5 Mar 2021 00:46:45 -0800 Subject: [PATCH] Support to print out cmdline flags that are supported by the TfLite test delegate providers. PiperOrigin-RevId: 361091579 Change-Id: I9cbe6b694ff6fc79bfa98e9f86385b338bc3a47f --- .../lite/kernels/test_delegate_providers.cc | 11 ++++++++++- tensorflow/lite/kernels/test_main.cc | 17 ++++++++++++----- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/tensorflow/lite/kernels/test_delegate_providers.cc b/tensorflow/lite/kernels/test_delegate_providers.cc index d2cb2d1021d..d1ab5e365ee 100644 --- a/tensorflow/lite/kernels/test_delegate_providers.cc +++ b/tensorflow/lite/kernels/test_delegate_providers.cc @@ -37,7 +37,16 @@ bool KernelTestDelegateProviders::InitFromCmdlineArgs(int* argc, auto one_flags = one->CreateFlags(¶ms_); flags.insert(flags.end(), one_flags.begin(), one_flags.end()); } - return tflite::Flags::Parse(argc, argv, flags); + + // Note: when "--help" is passed, the 'Parse' function will return false. + // TODO(b/181868587): The above logic to print out the all supported flags is + // not intuitive, so considering adding the "--help" flag explicitly. + const bool parse_result = tflite::Flags::Parse(argc, argv, flags); + if (!parse_result) { + std::string usage = Flags::Usage(argv[0], flags); + TFLITE_LOG(ERROR) << usage; + } + return parse_result; } std::vector diff --git a/tensorflow/lite/kernels/test_main.cc b/tensorflow/lite/kernels/test_main.cc index e29e0e93525..e36a33be2f1 100644 --- a/tensorflow/lite/kernels/test_main.cc +++ b/tensorflow/lite/kernels/test_main.cc @@ -22,10 +22,13 @@ limitations under the License. namespace { -void InitKernelTest(int* argc, char** argv) { +bool InitKernelTest(int* argc, char** argv) { tflite::KernelTestDelegateProviders* const delegate_providers = tflite::KernelTestDelegateProviders::Get(); - delegate_providers->InitFromCmdlineArgs(argc, const_cast(argv)); + if (!delegate_providers->InitFromCmdlineArgs( + argc, const_cast(argv))) { + return false; + } if (delegate_providers->ConstParams().Get("use_nnapi")) { // In Android Q, the NNAPI delegate avoids delegation if the only device @@ -37,13 +40,17 @@ void InitKernelTest(int* argc, char** argv) { params->Set("disable_nnapi_cpu", false); } } + return true; } } // namespace int main(int argc, char** argv) { ::tflite::LogToStderr(); - InitKernelTest(&argc, argv); - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); + if (InitKernelTest(&argc, argv)) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); + } else { + return EXIT_FAILURE; + } }