From 1857e187c98b4863669b62a469acc1251e1c1f04 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 5 Jul 2017 14:47:12 -0700 Subject: [PATCH] Add some utility functions for supporting an alternate pbtxt format that supports multi-line text without uncomfortable escaping. So: description: "A `SparseTensor` ... `sparse_indices`,\n`sparse_values`, and `sparse_shape`, where\n\n```sparse_indices.shape[1] == sparse_shape.shape[0] == R```\n\nAn `N`-minibatch ..." would become: description: <find(split_ch); + if (pos == StringPiece::npos) { + *before_split = *orig; + orig->clear(); + return false; + } else { + *before_split = orig->substr(0, pos); + orig->remove_prefix(pos + 1); + return true; + } +} + +// Does this line start with ":" where "" is +// in multi_line_fields? Sets *colon_pos to the position of the colon. +static bool StartsWithFieldName(StringPiece line, + const std::vector& multi_line_fields) { + StringPiece up_to_colon; + if (!SplitAt(':', &line, &up_to_colon)) return false; + while (up_to_colon.Consume(" ")) + ; // Remove leading spaces. + for (const auto& field : multi_line_fields) { + if (up_to_colon == field) { + return true; + } + } + return false; +} + +static bool ConvertLine(StringPiece line, + const std::vector& multi_line_fields, + string* ml) { + // Is this a field we should convert? + if (!StartsWithFieldName(line, multi_line_fields)) { + return false; + } + // Has a matching field name, so look for "..." after the colon. + StringPiece up_to_colon; + StringPiece after_colon = line; + SplitAt(':', &after_colon, &up_to_colon); + while (after_colon.Consume(" ")) + ; // Remove leading spaces. + if (!after_colon.Consume("\"")) { + // We only convert string fields, so don't convert this line. + return false; + } + auto last_quote = after_colon.rfind('\"'); + if (last_quote == StringPiece::npos) { + // Error: we don't see the expected matching quote, abort the conversion. + return false; + } + StringPiece escaped = after_colon.substr(0, last_quote); + StringPiece suffix = after_colon.substr(last_quote + 1); + // We've now parsed line into ': ""' + + string unescaped; + if (!str_util::CUnescape(escaped, &unescaped, nullptr)) { + // Error unescaping, abort the conversion. + return false; + } + // No more errors possible at this point. + + // Find a string to mark the end that isn't in unescaped. + string end = "END"; + for (int s = 0; unescaped.find(end) != string::npos; ++s) { + end = strings::StrCat("END", s); + } + + // Actually start writing the converted output. + strings::StrAppend(ml, up_to_colon, ": <<", end, "\n", unescaped, "\n", end); + if (!suffix.empty()) { + // Output suffix, in case there was a trailing comment in the source. + strings::StrAppend(ml, suffix); + } + strings::StrAppend(ml, "\n"); + return true; +} + +string PBTxtToMultiline(StringPiece pbtxt, + const std::vector& multi_line_fields) { + string ml; + // Probably big enough, since the input and output are about the + // same size, but just a guess. + ml.reserve(pbtxt.size() * (17. / 16)); + StringPiece line; + while (!pbtxt.empty()) { + // Split pbtxt into its first line and everything after. + SplitAt('\n', &pbtxt, &line); + // Convert line or output it unchanged + if (!ConvertLine(line, multi_line_fields, &ml)) { + strings::StrAppend(&ml, line, "\n"); + } + } + return ml; +} + +// Given a single line of text `line` with first : at `colon`, determine if +// there is an "<& multi_line_fields); +string PBTxtFromMultiline(StringPiece multiline_pbtxt); + // Takes a list of files with OpGenOverrides text protos, and allows you to // look up the specific override for any given op. class OpGenOverrideMap { diff --git a/tensorflow/core/framework/op_gen_lib_test.cc b/tensorflow/core/framework/op_gen_lib_test.cc new file mode 100644 index 00000000000..cc1d117f384 --- /dev/null +++ b/tensorflow/core/framework/op_gen_lib_test.cc @@ -0,0 +1,131 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_gen_lib.h" + +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +TEST(OpGenLibTest, MultilinePBTxt) { + // Non-multiline pbtxt + const string pbtxt = R"(foo: "abc" +foo: "" +foo: "\n\n" +foo: "abc\nEND" + foo: "ghi\njkl\n" +bar: "quotes:\"" +)"; + + // Field "foo" converted to multiline but not "bar". + const string ml_foo = R"(foo: < + +#include "tensorflow/core/framework/op_gen_lib.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace { + +int Run(int argc, char** argv) { + string FLAGS_in = ""; + string FLAGS_out = ""; + + std::vector flag_list = { + Flag("in", &FLAGS_in, "Input multi-line proto text (.mlpbtxt) file name"), + Flag("out", &FLAGS_out, "Output proto text (.pbtxt) file name")}; + + // Parse the command-line. + const string usage = Flags::Usage(argv[0], flag_list); + const bool parse_ok = Flags::Parse(&argc, argv, flag_list); + if (argc != 1 || !parse_ok) { + printf("%s", usage.c_str()); + return 2; + } + + port::InitMain(argv[0], &argc, &argv); + + // Read the input file --in. + string in_contents; + Status s = ReadFileToString(Env::Default(), FLAGS_in, &in_contents); + if (!s.ok()) { + printf("Error reading file %s: %s\n", FLAGS_in.c_str(), + s.ToString().c_str()); + return 1; + } + + // Write the output file --out. + const string out_contents = PBTxtFromMultiline(in_contents); + s = WriteStringToFile(Env::Default(), FLAGS_out, out_contents); + if (!s.ok()) { + printf("Error writing file %s: %s\n", FLAGS_out.c_str(), + s.ToString().c_str()); + return 1; + } + + return 0; +} + +} // namespace +} // namespace tensorflow + +int main(int argc, char** argv) { return tensorflow::Run(argc, argv); } diff --git a/tensorflow/tools/mlpbtxt/tomlpbtxt.cc b/tensorflow/tools/mlpbtxt/tomlpbtxt.cc new file mode 100644 index 00000000000..469be49ed3c --- /dev/null +++ b/tensorflow/tools/mlpbtxt/tomlpbtxt.cc @@ -0,0 +1,81 @@ +/* Copyright 2017 The TensorFlow Authors All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/core/framework/op_gen_lib.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace { + +int Run(int argc, char** argv) { + string FLAGS_in = ""; + string FLAGS_out = ""; + string FLAGS_fields = "description"; + + std::vector flag_list = { + Flag("in", &FLAGS_in, "Input proto text (.pbtxt) file name"), + Flag("out", &FLAGS_out, + "Output multi-line proto text (.mlpbtxt) file name"), + Flag("fields", &FLAGS_fields, "Comma-separated list of field names")}; + + // Parse the command-line. + const string usage = Flags::Usage(argv[0], flag_list); + const bool parse_ok = Flags::Parse(&argc, argv, flag_list); + if (argc != 1 || !parse_ok) { + printf("%s", usage.c_str()); + return 2; + } + + // Parse the --fields option. + std::vector fields = + str_util::Split(FLAGS_fields, ',', str_util::SkipEmpty()); + if (fields.empty()) { + printf("--fields must be non-empty.\n%s", usage.c_str()); + return 2; + } + + port::InitMain(argv[0], &argc, &argv); + + // Read the input file --in. + string in_contents; + Status s = ReadFileToString(Env::Default(), FLAGS_in, &in_contents); + if (!s.ok()) { + printf("Error reading file %s: %s\n", FLAGS_in.c_str(), + s.ToString().c_str()); + return 1; + } + + // Write the output file --out. + const string out_contents = PBTxtToMultiline(in_contents, fields); + s = WriteStringToFile(Env::Default(), FLAGS_out, out_contents); + if (!s.ok()) { + printf("Error writing file %s: %s\n", FLAGS_out.c_str(), + s.ToString().c_str()); + return 1; + } + + return 0; +} + +} // namespace +} // namespace tensorflow + +int main(int argc, char** argv) { return tensorflow::Run(argc, argv); }