diff --git a/tensorflow/python/eager/python_eager_op_gen.cc b/tensorflow/python/eager/python_eager_op_gen.cc index c46a3d8db37..62579bd23ae 100644 --- a/tensorflow/python/eager/python_eager_op_gen.cc +++ b/tensorflow/python/eager/python_eager_op_gen.cc @@ -659,14 +659,26 @@ void GenEagerPythonOp::AddEagerExecute(const string& num_outputs_expr) { string GetEagerPythonOps(const OpList& ops, const std::vector& hidden_ops, - bool require_shapes) { + bool require_shapes, + const string& source_file_name = "") { + string result; // Header // TODO(josh11b): Mention the library for which wrappers are being generated. - strings::StrAppend(&result, R"("""Python wrappers for TensorFlow ops. + strings::StrAppend(&result, R"("""Python wrappers around TensorFlow ops. This file is MACHINE GENERATED! Do not edit. -""" +)"); + + // Mention the original source file so someone tracing back through generated + // Python code will know where to look next. + if (!source_file_name.empty()) { + strings::StrAppend(&result, "Original C++ source file: "); + strings::StrAppend(&result, source_file_name); + strings::StrAppend(&result, "\n"); + } + + strings::StrAppend(&result, R"(""" import collections as _collections @@ -747,8 +759,11 @@ from tensorflow.python.framework import op_def_library as _op_def_library void PrintEagerPythonOps(const OpList& ops, const std::vector& hidden_ops, - bool require_shapes) { - printf("%s", GetEagerPythonOps(ops, hidden_ops, require_shapes).c_str()); + bool require_shapes, + const string& source_file_name) +{ + printf("%s", GetEagerPythonOps(ops, hidden_ops, require_shapes, + source_file_name).c_str()); } string GetEagerPythonWrappers(const char* op_list_buf, size_t op_list_len) { diff --git a/tensorflow/python/eager/python_eager_op_gen.h b/tensorflow/python/eager/python_eager_op_gen.h index 9a7ed28cf94..250623850f2 100644 --- a/tensorflow/python/eager/python_eager_op_gen.h +++ b/tensorflow/python/eager/python_eager_op_gen.h @@ -24,9 +24,12 @@ namespace tensorflow { // hidden_ops should be a list of Op names that should get a leading _ // in the output. Prints the output to stdout. +// Optional fourth argument is the name of the original C++ source file +// where the ops' REGISTER_OP() calls reside. void PrintEagerPythonOps(const OpList& ops, const std::vector& hidden_ops, - bool require_shapes); + bool require_shapes, + const string& source_file_name = ""); // Get the python wrappers for a list of ops in a OpList. // `op_list_buf` should be a pointer to a buffer containing diff --git a/tensorflow/python/framework/python_op_gen_main.cc b/tensorflow/python/framework/python_op_gen_main.cc index 83665422885..3cf56330e0d 100644 --- a/tensorflow/python/framework/python_op_gen_main.cc +++ b/tensorflow/python/framework/python_op_gen_main.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/io/inputbuffer.h" +#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/scanner.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" @@ -80,7 +81,31 @@ Status ParseOpListCommandLine(const char* arg, std::vector* op_list) { return Status::OK(); } -void PrintAllPythonOps(const std::vector& op_list, bool require_shapes, + +// Use the name of the current executable to infer the C++ source file +// where the REGISTER_OP() call for the operator can be found. +// Returns the name of the file. +// Returns an empty string if the current executable's name does not +// follow a known pattern. +string InferSourceFileName(const char* argv_zero) { + StringPiece command_str = io::Basename(argv_zero); + + // For built-in ops, the Bazel build creates a separate executable + // with the name gen__ops_py_wrappers_cc containing the + // operators defined in _ops.cc + const char* kExecPrefix = "gen_"; + const char* kExecSuffix = "_py_wrappers_cc"; + if (command_str.Consume(kExecPrefix) && command_str.ends_with(kExecSuffix)) { + command_str.remove_suffix(strlen(kExecSuffix)); + return strings::StrCat(command_str, ".cc"); + } else { + return string(""); + } +} + +void PrintAllPythonOps(const std::vector& op_list, + const string& source_file_name, + bool require_shapes, bool op_list_is_whitelist) { OpList ops; OpRegistry::Global()->Export(false, &ops); @@ -93,9 +118,9 @@ void PrintAllPythonOps(const std::vector& op_list, bool require_shapes, *pruned_ops.mutable_op()->Add() = op_def; } } - PrintEagerPythonOps(pruned_ops, {}, require_shapes); + PrintEagerPythonOps(pruned_ops, {}, require_shapes, source_file_name); } else { - PrintEagerPythonOps(ops, op_list, require_shapes); + PrintEagerPythonOps(ops, op_list, require_shapes, source_file_name); } } @@ -105,20 +130,26 @@ void PrintAllPythonOps(const std::vector& op_list, bool require_shapes, int main(int argc, char* argv[]) { tensorflow::port::InitMain(argv[0], &argc, &argv); + tensorflow::string source_file_name = + tensorflow::InferSourceFileName(argv[0]); + // Usage: // gen_main [ @FILENAME | OpName[,OpName]* ] (0 | 1) [0 | 1] if (argc == 2) { - tensorflow::PrintAllPythonOps({}, {}, tensorflow::string(argv[1]) == "1"); + tensorflow::PrintAllPythonOps({}, source_file_name, + tensorflow::string(argv[1]) == "1", + false /* op_list_is_whitelist */); } else if (argc == 3) { std::vector hidden_ops; TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[1], &hidden_ops)); - tensorflow::PrintAllPythonOps(hidden_ops, + tensorflow::PrintAllPythonOps(hidden_ops, source_file_name, tensorflow::string(argv[2]) == "1", false /* op_list_is_whitelist */); } else if (argc == 4) { std::vector op_list; TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[1], &op_list)); - tensorflow::PrintAllPythonOps(op_list, tensorflow::string(argv[2]) == "1", + tensorflow::PrintAllPythonOps(op_list, source_file_name, + tensorflow::string(argv[2]) == "1", tensorflow::string(argv[3]) == "1"); } else { return -1;