Rollback of breaking f255fc6854
PiperOrigin-RevId: 306568160 Change-Id: I44dcddceb5a886b6337151ada0f7b758370c8d0c
This commit is contained in:
parent
6964061dc0
commit
7741464810
|
@ -237,7 +237,7 @@ cc_library(
|
|||
srcs = ["command_line_flags.cc"],
|
||||
hdrs = ["command_line_flags.h"],
|
||||
copts = tflite_copts(),
|
||||
deps = ["//tensorflow/lite:minimal_logging"],
|
||||
deps = ["//tensorflow/lite/tools:logging"],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
|
@ -247,8 +247,7 @@ cc_test(
|
|||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
":command_line_flags",
|
||||
"//tensorflow/lite/testing:util",
|
||||
"@com_google_googletest//:gtest",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -18,10 +18,11 @@ limitations under the License.
|
|||
#include <numeric>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/minimal_logging.h"
|
||||
#include "tensorflow/lite/tools/logging.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
@ -165,7 +166,12 @@ std::string Flag::GetTypeName() const {
|
|||
/*static*/ bool Flags::Parse(int* argc, const char** argv,
|
||||
const std::vector<Flag>& flag_list) {
|
||||
bool result = true;
|
||||
std::vector<bool> unknown_flags(*argc, true);
|
||||
std::vector<bool> unknown_argvs(*argc, true);
|
||||
// Record the list of flags that have been processed. key is the flag's name
|
||||
// and the value is the corresponding argv index if there's one, or -1 when
|
||||
// the argv list doesn't contain this flag.
|
||||
std::unordered_map<std::string, int> processed_flags;
|
||||
|
||||
// Stores indexes of flag_list in a sorted order.
|
||||
std::vector<int> sorted_idx(flag_list.size());
|
||||
std::iota(std::begin(sorted_idx), std::end(sorted_idx), 0);
|
||||
|
@ -174,53 +180,84 @@ std::string Flag::GetTypeName() const {
|
|||
});
|
||||
int positional_count = 0;
|
||||
|
||||
for (int i = 0; i < sorted_idx.size(); ++i) {
|
||||
const Flag& flag = flag_list[sorted_idx[i]];
|
||||
for (int idx = 0; idx < sorted_idx.size(); ++idx) {
|
||||
const Flag& flag = flag_list[sorted_idx[idx]];
|
||||
|
||||
const auto it = processed_flags.find(flag.name_);
|
||||
if (it != processed_flags.end()) {
|
||||
TFLITE_LOG(WARN) << "Duplicate flags: " << flag.name_;
|
||||
if (it->second != -1) {
|
||||
bool value_parsing_ok;
|
||||
flag.Parse(argv[it->second], &value_parsing_ok);
|
||||
if (!value_parsing_ok) {
|
||||
TFLITE_LOG(ERROR) << "Failed to parse flag '" << flag.name_
|
||||
<< "' against argv '" << argv[it->second] << "'";
|
||||
result = false;
|
||||
}
|
||||
continue;
|
||||
} else if (flag.flag_type_ == Flag::REQUIRED) {
|
||||
TFLITE_LOG(ERROR) << "Required flag not provided: " << flag.name_;
|
||||
// If the required flag isn't found, we immediately stop the whole flag
|
||||
// parsing.
|
||||
result = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Parses positional flags.
|
||||
if (flag.flag_type_ == Flag::POSITIONAL) {
|
||||
if (++positional_count >= *argc) {
|
||||
TFLITE_LOG(TFLITE_LOG_ERROR, "Too few command line arguments");
|
||||
TFLITE_LOG(ERROR) << "Too few command line arguments.";
|
||||
return false;
|
||||
}
|
||||
bool value_parsing_ok;
|
||||
flag.Parse(argv[positional_count], &value_parsing_ok);
|
||||
if (!value_parsing_ok) {
|
||||
TFLITE_LOG(TFLITE_LOG_ERROR, "Failed to parse positional flag: %s",
|
||||
flag.name_.c_str());
|
||||
TFLITE_LOG(ERROR) << "Failed to parse positional flag: " << flag.name_;
|
||||
return false;
|
||||
}
|
||||
unknown_flags[positional_count] = false;
|
||||
unknown_argvs[positional_count] = false;
|
||||
processed_flags[flag.name_] = positional_count;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Parse other flags.
|
||||
bool was_found = false;
|
||||
for (int i = positional_count + 1; i < *argc; ++i) {
|
||||
if (!unknown_flags[i]) continue;
|
||||
if (!unknown_argvs[i]) continue;
|
||||
bool value_parsing_ok;
|
||||
was_found = flag.Parse(argv[i], &value_parsing_ok);
|
||||
if (!value_parsing_ok) {
|
||||
TFLITE_LOG(TFLITE_LOG_ERROR, "Failed to parse flag: %s",
|
||||
flag.name_.c_str());
|
||||
TFLITE_LOG(ERROR) << "Failed to parse flag '" << flag.name_
|
||||
<< "' against argv '" << argv[i] << "'";
|
||||
result = false;
|
||||
}
|
||||
if (was_found) {
|
||||
unknown_flags[i] = false;
|
||||
unknown_argvs[i] = false;
|
||||
processed_flags[flag.name_] = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Check if required flag not found.
|
||||
if (flag.flag_type_ == Flag::REQUIRED && !was_found) {
|
||||
TFLITE_LOG(TFLITE_LOG_ERROR, "Required flag not provided: %s",
|
||||
flag.name_.c_str());
|
||||
|
||||
// If the flag is found from the argv (i.e. the flag name appears in argv),
|
||||
// continue to the next flag parsing.
|
||||
if (was_found) continue;
|
||||
|
||||
// The flag isn't found, do some bookkeeping work.
|
||||
processed_flags[flag.name_] = -1;
|
||||
if (flag.flag_type_ == Flag::REQUIRED) {
|
||||
TFLITE_LOG(ERROR) << "Required flag not provided: " << flag.name_;
|
||||
result = false;
|
||||
// If the required flag isn't found, we immediately stop the whole flag
|
||||
// parsing by breaking the outer-loop (i.e. the 'sorted_idx'-iteration
|
||||
// loop).
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
int dst = 1; // Skip argv[0]
|
||||
for (int i = 1; i < *argc; ++i) {
|
||||
if (unknown_flags[i]) {
|
||||
if (unknown_argvs[i]) {
|
||||
argv[dst++] = argv[i];
|
||||
}
|
||||
}
|
||||
|
|
|
@ -125,6 +125,14 @@ class Flags {
|
|||
// with matching flags, and remove the matching arguments from (*argc, argv).
|
||||
// Return true iff all recognized flag values were parsed correctly, and the
|
||||
// first remaining argument is not "--help".
|
||||
// Note:
|
||||
// 1. when there are duplicate args in argv for the same flag, the flag value
|
||||
// and the parse result will be based on the 1st arg.
|
||||
// 2. when there are duplicate flags in flag_list (i.e. two flags having the
|
||||
// same name), all of them will be checked against the arg list and the parse
|
||||
// result will be false if any of the parsing fails.
|
||||
// See *Duplicate* unit tests in command_line_flags_test.cc for the
|
||||
// illustration of such behaviors.
|
||||
static bool Parse(int* argc, const char** argv,
|
||||
const std::vector<Flag>& flag_list);
|
||||
|
||||
|
|
|
@ -17,7 +17,6 @@ limitations under the License.
|
|||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/testing/util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
@ -60,12 +59,12 @@ TEST(CommandLineFlagsTest, BasicUsage) {
|
|||
Flag::CreateFlag("float_1", &float_1, "some float", Flag::POSITIONAL),
|
||||
});
|
||||
|
||||
EXPECT_EQ(true, parsed_ok);
|
||||
EXPECT_TRUE(parsed_ok);
|
||||
EXPECT_EQ(20, some_int32);
|
||||
EXPECT_EQ(8, some_int1);
|
||||
EXPECT_EQ(5, some_int2);
|
||||
EXPECT_EQ(214748364700, some_int64);
|
||||
EXPECT_EQ(true, some_switch);
|
||||
EXPECT_TRUE(some_switch);
|
||||
EXPECT_EQ("somethingelse", some_name);
|
||||
EXPECT_NEAR(42.0f, some_float, 1e-5f);
|
||||
EXPECT_NEAR(12.2f, float_1, 1e-5f);
|
||||
|
@ -82,7 +81,7 @@ TEST(CommandLineFlagsTest, EmptyStringFlag) {
|
|||
&argc, reinterpret_cast<const char**>(argv_strings),
|
||||
{Flag::CreateFlag("some_string", &some_string, "some string")});
|
||||
|
||||
EXPECT_EQ(true, parsed_ok);
|
||||
EXPECT_TRUE(parsed_ok);
|
||||
EXPECT_EQ(some_string, "");
|
||||
EXPECT_EQ(argc, 1);
|
||||
}
|
||||
|
@ -95,7 +94,7 @@ TEST(CommandLineFlagsTest, BadIntValue) {
|
|||
Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
|
||||
{Flag::CreateFlag("some_int", &some_int, "some int")});
|
||||
|
||||
EXPECT_EQ(false, parsed_ok);
|
||||
EXPECT_FALSE(parsed_ok);
|
||||
EXPECT_EQ(10, some_int);
|
||||
EXPECT_EQ(argc, 1);
|
||||
}
|
||||
|
@ -108,8 +107,8 @@ TEST(CommandLineFlagsTest, BadBoolValue) {
|
|||
&argc, reinterpret_cast<const char**>(argv_strings),
|
||||
{Flag::CreateFlag("some_switch", &some_switch, "some switch")});
|
||||
|
||||
EXPECT_EQ(false, parsed_ok);
|
||||
EXPECT_EQ(false, some_switch);
|
||||
EXPECT_FALSE(parsed_ok);
|
||||
EXPECT_FALSE(some_switch);
|
||||
EXPECT_EQ(argc, 1);
|
||||
}
|
||||
|
||||
|
@ -121,7 +120,7 @@ TEST(CommandLineFlagsTest, BadFloatValue) {
|
|||
Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
|
||||
{Flag::CreateFlag("some_float", &some_float, "some float")});
|
||||
|
||||
EXPECT_EQ(false, parsed_ok);
|
||||
EXPECT_FALSE(parsed_ok);
|
||||
EXPECT_NEAR(-23.23f, some_float, 1e-5f);
|
||||
EXPECT_EQ(argc, 1);
|
||||
}
|
||||
|
@ -134,7 +133,7 @@ TEST(CommandLineFlagsTest, RequiredFlagNotFound) {
|
|||
&argc, reinterpret_cast<const char**>(argv_strings),
|
||||
{Flag::CreateFlag("some_flag", &some_float, "", Flag::REQUIRED)});
|
||||
|
||||
EXPECT_EQ(false, parsed_ok);
|
||||
EXPECT_FALSE(parsed_ok);
|
||||
EXPECT_NEAR(-23.23f, some_float, 1e-5f);
|
||||
EXPECT_EQ(argc, 2);
|
||||
}
|
||||
|
@ -147,7 +146,7 @@ TEST(CommandLineFlagsTest, NoArguments) {
|
|||
&argc, reinterpret_cast<const char**>(argv_strings),
|
||||
{Flag::CreateFlag("some_flag", &some_float, "", Flag::REQUIRED)});
|
||||
|
||||
EXPECT_EQ(false, parsed_ok);
|
||||
EXPECT_FALSE(parsed_ok);
|
||||
EXPECT_NEAR(-23.23f, some_float, 1e-5f);
|
||||
EXPECT_EQ(argc, 1);
|
||||
}
|
||||
|
@ -160,7 +159,7 @@ TEST(CommandLineFlagsTest, NotEnoughArguments) {
|
|||
&argc, reinterpret_cast<const char**>(argv_strings),
|
||||
{Flag::CreateFlag("some_flag", &some_float, "", Flag::POSITIONAL)});
|
||||
|
||||
EXPECT_EQ(false, parsed_ok);
|
||||
EXPECT_FALSE(parsed_ok);
|
||||
EXPECT_NEAR(-23.23f, some_float, 1e-5f);
|
||||
EXPECT_EQ(argc, 1);
|
||||
}
|
||||
|
@ -173,7 +172,7 @@ TEST(CommandLineFlagsTest, PositionalFlagFailed) {
|
|||
&argc, reinterpret_cast<const char**>(argv_strings),
|
||||
{Flag::CreateFlag("some_flag", &some_float, "", Flag::POSITIONAL)});
|
||||
|
||||
EXPECT_EQ(false, parsed_ok);
|
||||
EXPECT_FALSE(parsed_ok);
|
||||
EXPECT_NEAR(-23.23f, some_float, 1e-5f);
|
||||
EXPECT_EQ(argc, 2);
|
||||
}
|
||||
|
@ -235,11 +234,125 @@ TEST(CommandLineFlagsTest, UsageString) {
|
|||
<< usage;
|
||||
}
|
||||
|
||||
// When there are duplicate args, the flag value and the parsing result will be
|
||||
// based on the 1st arg.
|
||||
TEST(CommandLineFlagsTest, DuplicateArgsParsableValues) {
|
||||
int some_int = -23;
|
||||
int argc = 3;
|
||||
const char* argv_strings[] = {"program_name", "--some_int=1", "--some_int=2"};
|
||||
bool parsed_ok =
|
||||
Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
|
||||
{Flag::CreateFlag("some_int", &some_int, "some int")});
|
||||
|
||||
EXPECT_TRUE(parsed_ok);
|
||||
EXPECT_EQ(1, some_int);
|
||||
EXPECT_EQ(argc, 2);
|
||||
EXPECT_EQ("--some_int=2", argv_strings[1]);
|
||||
}
|
||||
|
||||
TEST(CommandLineFlagsTest, DuplicateArgsBadValueAppearFirst) {
|
||||
int some_int = -23;
|
||||
int argc = 3;
|
||||
const char* argv_strings[] = {"program_name", "--some_int=value",
|
||||
"--some_int=1"};
|
||||
bool parsed_ok =
|
||||
Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
|
||||
{Flag::CreateFlag("some_int", &some_int, "some int")});
|
||||
|
||||
EXPECT_FALSE(parsed_ok);
|
||||
EXPECT_EQ(-23, some_int);
|
||||
EXPECT_EQ(argc, 2);
|
||||
EXPECT_EQ("--some_int=1", argv_strings[1]);
|
||||
}
|
||||
|
||||
TEST(CommandLineFlagsTest, DuplicateArgsBadValueAppearSecondly) {
|
||||
int some_int = -23;
|
||||
int argc = 3;
|
||||
// Although the 2nd arg has non-parsable int value, the flag 'some_int' value
|
||||
// could still be set and the parsing result is ok.
|
||||
const char* argv_strings[] = {"program_name", "--some_int=1",
|
||||
"--some_int=value"};
|
||||
bool parsed_ok =
|
||||
Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
|
||||
{Flag::CreateFlag("some_int", &some_int, "some int")});
|
||||
|
||||
EXPECT_TRUE(parsed_ok);
|
||||
EXPECT_EQ(1, some_int);
|
||||
EXPECT_EQ(argc, 2);
|
||||
EXPECT_EQ("--some_int=value", argv_strings[1]);
|
||||
}
|
||||
|
||||
// When there are duplicate flags, all of them will be checked against the arg
|
||||
// list.
|
||||
TEST(CommandLineFlagsTest, DuplicateFlags) {
|
||||
int some_int1 = -23;
|
||||
int some_int2 = -23;
|
||||
int argc = 2;
|
||||
const char* argv_strings[] = {"program_name", "--some_int=1"};
|
||||
bool parsed_ok =
|
||||
Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
|
||||
{Flag::CreateFlag("some_int", &some_int1, "some int1"),
|
||||
Flag::CreateFlag("some_int", &some_int2, "some int2")});
|
||||
|
||||
EXPECT_TRUE(parsed_ok);
|
||||
EXPECT_EQ(1, some_int1);
|
||||
EXPECT_EQ(1, some_int2);
|
||||
EXPECT_EQ(argc, 1);
|
||||
}
|
||||
|
||||
TEST(CommandLineFlagsTest, DuplicateFlagsNotFound) {
|
||||
int some_int1 = -23;
|
||||
int some_int2 = -23;
|
||||
int argc = 2;
|
||||
const char* argv_strings[] = {"program_name", "--some_float=1.0"};
|
||||
bool parsed_ok = Flags::Parse(
|
||||
&argc, reinterpret_cast<const char**>(argv_strings),
|
||||
{Flag::CreateFlag("some_int", &some_int1, "some int1", Flag::OPTIONAL),
|
||||
Flag::CreateFlag("some_int", &some_int2, "some int2", Flag::REQUIRED)});
|
||||
|
||||
EXPECT_FALSE(parsed_ok);
|
||||
EXPECT_EQ(-23, some_int1);
|
||||
EXPECT_EQ(-23, some_int2);
|
||||
EXPECT_EQ(argc, 2);
|
||||
}
|
||||
|
||||
TEST(CommandLineFlagsTest, DuplicateFlagNamesButDifferentTypes) {
|
||||
int some_int = -23;
|
||||
bool some_bool = true;
|
||||
int argc = 2;
|
||||
const char* argv_strings[] = {"program_name", "--some_val=20"};
|
||||
// In this case, the 2nd 'some_val' flag of bool type will cause a no-ok
|
||||
// parsing result.
|
||||
bool parsed_ok =
|
||||
Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
|
||||
{Flag::CreateFlag("some_val", &some_int, "some val-int"),
|
||||
Flag::CreateFlag("some_val", &some_bool, "some val-bool")});
|
||||
|
||||
EXPECT_FALSE(parsed_ok);
|
||||
EXPECT_EQ(20, some_int);
|
||||
EXPECT_TRUE(some_bool);
|
||||
EXPECT_EQ(argc, 1);
|
||||
}
|
||||
|
||||
TEST(CommandLineFlagsTest, DuplicateFlagsAndArgs) {
|
||||
int some_int1 = -23;
|
||||
int some_int2 = -23;
|
||||
int argc = 3;
|
||||
const char* argv_strings[] = {"program_name", "--some_int=1", "--some_int=2"};
|
||||
bool parsed_ok = Flags::Parse(
|
||||
&argc, reinterpret_cast<const char**>(argv_strings),
|
||||
{Flag::CreateFlag("some_int", &some_int1, "flag1: bind with some_int1"),
|
||||
Flag::CreateFlag("some_int", &some_int2, "flag2: bind with some_int2")});
|
||||
|
||||
// Note, when there're duplicate args, the flag value and the parsing result
|
||||
// will be based on the 1st arg (i.e. --some_int=1). And both duplicate flags
|
||||
// (i.e. flag1 and flag2) are checked, thus resulting their associated values
|
||||
// (some_int1 and some_int2) being set to 1.
|
||||
EXPECT_TRUE(parsed_ok);
|
||||
EXPECT_EQ(1, some_int1);
|
||||
EXPECT_EQ(1, some_int2);
|
||||
EXPECT_EQ(argc, 2);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
::tflite::LogToStderr();
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue