diff --git a/tensorflow/lite/tools/BUILD b/tensorflow/lite/tools/BUILD index a7a323be119..106b79bbafc 100644 --- a/tensorflow/lite/tools/BUILD +++ b/tensorflow/lite/tools/BUILD @@ -237,7 +237,7 @@ cc_library( srcs = ["command_line_flags.cc"], hdrs = ["command_line_flags.h"], copts = tflite_copts(), - deps = ["//tensorflow/lite:minimal_logging"], + deps = ["//tensorflow/lite/tools:logging"], ) cc_test( @@ -247,8 +247,7 @@ cc_test( visibility = ["//visibility:private"], deps = [ ":command_line_flags", - "//tensorflow/lite/testing:util", - "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", ], ) diff --git a/tensorflow/lite/tools/command_line_flags.cc b/tensorflow/lite/tools/command_line_flags.cc index 841424421e0..0db2d53df5a 100644 --- a/tensorflow/lite/tools/command_line_flags.cc +++ b/tensorflow/lite/tools/command_line_flags.cc @@ -18,10 +18,11 @@ limitations under the License. #include #include #include +#include #include #include -#include "tensorflow/lite/minimal_logging.h" +#include "tensorflow/lite/tools/logging.h" namespace tflite { namespace { @@ -165,7 +166,12 @@ std::string Flag::GetTypeName() const { /*static*/ bool Flags::Parse(int* argc, const char** argv, const std::vector& flag_list) { bool result = true; - std::vector unknown_flags(*argc, true); + std::vector unknown_argvs(*argc, true); + // Record the list of flags that have been processed. key is the flag's name + // and the value is the corresponding argv index if there's one, or -1 when + // the argv list doesn't contain this flag. + std::unordered_map processed_flags; + // Stores indexes of flag_list in a sorted order. std::vector sorted_idx(flag_list.size()); std::iota(std::begin(sorted_idx), std::end(sorted_idx), 0); @@ -174,53 +180,84 @@ std::string Flag::GetTypeName() const { }); int positional_count = 0; - for (int i = 0; i < sorted_idx.size(); ++i) { - const Flag& flag = flag_list[sorted_idx[i]]; + for (int idx = 0; idx < sorted_idx.size(); ++idx) { + const Flag& flag = flag_list[sorted_idx[idx]]; + + const auto it = processed_flags.find(flag.name_); + if (it != processed_flags.end()) { + TFLITE_LOG(WARN) << "Duplicate flags: " << flag.name_; + if (it->second != -1) { + bool value_parsing_ok; + flag.Parse(argv[it->second], &value_parsing_ok); + if (!value_parsing_ok) { + TFLITE_LOG(ERROR) << "Failed to parse flag '" << flag.name_ + << "' against argv '" << argv[it->second] << "'"; + result = false; + } + continue; + } else if (flag.flag_type_ == Flag::REQUIRED) { + TFLITE_LOG(ERROR) << "Required flag not provided: " << flag.name_; + // If the required flag isn't found, we immediately stop the whole flag + // parsing. + result = false; + break; + } + } + // Parses positional flags. if (flag.flag_type_ == Flag::POSITIONAL) { if (++positional_count >= *argc) { - TFLITE_LOG(TFLITE_LOG_ERROR, "Too few command line arguments"); + TFLITE_LOG(ERROR) << "Too few command line arguments."; return false; } bool value_parsing_ok; flag.Parse(argv[positional_count], &value_parsing_ok); if (!value_parsing_ok) { - TFLITE_LOG(TFLITE_LOG_ERROR, "Failed to parse positional flag: %s", - flag.name_.c_str()); + TFLITE_LOG(ERROR) << "Failed to parse positional flag: " << flag.name_; return false; } - unknown_flags[positional_count] = false; + unknown_argvs[positional_count] = false; + processed_flags[flag.name_] = positional_count; continue; } // Parse other flags. bool was_found = false; for (int i = positional_count + 1; i < *argc; ++i) { - if (!unknown_flags[i]) continue; + if (!unknown_argvs[i]) continue; bool value_parsing_ok; was_found = flag.Parse(argv[i], &value_parsing_ok); if (!value_parsing_ok) { - TFLITE_LOG(TFLITE_LOG_ERROR, "Failed to parse flag: %s", - flag.name_.c_str()); + TFLITE_LOG(ERROR) << "Failed to parse flag '" << flag.name_ + << "' against argv '" << argv[i] << "'"; result = false; } if (was_found) { - unknown_flags[i] = false; + unknown_argvs[i] = false; + processed_flags[flag.name_] = i; break; } } - // Check if required flag not found. - if (flag.flag_type_ == Flag::REQUIRED && !was_found) { - TFLITE_LOG(TFLITE_LOG_ERROR, "Required flag not provided: %s", - flag.name_.c_str()); + + // If the flag is found from the argv (i.e. the flag name appears in argv), + // continue to the next flag parsing. + if (was_found) continue; + + // The flag isn't found, do some bookkeeping work. + processed_flags[flag.name_] = -1; + if (flag.flag_type_ == Flag::REQUIRED) { + TFLITE_LOG(ERROR) << "Required flag not provided: " << flag.name_; result = false; + // If the required flag isn't found, we immediately stop the whole flag + // parsing by breaking the outer-loop (i.e. the 'sorted_idx'-iteration + // loop). break; } } int dst = 1; // Skip argv[0] for (int i = 1; i < *argc; ++i) { - if (unknown_flags[i]) { + if (unknown_argvs[i]) { argv[dst++] = argv[i]; } } diff --git a/tensorflow/lite/tools/command_line_flags.h b/tensorflow/lite/tools/command_line_flags.h index 2808a12a489..941a1b8b59a 100644 --- a/tensorflow/lite/tools/command_line_flags.h +++ b/tensorflow/lite/tools/command_line_flags.h @@ -125,6 +125,14 @@ class Flags { // with matching flags, and remove the matching arguments from (*argc, argv). // Return true iff all recognized flag values were parsed correctly, and the // first remaining argument is not "--help". + // Note: + // 1. when there are duplicate args in argv for the same flag, the flag value + // and the parse result will be based on the 1st arg. + // 2. when there are duplicate flags in flag_list (i.e. two flags having the + // same name), all of them will be checked against the arg list and the parse + // result will be false if any of the parsing fails. + // See *Duplicate* unit tests in command_line_flags_test.cc for the + // illustration of such behaviors. static bool Parse(int* argc, const char** argv, const std::vector& flag_list); diff --git a/tensorflow/lite/tools/command_line_flags_test.cc b/tensorflow/lite/tools/command_line_flags_test.cc index 1354c6d503b..eb02379143f 100644 --- a/tensorflow/lite/tools/command_line_flags_test.cc +++ b/tensorflow/lite/tools/command_line_flags_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include "tensorflow/lite/testing/util.h" namespace tflite { namespace { @@ -60,12 +59,12 @@ TEST(CommandLineFlagsTest, BasicUsage) { Flag::CreateFlag("float_1", &float_1, "some float", Flag::POSITIONAL), }); - EXPECT_EQ(true, parsed_ok); + EXPECT_TRUE(parsed_ok); EXPECT_EQ(20, some_int32); EXPECT_EQ(8, some_int1); EXPECT_EQ(5, some_int2); EXPECT_EQ(214748364700, some_int64); - EXPECT_EQ(true, some_switch); + EXPECT_TRUE(some_switch); EXPECT_EQ("somethingelse", some_name); EXPECT_NEAR(42.0f, some_float, 1e-5f); EXPECT_NEAR(12.2f, float_1, 1e-5f); @@ -82,7 +81,7 @@ TEST(CommandLineFlagsTest, EmptyStringFlag) { &argc, reinterpret_cast(argv_strings), {Flag::CreateFlag("some_string", &some_string, "some string")}); - EXPECT_EQ(true, parsed_ok); + EXPECT_TRUE(parsed_ok); EXPECT_EQ(some_string, ""); EXPECT_EQ(argc, 1); } @@ -95,7 +94,7 @@ TEST(CommandLineFlagsTest, BadIntValue) { Flags::Parse(&argc, reinterpret_cast(argv_strings), {Flag::CreateFlag("some_int", &some_int, "some int")}); - EXPECT_EQ(false, parsed_ok); + EXPECT_FALSE(parsed_ok); EXPECT_EQ(10, some_int); EXPECT_EQ(argc, 1); } @@ -108,8 +107,8 @@ TEST(CommandLineFlagsTest, BadBoolValue) { &argc, reinterpret_cast(argv_strings), {Flag::CreateFlag("some_switch", &some_switch, "some switch")}); - EXPECT_EQ(false, parsed_ok); - EXPECT_EQ(false, some_switch); + EXPECT_FALSE(parsed_ok); + EXPECT_FALSE(some_switch); EXPECT_EQ(argc, 1); } @@ -121,7 +120,7 @@ TEST(CommandLineFlagsTest, BadFloatValue) { Flags::Parse(&argc, reinterpret_cast(argv_strings), {Flag::CreateFlag("some_float", &some_float, "some float")}); - EXPECT_EQ(false, parsed_ok); + EXPECT_FALSE(parsed_ok); EXPECT_NEAR(-23.23f, some_float, 1e-5f); EXPECT_EQ(argc, 1); } @@ -134,7 +133,7 @@ TEST(CommandLineFlagsTest, RequiredFlagNotFound) { &argc, reinterpret_cast(argv_strings), {Flag::CreateFlag("some_flag", &some_float, "", Flag::REQUIRED)}); - EXPECT_EQ(false, parsed_ok); + EXPECT_FALSE(parsed_ok); EXPECT_NEAR(-23.23f, some_float, 1e-5f); EXPECT_EQ(argc, 2); } @@ -147,7 +146,7 @@ TEST(CommandLineFlagsTest, NoArguments) { &argc, reinterpret_cast(argv_strings), {Flag::CreateFlag("some_flag", &some_float, "", Flag::REQUIRED)}); - EXPECT_EQ(false, parsed_ok); + EXPECT_FALSE(parsed_ok); EXPECT_NEAR(-23.23f, some_float, 1e-5f); EXPECT_EQ(argc, 1); } @@ -160,7 +159,7 @@ TEST(CommandLineFlagsTest, NotEnoughArguments) { &argc, reinterpret_cast(argv_strings), {Flag::CreateFlag("some_flag", &some_float, "", Flag::POSITIONAL)}); - EXPECT_EQ(false, parsed_ok); + EXPECT_FALSE(parsed_ok); EXPECT_NEAR(-23.23f, some_float, 1e-5f); EXPECT_EQ(argc, 1); } @@ -173,7 +172,7 @@ TEST(CommandLineFlagsTest, PositionalFlagFailed) { &argc, reinterpret_cast(argv_strings), {Flag::CreateFlag("some_flag", &some_float, "", Flag::POSITIONAL)}); - EXPECT_EQ(false, parsed_ok); + EXPECT_FALSE(parsed_ok); EXPECT_NEAR(-23.23f, some_float, 1e-5f); EXPECT_EQ(argc, 2); } @@ -235,11 +234,125 @@ TEST(CommandLineFlagsTest, UsageString) { << usage; } +// When there are duplicate args, the flag value and the parsing result will be +// based on the 1st arg. +TEST(CommandLineFlagsTest, DuplicateArgsParsableValues) { + int some_int = -23; + int argc = 3; + const char* argv_strings[] = {"program_name", "--some_int=1", "--some_int=2"}; + bool parsed_ok = + Flags::Parse(&argc, reinterpret_cast(argv_strings), + {Flag::CreateFlag("some_int", &some_int, "some int")}); + + EXPECT_TRUE(parsed_ok); + EXPECT_EQ(1, some_int); + EXPECT_EQ(argc, 2); + EXPECT_EQ("--some_int=2", argv_strings[1]); +} + +TEST(CommandLineFlagsTest, DuplicateArgsBadValueAppearFirst) { + int some_int = -23; + int argc = 3; + const char* argv_strings[] = {"program_name", "--some_int=value", + "--some_int=1"}; + bool parsed_ok = + Flags::Parse(&argc, reinterpret_cast(argv_strings), + {Flag::CreateFlag("some_int", &some_int, "some int")}); + + EXPECT_FALSE(parsed_ok); + EXPECT_EQ(-23, some_int); + EXPECT_EQ(argc, 2); + EXPECT_EQ("--some_int=1", argv_strings[1]); +} + +TEST(CommandLineFlagsTest, DuplicateArgsBadValueAppearSecondly) { + int some_int = -23; + int argc = 3; + // Although the 2nd arg has non-parsable int value, the flag 'some_int' value + // could still be set and the parsing result is ok. + const char* argv_strings[] = {"program_name", "--some_int=1", + "--some_int=value"}; + bool parsed_ok = + Flags::Parse(&argc, reinterpret_cast(argv_strings), + {Flag::CreateFlag("some_int", &some_int, "some int")}); + + EXPECT_TRUE(parsed_ok); + EXPECT_EQ(1, some_int); + EXPECT_EQ(argc, 2); + EXPECT_EQ("--some_int=value", argv_strings[1]); +} + +// When there are duplicate flags, all of them will be checked against the arg +// list. +TEST(CommandLineFlagsTest, DuplicateFlags) { + int some_int1 = -23; + int some_int2 = -23; + int argc = 2; + const char* argv_strings[] = {"program_name", "--some_int=1"}; + bool parsed_ok = + Flags::Parse(&argc, reinterpret_cast(argv_strings), + {Flag::CreateFlag("some_int", &some_int1, "some int1"), + Flag::CreateFlag("some_int", &some_int2, "some int2")}); + + EXPECT_TRUE(parsed_ok); + EXPECT_EQ(1, some_int1); + EXPECT_EQ(1, some_int2); + EXPECT_EQ(argc, 1); +} + +TEST(CommandLineFlagsTest, DuplicateFlagsNotFound) { + int some_int1 = -23; + int some_int2 = -23; + int argc = 2; + const char* argv_strings[] = {"program_name", "--some_float=1.0"}; + bool parsed_ok = Flags::Parse( + &argc, reinterpret_cast(argv_strings), + {Flag::CreateFlag("some_int", &some_int1, "some int1", Flag::OPTIONAL), + Flag::CreateFlag("some_int", &some_int2, "some int2", Flag::REQUIRED)}); + + EXPECT_FALSE(parsed_ok); + EXPECT_EQ(-23, some_int1); + EXPECT_EQ(-23, some_int2); + EXPECT_EQ(argc, 2); +} + +TEST(CommandLineFlagsTest, DuplicateFlagNamesButDifferentTypes) { + int some_int = -23; + bool some_bool = true; + int argc = 2; + const char* argv_strings[] = {"program_name", "--some_val=20"}; + // In this case, the 2nd 'some_val' flag of bool type will cause a no-ok + // parsing result. + bool parsed_ok = + Flags::Parse(&argc, reinterpret_cast(argv_strings), + {Flag::CreateFlag("some_val", &some_int, "some val-int"), + Flag::CreateFlag("some_val", &some_bool, "some val-bool")}); + + EXPECT_FALSE(parsed_ok); + EXPECT_EQ(20, some_int); + EXPECT_TRUE(some_bool); + EXPECT_EQ(argc, 1); +} + +TEST(CommandLineFlagsTest, DuplicateFlagsAndArgs) { + int some_int1 = -23; + int some_int2 = -23; + int argc = 3; + const char* argv_strings[] = {"program_name", "--some_int=1", "--some_int=2"}; + bool parsed_ok = Flags::Parse( + &argc, reinterpret_cast(argv_strings), + {Flag::CreateFlag("some_int", &some_int1, "flag1: bind with some_int1"), + Flag::CreateFlag("some_int", &some_int2, "flag2: bind with some_int2")}); + + // Note, when there're duplicate args, the flag value and the parsing result + // will be based on the 1st arg (i.e. --some_int=1). And both duplicate flags + // (i.e. flag1 and flag2) are checked, thus resulting their associated values + // (some_int1 and some_int2) being set to 1. + EXPECT_TRUE(parsed_ok); + EXPECT_EQ(1, some_int1); + EXPECT_EQ(1, some_int2); + EXPECT_EQ(argc, 2); +} + } // namespace } // namespace tflite - -int main(int argc, char** argv) { - ::tflite::LogToStderr(); - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -}