Add name of C++ source file to generated Python files for ops. (#11663)
* Embed orig source file name in generated Python * Change pass-by-value --> by-ref to address review comment. * Reformatting source code with clang-format. * Change to different test for empty string per review comment. * Moved code from python_op_gen to python_eager_op_gen.
This commit is contained in:
parent
2306697c8a
commit
674db81731
@ -659,14 +659,26 @@ void GenEagerPythonOp::AddEagerExecute(const string& num_outputs_expr) {
|
||||
|
||||
string GetEagerPythonOps(const OpList& ops,
|
||||
const std::vector<string>& hidden_ops,
|
||||
bool require_shapes) {
|
||||
bool require_shapes,
|
||||
const string& source_file_name = "") {
|
||||
|
||||
string result;
|
||||
// Header
|
||||
// TODO(josh11b): Mention the library for which wrappers are being generated.
|
||||
strings::StrAppend(&result, R"("""Python wrappers for TensorFlow ops.
|
||||
strings::StrAppend(&result, R"("""Python wrappers around TensorFlow ops.
|
||||
|
||||
This file is MACHINE GENERATED! Do not edit.
|
||||
"""
|
||||
)");
|
||||
|
||||
// Mention the original source file so someone tracing back through generated
|
||||
// Python code will know where to look next.
|
||||
if (!source_file_name.empty()) {
|
||||
strings::StrAppend(&result, "Original C++ source file: ");
|
||||
strings::StrAppend(&result, source_file_name);
|
||||
strings::StrAppend(&result, "\n");
|
||||
}
|
||||
|
||||
strings::StrAppend(&result, R"("""
|
||||
|
||||
import collections as _collections
|
||||
|
||||
@ -747,8 +759,11 @@ from tensorflow.python.framework import op_def_library as _op_def_library
|
||||
|
||||
void PrintEagerPythonOps(const OpList& ops,
|
||||
const std::vector<string>& hidden_ops,
|
||||
bool require_shapes) {
|
||||
printf("%s", GetEagerPythonOps(ops, hidden_ops, require_shapes).c_str());
|
||||
bool require_shapes,
|
||||
const string& source_file_name)
|
||||
{
|
||||
printf("%s", GetEagerPythonOps(ops, hidden_ops, require_shapes,
|
||||
source_file_name).c_str());
|
||||
}
|
||||
|
||||
string GetEagerPythonWrappers(const char* op_list_buf, size_t op_list_len) {
|
||||
|
@ -24,9 +24,12 @@ namespace tensorflow {
|
||||
|
||||
// hidden_ops should be a list of Op names that should get a leading _
|
||||
// in the output. Prints the output to stdout.
|
||||
// Optional fourth argument is the name of the original C++ source file
|
||||
// where the ops' REGISTER_OP() calls reside.
|
||||
void PrintEagerPythonOps(const OpList& ops,
|
||||
const std::vector<string>& hidden_ops,
|
||||
bool require_shapes);
|
||||
bool require_shapes,
|
||||
const string& source_file_name = "");
|
||||
|
||||
// Get the python wrappers for a list of ops in a OpList.
|
||||
// `op_list_buf` should be a pointer to a buffer containing
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/io/inputbuffer.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/scanner.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
@ -80,7 +81,31 @@ Status ParseOpListCommandLine(const char* arg, std::vector<string>* op_list) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void PrintAllPythonOps(const std::vector<string>& op_list, bool require_shapes,
|
||||
|
||||
// Use the name of the current executable to infer the C++ source file
|
||||
// where the REGISTER_OP() call for the operator can be found.
|
||||
// Returns the name of the file.
|
||||
// Returns an empty string if the current executable's name does not
|
||||
// follow a known pattern.
|
||||
string InferSourceFileName(const char* argv_zero) {
|
||||
StringPiece command_str = io::Basename(argv_zero);
|
||||
|
||||
// For built-in ops, the Bazel build creates a separate executable
|
||||
// with the name gen_<op type>_ops_py_wrappers_cc containing the
|
||||
// operators defined in <op type>_ops.cc
|
||||
const char* kExecPrefix = "gen_";
|
||||
const char* kExecSuffix = "_py_wrappers_cc";
|
||||
if (command_str.Consume(kExecPrefix) && command_str.ends_with(kExecSuffix)) {
|
||||
command_str.remove_suffix(strlen(kExecSuffix));
|
||||
return strings::StrCat(command_str, ".cc");
|
||||
} else {
|
||||
return string("");
|
||||
}
|
||||
}
|
||||
|
||||
void PrintAllPythonOps(const std::vector<string>& op_list,
|
||||
const string& source_file_name,
|
||||
bool require_shapes,
|
||||
bool op_list_is_whitelist) {
|
||||
OpList ops;
|
||||
OpRegistry::Global()->Export(false, &ops);
|
||||
@ -93,9 +118,9 @@ void PrintAllPythonOps(const std::vector<string>& op_list, bool require_shapes,
|
||||
*pruned_ops.mutable_op()->Add() = op_def;
|
||||
}
|
||||
}
|
||||
PrintEagerPythonOps(pruned_ops, {}, require_shapes);
|
||||
PrintEagerPythonOps(pruned_ops, {}, require_shapes, source_file_name);
|
||||
} else {
|
||||
PrintEagerPythonOps(ops, op_list, require_shapes);
|
||||
PrintEagerPythonOps(ops, op_list, require_shapes, source_file_name);
|
||||
}
|
||||
}
|
||||
|
||||
@ -105,20 +130,26 @@ void PrintAllPythonOps(const std::vector<string>& op_list, bool require_shapes,
|
||||
int main(int argc, char* argv[]) {
|
||||
tensorflow::port::InitMain(argv[0], &argc, &argv);
|
||||
|
||||
tensorflow::string source_file_name =
|
||||
tensorflow::InferSourceFileName(argv[0]);
|
||||
|
||||
// Usage:
|
||||
// gen_main [ @FILENAME | OpName[,OpName]* ] (0 | 1) [0 | 1]
|
||||
if (argc == 2) {
|
||||
tensorflow::PrintAllPythonOps({}, {}, tensorflow::string(argv[1]) == "1");
|
||||
tensorflow::PrintAllPythonOps({}, source_file_name,
|
||||
tensorflow::string(argv[1]) == "1",
|
||||
false /* op_list_is_whitelist */);
|
||||
} else if (argc == 3) {
|
||||
std::vector<tensorflow::string> hidden_ops;
|
||||
TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[1], &hidden_ops));
|
||||
tensorflow::PrintAllPythonOps(hidden_ops,
|
||||
tensorflow::PrintAllPythonOps(hidden_ops, source_file_name,
|
||||
tensorflow::string(argv[2]) == "1",
|
||||
false /* op_list_is_whitelist */);
|
||||
} else if (argc == 4) {
|
||||
std::vector<tensorflow::string> op_list;
|
||||
TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[1], &op_list));
|
||||
tensorflow::PrintAllPythonOps(op_list, tensorflow::string(argv[2]) == "1",
|
||||
tensorflow::PrintAllPythonOps(op_list, source_file_name,
|
||||
tensorflow::string(argv[2]) == "1",
|
||||
tensorflow::string(argv[3]) == "1");
|
||||
} else {
|
||||
return -1;
|
||||
|
Loading…
Reference in New Issue
Block a user