Rollback of breaking f255fc6854

PiperOrigin-RevId: 306568160
Change-Id: I44dcddceb5a886b6337151ada0f7b758370c8d0c
This commit is contained in:
Chao Mei 2020-04-14 20:25:13 -07:00 committed by TensorFlower Gardener
parent 6964061dc0
commit 7741464810
4 changed files with 195 additions and 38 deletions

View File

@ -237,7 +237,7 @@ cc_library(
srcs = ["command_line_flags.cc"],
hdrs = ["command_line_flags.h"],
copts = tflite_copts(),
deps = ["//tensorflow/lite:minimal_logging"],
deps = ["//tensorflow/lite/tools:logging"],
)
cc_test(
@ -247,8 +247,7 @@ cc_test(
visibility = ["//visibility:private"],
deps = [
":command_line_flags",
"//tensorflow/lite/testing:util",
"@com_google_googletest//:gtest",
"@com_google_googletest//:gtest_main",
],
)

View File

@ -18,10 +18,11 @@ limitations under the License.
#include <numeric>
#include <sstream>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "tensorflow/lite/minimal_logging.h"
#include "tensorflow/lite/tools/logging.h"
namespace tflite {
namespace {
@ -165,7 +166,12 @@ std::string Flag::GetTypeName() const {
/*static*/ bool Flags::Parse(int* argc, const char** argv,
const std::vector<Flag>& flag_list) {
bool result = true;
std::vector<bool> unknown_flags(*argc, true);
std::vector<bool> unknown_argvs(*argc, true);
// Record the list of flags that have been processed. key is the flag's name
// and the value is the corresponding argv index if there's one, or -1 when
// the argv list doesn't contain this flag.
std::unordered_map<std::string, int> processed_flags;
// Stores indexes of flag_list in a sorted order.
std::vector<int> sorted_idx(flag_list.size());
std::iota(std::begin(sorted_idx), std::end(sorted_idx), 0);
@ -174,53 +180,84 @@ std::string Flag::GetTypeName() const {
});
int positional_count = 0;
for (int i = 0; i < sorted_idx.size(); ++i) {
const Flag& flag = flag_list[sorted_idx[i]];
for (int idx = 0; idx < sorted_idx.size(); ++idx) {
const Flag& flag = flag_list[sorted_idx[idx]];
const auto it = processed_flags.find(flag.name_);
if (it != processed_flags.end()) {
TFLITE_LOG(WARN) << "Duplicate flags: " << flag.name_;
if (it->second != -1) {
bool value_parsing_ok;
flag.Parse(argv[it->second], &value_parsing_ok);
if (!value_parsing_ok) {
TFLITE_LOG(ERROR) << "Failed to parse flag '" << flag.name_
<< "' against argv '" << argv[it->second] << "'";
result = false;
}
continue;
} else if (flag.flag_type_ == Flag::REQUIRED) {
TFLITE_LOG(ERROR) << "Required flag not provided: " << flag.name_;
// If the required flag isn't found, we immediately stop the whole flag
// parsing.
result = false;
break;
}
}
// Parses positional flags.
if (flag.flag_type_ == Flag::POSITIONAL) {
if (++positional_count >= *argc) {
TFLITE_LOG(TFLITE_LOG_ERROR, "Too few command line arguments");
TFLITE_LOG(ERROR) << "Too few command line arguments.";
return false;
}
bool value_parsing_ok;
flag.Parse(argv[positional_count], &value_parsing_ok);
if (!value_parsing_ok) {
TFLITE_LOG(TFLITE_LOG_ERROR, "Failed to parse positional flag: %s",
flag.name_.c_str());
TFLITE_LOG(ERROR) << "Failed to parse positional flag: " << flag.name_;
return false;
}
unknown_flags[positional_count] = false;
unknown_argvs[positional_count] = false;
processed_flags[flag.name_] = positional_count;
continue;
}
// Parse other flags.
bool was_found = false;
for (int i = positional_count + 1; i < *argc; ++i) {
if (!unknown_flags[i]) continue;
if (!unknown_argvs[i]) continue;
bool value_parsing_ok;
was_found = flag.Parse(argv[i], &value_parsing_ok);
if (!value_parsing_ok) {
TFLITE_LOG(TFLITE_LOG_ERROR, "Failed to parse flag: %s",
flag.name_.c_str());
TFLITE_LOG(ERROR) << "Failed to parse flag '" << flag.name_
<< "' against argv '" << argv[i] << "'";
result = false;
}
if (was_found) {
unknown_flags[i] = false;
unknown_argvs[i] = false;
processed_flags[flag.name_] = i;
break;
}
}
// Check if required flag not found.
if (flag.flag_type_ == Flag::REQUIRED && !was_found) {
TFLITE_LOG(TFLITE_LOG_ERROR, "Required flag not provided: %s",
flag.name_.c_str());
// If the flag is found from the argv (i.e. the flag name appears in argv),
// continue to the next flag parsing.
if (was_found) continue;
// The flag isn't found, do some bookkeeping work.
processed_flags[flag.name_] = -1;
if (flag.flag_type_ == Flag::REQUIRED) {
TFLITE_LOG(ERROR) << "Required flag not provided: " << flag.name_;
result = false;
// If the required flag isn't found, we immediately stop the whole flag
// parsing by breaking the outer-loop (i.e. the 'sorted_idx'-iteration
// loop).
break;
}
}
int dst = 1; // Skip argv[0]
for (int i = 1; i < *argc; ++i) {
if (unknown_flags[i]) {
if (unknown_argvs[i]) {
argv[dst++] = argv[i];
}
}

View File

@ -125,6 +125,14 @@ class Flags {
// with matching flags, and remove the matching arguments from (*argc, argv).
// Return true iff all recognized flag values were parsed correctly, and the
// first remaining argument is not "--help".
// Note:
// 1. when there are duplicate args in argv for the same flag, the flag value
// and the parse result will be based on the 1st arg.
// 2. when there are duplicate flags in flag_list (i.e. two flags having the
// same name), all of them will be checked against the arg list and the parse
// result will be false if any of the parsing fails.
// See *Duplicate* unit tests in command_line_flags_test.cc for the
// illustration of such behaviors.
static bool Parse(int* argc, const char** argv,
const std::vector<Flag>& flag_list);

View File

@ -17,7 +17,6 @@ limitations under the License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/testing/util.h"
namespace tflite {
namespace {
@ -60,12 +59,12 @@ TEST(CommandLineFlagsTest, BasicUsage) {
Flag::CreateFlag("float_1", &float_1, "some float", Flag::POSITIONAL),
});
EXPECT_EQ(true, parsed_ok);
EXPECT_TRUE(parsed_ok);
EXPECT_EQ(20, some_int32);
EXPECT_EQ(8, some_int1);
EXPECT_EQ(5, some_int2);
EXPECT_EQ(214748364700, some_int64);
EXPECT_EQ(true, some_switch);
EXPECT_TRUE(some_switch);
EXPECT_EQ("somethingelse", some_name);
EXPECT_NEAR(42.0f, some_float, 1e-5f);
EXPECT_NEAR(12.2f, float_1, 1e-5f);
@ -82,7 +81,7 @@ TEST(CommandLineFlagsTest, EmptyStringFlag) {
&argc, reinterpret_cast<const char**>(argv_strings),
{Flag::CreateFlag("some_string", &some_string, "some string")});
EXPECT_EQ(true, parsed_ok);
EXPECT_TRUE(parsed_ok);
EXPECT_EQ(some_string, "");
EXPECT_EQ(argc, 1);
}
@ -95,7 +94,7 @@ TEST(CommandLineFlagsTest, BadIntValue) {
Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
{Flag::CreateFlag("some_int", &some_int, "some int")});
EXPECT_EQ(false, parsed_ok);
EXPECT_FALSE(parsed_ok);
EXPECT_EQ(10, some_int);
EXPECT_EQ(argc, 1);
}
@ -108,8 +107,8 @@ TEST(CommandLineFlagsTest, BadBoolValue) {
&argc, reinterpret_cast<const char**>(argv_strings),
{Flag::CreateFlag("some_switch", &some_switch, "some switch")});
EXPECT_EQ(false, parsed_ok);
EXPECT_EQ(false, some_switch);
EXPECT_FALSE(parsed_ok);
EXPECT_FALSE(some_switch);
EXPECT_EQ(argc, 1);
}
@ -121,7 +120,7 @@ TEST(CommandLineFlagsTest, BadFloatValue) {
Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
{Flag::CreateFlag("some_float", &some_float, "some float")});
EXPECT_EQ(false, parsed_ok);
EXPECT_FALSE(parsed_ok);
EXPECT_NEAR(-23.23f, some_float, 1e-5f);
EXPECT_EQ(argc, 1);
}
@ -134,7 +133,7 @@ TEST(CommandLineFlagsTest, RequiredFlagNotFound) {
&argc, reinterpret_cast<const char**>(argv_strings),
{Flag::CreateFlag("some_flag", &some_float, "", Flag::REQUIRED)});
EXPECT_EQ(false, parsed_ok);
EXPECT_FALSE(parsed_ok);
EXPECT_NEAR(-23.23f, some_float, 1e-5f);
EXPECT_EQ(argc, 2);
}
@ -147,7 +146,7 @@ TEST(CommandLineFlagsTest, NoArguments) {
&argc, reinterpret_cast<const char**>(argv_strings),
{Flag::CreateFlag("some_flag", &some_float, "", Flag::REQUIRED)});
EXPECT_EQ(false, parsed_ok);
EXPECT_FALSE(parsed_ok);
EXPECT_NEAR(-23.23f, some_float, 1e-5f);
EXPECT_EQ(argc, 1);
}
@ -160,7 +159,7 @@ TEST(CommandLineFlagsTest, NotEnoughArguments) {
&argc, reinterpret_cast<const char**>(argv_strings),
{Flag::CreateFlag("some_flag", &some_float, "", Flag::POSITIONAL)});
EXPECT_EQ(false, parsed_ok);
EXPECT_FALSE(parsed_ok);
EXPECT_NEAR(-23.23f, some_float, 1e-5f);
EXPECT_EQ(argc, 1);
}
@ -173,7 +172,7 @@ TEST(CommandLineFlagsTest, PositionalFlagFailed) {
&argc, reinterpret_cast<const char**>(argv_strings),
{Flag::CreateFlag("some_flag", &some_float, "", Flag::POSITIONAL)});
EXPECT_EQ(false, parsed_ok);
EXPECT_FALSE(parsed_ok);
EXPECT_NEAR(-23.23f, some_float, 1e-5f);
EXPECT_EQ(argc, 2);
}
@ -235,11 +234,125 @@ TEST(CommandLineFlagsTest, UsageString) {
<< usage;
}
// When there are duplicate args, the flag value and the parsing result will be
// based on the 1st arg.
TEST(CommandLineFlagsTest, DuplicateArgsParsableValues) {
int some_int = -23;
int argc = 3;
const char* argv_strings[] = {"program_name", "--some_int=1", "--some_int=2"};
bool parsed_ok =
Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
{Flag::CreateFlag("some_int", &some_int, "some int")});
EXPECT_TRUE(parsed_ok);
EXPECT_EQ(1, some_int);
EXPECT_EQ(argc, 2);
EXPECT_EQ("--some_int=2", argv_strings[1]);
}
TEST(CommandLineFlagsTest, DuplicateArgsBadValueAppearFirst) {
int some_int = -23;
int argc = 3;
const char* argv_strings[] = {"program_name", "--some_int=value",
"--some_int=1"};
bool parsed_ok =
Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
{Flag::CreateFlag("some_int", &some_int, "some int")});
EXPECT_FALSE(parsed_ok);
EXPECT_EQ(-23, some_int);
EXPECT_EQ(argc, 2);
EXPECT_EQ("--some_int=1", argv_strings[1]);
}
TEST(CommandLineFlagsTest, DuplicateArgsBadValueAppearSecondly) {
int some_int = -23;
int argc = 3;
// Although the 2nd arg has non-parsable int value, the flag 'some_int' value
// could still be set and the parsing result is ok.
const char* argv_strings[] = {"program_name", "--some_int=1",
"--some_int=value"};
bool parsed_ok =
Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
{Flag::CreateFlag("some_int", &some_int, "some int")});
EXPECT_TRUE(parsed_ok);
EXPECT_EQ(1, some_int);
EXPECT_EQ(argc, 2);
EXPECT_EQ("--some_int=value", argv_strings[1]);
}
// When there are duplicate flags, all of them will be checked against the arg
// list.
TEST(CommandLineFlagsTest, DuplicateFlags) {
int some_int1 = -23;
int some_int2 = -23;
int argc = 2;
const char* argv_strings[] = {"program_name", "--some_int=1"};
bool parsed_ok =
Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
{Flag::CreateFlag("some_int", &some_int1, "some int1"),
Flag::CreateFlag("some_int", &some_int2, "some int2")});
EXPECT_TRUE(parsed_ok);
EXPECT_EQ(1, some_int1);
EXPECT_EQ(1, some_int2);
EXPECT_EQ(argc, 1);
}
TEST(CommandLineFlagsTest, DuplicateFlagsNotFound) {
int some_int1 = -23;
int some_int2 = -23;
int argc = 2;
const char* argv_strings[] = {"program_name", "--some_float=1.0"};
bool parsed_ok = Flags::Parse(
&argc, reinterpret_cast<const char**>(argv_strings),
{Flag::CreateFlag("some_int", &some_int1, "some int1", Flag::OPTIONAL),
Flag::CreateFlag("some_int", &some_int2, "some int2", Flag::REQUIRED)});
EXPECT_FALSE(parsed_ok);
EXPECT_EQ(-23, some_int1);
EXPECT_EQ(-23, some_int2);
EXPECT_EQ(argc, 2);
}
TEST(CommandLineFlagsTest, DuplicateFlagNamesButDifferentTypes) {
int some_int = -23;
bool some_bool = true;
int argc = 2;
const char* argv_strings[] = {"program_name", "--some_val=20"};
// In this case, the 2nd 'some_val' flag of bool type will cause a no-ok
// parsing result.
bool parsed_ok =
Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
{Flag::CreateFlag("some_val", &some_int, "some val-int"),
Flag::CreateFlag("some_val", &some_bool, "some val-bool")});
EXPECT_FALSE(parsed_ok);
EXPECT_EQ(20, some_int);
EXPECT_TRUE(some_bool);
EXPECT_EQ(argc, 1);
}
TEST(CommandLineFlagsTest, DuplicateFlagsAndArgs) {
int some_int1 = -23;
int some_int2 = -23;
int argc = 3;
const char* argv_strings[] = {"program_name", "--some_int=1", "--some_int=2"};
bool parsed_ok = Flags::Parse(
&argc, reinterpret_cast<const char**>(argv_strings),
{Flag::CreateFlag("some_int", &some_int1, "flag1: bind with some_int1"),
Flag::CreateFlag("some_int", &some_int2, "flag2: bind with some_int2")});
// Note, when there're duplicate args, the flag value and the parsing result
// will be based on the 1st arg (i.e. --some_int=1). And both duplicate flags
// (i.e. flag1 and flag2) are checked, thus resulting their associated values
// (some_int1 and some_int2) being set to 1.
EXPECT_TRUE(parsed_ok);
EXPECT_EQ(1, some_int1);
EXPECT_EQ(1, some_int2);
EXPECT_EQ(argc, 2);
}
} // namespace
} // namespace tflite
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}