Support stringmaplist parsing.

PiperOrigin-RevId: 231337720
This commit is contained in:
A. Unique TensorFlower 2019-01-28 20:30:28 -08:00 committed by TensorFlower Gardener
parent 90901d4454
commit d2cb5c3308
5 changed files with 257 additions and 72 deletions

View File

@ -265,7 +265,7 @@ def generated_test_models():
"logical_and",
"logical_or",
"logical_xor",
#"lstm", TODO(b/122889684): Resolve toco structured line parsing in oss.
"lstm",
"max_pool",
"maximum",
"mean",

View File

@ -133,6 +133,7 @@ cc_library(
cc_library(
name = "model_cmdline_flags",
srcs = [
"args.cc",
"model_cmdline_flags.cc",
],
hdrs = [
@ -478,3 +479,16 @@ tf_cc_test(
"@com_google_googletest//:gtest",
],
)
tf_cc_test(
name = "model_cmdline_flags_test",
srcs = [
"model_cmdline_flags_test.cc",
],
deps = [
":model_cmdline_flags",
":model_flags_proto_cc",
"//tensorflow/lite/testing:util",
"@com_google_googletest//:gtest",
],
)

View File

@ -0,0 +1,169 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/toco/args.h"
#include "absl/strings/str_split.h"
namespace toco {
namespace {
// Helper class for SplitStructuredLine parsing.
class ClosingSymbolLookup {
public:
explicit ClosingSymbolLookup(const char* symbol_pairs)
: closing_(), valid_closing_() {
// Initialize the opening/closing arrays.
for (const char* symbol = symbol_pairs; *symbol != 0; ++symbol) {
unsigned char opening = *symbol;
++symbol;
// If the string ends before the closing character has been found,
// use the opening character as the closing character.
unsigned char closing = *symbol != 0 ? *symbol : opening;
closing_[opening] = closing;
valid_closing_[closing] = true;
if (*symbol == 0) break;
}
}
ClosingSymbolLookup(const ClosingSymbolLookup&) = delete;
ClosingSymbolLookup& operator=(const ClosingSymbolLookup&) = delete;
// Returns the closing character corresponding to an opening one,
// or 0 if the argument is not an opening character.
char GetClosingChar(char opening) const {
return closing_[static_cast<unsigned char>(opening)];
}
// Returns true if the argument is a closing character.
bool IsClosing(char c) const {
return valid_closing_[static_cast<unsigned char>(c)];
}
private:
// Maps an opening character to its closing. If the entry contains 0,
// the character is not in the opening set.
char closing_[256];
// Valid closing characters.
bool valid_closing_[256];
};
bool SplitStructuredLine(absl::string_view line, char delimiter,
const char* symbol_pairs,
std::vector<absl::string_view>* cols) {
ClosingSymbolLookup lookup(symbol_pairs);
// Stack of symbols expected to close the current opened expressions.
std::vector<char> expected_to_close;
ABSL_RAW_CHECK(cols != nullptr, "");
cols->push_back(line);
for (size_t i = 0; i < line.size(); ++i) {
char c = line[i];
if (expected_to_close.empty() && c == delimiter) {
// We don't have any open expression, this is a valid separator.
cols->back().remove_suffix(line.size() - i);
cols->push_back(line.substr(i + 1));
} else if (!expected_to_close.empty() && c == expected_to_close.back()) {
// Can we close the currently open expression?
expected_to_close.pop_back();
} else if (lookup.GetClosingChar(c)) {
// If this is an opening symbol, we open a new expression and push
// the expected closing symbol on the stack.
expected_to_close.push_back(lookup.GetClosingChar(c));
} else if (lookup.IsClosing(c)) {
// Error: mismatched closing symbol.
return false;
}
}
if (!expected_to_close.empty()) {
return false; // Missing closing symbol(s)
}
return true; // Success
}
inline bool TryStripPrefixString(absl::string_view str,
absl::string_view prefix, string* result) {
bool res = absl::ConsumePrefix(&str, prefix);
result->assign(str.begin(), str.end());
return res;
}
inline bool TryStripSuffixString(absl::string_view str,
absl::string_view suffix, string* result) {
bool res = absl::ConsumeSuffix(&str, suffix);
result->assign(str.begin(), str.end());
return res;
}
} // namespace
bool Arg<toco::IntList>::Parse(string text) {
parsed_value_.elements.clear();
specified_ = true;
// strings::Split("") produces {""}, but we need {} on empty input.
// TODO(aselle): Moved this from elsewhere, but ahentz recommends we could
// use absl::SplitLeadingDec32Values(text.c_str(), &parsed_values_.elements)
if (!text.empty()) {
int32 element;
for (absl::string_view part : absl::StrSplit(text, ',')) {
if (!SimpleAtoi(part, &element)) return false;
parsed_value_.elements.push_back(element);
}
}
return true;
}
bool Arg<toco::StringMapList>::Parse(string text) {
parsed_value_.elements.clear();
specified_ = true;
if (text.empty()) {
return true;
}
std::vector<absl::string_view> outer_vector;
absl::string_view text_disposable_copy = text;
// TODO(aselle): Change argument parsing when absl supports structuredline.
SplitStructuredLine(text_disposable_copy, ',', "{}", &outer_vector);
for (const absl::string_view& outer_member_stringpiece : outer_vector) {
string outer_member(outer_member_stringpiece);
if (outer_member.empty()) {
continue;
}
string outer_member_copy = outer_member;
absl::StripAsciiWhitespace(&outer_member);
if (!TryStripPrefixString(outer_member, "{", &outer_member)) return false;
if (!TryStripSuffixString(outer_member, "}", &outer_member)) return false;
const std::vector<string> inner_fields_vector =
absl::StrSplit(outer_member, ',');
std::unordered_map<string, string> element;
for (const string& member_field : inner_fields_vector) {
std::vector<string> outer_member_key_value =
absl::StrSplit(member_field, ':');
if (outer_member_key_value.size() != 2) return false;
string& key = outer_member_key_value[0];
string& value = outer_member_key_value[1];
absl::StripAsciiWhitespace(&key);
absl::StripAsciiWhitespace(&value);
if (element.count(key) != 0) return false;
element[key] = value;
}
parsed_value_.elements.push_back(element);
}
return true;
}
} // namespace toco

View File

@ -22,10 +22,6 @@ limitations under the License.
#include <unordered_map>
#include <vector>
#include "tensorflow/lite/toco/toco_port.h"
#if defined(PLATFORM_GOOGLE)
#include "strings/split.h"
#include "strings/strip.h"
#endif
#include "absl/strings/numbers.h"
#include "absl/strings/str_split.h"
#include "tensorflow/lite/toco/toco_types.h"
@ -64,7 +60,7 @@ class Arg final {
const T& value() const { return value_; }
// Parsing callback for the tensorflow::Flags code
bool parse(T value_in) {
bool Parse(T value_in) {
value_ = value_in;
specified_ = true;
return true;
@ -72,7 +68,7 @@ class Arg final {
// Bind the parse member function so tensorflow::Flags can call it.
std::function<bool(T)> bind() {
return std::bind(&Arg::parse, this, std::placeholders::_1);
return std::bind(&Arg::Parse, this, std::placeholders::_1);
}
private:
@ -90,24 +86,10 @@ class Arg<toco::IntList> final {
// Return true if the command line argument was specified on the command line.
bool specified() const { return specified_; }
// Bind the parse member function so tensorflow::Flags can call it.
bool parse(string text) {
parsed_value_.elements.clear();
specified_ = true;
// strings::Split("") produces {""}, but we need {} on empty input.
// TODO(aselle): Moved this from elsewhere, but ahentz recommends we could
// use absl::SplitLeadingDec32Values(text.c_str(), &parsed_values_.elements)
if (!text.empty()) {
int32 element;
for (absl::string_view part : absl::StrSplit(text, ',')) {
if (!SimpleAtoi(part, &element)) return false;
parsed_value_.elements.push_back(element);
}
}
return true;
}
bool Parse(string text);
std::function<bool(string)> bind() {
return std::bind(&Arg::parse, this, std::placeholders::_1);
return std::bind(&Arg::Parse, this, std::placeholders::_1);
}
const toco::IntList& value() const { return parsed_value_; }
@ -126,57 +108,10 @@ class Arg<toco::StringMapList> final {
bool specified() const { return specified_; }
// Bind the parse member function so tensorflow::Flags can call it.
bool parse(string text) {
parsed_value_.elements.clear();
specified_ = true;
if (text.empty()) {
return true;
}
#if defined(PLATFORM_GOOGLE)
std::vector<absl::string_view> outer_vector;
absl::string_view text_disposable_copy = text;
SplitStructuredLine(text_disposable_copy, ',', "{}", &outer_vector);
for (const absl::string_view& outer_member_stringpiece : outer_vector) {
string outer_member(outer_member_stringpiece);
if (outer_member.empty()) {
continue;
}
string outer_member_copy = outer_member;
absl::StripAsciiWhitespace(&outer_member);
if (!strings::TryStripPrefixString(outer_member, "{", &outer_member))
return false;
if (!strings::TryStripSuffixString(outer_member, "}", &outer_member))
return false;
const std::vector<string> inner_fields_vector =
absl::StrSplit(outer_member, ',');
std::unordered_map<string, string> element;
for (const string& member_field : inner_fields_vector) {
std::vector<string> outer_member_key_value =
absl::StrSplit(member_field, ':');
if (outer_member_key_value.size() != 2) return false;
string& key = outer_member_key_value[0];
string& value = outer_member_key_value[1];
absl::StripAsciiWhitespace(&key);
absl::StripAsciiWhitespace(&value);
if (element.count(key) != 0) return false;
element[key] = value;
}
parsed_value_.elements.push_back(element);
}
return true;
#else
// TODO(aselle): Fix argument parsing when absl supports structuredline
fprintf(stderr, "%s:%d StringMapList arguments not supported\n", __FILE__,
__LINE__);
abort();
#endif
}
bool Parse(string text);
std::function<bool(string)> bind() {
return std::bind(&Arg::parse, this, std::placeholders::_1);
return std::bind(&Arg::Parse, this, std::placeholders::_1);
}
const toco::StringMapList& value() const { return parsed_value_; }

View File

@ -0,0 +1,67 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string>
#include <unordered_map>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/testing/util.h"
#include "tensorflow/lite/toco/args.h"
#include "tensorflow/lite/toco/model_cmdline_flags.h"
namespace toco {
namespace {
TEST(ModelCmdlineFlagsTest, ParseArgsStringMapList) {
int args_count = 3;
const char* args[] = {
"toco",
"--input_arrays=input_1",
"--rnn_states={state_array:rnn/BasicLSTMCellZeroState/zeros,"
"back_edge_source_array:rnn/basic_lstm_cell/Add_1,size:4},"
"{state_array:rnn/BasicLSTMCellZeroState/zeros_1,"
"back_edge_source_array:rnn/basic_lstm_cell/Mul_2,size:4}",
};
string expected_input_arrays = "input_1";
std::vector<std::unordered_map<string, string>> expected_rnn_states;
expected_rnn_states.push_back(
{{"state_array", "rnn/BasicLSTMCellZeroState/zeros"},
{"back_edge_source_array", "rnn/basic_lstm_cell/Add_1"},
{"size", "4"}});
expected_rnn_states.push_back(
{{"state_array", "rnn/BasicLSTMCellZeroState/zeros_1"},
{"back_edge_source_array", "rnn/basic_lstm_cell/Mul_2"},
{"size", "4"}});
string message;
ParsedModelFlags result_flags;
EXPECT_TRUE(ParseModelFlagsFromCommandLineFlags(
&args_count, const_cast<char**>(args), &message, &result_flags));
EXPECT_EQ(result_flags.input_arrays.value(), expected_input_arrays);
EXPECT_EQ(result_flags.rnn_states.value().elements, expected_rnn_states);
}
} // namespace
} // namespace toco
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
::toco::port::InitGoogleWasDoneElsewhere();
return RUN_ALL_TESTS();
}