diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 7562cbe939b..56bf3a48080 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -154,7 +154,7 @@ def tf_gen_op_wrapper_cc(name, out_ops_file, pkg=""): # tf_gen_op_wrappers_cc("tf_ops_lib", [ "array_ops", "math_ops" ]) # # -# This will ultimately generate ops/* files and a library like: +#This will ultimately generate ops/* files and a library like: # # cc_library(name = "tf_ops_lib", # srcs = [ "ops/array_ops.cc", @@ -667,7 +667,7 @@ def tf_generate_proto_text_sources(name, srcs_relative_dir, srcs): out_srcs = [p.replace(".proto", ".pb_text.cc") for p in srcs] native.genrule( name = name, - srcs = srcs, + srcs = srcs + ["//tensorflow/tools/proto_text:placeholder.txt"], outs = out_hdrs + out_srcs, cmd = "$(location //tensorflow/tools/proto_text:gen_proto_text_functions) " + "$(@D) " + srcs_relative_dir + " $(SRCS)", diff --git a/tensorflow/tools/proto_text/BUILD b/tensorflow/tools/proto_text/BUILD index 5c185e54e53..19f14748c54 100644 --- a/tensorflow/tools/proto_text/BUILD +++ b/tensorflow/tools/proto_text/BUILD @@ -12,7 +12,10 @@ package(default_visibility = ["//visibility:private"]) licenses(["notice"]) # Apache 2.0 -exports_files(["LICENSE"]) +exports_files([ + "LICENSE", + "placeholder.txt", +]) load( "//tensorflow:tensorflow.bzl", diff --git a/tensorflow/tools/proto_text/gen_proto_text_functions.cc b/tensorflow/tools/proto_text/gen_proto_text_functions.cc index ce8e2bcd3e5..23257eaa3ef 100644 --- a/tensorflow/tools/proto_text/gen_proto_text_functions.cc +++ b/tensorflow/tools/proto_text/gen_proto_text_functions.cc @@ -36,15 +36,33 @@ class CrashOnErrorCollector << column << " - " << message; } }; -} // namespace static const char kTensorflowHeaderPrefix[] = ""; +static const char kPlaceholderFile[] = + "tensorflow/tools/proto_text/placeholder.txt"; + +bool IsPlaceholderFile(const char* s) { + string ph(kPlaceholderFile); + string str(s); + return str.size() >= strlen(kPlaceholderFile) && + ph == str.substr(str.size() - ph.size()); +} + +} // namespace + // Main program to take input protos and write output pb_text source files that // contain generated proto text input and output functions. // -// Main expects the first argument to give the output path. This is followed by -// pairs of arguments: . +// Main expects: +// - First argument is output path +// - Second argument is the relative path of the protos to the root. E.g., +// for protos built by a rule in tensorflow/core, this will be +// tensorflow/core. +// - Then any number of source proto file names, plus one source name must be +// placeholder.txt from this gen tool's package. placeholder.txt is +// ignored for proto resolution, but is used to determine the root at which +// the build tool has placed the source proto files. // // Note that this code doesn't use tensorflow's command line parsing, because of // circular dependencies between libraries if that were done. @@ -59,26 +77,41 @@ int MainImpl(int argc, char** argv) { } const string output_root = argv[1]; - const string relative_path = kTensorflowHeaderPrefix + string(argv[2]); + const string output_relative_path = kTensorflowHeaderPrefix + string(argv[2]); + + string src_relative_path; + bool has_placeholder = false; + for (int i = 3; i < argc; ++i) { + if (IsPlaceholderFile(argv[i])) { + const string s(argv[i]); + src_relative_path = s.substr(0, s.size() - strlen(kPlaceholderFile)); + has_placeholder = true; + } + } + if (!has_placeholder) { + LOG(ERROR) << kPlaceholderFile << " must be passed"; + return -1; + } tensorflow::protobuf::compiler::DiskSourceTree source_tree; - // This requires all protos to be relative to the directory from which the - // genrule is invoked. If protos are generated in some other directory, - // then they may not be found. - source_tree.MapPath("", "."); + source_tree.MapPath("", src_relative_path.empty() ? "." : src_relative_path); CrashOnErrorCollector crash_on_error; tensorflow::protobuf::compiler::Importer importer(&source_tree, &crash_on_error); for (int i = 3; i < argc; i++) { - const string proto_path = argv[i]; + if (IsPlaceholderFile(argv[i])) continue; + const string proto_path = string(argv[i]).substr(src_relative_path.size()); + const tensorflow::protobuf::FileDescriptor* fd = importer.Import(proto_path); const int index = proto_path.find_last_of("."); string proto_path_no_suffix = proto_path.substr(0, index); - proto_path_no_suffix = proto_path_no_suffix.substr(relative_path.size()); + + proto_path_no_suffix = + proto_path_no_suffix.substr(output_relative_path.size()); const auto code = tensorflow::GetProtoTextFunctionCode(*fd, kTensorflowHeaderPrefix); diff --git a/tensorflow/tools/proto_text/placeholder.txt b/tensorflow/tools/proto_text/placeholder.txt new file mode 100644 index 00000000000..062066af639 --- /dev/null +++ b/tensorflow/tools/proto_text/placeholder.txt @@ -0,0 +1 @@ +Contents are unused. See gen_proto_functions.cc for details.