1. Allow parsing duplicate flags.

2. Document the behavior when there're duplicate flags or duplicate args during parsing, and add corresponding unit tests for such behaviors.

PiperOrigin-RevId: 306344156
Change-Id: I1601d4b2ab9efd9fb624a5b8e3c13a1dd74972de
This commit is contained in:
A. Unique TensorFlower 2020-04-13 17:28:01 -07:00 committed by TensorFlower Gardener
parent c3c62c31c5
commit f255fc6854
4 changed files with 186 additions and 36 deletions

View File

@ -235,7 +235,7 @@ cc_library(
srcs = ["command_line_flags.cc"], srcs = ["command_line_flags.cc"],
hdrs = ["command_line_flags.h"], hdrs = ["command_line_flags.h"],
copts = tflite_copts(), copts = tflite_copts(),
deps = ["//tensorflow/lite:minimal_logging"], deps = ["//tensorflow/lite/tools:logging"],
) )
cc_test( cc_test(
@ -245,8 +245,7 @@ cc_test(
visibility = ["//visibility:private"], visibility = ["//visibility:private"],
deps = [ deps = [
":command_line_flags", ":command_line_flags",
"//tensorflow/lite/testing:util", "@com_google_googletest//:gtest_main",
"@com_google_googletest//:gtest",
], ],
) )

View File

@ -18,10 +18,11 @@ limitations under the License.
#include <numeric> #include <numeric>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <unordered_map>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "tensorflow/lite/minimal_logging.h" #include "tensorflow/lite/tools/logging.h"
namespace tflite { namespace tflite {
namespace { namespace {
@ -165,7 +166,12 @@ std::string Flag::GetTypeName() const {
/*static*/ bool Flags::Parse(int* argc, const char** argv, /*static*/ bool Flags::Parse(int* argc, const char** argv,
const std::vector<Flag>& flag_list) { const std::vector<Flag>& flag_list) {
bool result = true; bool result = true;
std::vector<bool> unknown_flags(*argc, true); std::vector<bool> unknown_argvs(*argc, true);
// Record the list of flags that have been processed. key is the flag's name
// and the value is the corresponding argv index if there's one, or -1 when
// the argv list doesn't contain this flag.
std::unordered_map<std::string, int> processed_flags;
// Stores indexes of flag_list in a sorted order. // Stores indexes of flag_list in a sorted order.
std::vector<int> sorted_idx(flag_list.size()); std::vector<int> sorted_idx(flag_list.size());
std::iota(std::begin(sorted_idx), std::end(sorted_idx), 0); std::iota(std::begin(sorted_idx), std::end(sorted_idx), 0);
@ -174,45 +180,69 @@ std::string Flag::GetTypeName() const {
}); });
int positional_count = 0; int positional_count = 0;
for (int i = 0; i < sorted_idx.size(); ++i) { for (int idx = 0; idx < sorted_idx.size(); ++idx) {
const Flag& flag = flag_list[sorted_idx[i]]; const Flag& flag = flag_list[sorted_idx[idx]];
const auto it = processed_flags.find(flag.name_);
if (it != processed_flags.end()) {
TFLITE_LOG(WARN) << "Duplicate flags: " << flag.name_;
if (it->second != -1) {
bool value_parsing_ok;
flag.Parse(argv[it->second], &value_parsing_ok);
if (!value_parsing_ok) {
TFLITE_LOG(ERROR) << "Failed to parse flag '" << flag.name_
<< "' against argv '" << argv[it->second] << "'";
result = false;
}
continue;
} else if (flag.flag_type_ == Flag::REQUIRED) {
// Check if required flag not found.
TFLITE_LOG(ERROR) << "Required flag not provided: " << flag.name_;
result = false;
break;
}
}
// Parses positional flags. // Parses positional flags.
if (flag.flag_type_ == Flag::POSITIONAL) { if (flag.flag_type_ == Flag::POSITIONAL) {
if (++positional_count >= *argc) { if (++positional_count >= *argc) {
TFLITE_LOG(TFLITE_LOG_ERROR, "Too few command line arguments"); TFLITE_LOG(ERROR) << "Too few command line arguments.";
return false; return false;
} }
bool value_parsing_ok; bool value_parsing_ok;
flag.Parse(argv[positional_count], &value_parsing_ok); flag.Parse(argv[positional_count], &value_parsing_ok);
if (!value_parsing_ok) { if (!value_parsing_ok) {
TFLITE_LOG(TFLITE_LOG_ERROR, "Failed to parse positional flag: %s", TFLITE_LOG(ERROR) << "Failed to parse positional flag: " << flag.name_;
flag.name_.c_str());
return false; return false;
} }
unknown_flags[positional_count] = false; unknown_argvs[positional_count] = false;
processed_flags[flag.name_] = positional_count;
continue; continue;
} }
// Parse other flags. // Parse other flags.
bool was_found = false; bool was_found = false;
for (int i = positional_count + 1; i < *argc; ++i) { for (int i = positional_count + 1; i < *argc; ++i) {
if (!unknown_flags[i]) continue; if (!unknown_argvs[i]) continue;
bool value_parsing_ok; bool value_parsing_ok;
was_found = flag.Parse(argv[i], &value_parsing_ok); was_found = flag.Parse(argv[i], &value_parsing_ok);
if (!value_parsing_ok) { if (!value_parsing_ok) {
TFLITE_LOG(TFLITE_LOG_ERROR, "Failed to parse flag: %s", TFLITE_LOG(ERROR) << "Failed to parse flag '" << flag.name_
flag.name_.c_str()); << "' against argv '" << argv[i] << "'";
result = false; result = false;
} }
if (was_found) { if (was_found) {
unknown_flags[i] = false; unknown_argvs[i] = false;
processed_flags[flag.name_] = i;
break; break;
} }
} }
if (!was_found) {
processed_flags[flag.name_] = -1;
}
// Check if required flag not found. // Check if required flag not found.
if (flag.flag_type_ == Flag::REQUIRED && !was_found) { if (flag.flag_type_ == Flag::REQUIRED && !was_found) {
TFLITE_LOG(TFLITE_LOG_ERROR, "Required flag not provided: %s", TFLITE_LOG(ERROR) << "Required flag not provided: " << flag.name_;
flag.name_.c_str());
result = false; result = false;
break; break;
} }
@ -220,7 +250,7 @@ std::string Flag::GetTypeName() const {
int dst = 1; // Skip argv[0] int dst = 1; // Skip argv[0]
for (int i = 1; i < *argc; ++i) { for (int i = 1; i < *argc; ++i) {
if (unknown_flags[i]) { if (unknown_argvs[i]) {
argv[dst++] = argv[i]; argv[dst++] = argv[i];
} }
} }

View File

@ -125,6 +125,14 @@ class Flags {
// with matching flags, and remove the matching arguments from (*argc, argv). // with matching flags, and remove the matching arguments from (*argc, argv).
// Return true iff all recognized flag values were parsed correctly, and the // Return true iff all recognized flag values were parsed correctly, and the
// first remaining argument is not "--help". // first remaining argument is not "--help".
// Note:
// 1. when there are duplicate args in argv for the same flag, the flag value
// and the parse result will be based on the 1st arg.
// 2. when there are duplicate flags in flag_list (i.e. two flags having the
// same name), all of them will be checked against the arg list and the parse
// result will be false if any of the parsing fails.
// See *Duplicate* unit tests in command_line_flags_test.cc for the
// illustration of such behaviors.
static bool Parse(int* argc, const char** argv, static bool Parse(int* argc, const char** argv,
const std::vector<Flag>& flag_list); const std::vector<Flag>& flag_list);

View File

@ -17,7 +17,6 @@ limitations under the License.
#include <gmock/gmock.h> #include <gmock/gmock.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "tensorflow/lite/testing/util.h"
namespace tflite { namespace tflite {
namespace { namespace {
@ -60,12 +59,12 @@ TEST(CommandLineFlagsTest, BasicUsage) {
Flag::CreateFlag("float_1", &float_1, "some float", Flag::POSITIONAL), Flag::CreateFlag("float_1", &float_1, "some float", Flag::POSITIONAL),
}); });
EXPECT_EQ(true, parsed_ok); EXPECT_TRUE(parsed_ok);
EXPECT_EQ(20, some_int32); EXPECT_EQ(20, some_int32);
EXPECT_EQ(8, some_int1); EXPECT_EQ(8, some_int1);
EXPECT_EQ(5, some_int2); EXPECT_EQ(5, some_int2);
EXPECT_EQ(214748364700, some_int64); EXPECT_EQ(214748364700, some_int64);
EXPECT_EQ(true, some_switch); EXPECT_TRUE(some_switch);
EXPECT_EQ("somethingelse", some_name); EXPECT_EQ("somethingelse", some_name);
EXPECT_NEAR(42.0f, some_float, 1e-5f); EXPECT_NEAR(42.0f, some_float, 1e-5f);
EXPECT_NEAR(12.2f, float_1, 1e-5f); EXPECT_NEAR(12.2f, float_1, 1e-5f);
@ -82,7 +81,7 @@ TEST(CommandLineFlagsTest, EmptyStringFlag) {
&argc, reinterpret_cast<const char**>(argv_strings), &argc, reinterpret_cast<const char**>(argv_strings),
{Flag::CreateFlag("some_string", &some_string, "some string")}); {Flag::CreateFlag("some_string", &some_string, "some string")});
EXPECT_EQ(true, parsed_ok); EXPECT_TRUE(parsed_ok);
EXPECT_EQ(some_string, ""); EXPECT_EQ(some_string, "");
EXPECT_EQ(argc, 1); EXPECT_EQ(argc, 1);
} }
@ -95,7 +94,7 @@ TEST(CommandLineFlagsTest, BadIntValue) {
Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings), Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
{Flag::CreateFlag("some_int", &some_int, "some int")}); {Flag::CreateFlag("some_int", &some_int, "some int")});
EXPECT_EQ(false, parsed_ok); EXPECT_FALSE(parsed_ok);
EXPECT_EQ(10, some_int); EXPECT_EQ(10, some_int);
EXPECT_EQ(argc, 1); EXPECT_EQ(argc, 1);
} }
@ -108,8 +107,8 @@ TEST(CommandLineFlagsTest, BadBoolValue) {
&argc, reinterpret_cast<const char**>(argv_strings), &argc, reinterpret_cast<const char**>(argv_strings),
{Flag::CreateFlag("some_switch", &some_switch, "some switch")}); {Flag::CreateFlag("some_switch", &some_switch, "some switch")});
EXPECT_EQ(false, parsed_ok); EXPECT_FALSE(parsed_ok);
EXPECT_EQ(false, some_switch); EXPECT_FALSE(some_switch);
EXPECT_EQ(argc, 1); EXPECT_EQ(argc, 1);
} }
@ -121,7 +120,7 @@ TEST(CommandLineFlagsTest, BadFloatValue) {
Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings), Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
{Flag::CreateFlag("some_float", &some_float, "some float")}); {Flag::CreateFlag("some_float", &some_float, "some float")});
EXPECT_EQ(false, parsed_ok); EXPECT_FALSE(parsed_ok);
EXPECT_NEAR(-23.23f, some_float, 1e-5f); EXPECT_NEAR(-23.23f, some_float, 1e-5f);
EXPECT_EQ(argc, 1); EXPECT_EQ(argc, 1);
} }
@ -134,7 +133,7 @@ TEST(CommandLineFlagsTest, RequiredFlagNotFound) {
&argc, reinterpret_cast<const char**>(argv_strings), &argc, reinterpret_cast<const char**>(argv_strings),
{Flag::CreateFlag("some_flag", &some_float, "", Flag::REQUIRED)}); {Flag::CreateFlag("some_flag", &some_float, "", Flag::REQUIRED)});
EXPECT_EQ(false, parsed_ok); EXPECT_FALSE(parsed_ok);
EXPECT_NEAR(-23.23f, some_float, 1e-5f); EXPECT_NEAR(-23.23f, some_float, 1e-5f);
EXPECT_EQ(argc, 2); EXPECT_EQ(argc, 2);
} }
@ -147,7 +146,7 @@ TEST(CommandLineFlagsTest, NoArguments) {
&argc, reinterpret_cast<const char**>(argv_strings), &argc, reinterpret_cast<const char**>(argv_strings),
{Flag::CreateFlag("some_flag", &some_float, "", Flag::REQUIRED)}); {Flag::CreateFlag("some_flag", &some_float, "", Flag::REQUIRED)});
EXPECT_EQ(false, parsed_ok); EXPECT_FALSE(parsed_ok);
EXPECT_NEAR(-23.23f, some_float, 1e-5f); EXPECT_NEAR(-23.23f, some_float, 1e-5f);
EXPECT_EQ(argc, 1); EXPECT_EQ(argc, 1);
} }
@ -160,7 +159,7 @@ TEST(CommandLineFlagsTest, NotEnoughArguments) {
&argc, reinterpret_cast<const char**>(argv_strings), &argc, reinterpret_cast<const char**>(argv_strings),
{Flag::CreateFlag("some_flag", &some_float, "", Flag::POSITIONAL)}); {Flag::CreateFlag("some_flag", &some_float, "", Flag::POSITIONAL)});
EXPECT_EQ(false, parsed_ok); EXPECT_FALSE(parsed_ok);
EXPECT_NEAR(-23.23f, some_float, 1e-5f); EXPECT_NEAR(-23.23f, some_float, 1e-5f);
EXPECT_EQ(argc, 1); EXPECT_EQ(argc, 1);
} }
@ -173,7 +172,7 @@ TEST(CommandLineFlagsTest, PositionalFlagFailed) {
&argc, reinterpret_cast<const char**>(argv_strings), &argc, reinterpret_cast<const char**>(argv_strings),
{Flag::CreateFlag("some_flag", &some_float, "", Flag::POSITIONAL)}); {Flag::CreateFlag("some_flag", &some_float, "", Flag::POSITIONAL)});
EXPECT_EQ(false, parsed_ok); EXPECT_FALSE(parsed_ok);
EXPECT_NEAR(-23.23f, some_float, 1e-5f); EXPECT_NEAR(-23.23f, some_float, 1e-5f);
EXPECT_EQ(argc, 2); EXPECT_EQ(argc, 2);
} }
@ -235,11 +234,125 @@ TEST(CommandLineFlagsTest, UsageString) {
<< usage; << usage;
} }
// When there are duplicate args, the flag value and the parsing result will be
// based on the 1st arg.
TEST(CommandLineFlagsTest, DuplicateArgsParsableValues) {
int some_int = -23;
int argc = 3;
const char* argv_strings[] = {"program_name", "--some_int=1", "--some_int=2"};
bool parsed_ok =
Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
{Flag::CreateFlag("some_int", &some_int, "some int")});
EXPECT_TRUE(parsed_ok);
EXPECT_EQ(1, some_int);
EXPECT_EQ(argc, 2);
EXPECT_EQ("--some_int=2", argv_strings[1]);
}
TEST(CommandLineFlagsTest, DuplicateArgsBadValueAppearFirst) {
int some_int = -23;
int argc = 3;
const char* argv_strings[] = {"program_name", "--some_int=value",
"--some_int=1"};
bool parsed_ok =
Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
{Flag::CreateFlag("some_int", &some_int, "some int")});
EXPECT_FALSE(parsed_ok);
EXPECT_EQ(-23, some_int);
EXPECT_EQ(argc, 2);
EXPECT_EQ("--some_int=1", argv_strings[1]);
}
TEST(CommandLineFlagsTest, DuplicateArgsBadValueAppearSecondly) {
int some_int = -23;
int argc = 3;
// Although the 2nd arg has non-parsable int value, the flag 'some_int' value
// could still be set and the parsing result is ok.
const char* argv_strings[] = {"program_name", "--some_int=1",
"--some_int=value"};
bool parsed_ok =
Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
{Flag::CreateFlag("some_int", &some_int, "some int")});
EXPECT_TRUE(parsed_ok);
EXPECT_EQ(1, some_int);
EXPECT_EQ(argc, 2);
EXPECT_EQ("--some_int=value", argv_strings[1]);
}
// When there are duplicate flags, all of them will be checked against the arg
// list.
TEST(CommandLineFlagsTest, DuplicateFlags) {
int some_int1 = -23;
int some_int2 = -23;
int argc = 2;
const char* argv_strings[] = {"program_name", "--some_int=1"};
bool parsed_ok =
Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
{Flag::CreateFlag("some_int", &some_int1, "some int1"),
Flag::CreateFlag("some_int", &some_int2, "some int2")});
EXPECT_TRUE(parsed_ok);
EXPECT_EQ(1, some_int1);
EXPECT_EQ(1, some_int2);
EXPECT_EQ(argc, 1);
}
TEST(CommandLineFlagsTest, DuplicateFlagsNotFound) {
int some_int1 = -23;
int some_int2 = -23;
int argc = 2;
const char* argv_strings[] = {"program_name", "--some_float=1.0"};
bool parsed_ok = Flags::Parse(
&argc, reinterpret_cast<const char**>(argv_strings),
{Flag::CreateFlag("some_int", &some_int1, "some int1", Flag::OPTIONAL),
Flag::CreateFlag("some_int", &some_int2, "some int2", Flag::REQUIRED)});
EXPECT_FALSE(parsed_ok);
EXPECT_EQ(-23, some_int1);
EXPECT_EQ(-23, some_int2);
EXPECT_EQ(argc, 2);
}
TEST(CommandLineFlagsTest, DuplicateFlagNamesButDifferentTypes) {
int some_int = -23;
bool some_bool = true;
int argc = 2;
const char* argv_strings[] = {"program_name", "--some_val=20"};
// In this case, the 2nd 'some_val' flag of bool type will cause a no-ok
// parsing result.
bool parsed_ok =
Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
{Flag::CreateFlag("some_val", &some_int, "some val-int"),
Flag::CreateFlag("some_val", &some_bool, "some val-bool")});
EXPECT_FALSE(parsed_ok);
EXPECT_EQ(20, some_int);
EXPECT_TRUE(some_bool);
EXPECT_EQ(argc, 1);
}
TEST(CommandLineFlagsTest, DuplicateFlagsAndArgs) {
int some_int1 = -23;
int some_int2 = -23;
int argc = 3;
const char* argv_strings[] = {"program_name", "--some_int=1 --some_int=2"};
bool parsed_ok = Flags::Parse(
&argc, reinterpret_cast<const char**>(argv_strings),
{Flag::CreateFlag("some_int", &some_int1, "flag1: bind with some_int1"),
Flag::CreateFlag("some_int", &some_int2, "flag2: bind with some_int2")});
// Note, when there're duplicate args, the flag value and the parsing result
// will be based on the 1st arg (i.e. --some_int=1). And both duplicate flags
// (i.e. flag1 and flag2) are checked, thus resulting their associated values
// (some_int1 and some_int2) being set to 1.
EXPECT_TRUE(parsed_ok);
EXPECT_EQ(1, some_int1);
EXPECT_EQ(1, some_int2);
EXPECT_EQ(argc, 2);
}
} // namespace } // namespace
} // namespace tflite } // namespace tflite
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}