Refactoring Python op code generation.
PiperOrigin-RevId: 157675126
This commit is contained in:
parent
d9620cab82
commit
6db400bbcf
@ -530,6 +530,7 @@ set(tf_python_op_gen_main_srcs
|
||||
"${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_main.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.h"
|
||||
"${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_internal.h"
|
||||
)
|
||||
|
||||
add_library(tf_python_op_gen_main OBJECT ${tf_python_op_gen_main_srcs})
|
||||
|
@ -318,7 +318,10 @@ py_test(
|
||||
cc_library(
|
||||
name = "python_op_gen",
|
||||
srcs = ["framework/python_op_gen.cc"],
|
||||
hdrs = ["framework/python_op_gen.h"],
|
||||
hdrs = [
|
||||
"framework/python_op_gen.h",
|
||||
"framework/python_op_gen_internal.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
|
@ -36,9 +36,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/python/framework/python_op_gen_internal.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
namespace python_op_gen_internal {
|
||||
|
||||
const int kRightMargin = 78;
|
||||
|
||||
@ -67,15 +68,11 @@ bool IsPythonReserved(const string& s) {
|
||||
"UnicodeEncodeError", "UnicodeError", "UnicodeTranslateError",
|
||||
"UnicodeWarning", "UserWarning", "ValueError", "Warning",
|
||||
"ZeroDivisionError", "__debug__", "__doc__", "__import__", "__name__",
|
||||
"__package__",
|
||||
// Imports and symbols used in the generated code:
|
||||
"_text_format", "_op_def_pb2", "_common_shapes", "_op_def_registry",
|
||||
"_ops", "_op_def_library"});
|
||||
"__package__"});
|
||||
|
||||
return kPythonReserved->count(s) > 0;
|
||||
}
|
||||
|
||||
// Add a _ to the end of s if necessary to avoid a Python keyword or built-in.
|
||||
string AvoidPythonReserved(const string& s) {
|
||||
if (IsPythonReserved(s)) return strings::StrCat(s, "_");
|
||||
return s;
|
||||
@ -323,8 +320,8 @@ string StringToPython(const string& str) {
|
||||
return strings::StrCat("\"", str_util::CEscape(str), "\"");
|
||||
}
|
||||
|
||||
string DataTypeToPython(DataType dtype) {
|
||||
return strings::StrCat("tf.", PythonDataTypeString(dtype));
|
||||
string DataTypeToPython(DataType dtype, const string& dtype_module) {
|
||||
return strings::StrCat(dtype_module, PythonDataTypeString(dtype));
|
||||
}
|
||||
|
||||
string ShapeToPython(const TensorShapeProto& shape) {
|
||||
@ -346,7 +343,8 @@ string TensorToPython(const TensorProto& proto) {
|
||||
return ProtoShortDebugString(proto);
|
||||
}
|
||||
|
||||
string AttrListToPython(const AttrValue& value) {
|
||||
string AttrListToPython(const AttrValue& value,
|
||||
const string& dtype_module = "tf.") {
|
||||
string ret;
|
||||
if (value.list().s_size() > 0) {
|
||||
for (int i = 0; i < value.list().s_size(); ++i) {
|
||||
@ -371,7 +369,8 @@ string AttrListToPython(const AttrValue& value) {
|
||||
} else if (value.list().type_size() > 0) {
|
||||
for (int i = 0; i < value.list().type_size(); ++i) {
|
||||
if (i > 0) strings::StrAppend(&ret, ", ");
|
||||
strings::StrAppend(&ret, DataTypeToPython(value.list().type(i)));
|
||||
strings::StrAppend(&ret,
|
||||
DataTypeToPython(value.list().type(i), dtype_module));
|
||||
}
|
||||
} else if (value.list().shape_size() > 0) {
|
||||
for (int i = 0; i < value.list().shape_size(); ++i) {
|
||||
@ -392,7 +391,8 @@ string AttrListToPython(const AttrValue& value) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
string AttrValueToPython(const string& type, const AttrValue& value) {
|
||||
string AttrValueToPython(const string& type, const AttrValue& value,
|
||||
const string& dtype_module) {
|
||||
if (type == "string") {
|
||||
return StringToPython(value.s());
|
||||
} else if (type == "int") {
|
||||
@ -402,7 +402,7 @@ string AttrValueToPython(const string& type, const AttrValue& value) {
|
||||
} else if (type == "bool") {
|
||||
return value.b() ? "True" : "False";
|
||||
} else if (type == "type") {
|
||||
return DataTypeToPython(value.type());
|
||||
return DataTypeToPython(value.type(), dtype_module);
|
||||
} else if (type == "shape") {
|
||||
return ShapeToPython(value.shape());
|
||||
} else if (type == "tensor") {
|
||||
@ -410,7 +410,7 @@ string AttrValueToPython(const string& type, const AttrValue& value) {
|
||||
} else if (type == "func") {
|
||||
return StringToPython(value.func().name());
|
||||
} else if (StringPiece(type).starts_with("list(")) {
|
||||
return strings::StrCat("[", AttrListToPython(value), "]");
|
||||
return strings::StrCat("[", AttrListToPython(value, dtype_module), "]");
|
||||
} else {
|
||||
return "?";
|
||||
}
|
||||
@ -432,35 +432,41 @@ void GenerateLowerCaseOpName(const string& str, string* result) {
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
static void AddDelimiter(string* append_to, const string& delim) {
|
||||
if (!append_to->empty()) strings::StrAppend(append_to, delim);
|
||||
}
|
||||
|
||||
string GetPythonOp(const OpDef& op_def, bool is_hidden, const string& op_name) {
|
||||
string result;
|
||||
// Map from attr name to the first input arg it is inferred from.
|
||||
std::unordered_map<string, string> inferred_attrs;
|
||||
GenPythonOp::GenPythonOp(const OpDef& op_def, const string& function_name)
|
||||
: op_def_(op_def),
|
||||
function_name_(function_name),
|
||||
num_outs_(op_def.output_arg_size()) {}
|
||||
|
||||
GenPythonOp::~GenPythonOp() {}
|
||||
|
||||
string GenPythonOp::Code() {
|
||||
// This has all the input args followed by those attrs that don't have
|
||||
// defaults.
|
||||
std::vector<string> args_no_default;
|
||||
// The parameters with defaults (these have to be listed after those without).
|
||||
// No input args are included, just attrs.
|
||||
std::vector<string> args_with_defaults;
|
||||
for (int i = 0; i < op_def.input_arg_size(); ++i) {
|
||||
const auto& arg(op_def.input_arg(i));
|
||||
for (int i = 0; i < op_def_.input_arg_size(); ++i) {
|
||||
const auto& arg(op_def_.input_arg(i));
|
||||
args_no_default.push_back(arg.name());
|
||||
if (!arg.type_attr().empty()) {
|
||||
gtl::InsertIfNotPresent(&inferred_attrs, arg.type_attr(), arg.name());
|
||||
gtl::InsertIfNotPresent(&inferred_attrs_, arg.type_attr(), arg.name());
|
||||
} else if (!arg.type_list_attr().empty()) {
|
||||
gtl::InsertIfNotPresent(&inferred_attrs, arg.type_list_attr(),
|
||||
gtl::InsertIfNotPresent(&inferred_attrs_, arg.type_list_attr(),
|
||||
arg.name());
|
||||
}
|
||||
if (!arg.number_attr().empty()) {
|
||||
gtl::InsertIfNotPresent(&inferred_attrs, arg.number_attr(), arg.name());
|
||||
gtl::InsertIfNotPresent(&inferred_attrs_, arg.number_attr(), arg.name());
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < op_def.attr_size(); ++i) {
|
||||
const auto& attr(op_def.attr(i));
|
||||
for (int i = 0; i < op_def_.attr_size(); ++i) {
|
||||
const auto& attr(op_def_.attr(i));
|
||||
// Do not add inferred attrs to the Python function signature.
|
||||
if (inferred_attrs.find(attr.name()) == inferred_attrs.end()) {
|
||||
if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
|
||||
if (attr.has_default_value()) {
|
||||
args_with_defaults.push_back(attr.name());
|
||||
} else {
|
||||
@ -471,110 +477,92 @@ string GetPythonOp(const OpDef& op_def, bool is_hidden, const string& op_name) {
|
||||
|
||||
// Save the list of attr parameters (attrs that won't be inferred),
|
||||
// those with defaults go at the end.
|
||||
std::vector<string> attrs;
|
||||
// Get the attrs in the order we want by taking the attrs without defaults
|
||||
// from the end of args_no_default, and adding args_no_default.
|
||||
attrs.reserve(args_no_default.size() - op_def.input_arg_size() +
|
||||
args_with_defaults.size());
|
||||
attrs.insert(attrs.end(), args_no_default.begin() + op_def.input_arg_size(),
|
||||
args_no_default.end());
|
||||
attrs.insert(attrs.end(), args_with_defaults.begin(),
|
||||
args_with_defaults.end());
|
||||
attrs_.reserve(args_no_default.size() - op_def_.input_arg_size() +
|
||||
args_with_defaults.size());
|
||||
attrs_.insert(attrs_.end(),
|
||||
args_no_default.begin() + op_def_.input_arg_size(),
|
||||
args_no_default.end());
|
||||
attrs_.insert(attrs_.end(), args_with_defaults.begin(),
|
||||
args_with_defaults.end());
|
||||
|
||||
std::vector<string> param_names;
|
||||
param_names.reserve(args_no_default.size() + args_with_defaults.size());
|
||||
param_names_.reserve(args_no_default.size() + args_with_defaults.size());
|
||||
string parameters;
|
||||
for (const string& name : args_no_default) {
|
||||
if (!parameters.empty()) strings::StrAppend(¶meters, ", ");
|
||||
AddDelimiter(¶meters, ", ");
|
||||
const string param = AvoidPythonReserved(name);
|
||||
strings::StrAppend(¶meters, param);
|
||||
param_names.push_back(param);
|
||||
param_names_.push_back(param);
|
||||
}
|
||||
for (const string& name : args_with_defaults) {
|
||||
if (!parameters.empty()) strings::StrAppend(¶meters, ", ");
|
||||
AddDelimiter(¶meters, ", ");
|
||||
const string param = AvoidPythonReserved(name);
|
||||
strings::StrAppend(¶meters, param, "=None");
|
||||
param_names.push_back(param);
|
||||
param_names_.push_back(param);
|
||||
}
|
||||
AddDelimiter(¶meters, ", ");
|
||||
strings::StrAppend(¶meters, "name=None");
|
||||
|
||||
const string lower_op_name = strings::StrCat(is_hidden ? "_" : "", op_name);
|
||||
AddDefLine(parameters);
|
||||
AddDocStringDescription();
|
||||
AddDocStringArgs();
|
||||
AddDocStringInputs();
|
||||
AddDocStringAttrs();
|
||||
AddDocStringNameArg();
|
||||
AddOutputGlobals();
|
||||
AddDocStringOutputs();
|
||||
strings::StrAppend(&result_, " \"\"\"\n");
|
||||
AddBody(" ");
|
||||
strings::StrAppend(&result_, "\n\n");
|
||||
|
||||
const int num_outs = op_def.output_arg_size();
|
||||
// Prepare a NamedTuple type to hold the outputs, if there are multiple
|
||||
if (num_outs > 1) {
|
||||
// Prepare the list of output names
|
||||
std::vector<string> out_names(num_outs);
|
||||
for (int i = 0; i < num_outs; ++i) {
|
||||
if (!op_def.output_arg(i).name().empty()) {
|
||||
out_names[i] = op_def.output_arg(i).name();
|
||||
} else {
|
||||
out_names[i] = strings::StrCat("output", i);
|
||||
}
|
||||
}
|
||||
string out_names_list =
|
||||
strings::StrCat("[\"", str_util::Join(out_names, "\", \""), "\"]");
|
||||
return prelude_ + result_;
|
||||
}
|
||||
|
||||
// Provide the output names as a Python list
|
||||
string lower_op_name_outputs =
|
||||
strings::StrCat("_", lower_op_name, "_outputs");
|
||||
const string outputs_prefix = strings::StrCat(lower_op_name_outputs, " = ");
|
||||
strings::StrAppend(&result, "\n",
|
||||
WordWrap(outputs_prefix, out_names_list, kRightMargin),
|
||||
"\n");
|
||||
void GenPythonOp::AddDefLine(const string& parameters) {
|
||||
const string def_prefix = strings::StrCat("def ", function_name_, "(");
|
||||
strings::StrAppend(
|
||||
&result_, WordWrap(def_prefix, parameters + "):", kRightMargin), "\n");
|
||||
}
|
||||
|
||||
strings::StrAppend(&result, "_", op_def.name(),
|
||||
"Output = _collections.namedtuple(\n");
|
||||
const string tuple_type_prefix = " ";
|
||||
const string tuple_type_suffix = strings::StrCat(
|
||||
"\"", op_def.name(), "\", ", lower_op_name_outputs, ")");
|
||||
strings::StrAppend(
|
||||
&result, WordWrap(tuple_type_prefix, tuple_type_suffix, kRightMargin),
|
||||
"\n\n");
|
||||
}
|
||||
strings::StrAppend(&result, "\n");
|
||||
|
||||
// Print: def Function(parameters):
|
||||
const string def_prefix = strings::StrCat("def ", lower_op_name, "(");
|
||||
const bool has_args = args_no_default.size() + args_with_defaults.size() > 0;
|
||||
const string def_suffix =
|
||||
strings::StrCat(parameters, has_args ? ", " : "", "name=None):");
|
||||
|
||||
strings::StrAppend(&result, WordWrap(def_prefix, def_suffix, kRightMargin),
|
||||
"\n");
|
||||
|
||||
// Format the Op's descriptions so that it can be a Python docstring.
|
||||
void GenPythonOp::AddDocStringDescription() {
|
||||
string comment;
|
||||
if (op_def.summary().empty()) {
|
||||
if (op_def_.summary().empty()) {
|
||||
comment = "TODO: add doc.\n";
|
||||
} else {
|
||||
comment = strings::StrCat(op_def.summary(), "\n");
|
||||
if (!op_def.description().empty()) {
|
||||
strings::StrAppend(&comment, "\n", Indent(2, 2, op_def.description()));
|
||||
comment = strings::StrCat(op_def_.summary(), "\n");
|
||||
if (!op_def_.description().empty()) {
|
||||
strings::StrAppend(&comment, "\n", Indent(2, 2, op_def_.description()));
|
||||
}
|
||||
}
|
||||
strings::StrAppend(&result_, " r\"\"\"", comment, "\n");
|
||||
}
|
||||
|
||||
strings::StrAppend(&result, " r\"\"\"", comment, "\n Args:\n");
|
||||
void GenPythonOp::AddDocStringArgs() {
|
||||
strings::StrAppend(&result_, " Args:\n");
|
||||
}
|
||||
|
||||
// Inputs
|
||||
for (int i = 0; i < op_def.input_arg_size(); ++i) {
|
||||
const auto& arg(op_def.input_arg(i));
|
||||
StringPiece description = op_def.input_arg(i).description();
|
||||
void GenPythonOp::AddDocStringInputs() {
|
||||
for (int i = 0; i < op_def_.input_arg_size(); ++i) {
|
||||
const auto& arg(op_def_.input_arg(i));
|
||||
StringPiece description = op_def_.input_arg(i).description();
|
||||
string desc;
|
||||
if (ConsumeEquals(&description)) { // Skip the generated type info.
|
||||
desc = strings::StrCat(param_names[i], ": ");
|
||||
desc = strings::StrCat(param_names_[i], ": ");
|
||||
} else {
|
||||
desc = strings::StrCat(param_names[i], ": ",
|
||||
ArgTypeName(op_def, arg, inferred_attrs, false));
|
||||
desc = strings::StrCat(param_names_[i], ": ",
|
||||
ArgTypeName(op_def_, arg, inferred_attrs_, false));
|
||||
}
|
||||
if (!description.empty()) {
|
||||
AppendWithinWidth(&desc, description, kRightMargin - 4 /* indent */);
|
||||
}
|
||||
strings::StrAppend(&result, Indent(4, 6, desc));
|
||||
strings::StrAppend(&result_, Indent(4, 6, desc));
|
||||
}
|
||||
}
|
||||
|
||||
// Attrs
|
||||
for (const string& name : attrs) {
|
||||
const auto& attr = *FindAttr(name, op_def);
|
||||
void GenPythonOp::AddDocStringAttrs() {
|
||||
for (const string& name : attrs_) {
|
||||
const auto& attr = *FindAttr(name, op_def_);
|
||||
string desc = strings::StrCat(AvoidPythonReserved(name), ": ");
|
||||
|
||||
static const char* const kAttrTypeName[][2] = {
|
||||
@ -638,40 +626,86 @@ string GetPythonOp(const OpDef& op_def, bool is_hidden, const string& op_name) {
|
||||
AppendWithinWidth(&desc, attr.description(),
|
||||
kRightMargin - 4 /* indent */);
|
||||
}
|
||||
strings::StrAppend(&result, Indent(4, 6, desc));
|
||||
strings::StrAppend(&result_, Indent(4, 6, desc));
|
||||
}
|
||||
}
|
||||
|
||||
strings::StrAppend(&result,
|
||||
void GenPythonOp::AddDocStringNameArg() {
|
||||
strings::StrAppend(&result_,
|
||||
" name: A name for the operation (optional).\n");
|
||||
}
|
||||
|
||||
std::vector<string> output_type_string;
|
||||
output_type_string.reserve(num_outs);
|
||||
for (int i = 0; i < num_outs; ++i) {
|
||||
output_type_string.push_back(
|
||||
ArgTypeName(op_def, op_def.output_arg(i), inferred_attrs, true));
|
||||
void GenPythonOp::AddOutputGlobals() {
|
||||
// Prepare a NamedTuple type to hold the outputs, if there are multiple
|
||||
if (num_outs_ > 1) {
|
||||
// Prepare the list of output names
|
||||
std::vector<string> out_names(num_outs_);
|
||||
for (int i = 0; i < num_outs_; ++i) {
|
||||
if (!op_def_.output_arg(i).name().empty()) {
|
||||
out_names[i] = op_def_.output_arg(i).name();
|
||||
} else {
|
||||
out_names[i] = strings::StrCat("output", i);
|
||||
}
|
||||
}
|
||||
string out_names_list =
|
||||
strings::StrCat("[\"", str_util::Join(out_names, "\", \""), "\"]");
|
||||
|
||||
// Provide the output names as a Python list
|
||||
string lower_op_name_outputs =
|
||||
strings::StrCat("_", function_name_, "_outputs");
|
||||
const string outputs_prefix = strings::StrCat(lower_op_name_outputs, " = ");
|
||||
strings::StrAppend(&prelude_, "\n",
|
||||
WordWrap(outputs_prefix, out_names_list, kRightMargin),
|
||||
"\n");
|
||||
|
||||
strings::StrAppend(&prelude_, "_", op_def_.name(),
|
||||
"Output = _collections.namedtuple(\n");
|
||||
const string tuple_type_prefix = " ";
|
||||
const string tuple_type_suffix = strings::StrCat(
|
||||
"\"", op_def_.name(), "\", ", lower_op_name_outputs, ")");
|
||||
strings::StrAppend(
|
||||
&prelude_, WordWrap(tuple_type_prefix, tuple_type_suffix, kRightMargin),
|
||||
"\n\n");
|
||||
}
|
||||
strings::StrAppend(&result, GetReturns(op_def, output_type_string));
|
||||
strings::StrAppend(&prelude_, "\n");
|
||||
}
|
||||
|
||||
string return_prefix = strings::StrCat(" result = _op_def_lib.apply_op(");
|
||||
string return_args = strings::StrCat("\"", op_def.name(), "\", ");
|
||||
for (size_t i = 0; i < param_names.size(); ++i) {
|
||||
strings::StrAppend(&return_args, param_names[i], "=", param_names[i], ", ");
|
||||
void GenPythonOp::AddDocStringOutputs() {
|
||||
std::vector<string> output_type_string;
|
||||
output_type_string.reserve(num_outs_);
|
||||
for (int i = 0; i < num_outs_; ++i) {
|
||||
output_type_string.push_back(
|
||||
ArgTypeName(op_def_, op_def_.output_arg(i), inferred_attrs_, true));
|
||||
}
|
||||
strings::StrAppend(&result_, GetReturns(op_def_, output_type_string));
|
||||
}
|
||||
|
||||
void GenPythonOp::AddBody(const string& prefix) {
|
||||
string return_prefix =
|
||||
strings::StrCat(prefix, "result = _op_def_lib.apply_op(");
|
||||
string return_args = strings::StrCat("\"", op_def_.name(), "\", ");
|
||||
for (size_t i = 0; i < param_names_.size(); ++i) {
|
||||
strings::StrAppend(&return_args, param_names_[i], "=", param_names_[i],
|
||||
", ");
|
||||
}
|
||||
strings::StrAppend(&return_args, "name=name)");
|
||||
|
||||
strings::StrAppend(&result, " \"\"\"\n",
|
||||
strings::StrAppend(&result_,
|
||||
// Wrap the arguments, and indent to the (.
|
||||
WordWrap(return_prefix, return_args, kRightMargin), "\n");
|
||||
|
||||
if (num_outs <= 1) {
|
||||
strings::StrAppend(&result, " return result\n");
|
||||
if (num_outs_ <= 1) {
|
||||
strings::StrAppend(&result_, prefix, "return result\n");
|
||||
} else {
|
||||
strings::StrAppend(&result, " return _", op_def.name(),
|
||||
strings::StrAppend(&result_, prefix, "return _", op_def_.name(),
|
||||
"Output._make(result)\n");
|
||||
}
|
||||
strings::StrAppend(&result, "\n\n");
|
||||
}
|
||||
|
||||
return result;
|
||||
} // namespace python_op_gen_internal
|
||||
|
||||
string GetPythonOp(const OpDef& op_def, const string& function_name) {
|
||||
return python_op_gen_internal::GenPythonOp(op_def, function_name).Code();
|
||||
}
|
||||
|
||||
string GetPythonOps(const OpList& ops, const std::vector<string>& hidden_ops,
|
||||
@ -711,20 +745,20 @@ from tensorflow.python.framework import op_def_library as _op_def_library
|
||||
}
|
||||
}
|
||||
|
||||
// PrintPythonOp(op_def, is_hidden, op_def.name());
|
||||
string lower_case_name;
|
||||
GenerateLowerCaseOpName(op_def.name(), &lower_case_name);
|
||||
string function_name;
|
||||
python_op_gen_internal::GenerateLowerCaseOpName(op_def.name(),
|
||||
&function_name);
|
||||
if (is_hidden) function_name = strings::StrCat("_", function_name);
|
||||
|
||||
// When users create custom python wrappers, they may link in the
|
||||
// default op registry by accident, and because they can't
|
||||
// enumerate all 'hidden' symbols, this guard is to prevent
|
||||
// instantiating a python reserved word in their wrapper.
|
||||
if (!is_hidden && IsPythonReserved(lower_case_name)) {
|
||||
if (python_op_gen_internal::IsPythonReserved(function_name)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
strings::StrAppend(&result,
|
||||
GetPythonOp(op_def, is_hidden, lower_case_name));
|
||||
strings::StrAppend(&result, GetPythonOp(op_def, function_name));
|
||||
|
||||
if (!require_shapes) {
|
||||
strings::StrAppend(&result, "_ops.RegisterShape(\"", op_def.name(),
|
||||
|
@ -31,7 +31,7 @@ void PrintPythonOps(const OpList& ops, const std::vector<string>& hidden_ops,
|
||||
bool require_shapes);
|
||||
string GetPythonOps(const OpList& ops, const std::vector<string>& hidden_ops,
|
||||
bool require_shapes);
|
||||
string GetPythonOp(const OpDef& op_def, bool is_hidden, const string& op_name);
|
||||
string GetPythonOp(const OpDef& op_def, const string& function_name);
|
||||
|
||||
// Get the python wrappers for a list of ops in a OpList.
|
||||
// `op_list_buf` should be a pointer to a buffer containing
|
||||
|
86
tensorflow/python/framework/python_op_gen_internal.h
Normal file
86
tensorflow/python/framework/python_op_gen_internal.h
Normal file
@ -0,0 +1,86 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_INTERNAL_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_INTERNAL_H_
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace python_op_gen_internal {
|
||||
|
||||
// Returns true if s is a Python keyword or built-in.
|
||||
bool IsPythonReserved(const string& s);
|
||||
|
||||
// Add a _ to the end of s if necessary to avoid a Python keyword or built-in.
|
||||
string AvoidPythonReserved(const string& s);
|
||||
|
||||
// Convert an AttrValue with type `type` to the Python representation for
|
||||
// that value.
|
||||
string AttrValueToPython(const string& type, const AttrValue& value,
|
||||
const string& dtype_module = "tf.");
|
||||
|
||||
void GenerateLowerCaseOpName(const string& str, string* result);
|
||||
|
||||
class GenPythonOp {
|
||||
public:
|
||||
GenPythonOp(const OpDef& op_def, const string& function_name);
|
||||
virtual ~GenPythonOp();
|
||||
|
||||
virtual string Code();
|
||||
|
||||
protected:
|
||||
// Print: def Function(parameters):
|
||||
void AddDefLine(const string& parameters);
|
||||
|
||||
// Format the Op's descriptions so that it can be a Python docstring.
|
||||
void AddDocStringDescription();
|
||||
|
||||
void AddDocStringArgs();
|
||||
void AddDocStringInputs();
|
||||
void AddDocStringAttrs();
|
||||
void AddDocStringNameArg();
|
||||
void AddOutputGlobals();
|
||||
void AddDocStringOutputs();
|
||||
void AddBody(const string& prefix);
|
||||
|
||||
// From constructor arguments
|
||||
const OpDef& op_def_;
|
||||
const string& function_name_;
|
||||
const int num_outs_;
|
||||
|
||||
// Return value from Code() is prelude_ + result_.
|
||||
string prelude_; // Code before function definition
|
||||
string result_; // Function definition
|
||||
|
||||
// Map from attr name to the first input arg it is inferred from
|
||||
std::unordered_map<string, string> inferred_attrs_;
|
||||
|
||||
// The names of the non-inferred attrs, in parameter order
|
||||
std::vector<string> attrs_;
|
||||
|
||||
// All parameters, including inputs & non-inferred attrs, required and those
|
||||
// with defaults, except "name"
|
||||
std::vector<string> param_names_;
|
||||
};
|
||||
|
||||
} // namespace python_op_gen_internal
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_INTERNAL_H_
|
Loading…
Reference in New Issue
Block a user