Add some utility functions for supporting an alternate pbtxt format
that supports multi-line text without uncomfortable escaping. So: description: "A `SparseTensor` ... `sparse_indices`,\n`sparse_values`, and `sparse_shape`, where\n\n```sparse_indices.shape[1] == sparse_shape.shape[0] == R```\n\nAn `N`-minibatch ..." would become: description: <<END A `SparseTensor` ... `sparse_indices`, `sparse_values`, and `sparse_shape`, where ```sparse_indices.shape[1] == sparse_shape.shape[0] == R``` An `N`-minibatch ... END PiperOrigin-RevId: 161008382
This commit is contained in:
parent
eccd162119
commit
1857e187c9
@ -372,6 +372,7 @@ filegroup(
|
||||
"//tensorflow/tools/docker/notebooks:all_files",
|
||||
"//tensorflow/tools/docs:all_files",
|
||||
"//tensorflow/tools/git:all_files",
|
||||
"//tensorflow/tools/mlpbtxt:all_files",
|
||||
"//tensorflow/tools/proto_text:all_files",
|
||||
"//tensorflow/tools/quantization:all_files",
|
||||
"//tensorflow/tools/test:all_files",
|
||||
|
@ -2109,6 +2109,17 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "framework_op_gen_lib_test",
|
||||
size = "small",
|
||||
srcs = ["framework/op_gen_lib_test.cc"],
|
||||
deps = [
|
||||
":op_gen_lib",
|
||||
":test",
|
||||
":test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "quantize_training_test",
|
||||
srcs = ["graph/quantize_training_test.cc"],
|
||||
|
@ -73,6 +73,178 @@ bool ConsumeEquals(StringPiece* description) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Split `*orig` into two pieces at the first occurence of `split_ch`.
|
||||
// Returns whether `split_ch` was found. Afterwards, `*before_split`
|
||||
// contains the maximum prefix of the input `*orig` that doesn't
|
||||
// contain `split_ch`, and `*orig` contains everything after the
|
||||
// first `split_ch`.
|
||||
static bool SplitAt(char split_ch, StringPiece* orig,
|
||||
StringPiece* before_split) {
|
||||
auto pos = orig->find(split_ch);
|
||||
if (pos == StringPiece::npos) {
|
||||
*before_split = *orig;
|
||||
orig->clear();
|
||||
return false;
|
||||
} else {
|
||||
*before_split = orig->substr(0, pos);
|
||||
orig->remove_prefix(pos + 1);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Does this line start with "<spaces><field>:" where "<field>" is
|
||||
// in multi_line_fields? Sets *colon_pos to the position of the colon.
|
||||
static bool StartsWithFieldName(StringPiece line,
|
||||
const std::vector<string>& multi_line_fields) {
|
||||
StringPiece up_to_colon;
|
||||
if (!SplitAt(':', &line, &up_to_colon)) return false;
|
||||
while (up_to_colon.Consume(" "))
|
||||
; // Remove leading spaces.
|
||||
for (const auto& field : multi_line_fields) {
|
||||
if (up_to_colon == field) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool ConvertLine(StringPiece line,
|
||||
const std::vector<string>& multi_line_fields,
|
||||
string* ml) {
|
||||
// Is this a field we should convert?
|
||||
if (!StartsWithFieldName(line, multi_line_fields)) {
|
||||
return false;
|
||||
}
|
||||
// Has a matching field name, so look for "..." after the colon.
|
||||
StringPiece up_to_colon;
|
||||
StringPiece after_colon = line;
|
||||
SplitAt(':', &after_colon, &up_to_colon);
|
||||
while (after_colon.Consume(" "))
|
||||
; // Remove leading spaces.
|
||||
if (!after_colon.Consume("\"")) {
|
||||
// We only convert string fields, so don't convert this line.
|
||||
return false;
|
||||
}
|
||||
auto last_quote = after_colon.rfind('\"');
|
||||
if (last_quote == StringPiece::npos) {
|
||||
// Error: we don't see the expected matching quote, abort the conversion.
|
||||
return false;
|
||||
}
|
||||
StringPiece escaped = after_colon.substr(0, last_quote);
|
||||
StringPiece suffix = after_colon.substr(last_quote + 1);
|
||||
// We've now parsed line into '<up_to_colon>: "<escaped>"<suffix>'
|
||||
|
||||
string unescaped;
|
||||
if (!str_util::CUnescape(escaped, &unescaped, nullptr)) {
|
||||
// Error unescaping, abort the conversion.
|
||||
return false;
|
||||
}
|
||||
// No more errors possible at this point.
|
||||
|
||||
// Find a string to mark the end that isn't in unescaped.
|
||||
string end = "END";
|
||||
for (int s = 0; unescaped.find(end) != string::npos; ++s) {
|
||||
end = strings::StrCat("END", s);
|
||||
}
|
||||
|
||||
// Actually start writing the converted output.
|
||||
strings::StrAppend(ml, up_to_colon, ": <<", end, "\n", unescaped, "\n", end);
|
||||
if (!suffix.empty()) {
|
||||
// Output suffix, in case there was a trailing comment in the source.
|
||||
strings::StrAppend(ml, suffix);
|
||||
}
|
||||
strings::StrAppend(ml, "\n");
|
||||
return true;
|
||||
}
|
||||
|
||||
string PBTxtToMultiline(StringPiece pbtxt,
|
||||
const std::vector<string>& multi_line_fields) {
|
||||
string ml;
|
||||
// Probably big enough, since the input and output are about the
|
||||
// same size, but just a guess.
|
||||
ml.reserve(pbtxt.size() * (17. / 16));
|
||||
StringPiece line;
|
||||
while (!pbtxt.empty()) {
|
||||
// Split pbtxt into its first line and everything after.
|
||||
SplitAt('\n', &pbtxt, &line);
|
||||
// Convert line or output it unchanged
|
||||
if (!ConvertLine(line, multi_line_fields, &ml)) {
|
||||
strings::StrAppend(&ml, line, "\n");
|
||||
}
|
||||
}
|
||||
return ml;
|
||||
}
|
||||
|
||||
// Given a single line of text `line` with first : at `colon`, determine if
|
||||
// there is an "<<END" expression after the colon and if so return true and set
|
||||
// `*end` to everything after the "<<".
|
||||
static bool FindMultiline(StringPiece line, size_t colon, string* end) {
|
||||
if (colon == StringPiece::npos) return false;
|
||||
line.remove_prefix(colon + 1);
|
||||
while (line.Consume(" ")) {
|
||||
}
|
||||
if (line.Consume("<<")) {
|
||||
*end = line.ToString();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
string PBTxtFromMultiline(StringPiece multiline_pbtxt) {
|
||||
string pbtxt;
|
||||
// Probably big enough, since the input and output are about the
|
||||
// same size, but just a guess.
|
||||
pbtxt.reserve(multiline_pbtxt.size() * (33. / 32));
|
||||
StringPiece line;
|
||||
while (!multiline_pbtxt.empty()) {
|
||||
// Split multiline_pbtxt into its first line and everything after.
|
||||
if (!SplitAt('\n', &multiline_pbtxt, &line)) {
|
||||
strings::StrAppend(&pbtxt, line);
|
||||
break;
|
||||
}
|
||||
|
||||
string end;
|
||||
auto colon = line.find(':');
|
||||
if (!FindMultiline(line, colon, &end)) {
|
||||
// Normal case: not a multi-line string, just output the line as-is.
|
||||
strings::StrAppend(&pbtxt, line, "\n");
|
||||
continue;
|
||||
}
|
||||
|
||||
// Multi-line case:
|
||||
// something: <<END
|
||||
// xx
|
||||
// yy
|
||||
// END
|
||||
// Should be converted to:
|
||||
// something: "xx\nyy"
|
||||
|
||||
// Output everything up to the colon (" something:").
|
||||
strings::StrAppend(&pbtxt, line.substr(0, colon + 1));
|
||||
|
||||
// Add every line to unescaped until we see the "END" string.
|
||||
string unescaped;
|
||||
bool first = true;
|
||||
string suffix;
|
||||
while (!multiline_pbtxt.empty()) {
|
||||
SplitAt('\n', &multiline_pbtxt, &line);
|
||||
if (line.Consume(end)) break;
|
||||
if (first) {
|
||||
first = false;
|
||||
} else {
|
||||
unescaped.push_back('\n');
|
||||
}
|
||||
strings::StrAppend(&unescaped, line);
|
||||
line.clear();
|
||||
}
|
||||
|
||||
// Escape what we extracted and then output it in quotes.
|
||||
strings::StrAppend(&pbtxt, " \"", str_util::CEscape(unescaped), "\"", line,
|
||||
"\n");
|
||||
}
|
||||
return pbtxt;
|
||||
}
|
||||
|
||||
OpGenOverrideMap::OpGenOverrideMap() {}
|
||||
OpGenOverrideMap::~OpGenOverrideMap() {}
|
||||
|
||||
|
@ -43,6 +43,11 @@ string WordWrap(StringPiece prefix, StringPiece str, int width);
|
||||
// returns false.
|
||||
bool ConsumeEquals(StringPiece* description);
|
||||
|
||||
// Convert text-serialized protobufs to/from multiline format.
|
||||
string PBTxtToMultiline(StringPiece pbtxt,
|
||||
const std::vector<string>& multi_line_fields);
|
||||
string PBTxtFromMultiline(StringPiece multiline_pbtxt);
|
||||
|
||||
// Takes a list of files with OpGenOverrides text protos, and allows you to
|
||||
// look up the specific override for any given op.
|
||||
class OpGenOverrideMap {
|
||||
|
131
tensorflow/core/framework/op_gen_lib_test.cc
Normal file
131
tensorflow/core/framework/op_gen_lib_test.cc
Normal file
@ -0,0 +1,131 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/op_gen_lib.h"
|
||||
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
TEST(OpGenLibTest, MultilinePBTxt) {
|
||||
// Non-multiline pbtxt
|
||||
const string pbtxt = R"(foo: "abc"
|
||||
foo: ""
|
||||
foo: "\n\n"
|
||||
foo: "abc\nEND"
|
||||
foo: "ghi\njkl\n"
|
||||
bar: "quotes:\""
|
||||
)";
|
||||
|
||||
// Field "foo" converted to multiline but not "bar".
|
||||
const string ml_foo = R"(foo: <<END
|
||||
abc
|
||||
END
|
||||
foo: <<END
|
||||
|
||||
END
|
||||
foo: <<END
|
||||
|
||||
|
||||
|
||||
END
|
||||
foo: <<END0
|
||||
abc
|
||||
END
|
||||
END0
|
||||
foo: <<END
|
||||
ghi
|
||||
jkl
|
||||
|
||||
END
|
||||
bar: "quotes:\""
|
||||
)";
|
||||
|
||||
// Both fields "foo" and "bar" converted to multiline.
|
||||
const string ml_foo_bar = R"(foo: <<END
|
||||
abc
|
||||
END
|
||||
foo: <<END
|
||||
|
||||
END
|
||||
foo: <<END
|
||||
|
||||
|
||||
|
||||
END
|
||||
foo: <<END0
|
||||
abc
|
||||
END
|
||||
END0
|
||||
foo: <<END
|
||||
ghi
|
||||
jkl
|
||||
|
||||
END
|
||||
bar: <<END
|
||||
quotes:"
|
||||
END
|
||||
)";
|
||||
|
||||
// ToMultiline
|
||||
EXPECT_EQ(ml_foo, PBTxtToMultiline(pbtxt, {"foo"}));
|
||||
EXPECT_EQ(pbtxt, PBTxtToMultiline(pbtxt, {"baz"}));
|
||||
EXPECT_EQ(ml_foo_bar, PBTxtToMultiline(pbtxt, {"foo", "bar"}));
|
||||
|
||||
// FromMultiline
|
||||
EXPECT_EQ(pbtxt, PBTxtFromMultiline(pbtxt));
|
||||
EXPECT_EQ(pbtxt, PBTxtFromMultiline(ml_foo));
|
||||
EXPECT_EQ(pbtxt, PBTxtFromMultiline(ml_foo_bar));
|
||||
}
|
||||
|
||||
TEST(OpGenLibTest, PBTxtToMultilineErrorCases) {
|
||||
// Everything correct.
|
||||
EXPECT_EQ("f: <<END\n7\nEND\n", PBTxtToMultiline("f: \"7\"\n", {"f"}));
|
||||
|
||||
// In general, if there is a problem parsing in PBTxtToMultiline, it leaves
|
||||
// the line alone.
|
||||
|
||||
// No colon
|
||||
EXPECT_EQ("f \"7\"\n", PBTxtToMultiline("f \"7\"\n", {"f"}));
|
||||
// Only converts strings.
|
||||
EXPECT_EQ("f: 7\n", PBTxtToMultiline("f: 7\n", {"f"}));
|
||||
// No quote after colon.
|
||||
EXPECT_EQ("f: 7\"\n", PBTxtToMultiline("f: 7\"\n", {"f"}));
|
||||
// Only one quote
|
||||
EXPECT_EQ("f: \"7\n", PBTxtToMultiline("f: \"7\n", {"f"}));
|
||||
// Illegal escaping
|
||||
EXPECT_EQ("f: \"7\\\"\n", PBTxtToMultiline("f: \"7\\\"\n", {"f"}));
|
||||
}
|
||||
|
||||
TEST(OpGenLibTest, PBTxtToMultilineComments) {
|
||||
const string pbtxt = R"(f: "bar" # Comment 1
|
||||
f: "\n" # Comment 2
|
||||
)";
|
||||
const string ml = R"(f: <<END
|
||||
bar
|
||||
END # Comment 1
|
||||
f: <<END
|
||||
|
||||
|
||||
END # Comment 2
|
||||
)";
|
||||
|
||||
EXPECT_EQ(ml, PBTxtToMultiline(pbtxt, {"f"}));
|
||||
EXPECT_EQ(pbtxt, PBTxtFromMultiline(ml));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
44
tensorflow/tools/mlpbtxt/BUILD
Normal file
44
tensorflow/tools/mlpbtxt/BUILD
Normal file
@ -0,0 +1,44 @@
|
||||
# Description:
|
||||
# This package provides binaries that convert between multi-line and standard
|
||||
# pbtxt (text-serialization of protocol message) files.
|
||||
|
||||
package(default_visibility = ["//visibility:private"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files([
|
||||
"LICENSE",
|
||||
"placeholder.txt",
|
||||
])
|
||||
|
||||
cc_binary(
|
||||
name = "tomlpbtxt",
|
||||
srcs = ["tomlpbtxt.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:op_gen_lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "frommlpbtxt",
|
||||
srcs = ["frommlpbtxt.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:op_gen_lib",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
70
tensorflow/tools/mlpbtxt/frommlpbtxt.cc
Normal file
70
tensorflow/tools/mlpbtxt/frommlpbtxt.cc
Normal file
@ -0,0 +1,70 @@
|
||||
/* Copyright 2017 The TensorFlow Authors All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include "tensorflow/core/framework/op_gen_lib.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
int Run(int argc, char** argv) {
|
||||
string FLAGS_in = "";
|
||||
string FLAGS_out = "";
|
||||
|
||||
std::vector<Flag> flag_list = {
|
||||
Flag("in", &FLAGS_in, "Input multi-line proto text (.mlpbtxt) file name"),
|
||||
Flag("out", &FLAGS_out, "Output proto text (.pbtxt) file name")};
|
||||
|
||||
// Parse the command-line.
|
||||
const string usage = Flags::Usage(argv[0], flag_list);
|
||||
const bool parse_ok = Flags::Parse(&argc, argv, flag_list);
|
||||
if (argc != 1 || !parse_ok) {
|
||||
printf("%s", usage.c_str());
|
||||
return 2;
|
||||
}
|
||||
|
||||
port::InitMain(argv[0], &argc, &argv);
|
||||
|
||||
// Read the input file --in.
|
||||
string in_contents;
|
||||
Status s = ReadFileToString(Env::Default(), FLAGS_in, &in_contents);
|
||||
if (!s.ok()) {
|
||||
printf("Error reading file %s: %s\n", FLAGS_in.c_str(),
|
||||
s.ToString().c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Write the output file --out.
|
||||
const string out_contents = PBTxtFromMultiline(in_contents);
|
||||
s = WriteStringToFile(Env::Default(), FLAGS_out, out_contents);
|
||||
if (!s.ok()) {
|
||||
printf("Error writing file %s: %s\n", FLAGS_out.c_str(),
|
||||
s.ToString().c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
int main(int argc, char** argv) { return tensorflow::Run(argc, argv); }
|
81
tensorflow/tools/mlpbtxt/tomlpbtxt.cc
Normal file
81
tensorflow/tools/mlpbtxt/tomlpbtxt.cc
Normal file
@ -0,0 +1,81 @@
|
||||
/* Copyright 2017 The TensorFlow Authors All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include "tensorflow/core/framework/op_gen_lib.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
int Run(int argc, char** argv) {
|
||||
string FLAGS_in = "";
|
||||
string FLAGS_out = "";
|
||||
string FLAGS_fields = "description";
|
||||
|
||||
std::vector<Flag> flag_list = {
|
||||
Flag("in", &FLAGS_in, "Input proto text (.pbtxt) file name"),
|
||||
Flag("out", &FLAGS_out,
|
||||
"Output multi-line proto text (.mlpbtxt) file name"),
|
||||
Flag("fields", &FLAGS_fields, "Comma-separated list of field names")};
|
||||
|
||||
// Parse the command-line.
|
||||
const string usage = Flags::Usage(argv[0], flag_list);
|
||||
const bool parse_ok = Flags::Parse(&argc, argv, flag_list);
|
||||
if (argc != 1 || !parse_ok) {
|
||||
printf("%s", usage.c_str());
|
||||
return 2;
|
||||
}
|
||||
|
||||
// Parse the --fields option.
|
||||
std::vector<string> fields =
|
||||
str_util::Split(FLAGS_fields, ',', str_util::SkipEmpty());
|
||||
if (fields.empty()) {
|
||||
printf("--fields must be non-empty.\n%s", usage.c_str());
|
||||
return 2;
|
||||
}
|
||||
|
||||
port::InitMain(argv[0], &argc, &argv);
|
||||
|
||||
// Read the input file --in.
|
||||
string in_contents;
|
||||
Status s = ReadFileToString(Env::Default(), FLAGS_in, &in_contents);
|
||||
if (!s.ok()) {
|
||||
printf("Error reading file %s: %s\n", FLAGS_in.c_str(),
|
||||
s.ToString().c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Write the output file --out.
|
||||
const string out_contents = PBTxtToMultiline(in_contents, fields);
|
||||
s = WriteStringToFile(Env::Default(), FLAGS_out, out_contents);
|
||||
if (!s.ok()) {
|
||||
printf("Error writing file %s: %s\n", FLAGS_out.c_str(),
|
||||
s.ToString().c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
int main(int argc, char** argv) { return tensorflow::Run(argc, argv); }
|
Loading…
Reference in New Issue
Block a user