Support stringmaplist parsing.
PiperOrigin-RevId: 231337720
This commit is contained in:
		
							parent
							
								
									90901d4454
								
							
						
					
					
						commit
						d2cb5c3308
					
				| @ -265,7 +265,7 @@ def generated_test_models(): | ||||
|         "logical_and", | ||||
|         "logical_or", | ||||
|         "logical_xor", | ||||
|         #"lstm", TODO(b/122889684): Resolve toco structured line parsing in oss. | ||||
|         "lstm", | ||||
|         "max_pool", | ||||
|         "maximum", | ||||
|         "mean", | ||||
|  | ||||
| @ -133,6 +133,7 @@ cc_library( | ||||
| cc_library( | ||||
|     name = "model_cmdline_flags", | ||||
|     srcs = [ | ||||
|         "args.cc", | ||||
|         "model_cmdline_flags.cc", | ||||
|     ], | ||||
|     hdrs = [ | ||||
| @ -478,3 +479,16 @@ tf_cc_test( | ||||
|         "@com_google_googletest//:gtest", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| tf_cc_test( | ||||
|     name = "model_cmdline_flags_test", | ||||
|     srcs = [ | ||||
|         "model_cmdline_flags_test.cc", | ||||
|     ], | ||||
|     deps = [ | ||||
|         ":model_cmdline_flags", | ||||
|         ":model_flags_proto_cc", | ||||
|         "//tensorflow/lite/testing:util", | ||||
|         "@com_google_googletest//:gtest", | ||||
|     ], | ||||
| ) | ||||
|  | ||||
							
								
								
									
										169
									
								
								tensorflow/lite/toco/args.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										169
									
								
								tensorflow/lite/toco/args.cc
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,169 @@ | ||||
| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #include "tensorflow/lite/toco/args.h" | ||||
| #include "absl/strings/str_split.h" | ||||
| 
 | ||||
| namespace toco { | ||||
| namespace { | ||||
| 
 | ||||
| // Helper class for SplitStructuredLine parsing.
 | ||||
| class ClosingSymbolLookup { | ||||
|  public: | ||||
|   explicit ClosingSymbolLookup(const char* symbol_pairs) | ||||
|       : closing_(), valid_closing_() { | ||||
|     // Initialize the opening/closing arrays.
 | ||||
|     for (const char* symbol = symbol_pairs; *symbol != 0; ++symbol) { | ||||
|       unsigned char opening = *symbol; | ||||
|       ++symbol; | ||||
|       // If the string ends before the closing character has been found,
 | ||||
|       // use the opening character as the closing character.
 | ||||
|       unsigned char closing = *symbol != 0 ? *symbol : opening; | ||||
|       closing_[opening] = closing; | ||||
|       valid_closing_[closing] = true; | ||||
|       if (*symbol == 0) break; | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   ClosingSymbolLookup(const ClosingSymbolLookup&) = delete; | ||||
|   ClosingSymbolLookup& operator=(const ClosingSymbolLookup&) = delete; | ||||
| 
 | ||||
|   // Returns the closing character corresponding to an opening one,
 | ||||
|   // or 0 if the argument is not an opening character.
 | ||||
|   char GetClosingChar(char opening) const { | ||||
|     return closing_[static_cast<unsigned char>(opening)]; | ||||
|   } | ||||
| 
 | ||||
|   // Returns true if the argument is a closing character.
 | ||||
|   bool IsClosing(char c) const { | ||||
|     return valid_closing_[static_cast<unsigned char>(c)]; | ||||
|   } | ||||
| 
 | ||||
|  private: | ||||
|   // Maps an opening character to its closing. If the entry contains 0,
 | ||||
|   // the character is not in the opening set.
 | ||||
|   char closing_[256]; | ||||
|   // Valid closing characters.
 | ||||
|   bool valid_closing_[256]; | ||||
| }; | ||||
| 
 | ||||
| bool SplitStructuredLine(absl::string_view line, char delimiter, | ||||
|                          const char* symbol_pairs, | ||||
|                          std::vector<absl::string_view>* cols) { | ||||
|   ClosingSymbolLookup lookup(symbol_pairs); | ||||
| 
 | ||||
|   // Stack of symbols expected to close the current opened expressions.
 | ||||
|   std::vector<char> expected_to_close; | ||||
| 
 | ||||
|   ABSL_RAW_CHECK(cols != nullptr, ""); | ||||
|   cols->push_back(line); | ||||
|   for (size_t i = 0; i < line.size(); ++i) { | ||||
|     char c = line[i]; | ||||
|     if (expected_to_close.empty() && c == delimiter) { | ||||
|       // We don't have any open expression, this is a valid separator.
 | ||||
|       cols->back().remove_suffix(line.size() - i); | ||||
|       cols->push_back(line.substr(i + 1)); | ||||
|     } else if (!expected_to_close.empty() && c == expected_to_close.back()) { | ||||
|       // Can we close the currently open expression?
 | ||||
|       expected_to_close.pop_back(); | ||||
|     } else if (lookup.GetClosingChar(c)) { | ||||
|       // If this is an opening symbol, we open a new expression and push
 | ||||
|       // the expected closing symbol on the stack.
 | ||||
|       expected_to_close.push_back(lookup.GetClosingChar(c)); | ||||
|     } else if (lookup.IsClosing(c)) { | ||||
|       // Error: mismatched closing symbol.
 | ||||
|       return false; | ||||
|     } | ||||
|   } | ||||
|   if (!expected_to_close.empty()) { | ||||
|     return false;  // Missing closing symbol(s)
 | ||||
|   } | ||||
|   return true;  // Success
 | ||||
| } | ||||
| 
 | ||||
| inline bool TryStripPrefixString(absl::string_view str, | ||||
|                                  absl::string_view prefix, string* result) { | ||||
|   bool res = absl::ConsumePrefix(&str, prefix); | ||||
|   result->assign(str.begin(), str.end()); | ||||
|   return res; | ||||
| } | ||||
| 
 | ||||
| inline bool TryStripSuffixString(absl::string_view str, | ||||
|                                  absl::string_view suffix, string* result) { | ||||
|   bool res = absl::ConsumeSuffix(&str, suffix); | ||||
|   result->assign(str.begin(), str.end()); | ||||
|   return res; | ||||
| } | ||||
| 
 | ||||
| }  // namespace
 | ||||
| 
 | ||||
| bool Arg<toco::IntList>::Parse(string text) { | ||||
|   parsed_value_.elements.clear(); | ||||
|   specified_ = true; | ||||
|   // strings::Split("") produces {""}, but we need {} on empty input.
 | ||||
|   // TODO(aselle): Moved this from elsewhere, but ahentz recommends we could
 | ||||
|   // use absl::SplitLeadingDec32Values(text.c_str(), &parsed_values_.elements)
 | ||||
|   if (!text.empty()) { | ||||
|     int32 element; | ||||
|     for (absl::string_view part : absl::StrSplit(text, ',')) { | ||||
|       if (!SimpleAtoi(part, &element)) return false; | ||||
|       parsed_value_.elements.push_back(element); | ||||
|     } | ||||
|   } | ||||
|   return true; | ||||
| } | ||||
| 
 | ||||
| bool Arg<toco::StringMapList>::Parse(string text) { | ||||
|   parsed_value_.elements.clear(); | ||||
|   specified_ = true; | ||||
| 
 | ||||
|   if (text.empty()) { | ||||
|     return true; | ||||
|   } | ||||
| 
 | ||||
|   std::vector<absl::string_view> outer_vector; | ||||
|   absl::string_view text_disposable_copy = text; | ||||
|   // TODO(aselle): Change argument parsing when absl supports structuredline.
 | ||||
|   SplitStructuredLine(text_disposable_copy, ',', "{}", &outer_vector); | ||||
|   for (const absl::string_view& outer_member_stringpiece : outer_vector) { | ||||
|     string outer_member(outer_member_stringpiece); | ||||
|     if (outer_member.empty()) { | ||||
|       continue; | ||||
|     } | ||||
|     string outer_member_copy = outer_member; | ||||
|     absl::StripAsciiWhitespace(&outer_member); | ||||
|     if (!TryStripPrefixString(outer_member, "{", &outer_member)) return false; | ||||
|     if (!TryStripSuffixString(outer_member, "}", &outer_member)) return false; | ||||
|     const std::vector<string> inner_fields_vector = | ||||
|         absl::StrSplit(outer_member, ','); | ||||
| 
 | ||||
|     std::unordered_map<string, string> element; | ||||
|     for (const string& member_field : inner_fields_vector) { | ||||
|       std::vector<string> outer_member_key_value = | ||||
|           absl::StrSplit(member_field, ':'); | ||||
|       if (outer_member_key_value.size() != 2) return false; | ||||
|       string& key = outer_member_key_value[0]; | ||||
|       string& value = outer_member_key_value[1]; | ||||
|       absl::StripAsciiWhitespace(&key); | ||||
|       absl::StripAsciiWhitespace(&value); | ||||
|       if (element.count(key) != 0) return false; | ||||
|       element[key] = value; | ||||
|     } | ||||
|     parsed_value_.elements.push_back(element); | ||||
|   } | ||||
|   return true; | ||||
| } | ||||
| 
 | ||||
| }  // namespace toco
 | ||||
| @ -22,10 +22,6 @@ limitations under the License. | ||||
| #include <unordered_map> | ||||
| #include <vector> | ||||
| #include "tensorflow/lite/toco/toco_port.h" | ||||
| #if defined(PLATFORM_GOOGLE) | ||||
| #include "strings/split.h" | ||||
| #include "strings/strip.h" | ||||
| #endif | ||||
| #include "absl/strings/numbers.h" | ||||
| #include "absl/strings/str_split.h" | ||||
| #include "tensorflow/lite/toco/toco_types.h" | ||||
| @ -64,7 +60,7 @@ class Arg final { | ||||
|   const T& value() const { return value_; } | ||||
| 
 | ||||
|   // Parsing callback for the tensorflow::Flags code
 | ||||
|   bool parse(T value_in) { | ||||
|   bool Parse(T value_in) { | ||||
|     value_ = value_in; | ||||
|     specified_ = true; | ||||
|     return true; | ||||
| @ -72,7 +68,7 @@ class Arg final { | ||||
| 
 | ||||
|   // Bind the parse member function so tensorflow::Flags can call it.
 | ||||
|   std::function<bool(T)> bind() { | ||||
|     return std::bind(&Arg::parse, this, std::placeholders::_1); | ||||
|     return std::bind(&Arg::Parse, this, std::placeholders::_1); | ||||
|   } | ||||
| 
 | ||||
|  private: | ||||
| @ -90,24 +86,10 @@ class Arg<toco::IntList> final { | ||||
|   // Return true if the command line argument was specified on the command line.
 | ||||
|   bool specified() const { return specified_; } | ||||
|   // Bind the parse member function so tensorflow::Flags can call it.
 | ||||
|   bool parse(string text) { | ||||
|     parsed_value_.elements.clear(); | ||||
|     specified_ = true; | ||||
|     // strings::Split("") produces {""}, but we need {} on empty input.
 | ||||
|     // TODO(aselle): Moved this from elsewhere, but ahentz recommends we could
 | ||||
|     // use absl::SplitLeadingDec32Values(text.c_str(), &parsed_values_.elements)
 | ||||
|     if (!text.empty()) { | ||||
|       int32 element; | ||||
|       for (absl::string_view part : absl::StrSplit(text, ',')) { | ||||
|         if (!SimpleAtoi(part, &element)) return false; | ||||
|         parsed_value_.elements.push_back(element); | ||||
|       } | ||||
|     } | ||||
|     return true; | ||||
|   } | ||||
|   bool Parse(string text); | ||||
| 
 | ||||
|   std::function<bool(string)> bind() { | ||||
|     return std::bind(&Arg::parse, this, std::placeholders::_1); | ||||
|     return std::bind(&Arg::Parse, this, std::placeholders::_1); | ||||
|   } | ||||
| 
 | ||||
|   const toco::IntList& value() const { return parsed_value_; } | ||||
| @ -126,57 +108,10 @@ class Arg<toco::StringMapList> final { | ||||
|   bool specified() const { return specified_; } | ||||
|   // Bind the parse member function so tensorflow::Flags can call it.
 | ||||
| 
 | ||||
|   bool parse(string text) { | ||||
|     parsed_value_.elements.clear(); | ||||
|     specified_ = true; | ||||
| 
 | ||||
|     if (text.empty()) { | ||||
|       return true; | ||||
|     } | ||||
| 
 | ||||
| #if defined(PLATFORM_GOOGLE) | ||||
|     std::vector<absl::string_view> outer_vector; | ||||
|     absl::string_view text_disposable_copy = text; | ||||
|     SplitStructuredLine(text_disposable_copy, ',', "{}", &outer_vector); | ||||
|     for (const absl::string_view& outer_member_stringpiece : outer_vector) { | ||||
|       string outer_member(outer_member_stringpiece); | ||||
|       if (outer_member.empty()) { | ||||
|         continue; | ||||
|       } | ||||
|       string outer_member_copy = outer_member; | ||||
|       absl::StripAsciiWhitespace(&outer_member); | ||||
|       if (!strings::TryStripPrefixString(outer_member, "{", &outer_member)) | ||||
|         return false; | ||||
|       if (!strings::TryStripSuffixString(outer_member, "}", &outer_member)) | ||||
|         return false; | ||||
|       const std::vector<string> inner_fields_vector = | ||||
|           absl::StrSplit(outer_member, ','); | ||||
| 
 | ||||
|       std::unordered_map<string, string> element; | ||||
|       for (const string& member_field : inner_fields_vector) { | ||||
|         std::vector<string> outer_member_key_value = | ||||
|             absl::StrSplit(member_field, ':'); | ||||
|         if (outer_member_key_value.size() != 2) return false; | ||||
|         string& key = outer_member_key_value[0]; | ||||
|         string& value = outer_member_key_value[1]; | ||||
|         absl::StripAsciiWhitespace(&key); | ||||
|         absl::StripAsciiWhitespace(&value); | ||||
|         if (element.count(key) != 0) return false; | ||||
|         element[key] = value; | ||||
|       } | ||||
|       parsed_value_.elements.push_back(element); | ||||
|     } | ||||
|     return true; | ||||
| #else | ||||
|     // TODO(aselle): Fix argument parsing when absl supports structuredline
 | ||||
|     fprintf(stderr, "%s:%d StringMapList arguments not supported\n", __FILE__, | ||||
|             __LINE__); | ||||
|     abort(); | ||||
| #endif | ||||
|   } | ||||
|   bool Parse(string text); | ||||
| 
 | ||||
|   std::function<bool(string)> bind() { | ||||
|     return std::bind(&Arg::parse, this, std::placeholders::_1); | ||||
|     return std::bind(&Arg::Parse, this, std::placeholders::_1); | ||||
|   } | ||||
| 
 | ||||
|   const toco::StringMapList& value() const { return parsed_value_; } | ||||
|  | ||||
							
								
								
									
										67
									
								
								tensorflow/lite/toco/model_cmdline_flags_test.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								tensorflow/lite/toco/model_cmdline_flags_test.cc
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,67 @@ | ||||
| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #include <string> | ||||
| #include <unordered_map> | ||||
| 
 | ||||
| #include <gmock/gmock.h> | ||||
| #include <gtest/gtest.h> | ||||
| #include "tensorflow/lite/testing/util.h" | ||||
| #include "tensorflow/lite/toco/args.h" | ||||
| #include "tensorflow/lite/toco/model_cmdline_flags.h" | ||||
| 
 | ||||
| namespace toco { | ||||
| namespace { | ||||
| 
 | ||||
| TEST(ModelCmdlineFlagsTest, ParseArgsStringMapList) { | ||||
|   int args_count = 3; | ||||
|   const char* args[] = { | ||||
|       "toco", | ||||
|       "--input_arrays=input_1", | ||||
|       "--rnn_states={state_array:rnn/BasicLSTMCellZeroState/zeros," | ||||
|       "back_edge_source_array:rnn/basic_lstm_cell/Add_1,size:4}," | ||||
|       "{state_array:rnn/BasicLSTMCellZeroState/zeros_1," | ||||
|       "back_edge_source_array:rnn/basic_lstm_cell/Mul_2,size:4}", | ||||
|   }; | ||||
| 
 | ||||
|   string expected_input_arrays = "input_1"; | ||||
|   std::vector<std::unordered_map<string, string>> expected_rnn_states; | ||||
|   expected_rnn_states.push_back( | ||||
|       {{"state_array", "rnn/BasicLSTMCellZeroState/zeros"}, | ||||
|        {"back_edge_source_array", "rnn/basic_lstm_cell/Add_1"}, | ||||
|        {"size", "4"}}); | ||||
|   expected_rnn_states.push_back( | ||||
|       {{"state_array", "rnn/BasicLSTMCellZeroState/zeros_1"}, | ||||
|        {"back_edge_source_array", "rnn/basic_lstm_cell/Mul_2"}, | ||||
|        {"size", "4"}}); | ||||
| 
 | ||||
|   string message; | ||||
|   ParsedModelFlags result_flags; | ||||
| 
 | ||||
|   EXPECT_TRUE(ParseModelFlagsFromCommandLineFlags( | ||||
|       &args_count, const_cast<char**>(args), &message, &result_flags)); | ||||
|   EXPECT_EQ(result_flags.input_arrays.value(), expected_input_arrays); | ||||
|   EXPECT_EQ(result_flags.rnn_states.value().elements, expected_rnn_states); | ||||
| } | ||||
| 
 | ||||
| }  // namespace
 | ||||
| }  // namespace toco
 | ||||
| 
 | ||||
| int main(int argc, char** argv) { | ||||
|   ::tflite::LogToStderr(); | ||||
|   ::testing::InitGoogleTest(&argc, argv); | ||||
|   ::toco::port::InitGoogleWasDoneElsewhere(); | ||||
|   return RUN_ALL_TESTS(); | ||||
| } | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user