Add some utility functions for supporting an alternate pbtxt format

that supports multi-line text without uncomfortable escaping.  So:

  description: "A `SparseTensor` ... `sparse_indices`,\n`sparse_values`, and `sparse_shape`, where\n\n```sparse_indices.shape[1] == sparse_shape.shape[0] == R```\n\nAn `N`-minibatch ..."

would become:

  description: <<END
A `SparseTensor` ... `sparse_indices`,
`sparse_values`, and `sparse_shape`, where

```sparse_indices.shape[1] == sparse_shape.shape[0] == R```

An `N`-minibatch ...
END

PiperOrigin-RevId: 161008382
This commit is contained in:
A. Unique TensorFlower 2017-07-05 14:47:12 -07:00 committed by TensorFlower Gardener
parent eccd162119
commit 1857e187c9
8 changed files with 515 additions and 0 deletions

View File

@ -372,6 +372,7 @@ filegroup(
"//tensorflow/tools/docker/notebooks:all_files",
"//tensorflow/tools/docs:all_files",
"//tensorflow/tools/git:all_files",
"//tensorflow/tools/mlpbtxt:all_files",
"//tensorflow/tools/proto_text:all_files",
"//tensorflow/tools/quantization:all_files",
"//tensorflow/tools/test:all_files",

View File

@ -2109,6 +2109,17 @@ cc_test(
],
)
tf_cc_test(
name = "framework_op_gen_lib_test",
size = "small",
srcs = ["framework/op_gen_lib_test.cc"],
deps = [
":op_gen_lib",
":test",
":test_main",
],
)
tf_cc_test(
name = "quantize_training_test",
srcs = ["graph/quantize_training_test.cc"],

View File

@ -73,6 +73,178 @@ bool ConsumeEquals(StringPiece* description) {
return false;
}
// Split `*orig` into two pieces at the first occurence of `split_ch`.
// Returns whether `split_ch` was found. Afterwards, `*before_split`
// contains the maximum prefix of the input `*orig` that doesn't
// contain `split_ch`, and `*orig` contains everything after the
// first `split_ch`.
static bool SplitAt(char split_ch, StringPiece* orig,
StringPiece* before_split) {
auto pos = orig->find(split_ch);
if (pos == StringPiece::npos) {
*before_split = *orig;
orig->clear();
return false;
} else {
*before_split = orig->substr(0, pos);
orig->remove_prefix(pos + 1);
return true;
}
}
// Does this line start with "<spaces><field>:" where "<field>" is
// in multi_line_fields? Sets *colon_pos to the position of the colon.
static bool StartsWithFieldName(StringPiece line,
const std::vector<string>& multi_line_fields) {
StringPiece up_to_colon;
if (!SplitAt(':', &line, &up_to_colon)) return false;
while (up_to_colon.Consume(" "))
; // Remove leading spaces.
for (const auto& field : multi_line_fields) {
if (up_to_colon == field) {
return true;
}
}
return false;
}
static bool ConvertLine(StringPiece line,
const std::vector<string>& multi_line_fields,
string* ml) {
// Is this a field we should convert?
if (!StartsWithFieldName(line, multi_line_fields)) {
return false;
}
// Has a matching field name, so look for "..." after the colon.
StringPiece up_to_colon;
StringPiece after_colon = line;
SplitAt(':', &after_colon, &up_to_colon);
while (after_colon.Consume(" "))
; // Remove leading spaces.
if (!after_colon.Consume("\"")) {
// We only convert string fields, so don't convert this line.
return false;
}
auto last_quote = after_colon.rfind('\"');
if (last_quote == StringPiece::npos) {
// Error: we don't see the expected matching quote, abort the conversion.
return false;
}
StringPiece escaped = after_colon.substr(0, last_quote);
StringPiece suffix = after_colon.substr(last_quote + 1);
// We've now parsed line into '<up_to_colon>: "<escaped>"<suffix>'
string unescaped;
if (!str_util::CUnescape(escaped, &unescaped, nullptr)) {
// Error unescaping, abort the conversion.
return false;
}
// No more errors possible at this point.
// Find a string to mark the end that isn't in unescaped.
string end = "END";
for (int s = 0; unescaped.find(end) != string::npos; ++s) {
end = strings::StrCat("END", s);
}
// Actually start writing the converted output.
strings::StrAppend(ml, up_to_colon, ": <<", end, "\n", unescaped, "\n", end);
if (!suffix.empty()) {
// Output suffix, in case there was a trailing comment in the source.
strings::StrAppend(ml, suffix);
}
strings::StrAppend(ml, "\n");
return true;
}
string PBTxtToMultiline(StringPiece pbtxt,
const std::vector<string>& multi_line_fields) {
string ml;
// Probably big enough, since the input and output are about the
// same size, but just a guess.
ml.reserve(pbtxt.size() * (17. / 16));
StringPiece line;
while (!pbtxt.empty()) {
// Split pbtxt into its first line and everything after.
SplitAt('\n', &pbtxt, &line);
// Convert line or output it unchanged
if (!ConvertLine(line, multi_line_fields, &ml)) {
strings::StrAppend(&ml, line, "\n");
}
}
return ml;
}
// Given a single line of text `line` with first : at `colon`, determine if
// there is an "<<END" expression after the colon and if so return true and set
// `*end` to everything after the "<<".
static bool FindMultiline(StringPiece line, size_t colon, string* end) {
if (colon == StringPiece::npos) return false;
line.remove_prefix(colon + 1);
while (line.Consume(" ")) {
}
if (line.Consume("<<")) {
*end = line.ToString();
return true;
}
return false;
}
string PBTxtFromMultiline(StringPiece multiline_pbtxt) {
string pbtxt;
// Probably big enough, since the input and output are about the
// same size, but just a guess.
pbtxt.reserve(multiline_pbtxt.size() * (33. / 32));
StringPiece line;
while (!multiline_pbtxt.empty()) {
// Split multiline_pbtxt into its first line and everything after.
if (!SplitAt('\n', &multiline_pbtxt, &line)) {
strings::StrAppend(&pbtxt, line);
break;
}
string end;
auto colon = line.find(':');
if (!FindMultiline(line, colon, &end)) {
// Normal case: not a multi-line string, just output the line as-is.
strings::StrAppend(&pbtxt, line, "\n");
continue;
}
// Multi-line case:
// something: <<END
// xx
// yy
// END
// Should be converted to:
// something: "xx\nyy"
// Output everything up to the colon (" something:").
strings::StrAppend(&pbtxt, line.substr(0, colon + 1));
// Add every line to unescaped until we see the "END" string.
string unescaped;
bool first = true;
string suffix;
while (!multiline_pbtxt.empty()) {
SplitAt('\n', &multiline_pbtxt, &line);
if (line.Consume(end)) break;
if (first) {
first = false;
} else {
unescaped.push_back('\n');
}
strings::StrAppend(&unescaped, line);
line.clear();
}
// Escape what we extracted and then output it in quotes.
strings::StrAppend(&pbtxt, " \"", str_util::CEscape(unescaped), "\"", line,
"\n");
}
return pbtxt;
}
OpGenOverrideMap::OpGenOverrideMap() {}
OpGenOverrideMap::~OpGenOverrideMap() {}

View File

@ -43,6 +43,11 @@ string WordWrap(StringPiece prefix, StringPiece str, int width);
// returns false.
bool ConsumeEquals(StringPiece* description);
// Convert text-serialized protobufs to/from multiline format.
string PBTxtToMultiline(StringPiece pbtxt,
const std::vector<string>& multi_line_fields);
string PBTxtFromMultiline(StringPiece multiline_pbtxt);
// Takes a list of files with OpGenOverrides text protos, and allows you to
// look up the specific override for any given op.
class OpGenOverrideMap {

View File

@ -0,0 +1,131 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op_gen_lib.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
TEST(OpGenLibTest, MultilinePBTxt) {
// Non-multiline pbtxt
const string pbtxt = R"(foo: "abc"
foo: ""
foo: "\n\n"
foo: "abc\nEND"
foo: "ghi\njkl\n"
bar: "quotes:\""
)";
// Field "foo" converted to multiline but not "bar".
const string ml_foo = R"(foo: <<END
abc
END
foo: <<END
END
foo: <<END
END
foo: <<END0
abc
END
END0
foo: <<END
ghi
jkl
END
bar: "quotes:\""
)";
// Both fields "foo" and "bar" converted to multiline.
const string ml_foo_bar = R"(foo: <<END
abc
END
foo: <<END
END
foo: <<END
END
foo: <<END0
abc
END
END0
foo: <<END
ghi
jkl
END
bar: <<END
quotes:"
END
)";
// ToMultiline
EXPECT_EQ(ml_foo, PBTxtToMultiline(pbtxt, {"foo"}));
EXPECT_EQ(pbtxt, PBTxtToMultiline(pbtxt, {"baz"}));
EXPECT_EQ(ml_foo_bar, PBTxtToMultiline(pbtxt, {"foo", "bar"}));
// FromMultiline
EXPECT_EQ(pbtxt, PBTxtFromMultiline(pbtxt));
EXPECT_EQ(pbtxt, PBTxtFromMultiline(ml_foo));
EXPECT_EQ(pbtxt, PBTxtFromMultiline(ml_foo_bar));
}
TEST(OpGenLibTest, PBTxtToMultilineErrorCases) {
// Everything correct.
EXPECT_EQ("f: <<END\n7\nEND\n", PBTxtToMultiline("f: \"7\"\n", {"f"}));
// In general, if there is a problem parsing in PBTxtToMultiline, it leaves
// the line alone.
// No colon
EXPECT_EQ("f \"7\"\n", PBTxtToMultiline("f \"7\"\n", {"f"}));
// Only converts strings.
EXPECT_EQ("f: 7\n", PBTxtToMultiline("f: 7\n", {"f"}));
// No quote after colon.
EXPECT_EQ("f: 7\"\n", PBTxtToMultiline("f: 7\"\n", {"f"}));
// Only one quote
EXPECT_EQ("f: \"7\n", PBTxtToMultiline("f: \"7\n", {"f"}));
// Illegal escaping
EXPECT_EQ("f: \"7\\\"\n", PBTxtToMultiline("f: \"7\\\"\n", {"f"}));
}
TEST(OpGenLibTest, PBTxtToMultilineComments) {
const string pbtxt = R"(f: "bar" # Comment 1
f: "\n" # Comment 2
)";
const string ml = R"(f: <<END
bar
END # Comment 1
f: <<END
END # Comment 2
)";
EXPECT_EQ(ml, PBTxtToMultiline(pbtxt, {"f"}));
EXPECT_EQ(pbtxt, PBTxtFromMultiline(ml));
}
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,44 @@
# Description:
# This package provides binaries that convert between multi-line and standard
# pbtxt (text-serialization of protocol message) files.
package(default_visibility = ["//visibility:private"])
licenses(["notice"]) # Apache 2.0
exports_files([
"LICENSE",
"placeholder.txt",
])
cc_binary(
name = "tomlpbtxt",
srcs = ["tomlpbtxt.cc"],
deps = [
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:op_gen_lib",
],
)
cc_binary(
name = "frommlpbtxt",
srcs = ["frommlpbtxt.cc"],
deps = [
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:op_gen_lib",
],
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,70 @@
/* Copyright 2017 The TensorFlow Authors All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stdio.h>
#include "tensorflow/core/framework/op_gen_lib.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow {
namespace {
int Run(int argc, char** argv) {
string FLAGS_in = "";
string FLAGS_out = "";
std::vector<Flag> flag_list = {
Flag("in", &FLAGS_in, "Input multi-line proto text (.mlpbtxt) file name"),
Flag("out", &FLAGS_out, "Output proto text (.pbtxt) file name")};
// Parse the command-line.
const string usage = Flags::Usage(argv[0], flag_list);
const bool parse_ok = Flags::Parse(&argc, argv, flag_list);
if (argc != 1 || !parse_ok) {
printf("%s", usage.c_str());
return 2;
}
port::InitMain(argv[0], &argc, &argv);
// Read the input file --in.
string in_contents;
Status s = ReadFileToString(Env::Default(), FLAGS_in, &in_contents);
if (!s.ok()) {
printf("Error reading file %s: %s\n", FLAGS_in.c_str(),
s.ToString().c_str());
return 1;
}
// Write the output file --out.
const string out_contents = PBTxtFromMultiline(in_contents);
s = WriteStringToFile(Env::Default(), FLAGS_out, out_contents);
if (!s.ok()) {
printf("Error writing file %s: %s\n", FLAGS_out.c_str(),
s.ToString().c_str());
return 1;
}
return 0;
}
} // namespace
} // namespace tensorflow
int main(int argc, char** argv) { return tensorflow::Run(argc, argv); }

View File

@ -0,0 +1,81 @@
/* Copyright 2017 The TensorFlow Authors All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stdio.h>
#include "tensorflow/core/framework/op_gen_lib.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow {
namespace {
int Run(int argc, char** argv) {
string FLAGS_in = "";
string FLAGS_out = "";
string FLAGS_fields = "description";
std::vector<Flag> flag_list = {
Flag("in", &FLAGS_in, "Input proto text (.pbtxt) file name"),
Flag("out", &FLAGS_out,
"Output multi-line proto text (.mlpbtxt) file name"),
Flag("fields", &FLAGS_fields, "Comma-separated list of field names")};
// Parse the command-line.
const string usage = Flags::Usage(argv[0], flag_list);
const bool parse_ok = Flags::Parse(&argc, argv, flag_list);
if (argc != 1 || !parse_ok) {
printf("%s", usage.c_str());
return 2;
}
// Parse the --fields option.
std::vector<string> fields =
str_util::Split(FLAGS_fields, ',', str_util::SkipEmpty());
if (fields.empty()) {
printf("--fields must be non-empty.\n%s", usage.c_str());
return 2;
}
port::InitMain(argv[0], &argc, &argv);
// Read the input file --in.
string in_contents;
Status s = ReadFileToString(Env::Default(), FLAGS_in, &in_contents);
if (!s.ok()) {
printf("Error reading file %s: %s\n", FLAGS_in.c_str(),
s.ToString().c_str());
return 1;
}
// Write the output file --out.
const string out_contents = PBTxtToMultiline(in_contents, fields);
s = WriteStringToFile(Env::Default(), FLAGS_out, out_contents);
if (!s.ok()) {
printf("Error writing file %s: %s\n", FLAGS_out.c_str(),
s.ToString().c_str());
return 1;
}
return 0;
}
} // namespace
} // namespace tensorflow
int main(int argc, char** argv) { return tensorflow::Run(argc, argv); }