Refactoring Python op code generation.

PiperOrigin-RevId: 157675126
This commit is contained in:
A. Unique TensorFlower 2017-05-31 21:55:19 -07:00 committed by TensorFlower Gardener
parent d9620cab82
commit 6db400bbcf
5 changed files with 248 additions and 124 deletions

View File

@ -530,6 +530,7 @@ set(tf_python_op_gen_main_srcs
"${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.cc"
"${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_main.cc"
"${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.h"
"${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_internal.h"
)
add_library(tf_python_op_gen_main OBJECT ${tf_python_op_gen_main_srcs})

View File

@ -318,7 +318,10 @@ py_test(
cc_library(
name = "python_op_gen",
srcs = ["framework/python_op_gen.cc"],
hdrs = ["framework/python_op_gen.h"],
hdrs = [
"framework/python_op_gen.h",
"framework/python_op_gen_internal.h",
],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:framework",

View File

@ -36,9 +36,10 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/python/framework/python_op_gen_internal.h"
namespace tensorflow {
namespace {
namespace python_op_gen_internal {
const int kRightMargin = 78;
@ -67,15 +68,11 @@ bool IsPythonReserved(const string& s) {
"UnicodeEncodeError", "UnicodeError", "UnicodeTranslateError",
"UnicodeWarning", "UserWarning", "ValueError", "Warning",
"ZeroDivisionError", "__debug__", "__doc__", "__import__", "__name__",
"__package__",
// Imports and symbols used in the generated code:
"_text_format", "_op_def_pb2", "_common_shapes", "_op_def_registry",
"_ops", "_op_def_library"});
"__package__"});
return kPythonReserved->count(s) > 0;
}
// Add a _ to the end of s if necessary to avoid a Python keyword or built-in.
string AvoidPythonReserved(const string& s) {
if (IsPythonReserved(s)) return strings::StrCat(s, "_");
return s;
@ -323,8 +320,8 @@ string StringToPython(const string& str) {
return strings::StrCat("\"", str_util::CEscape(str), "\"");
}
string DataTypeToPython(DataType dtype) {
return strings::StrCat("tf.", PythonDataTypeString(dtype));
string DataTypeToPython(DataType dtype, const string& dtype_module) {
return strings::StrCat(dtype_module, PythonDataTypeString(dtype));
}
string ShapeToPython(const TensorShapeProto& shape) {
@ -346,7 +343,8 @@ string TensorToPython(const TensorProto& proto) {
return ProtoShortDebugString(proto);
}
string AttrListToPython(const AttrValue& value) {
string AttrListToPython(const AttrValue& value,
const string& dtype_module = "tf.") {
string ret;
if (value.list().s_size() > 0) {
for (int i = 0; i < value.list().s_size(); ++i) {
@ -371,7 +369,8 @@ string AttrListToPython(const AttrValue& value) {
} else if (value.list().type_size() > 0) {
for (int i = 0; i < value.list().type_size(); ++i) {
if (i > 0) strings::StrAppend(&ret, ", ");
strings::StrAppend(&ret, DataTypeToPython(value.list().type(i)));
strings::StrAppend(&ret,
DataTypeToPython(value.list().type(i), dtype_module));
}
} else if (value.list().shape_size() > 0) {
for (int i = 0; i < value.list().shape_size(); ++i) {
@ -392,7 +391,8 @@ string AttrListToPython(const AttrValue& value) {
return ret;
}
string AttrValueToPython(const string& type, const AttrValue& value) {
string AttrValueToPython(const string& type, const AttrValue& value,
const string& dtype_module) {
if (type == "string") {
return StringToPython(value.s());
} else if (type == "int") {
@ -402,7 +402,7 @@ string AttrValueToPython(const string& type, const AttrValue& value) {
} else if (type == "bool") {
return value.b() ? "True" : "False";
} else if (type == "type") {
return DataTypeToPython(value.type());
return DataTypeToPython(value.type(), dtype_module);
} else if (type == "shape") {
return ShapeToPython(value.shape());
} else if (type == "tensor") {
@ -410,7 +410,7 @@ string AttrValueToPython(const string& type, const AttrValue& value) {
} else if (type == "func") {
return StringToPython(value.func().name());
} else if (StringPiece(type).starts_with("list(")) {
return strings::StrCat("[", AttrListToPython(value), "]");
return strings::StrCat("[", AttrListToPython(value, dtype_module), "]");
} else {
return "?";
}
@ -432,35 +432,41 @@ void GenerateLowerCaseOpName(const string& str, string* result) {
}
}
} // namespace
static void AddDelimiter(string* append_to, const string& delim) {
if (!append_to->empty()) strings::StrAppend(append_to, delim);
}
string GetPythonOp(const OpDef& op_def, bool is_hidden, const string& op_name) {
string result;
// Map from attr name to the first input arg it is inferred from.
std::unordered_map<string, string> inferred_attrs;
GenPythonOp::GenPythonOp(const OpDef& op_def, const string& function_name)
: op_def_(op_def),
function_name_(function_name),
num_outs_(op_def.output_arg_size()) {}
GenPythonOp::~GenPythonOp() {}
string GenPythonOp::Code() {
// This has all the input args followed by those attrs that don't have
// defaults.
std::vector<string> args_no_default;
// The parameters with defaults (these have to be listed after those without).
// No input args are included, just attrs.
std::vector<string> args_with_defaults;
for (int i = 0; i < op_def.input_arg_size(); ++i) {
const auto& arg(op_def.input_arg(i));
for (int i = 0; i < op_def_.input_arg_size(); ++i) {
const auto& arg(op_def_.input_arg(i));
args_no_default.push_back(arg.name());
if (!arg.type_attr().empty()) {
gtl::InsertIfNotPresent(&inferred_attrs, arg.type_attr(), arg.name());
gtl::InsertIfNotPresent(&inferred_attrs_, arg.type_attr(), arg.name());
} else if (!arg.type_list_attr().empty()) {
gtl::InsertIfNotPresent(&inferred_attrs, arg.type_list_attr(),
gtl::InsertIfNotPresent(&inferred_attrs_, arg.type_list_attr(),
arg.name());
}
if (!arg.number_attr().empty()) {
gtl::InsertIfNotPresent(&inferred_attrs, arg.number_attr(), arg.name());
gtl::InsertIfNotPresent(&inferred_attrs_, arg.number_attr(), arg.name());
}
}
for (int i = 0; i < op_def.attr_size(); ++i) {
const auto& attr(op_def.attr(i));
for (int i = 0; i < op_def_.attr_size(); ++i) {
const auto& attr(op_def_.attr(i));
// Do not add inferred attrs to the Python function signature.
if (inferred_attrs.find(attr.name()) == inferred_attrs.end()) {
if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
if (attr.has_default_value()) {
args_with_defaults.push_back(attr.name());
} else {
@ -471,110 +477,92 @@ string GetPythonOp(const OpDef& op_def, bool is_hidden, const string& op_name) {
// Save the list of attr parameters (attrs that won't be inferred),
// those with defaults go at the end.
std::vector<string> attrs;
// Get the attrs in the order we want by taking the attrs without defaults
// from the end of args_no_default, and adding args_no_default.
attrs.reserve(args_no_default.size() - op_def.input_arg_size() +
args_with_defaults.size());
attrs.insert(attrs.end(), args_no_default.begin() + op_def.input_arg_size(),
args_no_default.end());
attrs.insert(attrs.end(), args_with_defaults.begin(),
args_with_defaults.end());
attrs_.reserve(args_no_default.size() - op_def_.input_arg_size() +
args_with_defaults.size());
attrs_.insert(attrs_.end(),
args_no_default.begin() + op_def_.input_arg_size(),
args_no_default.end());
attrs_.insert(attrs_.end(), args_with_defaults.begin(),
args_with_defaults.end());
std::vector<string> param_names;
param_names.reserve(args_no_default.size() + args_with_defaults.size());
param_names_.reserve(args_no_default.size() + args_with_defaults.size());
string parameters;
for (const string& name : args_no_default) {
if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
AddDelimiter(&parameters, ", ");
const string param = AvoidPythonReserved(name);
strings::StrAppend(&parameters, param);
param_names.push_back(param);
param_names_.push_back(param);
}
for (const string& name : args_with_defaults) {
if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
AddDelimiter(&parameters, ", ");
const string param = AvoidPythonReserved(name);
strings::StrAppend(&parameters, param, "=None");
param_names.push_back(param);
param_names_.push_back(param);
}
AddDelimiter(&parameters, ", ");
strings::StrAppend(&parameters, "name=None");
const string lower_op_name = strings::StrCat(is_hidden ? "_" : "", op_name);
AddDefLine(parameters);
AddDocStringDescription();
AddDocStringArgs();
AddDocStringInputs();
AddDocStringAttrs();
AddDocStringNameArg();
AddOutputGlobals();
AddDocStringOutputs();
strings::StrAppend(&result_, " \"\"\"\n");
AddBody(" ");
strings::StrAppend(&result_, "\n\n");
const int num_outs = op_def.output_arg_size();
// Prepare a NamedTuple type to hold the outputs, if there are multiple
if (num_outs > 1) {
// Prepare the list of output names
std::vector<string> out_names(num_outs);
for (int i = 0; i < num_outs; ++i) {
if (!op_def.output_arg(i).name().empty()) {
out_names[i] = op_def.output_arg(i).name();
} else {
out_names[i] = strings::StrCat("output", i);
}
}
string out_names_list =
strings::StrCat("[\"", str_util::Join(out_names, "\", \""), "\"]");
return prelude_ + result_;
}
// Provide the output names as a Python list
string lower_op_name_outputs =
strings::StrCat("_", lower_op_name, "_outputs");
const string outputs_prefix = strings::StrCat(lower_op_name_outputs, " = ");
strings::StrAppend(&result, "\n",
WordWrap(outputs_prefix, out_names_list, kRightMargin),
"\n");
void GenPythonOp::AddDefLine(const string& parameters) {
const string def_prefix = strings::StrCat("def ", function_name_, "(");
strings::StrAppend(
&result_, WordWrap(def_prefix, parameters + "):", kRightMargin), "\n");
}
strings::StrAppend(&result, "_", op_def.name(),
"Output = _collections.namedtuple(\n");
const string tuple_type_prefix = " ";
const string tuple_type_suffix = strings::StrCat(
"\"", op_def.name(), "\", ", lower_op_name_outputs, ")");
strings::StrAppend(
&result, WordWrap(tuple_type_prefix, tuple_type_suffix, kRightMargin),
"\n\n");
}
strings::StrAppend(&result, "\n");
// Print: def Function(parameters):
const string def_prefix = strings::StrCat("def ", lower_op_name, "(");
const bool has_args = args_no_default.size() + args_with_defaults.size() > 0;
const string def_suffix =
strings::StrCat(parameters, has_args ? ", " : "", "name=None):");
strings::StrAppend(&result, WordWrap(def_prefix, def_suffix, kRightMargin),
"\n");
// Format the Op's descriptions so that it can be a Python docstring.
void GenPythonOp::AddDocStringDescription() {
string comment;
if (op_def.summary().empty()) {
if (op_def_.summary().empty()) {
comment = "TODO: add doc.\n";
} else {
comment = strings::StrCat(op_def.summary(), "\n");
if (!op_def.description().empty()) {
strings::StrAppend(&comment, "\n", Indent(2, 2, op_def.description()));
comment = strings::StrCat(op_def_.summary(), "\n");
if (!op_def_.description().empty()) {
strings::StrAppend(&comment, "\n", Indent(2, 2, op_def_.description()));
}
}
strings::StrAppend(&result_, " r\"\"\"", comment, "\n");
}
strings::StrAppend(&result, " r\"\"\"", comment, "\n Args:\n");
void GenPythonOp::AddDocStringArgs() {
strings::StrAppend(&result_, " Args:\n");
}
// Inputs
for (int i = 0; i < op_def.input_arg_size(); ++i) {
const auto& arg(op_def.input_arg(i));
StringPiece description = op_def.input_arg(i).description();
void GenPythonOp::AddDocStringInputs() {
for (int i = 0; i < op_def_.input_arg_size(); ++i) {
const auto& arg(op_def_.input_arg(i));
StringPiece description = op_def_.input_arg(i).description();
string desc;
if (ConsumeEquals(&description)) { // Skip the generated type info.
desc = strings::StrCat(param_names[i], ": ");
desc = strings::StrCat(param_names_[i], ": ");
} else {
desc = strings::StrCat(param_names[i], ": ",
ArgTypeName(op_def, arg, inferred_attrs, false));
desc = strings::StrCat(param_names_[i], ": ",
ArgTypeName(op_def_, arg, inferred_attrs_, false));
}
if (!description.empty()) {
AppendWithinWidth(&desc, description, kRightMargin - 4 /* indent */);
}
strings::StrAppend(&result, Indent(4, 6, desc));
strings::StrAppend(&result_, Indent(4, 6, desc));
}
}
// Attrs
for (const string& name : attrs) {
const auto& attr = *FindAttr(name, op_def);
void GenPythonOp::AddDocStringAttrs() {
for (const string& name : attrs_) {
const auto& attr = *FindAttr(name, op_def_);
string desc = strings::StrCat(AvoidPythonReserved(name), ": ");
static const char* const kAttrTypeName[][2] = {
@ -638,40 +626,86 @@ string GetPythonOp(const OpDef& op_def, bool is_hidden, const string& op_name) {
AppendWithinWidth(&desc, attr.description(),
kRightMargin - 4 /* indent */);
}
strings::StrAppend(&result, Indent(4, 6, desc));
strings::StrAppend(&result_, Indent(4, 6, desc));
}
}
strings::StrAppend(&result,
void GenPythonOp::AddDocStringNameArg() {
strings::StrAppend(&result_,
" name: A name for the operation (optional).\n");
}
std::vector<string> output_type_string;
output_type_string.reserve(num_outs);
for (int i = 0; i < num_outs; ++i) {
output_type_string.push_back(
ArgTypeName(op_def, op_def.output_arg(i), inferred_attrs, true));
void GenPythonOp::AddOutputGlobals() {
// Prepare a NamedTuple type to hold the outputs, if there are multiple
if (num_outs_ > 1) {
// Prepare the list of output names
std::vector<string> out_names(num_outs_);
for (int i = 0; i < num_outs_; ++i) {
if (!op_def_.output_arg(i).name().empty()) {
out_names[i] = op_def_.output_arg(i).name();
} else {
out_names[i] = strings::StrCat("output", i);
}
}
string out_names_list =
strings::StrCat("[\"", str_util::Join(out_names, "\", \""), "\"]");
// Provide the output names as a Python list
string lower_op_name_outputs =
strings::StrCat("_", function_name_, "_outputs");
const string outputs_prefix = strings::StrCat(lower_op_name_outputs, " = ");
strings::StrAppend(&prelude_, "\n",
WordWrap(outputs_prefix, out_names_list, kRightMargin),
"\n");
strings::StrAppend(&prelude_, "_", op_def_.name(),
"Output = _collections.namedtuple(\n");
const string tuple_type_prefix = " ";
const string tuple_type_suffix = strings::StrCat(
"\"", op_def_.name(), "\", ", lower_op_name_outputs, ")");
strings::StrAppend(
&prelude_, WordWrap(tuple_type_prefix, tuple_type_suffix, kRightMargin),
"\n\n");
}
strings::StrAppend(&result, GetReturns(op_def, output_type_string));
strings::StrAppend(&prelude_, "\n");
}
string return_prefix = strings::StrCat(" result = _op_def_lib.apply_op(");
string return_args = strings::StrCat("\"", op_def.name(), "\", ");
for (size_t i = 0; i < param_names.size(); ++i) {
strings::StrAppend(&return_args, param_names[i], "=", param_names[i], ", ");
void GenPythonOp::AddDocStringOutputs() {
std::vector<string> output_type_string;
output_type_string.reserve(num_outs_);
for (int i = 0; i < num_outs_; ++i) {
output_type_string.push_back(
ArgTypeName(op_def_, op_def_.output_arg(i), inferred_attrs_, true));
}
strings::StrAppend(&result_, GetReturns(op_def_, output_type_string));
}
void GenPythonOp::AddBody(const string& prefix) {
string return_prefix =
strings::StrCat(prefix, "result = _op_def_lib.apply_op(");
string return_args = strings::StrCat("\"", op_def_.name(), "\", ");
for (size_t i = 0; i < param_names_.size(); ++i) {
strings::StrAppend(&return_args, param_names_[i], "=", param_names_[i],
", ");
}
strings::StrAppend(&return_args, "name=name)");
strings::StrAppend(&result, " \"\"\"\n",
strings::StrAppend(&result_,
// Wrap the arguments, and indent to the (.
WordWrap(return_prefix, return_args, kRightMargin), "\n");
if (num_outs <= 1) {
strings::StrAppend(&result, " return result\n");
if (num_outs_ <= 1) {
strings::StrAppend(&result_, prefix, "return result\n");
} else {
strings::StrAppend(&result, " return _", op_def.name(),
strings::StrAppend(&result_, prefix, "return _", op_def_.name(),
"Output._make(result)\n");
}
strings::StrAppend(&result, "\n\n");
}
return result;
} // namespace python_op_gen_internal
string GetPythonOp(const OpDef& op_def, const string& function_name) {
return python_op_gen_internal::GenPythonOp(op_def, function_name).Code();
}
string GetPythonOps(const OpList& ops, const std::vector<string>& hidden_ops,
@ -711,20 +745,20 @@ from tensorflow.python.framework import op_def_library as _op_def_library
}
}
// PrintPythonOp(op_def, is_hidden, op_def.name());
string lower_case_name;
GenerateLowerCaseOpName(op_def.name(), &lower_case_name);
string function_name;
python_op_gen_internal::GenerateLowerCaseOpName(op_def.name(),
&function_name);
if (is_hidden) function_name = strings::StrCat("_", function_name);
// When users create custom python wrappers, they may link in the
// default op registry by accident, and because they can't
// enumerate all 'hidden' symbols, this guard is to prevent
// instantiating a python reserved word in their wrapper.
if (!is_hidden && IsPythonReserved(lower_case_name)) {
if (python_op_gen_internal::IsPythonReserved(function_name)) {
continue;
}
strings::StrAppend(&result,
GetPythonOp(op_def, is_hidden, lower_case_name));
strings::StrAppend(&result, GetPythonOp(op_def, function_name));
if (!require_shapes) {
strings::StrAppend(&result, "_ops.RegisterShape(\"", op_def.name(),

View File

@ -31,7 +31,7 @@ void PrintPythonOps(const OpList& ops, const std::vector<string>& hidden_ops,
bool require_shapes);
string GetPythonOps(const OpList& ops, const std::vector<string>& hidden_ops,
bool require_shapes);
string GetPythonOp(const OpDef& op_def, bool is_hidden, const string& op_name);
string GetPythonOp(const OpDef& op_def, const string& function_name);
// Get the python wrappers for a list of ops in a OpList.
// `op_list_buf` should be a pointer to a buffer containing

View File

@ -0,0 +1,86 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_INTERNAL_H_
#define THIRD_PARTY_TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_INTERNAL_H_
#include <unordered_map>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace python_op_gen_internal {
// Returns true if s is a Python keyword or built-in.
bool IsPythonReserved(const string& s);
// Add a _ to the end of s if necessary to avoid a Python keyword or built-in.
string AvoidPythonReserved(const string& s);
// Convert an AttrValue with type `type` to the Python representation for
// that value.
string AttrValueToPython(const string& type, const AttrValue& value,
const string& dtype_module = "tf.");
void GenerateLowerCaseOpName(const string& str, string* result);
class GenPythonOp {
public:
GenPythonOp(const OpDef& op_def, const string& function_name);
virtual ~GenPythonOp();
virtual string Code();
protected:
// Print: def Function(parameters):
void AddDefLine(const string& parameters);
// Format the Op's descriptions so that it can be a Python docstring.
void AddDocStringDescription();
void AddDocStringArgs();
void AddDocStringInputs();
void AddDocStringAttrs();
void AddDocStringNameArg();
void AddOutputGlobals();
void AddDocStringOutputs();
void AddBody(const string& prefix);
// From constructor arguments
const OpDef& op_def_;
const string& function_name_;
const int num_outs_;
// Return value from Code() is prelude_ + result_.
string prelude_; // Code before function definition
string result_; // Function definition
// Map from attr name to the first input arg it is inferred from
std::unordered_map<string, string> inferred_attrs_;
// The names of the non-inferred attrs, in parameter order
std::vector<string> attrs_;
// All parameters, including inputs & non-inferred attrs, required and those
// with defaults, except "name"
std::vector<string> param_names_;
};
} // namespace python_op_gen_internal
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_INTERNAL_H_