Add name of C++ source file to generated Python files for ops. (#11663)

* Embed orig source file name in generated Python

* Change pass-by-value --> by-ref to address review comment.

* Reformatting source code with clang-format.

* Change to different test for empty string per review comment.

* Moved code from python_op_gen to python_eager_op_gen.
This commit is contained in:
Fred Reiss 2017-09-05 09:34:23 -07:00 committed by Martin Wicke
parent 2306697c8a
commit 674db81731
3 changed files with 61 additions and 12 deletions

View File

@ -659,14 +659,26 @@ void GenEagerPythonOp::AddEagerExecute(const string& num_outputs_expr) {
string GetEagerPythonOps(const OpList& ops,
const std::vector<string>& hidden_ops,
bool require_shapes) {
bool require_shapes,
const string& source_file_name = "") {
string result;
// Header
// TODO(josh11b): Mention the library for which wrappers are being generated.
strings::StrAppend(&result, R"("""Python wrappers for TensorFlow ops.
strings::StrAppend(&result, R"("""Python wrappers around TensorFlow ops.
This file is MACHINE GENERATED! Do not edit.
"""
)");
// Mention the original source file so someone tracing back through generated
// Python code will know where to look next.
if (!source_file_name.empty()) {
strings::StrAppend(&result, "Original C++ source file: ");
strings::StrAppend(&result, source_file_name);
strings::StrAppend(&result, "\n");
}
strings::StrAppend(&result, R"("""
import collections as _collections
@ -747,8 +759,11 @@ from tensorflow.python.framework import op_def_library as _op_def_library
void PrintEagerPythonOps(const OpList& ops,
const std::vector<string>& hidden_ops,
bool require_shapes) {
printf("%s", GetEagerPythonOps(ops, hidden_ops, require_shapes).c_str());
bool require_shapes,
const string& source_file_name)
{
printf("%s", GetEagerPythonOps(ops, hidden_ops, require_shapes,
source_file_name).c_str());
}
string GetEagerPythonWrappers(const char* op_list_buf, size_t op_list_len) {

View File

@ -24,9 +24,12 @@ namespace tensorflow {
// hidden_ops should be a list of Op names that should get a leading _
// in the output. Prints the output to stdout.
// Optional fourth argument is the name of the original C++ source file
// where the ops' REGISTER_OP() calls reside.
void PrintEagerPythonOps(const OpList& ops,
const std::vector<string>& hidden_ops,
bool require_shapes);
bool require_shapes,
const string& source_file_name = "");
// Get the python wrappers for a list of ops in a OpList.
// `op_list_buf` should be a pointer to a buffer containing

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/io/inputbuffer.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/scanner.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
@ -80,7 +81,31 @@ Status ParseOpListCommandLine(const char* arg, std::vector<string>* op_list) {
return Status::OK();
}
void PrintAllPythonOps(const std::vector<string>& op_list, bool require_shapes,
// Use the name of the current executable to infer the C++ source file
// where the REGISTER_OP() call for the operator can be found.
// Returns the name of the file.
// Returns an empty string if the current executable's name does not
// follow a known pattern.
string InferSourceFileName(const char* argv_zero) {
StringPiece command_str = io::Basename(argv_zero);
// For built-in ops, the Bazel build creates a separate executable
// with the name gen_<op type>_ops_py_wrappers_cc containing the
// operators defined in <op type>_ops.cc
const char* kExecPrefix = "gen_";
const char* kExecSuffix = "_py_wrappers_cc";
if (command_str.Consume(kExecPrefix) && command_str.ends_with(kExecSuffix)) {
command_str.remove_suffix(strlen(kExecSuffix));
return strings::StrCat(command_str, ".cc");
} else {
return string("");
}
}
void PrintAllPythonOps(const std::vector<string>& op_list,
const string& source_file_name,
bool require_shapes,
bool op_list_is_whitelist) {
OpList ops;
OpRegistry::Global()->Export(false, &ops);
@ -93,9 +118,9 @@ void PrintAllPythonOps(const std::vector<string>& op_list, bool require_shapes,
*pruned_ops.mutable_op()->Add() = op_def;
}
}
PrintEagerPythonOps(pruned_ops, {}, require_shapes);
PrintEagerPythonOps(pruned_ops, {}, require_shapes, source_file_name);
} else {
PrintEagerPythonOps(ops, op_list, require_shapes);
PrintEagerPythonOps(ops, op_list, require_shapes, source_file_name);
}
}
@ -105,20 +130,26 @@ void PrintAllPythonOps(const std::vector<string>& op_list, bool require_shapes,
int main(int argc, char* argv[]) {
tensorflow::port::InitMain(argv[0], &argc, &argv);
tensorflow::string source_file_name =
tensorflow::InferSourceFileName(argv[0]);
// Usage:
// gen_main [ @FILENAME | OpName[,OpName]* ] (0 | 1) [0 | 1]
if (argc == 2) {
tensorflow::PrintAllPythonOps({}, {}, tensorflow::string(argv[1]) == "1");
tensorflow::PrintAllPythonOps({}, source_file_name,
tensorflow::string(argv[1]) == "1",
false /* op_list_is_whitelist */);
} else if (argc == 3) {
std::vector<tensorflow::string> hidden_ops;
TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[1], &hidden_ops));
tensorflow::PrintAllPythonOps(hidden_ops,
tensorflow::PrintAllPythonOps(hidden_ops, source_file_name,
tensorflow::string(argv[2]) == "1",
false /* op_list_is_whitelist */);
} else if (argc == 4) {
std::vector<tensorflow::string> op_list;
TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[1], &op_list));
tensorflow::PrintAllPythonOps(op_list, tensorflow::string(argv[2]) == "1",
tensorflow::PrintAllPythonOps(op_list, source_file_name,
tensorflow::string(argv[2]) == "1",
tensorflow::string(argv[3]) == "1");
} else {
return -1;