Actually use ApiDef when generating Python API.

PiperOrigin-RevId: 177851421
This commit is contained in:
Anna R 2017-12-04 12:31:03 -08:00 committed by TensorFlower Gardener
parent a1c29139cc
commit 8f1e63d562
13 changed files with 302 additions and 187 deletions

View File

@ -739,7 +739,7 @@ function(GENERATE_PYTHON_OP_LIB tf_python_op_lib_name)
# containing the wrappers.
add_custom_command(
OUTPUT ${GENERATE_PYTHON_OP_LIB_DESTINATION}
COMMAND ${tf_python_op_lib_name}_gen_python @${tensorflow_source_dir}/tensorflow/python/ops/hidden_ops.txt ${require_shape_fn} > ${GENERATE_PYTHON_OP_LIB_DESTINATION}
COMMAND ${tf_python_op_lib_name}_gen_python ${tensorflow_source_dir}/tensorflow/core/api_def/base_api,${tensorflow_source_dir}/tensorflow/core/api_def/python_api @${tensorflow_source_dir}/tensorflow/python/ops/hidden_ops.txt ${require_shape_fn} > ${GENERATE_PYTHON_OP_LIB_DESTINATION}
DEPENDS ${tf_python_op_lib_name}_gen_python
)

View File

@ -3416,7 +3416,7 @@ filegroup(
filegroup(
name = "python_api_def",
data = glob(["api_def/python_api/*"]),
srcs = glob(["api_def/python_api/*"]),
)
tf_cc_test(

View File

@ -629,14 +629,11 @@ Status ApiDefMap::LoadApiDef(const string& api_def_file_contents) {
ApiDefs api_defs;
protobuf::TextFormat::ParseFromString(contents, &api_defs);
for (const auto& api_def : api_defs.op()) {
// Check if the op definition is already loaded.
// Check if the op definition is loaded. If op definition is not
// loaded, then we just skip this ApiDef.
if (map_.find(api_def.graph_op_name()) != map_.end()) {
// Overwrite current api def with data in api_def.
TF_RETURN_IF_ERROR(MergeApiDefs(&map_[api_def.graph_op_name()], api_def));
} else {
return errors::FailedPrecondition(
"Unexpected ApiDef override: ", api_def.graph_op_name(),
" is not defined in base ApiDef.");
}
}
return Status::OK();

View File

@ -410,8 +410,8 @@ op {
ApiDefMap api_map(op_list);
TF_CHECK_OK(api_map.LoadApiDef(kTestApiDef));
auto status = api_map.LoadApiDef(api_def1);
ASSERT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
TF_CHECK_OK(api_map.LoadApiDef(api_def1));
ASSERT_EQ(nullptr, api_map.GetApiDef("different_testop"));
}
TEST(OpGenLibTest, ApiDefInvalidArgOrder) {

View File

@ -27,4 +27,8 @@ def tf_gen_op_wrapper_private_py(name, out=None, deps=[],
deps=deps,
require_shape_functions=require_shape_functions,
generated_target_name=name,
api_def_srcs = [
"//tensorflow/core:base_api_def",
"//tensorflow/core:python_api_def",
],
)

View File

@ -10,7 +10,9 @@ def tfe_gen_op_wrapper_py(name,
out=None,
visibility=None,
deps=[],
generated_target_name=None):
generated_target_name=None,
# ApiDefs will be loaded in the order specified in this list.
api_def_srcs=[]):
"""Generate an eager-mode Python op wrapper for an op library."""
# Construct a cc_binary containing the specified ops.
tool_name = "gen_" + name + "_py_wrappers_cc"
@ -30,11 +32,25 @@ def tfe_gen_op_wrapper_py(name,
if not out:
out = "gen_" + name + ".py"
if not api_def_srcs:
api_def_args_str = ","
else:
api_def_args = []
for api_def_src in api_def_srcs:
# Add directory of the first ApiDef source to args.
# We are assuming all ApiDefs in a single api_def_src are in the
# same directory.
api_def_args.append(
"$$(dirname $$(echo $(locations " + api_def_src +
") | cut -d\" \" -f1))")
api_def_args_str = ",".join(api_def_args)
native.genrule(
name=name + "_pygenrule",
outs=[out],
srcs=api_def_srcs,
tools=[tool_name] + tf_binary_additional_srcs(),
cmd=("$(location " + tool_name + ") > $@"))
cmd=("$(location " + tool_name + ") " + api_def_args_str + " > $@"))
# Make a py_library out of the generated python file.
if not generated_target_name:

View File

@ -99,6 +99,15 @@ string TensorPBString(const TensorProto& pb) {
return strings::StrCat("\"\"\"", ProtoShortDebugString(pb), "\"\"\"");
}
const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) {
for (int i = 0; i < api_def.in_arg_size(); ++i) {
if (api_def.in_arg(i).name() == name) {
return &api_def.in_arg(i);
}
}
return nullptr;
}
class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
public:
GenEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
@ -164,14 +173,14 @@ string GenEagerPythonOp::FlattenInputs(
} else if (inputs_state == WAS_LIST_INPUT) {
strings::StrAppend(&inputs, " + ");
}
strings::StrAppend(&inputs, "list(", param_names_[i], ")");
strings::StrAppend(&inputs, "list(", param_names_[i].GetRenameTo(), ")");
inputs_state = WAS_LIST_INPUT;
if (output_sizes != nullptr) {
if (!arg.number_attr().empty()) {
output_sizes->emplace_back(AttrVarName(arg.number_attr(), nullptr));
} else {
output_sizes->emplace_back(
strings::StrCat("len(", param_names_[i], ")"));
strings::StrCat("len(", param_names_[i].GetRenameTo(), ")"));
}
}
} else {
@ -182,7 +191,7 @@ string GenEagerPythonOp::FlattenInputs(
} else {
strings::StrAppend(&inputs, "[");
}
strings::StrAppend(&inputs, param_names_[i]);
strings::StrAppend(&inputs, param_names_[i].GetRenameTo());
inputs_state = WAS_SOLO_INPUT;
if (output_sizes != nullptr) output_sizes->emplace_back();
}
@ -195,15 +204,21 @@ string GenEagerPythonOp::FlattenInputs(
}
string GenEagerPythonOp::Code() {
if (api_def_.visibility() == ApiDef::SKIP) {
return "";
}
// This has all the input args followed by those attrs that don't have
// defaults.
std::vector<string> args_no_default;
std::vector<python_op_gen_internal::ParamNames> params_no_default;
// The parameters with defaults (these have to be listed after those without).
// No input args are included, just attrs.
std::vector<std::pair<string, string>> args_with_defaults;
for (int i = 0; i < op_def_.input_arg_size(); ++i) {
const auto& arg(op_def_.input_arg(i));
args_no_default.push_back(arg.name());
std::vector<std::pair<python_op_gen_internal::ParamNames, string>>
params_with_default;
for (int i = 0; i < api_def_.arg_order_size(); ++i) {
const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_);
const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_);
params_no_default.emplace_back(api_def_arg.name(), api_def_arg.rename_to());
if (!arg.type_attr().empty()) {
AddAttrForArg(arg.type_attr(), i);
} else if (!arg.type_list_attr().empty()) {
@ -215,31 +230,39 @@ string GenEagerPythonOp::Code() {
}
for (int i = 0; i < op_def_.attr_size(); ++i) {
const auto& attr(op_def_.attr(i));
const auto& api_def_attr(api_def_.attr(i));
// Do not add inferred attrs to the Python function signature.
if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
if (attr.has_default_value()) {
if (api_def_attr.has_default_value()) {
if (attr.type() == "tensor") {
args_with_defaults.emplace_back(
attr.name(),
strings::StrCat("_execute.make_tensor(",
TensorPBString(attr.default_value().tensor()),
", \"", attr.name(), "\")"));
params_with_default.emplace_back(
python_op_gen_internal::ParamNames(api_def_attr.name(),
api_def_attr.rename_to()),
strings::StrCat(
"_execute.make_tensor(",
TensorPBString(api_def_attr.default_value().tensor()), ", \"",
api_def_attr.rename_to(), "\")"));
} else if (attr.type() == "list(tensor)") {
std::vector<string> pbtxt;
for (const auto& pb : attr.default_value().list().tensor()) {
for (const auto& pb : api_def_attr.default_value().list().tensor()) {
pbtxt.emplace_back(TensorPBString(pb));
}
args_with_defaults.emplace_back(
attr.name(),
strings::StrCat("[_execute.make_tensor(_pb, \"", attr.name(),
"\") for _pb in ", VectorToTuple(pbtxt), "]"));
params_with_default.emplace_back(
python_op_gen_internal::ParamNames(api_def_attr.name(),
api_def_attr.rename_to()),
strings::StrCat("[_execute.make_tensor(_pb, \"",
api_def_attr.rename_to(), "\") for _pb in ",
VectorToTuple(pbtxt), "]"));
} else {
args_with_defaults.emplace_back(
attr.name(), python_op_gen_internal::AttrValueToPython(
attr.type(), attr.default_value(), "_dtypes."));
params_with_default.emplace_back(
python_op_gen_internal::ParamNames(api_def_attr.name(),
api_def_attr.rename_to()),
python_op_gen_internal::AttrValueToPython(
attr.type(), api_def_attr.default_value(), "_dtypes."));
}
} else {
args_no_default.push_back(attr.name());
params_no_default.emplace_back(api_def_attr.name(),
api_def_attr.rename_to());
}
}
}
@ -247,34 +270,37 @@ string GenEagerPythonOp::Code() {
// Save the list of attr parameters (attrs that won't be inferred),
// those with defaults go at the end.
// Get the attrs in the order we want by taking the attrs without defaults
// from the end of args_no_default, and adding args_no_default.
attrs_.reserve(args_no_default.size() - op_def_.input_arg_size() +
args_with_defaults.size());
attrs_.insert(attrs_.end(),
args_no_default.begin() + op_def_.input_arg_size(),
args_no_default.end());
for (const auto& a : args_with_defaults) {
attrs_.push_back(a.first);
// from the end of params_no_default, and adding params_no_default.
attrs_.reserve(params_no_default.size() - op_def_.input_arg_size() +
params_with_default.size());
for (int i = op_def_.input_arg_size(); i < params_no_default.size(); ++i) {
attrs_.push_back(params_no_default[i].GetName());
}
for (const auto& p : params_with_default) {
attrs_.push_back(p.first.GetName());
}
param_names_.reserve(args_no_default.size() + args_with_defaults.size());
string parameters;
for (const string& name : args_no_default) {
if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
const string param = python_op_gen_internal::AvoidPythonReserved(name);
strings::StrAppend(&parameters, param);
param_names_.push_back(param);
param_names_.reserve(params_no_default.size() + params_with_default.size());
param_names_.insert(param_names_.begin(), params_no_default.begin(),
params_no_default.end());
for (const auto& param_and_default : params_with_default) {
param_names_.push_back(param_and_default.first);
}
for (const auto& name_default : args_with_defaults) {
string parameters;
for (const auto& param : params_no_default) {
if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
const string param =
python_op_gen_internal::AvoidPythonReserved(name_default.first);
strings::StrAppend(&parameters, param, "=", name_default.second);
param_names_.push_back(param);
strings::StrAppend(&parameters, param.GetRenameTo());
}
for (const auto& param_and_default : params_with_default) {
if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
strings::StrAppend(&parameters, param_and_default.first.GetRenameTo(), "=",
param_and_default.second);
}
if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
strings::StrAppend(&parameters, "name=None");
AddExport();
AddDefLine(parameters);
AddDocStringDescription();
AddDocStringArgs();
@ -297,25 +323,26 @@ string GenEagerPythonOp::Code() {
// inputs are lists and have the same length.
for (auto iter = arg_list->second.begin();
iter != arg_list->second.end(); ++iter) {
const string& arg_name = param_names_[*iter];
ExpectListArg(arg_name);
const string& arg_api_name = param_names_[*iter].GetRenameTo();
ExpectListArg(arg_api_name);
if (iter == arg_list->second.begin()) {
AddInferredAttr(attr.name(), strings::StrCat("len(", arg_name, ")"),
AddInferredAttr(attr.name(),
strings::StrCat("len(", arg_api_name, ")"),
&result_, &attr_expressions_);
} else {
const auto& attr_var = attr_expressions_[attr.name()];
strings::StrAppend(&result_, " if len(", arg_name,
strings::StrAppend(&result_, " if len(", arg_api_name,
") != ", attr_var,
":\n"
" raise ValueError(\n"
" \"List argument '",
arg_name, "' to '", op_name_,
arg_api_name, "' to '", op_name_,
"' Op with length %d \"\n"
" \"must match length %d of argument '",
inferred_attrs_[attr.name()],
"'.\" %\n"
" (len(",
arg_name, "), ", attr_var, "))\n");
arg_api_name, "), ", attr_var, "))\n");
}
}
}
@ -325,65 +352,76 @@ string GenEagerPythonOp::Code() {
// Values for non-inferred attrs.
for (int i = 0; i < attrs_.size(); ++i) {
const string& attr_name = attrs_[i];
const string& param = param_names_[i + op_def_.input_arg_size()];
const auto& param = param_names_[i + op_def_.input_arg_size()];
const auto& attr = *FindAttr(attr_name, op_def_);
const string& attr_api_name = param.GetRenameTo();
StringPiece attr_type = attr.type();
attr_expressions_[attr_name] = param;
const int default_index = i - (attrs_.size() - args_with_defaults.size());
attr_expressions_[attr_name] = attr_api_name;
const int default_index = i - (attrs_.size() - params_with_default.size());
if (default_index >= 0) {
const string& default_value = args_with_defaults[default_index].second;
strings::StrAppend(&result_, " if ", param, " is None:\n");
strings::StrAppend(&result_, " ", param, " = ", default_value, "\n");
const string& default_value = params_with_default[default_index].second;
strings::StrAppend(&result_, " if ", attr_api_name, " is None:\n");
strings::StrAppend(&result_, " ", attr_api_name, " = ", default_value,
"\n");
}
if (attr_type.starts_with("list(")) {
ExpectListArg(param);
ExpectListArg(attr_api_name);
}
if (attr_type == "string") {
strings::StrAppend(&result_, " ", param, " = _execute.make_str(", param,
", \"", param, "\")\n");
strings::StrAppend(&result_, " ", attr_api_name, " = _execute.make_str(",
attr_api_name, ", \"", attr_api_name, "\")\n");
} else if (attr_type == "list(string)") {
strings::StrAppend(&result_, " ", param, " = [_execute.make_str(_s, \"",
param, "\") for _s in ", param, "]\n");
strings::StrAppend(&result_, " ", attr_api_name,
" = [_execute.make_str(_s, \"", attr_api_name,
"\") for _s in ", attr_api_name, "]\n");
} else if (attr_type == "int") {
strings::StrAppend(&result_, " ", param, " = _execute.make_int(", param,
", \"", param, "\")\n");
strings::StrAppend(&result_, " ", attr_api_name, " = _execute.make_int(",
attr_api_name, ", \"", attr_api_name, "\")\n");
} else if (attr_type == "list(int)") {
strings::StrAppend(&result_, " ", param, " = [_execute.make_int(_i, \"",
param, "\") for _i in ", param, "]\n");
strings::StrAppend(&result_, " ", attr_api_name,
" = [_execute.make_int(_i, \"", attr_api_name,
"\") for _i in ", attr_api_name, "]\n");
} else if (attr_type == "float") {
strings::StrAppend(&result_, " ", param, " = _execute.make_float(",
param, ", \"", param, "\")\n");
strings::StrAppend(&result_, " ", attr_api_name,
" = _execute.make_float(", attr_api_name, ", \"",
attr_api_name, "\")\n");
} else if (attr_type == "list(float)") {
strings::StrAppend(&result_, " ", param,
" = [_execute.make_float(_f, \"", param,
"\") for _f in ", param, "]\n");
strings::StrAppend(&result_, " ", attr_api_name,
" = [_execute.make_float(_f, \"", attr_api_name,
"\") for _f in ", attr_api_name, "]\n");
} else if (attr_type == "bool") {
strings::StrAppend(&result_, " ", param, " = _execute.make_bool(", param,
", \"", param, "\")\n");
strings::StrAppend(&result_, " ", attr_api_name,
" = _execute.make_bool(", attr_api_name, ", \"",
attr_api_name, "\")\n");
} else if (attr_type == "list(bool)") {
strings::StrAppend(&result_, " ", param, " = [_execute.make_bool(_b, \"",
param, "\") for _b in ", param, "]\n");
strings::StrAppend(&result_, " ", attr_api_name,
" = [_execute.make_bool(_b, \"", attr_api_name,
"\") for _b in ", attr_api_name, "]\n");
} else if (attr_type == "type") {
strings::StrAppend(&result_, " ", param, " = _execute.make_type(", param,
", \"", param, "\")\n");
strings::StrAppend(&result_, " ", attr_api_name,
" = _execute.make_type(", attr_api_name, ", \"",
attr_api_name, "\")\n");
} else if (attr_type == "list(type)") {
strings::StrAppend(&result_, " ", param, " = [_execute.make_type(_t, \"",
param, "\") for _t in ", param, "]\n");
strings::StrAppend(&result_, " ", attr_api_name,
" = [_execute.make_type(_t, \"", attr_api_name,
"\") for _t in ", attr_api_name, "]\n");
} else if (attr_type == "shape") {
strings::StrAppend(&result_, " ", param, " = _execute.make_shape(",
param, ", \"", param, "\")\n");
strings::StrAppend(&result_, " ", attr_api_name,
" = _execute.make_shape(", attr_api_name, ", \"",
attr_api_name, "\")\n");
} else if (attr_type == "list(shape)") {
strings::StrAppend(&result_, " ", param,
" = [_execute.make_shape(_s, \"", param,
"\") for _s in ", param, "]\n");
strings::StrAppend(&result_, " ", attr_api_name,
" = [_execute.make_shape(_s, \"", attr_api_name,
"\") for _s in ", attr_api_name, "]\n");
} else if (attr_type == "tensor") {
strings::StrAppend(&result_, " ", param, " = _execute.make_tensor(",
param, ", \"", param, "\")\n");
strings::StrAppend(&result_, " ", attr_api_name,
" = _execute.make_tensor(", attr_api_name, ", \"",
attr_api_name, "\")\n");
} else if (attr_type == "list(tensor)") {
strings::StrAppend(&result_, " ", param,
" = [_execute.make_tensor(_t, \"", param,
"\") for _t in ", param, "]\n");
strings::StrAppend(&result_, " ", attr_api_name,
" = [_execute.make_tensor(_t, \"", attr_api_name,
"\") for _t in ", attr_api_name, "]\n");
} else if (attr_type != "func") {
return strings::StrCat("# No definition for ", function_name_,
" since we don't support attrs with type\n"
@ -484,16 +522,20 @@ string GenEagerPythonOp::Code() {
bool eager_allowed = true;
string ref_arg;
for (const auto& arg : op_def_.input_arg()) {
for (int i = 0; i < op_def_.input_arg_size(); ++i) {
const auto& arg = op_def_.input_arg(i);
if (arg.is_ref()) {
eager_allowed = false;
ref_arg = arg.name();
DCHECK_EQ(op_def_.input_arg(i).name(), api_def_.in_arg(i).name());
ref_arg = api_def_.in_arg(i).rename_to();
}
}
for (const auto& arg : op_def_.output_arg()) {
for (int i = 0; i < op_def_.output_arg_size(); ++i) {
const auto& arg = op_def_.output_arg(i);
if (arg.is_ref()) {
eager_allowed = false;
ref_arg = arg.name();
DCHECK_EQ(op_def_.output_arg(i).name(), api_def_.out_arg(i).name());
ref_arg = api_def_.out_arg(i).rename_to();
}
}
@ -553,6 +595,7 @@ void GenEagerPythonOp::AddEagerInferredAttrs() {
// Figure out values for inferred attrs, and cast to eager tensors.
for (int i = 0; i < op_def_.attr_size(); ++i) {
const auto& attr(op_def_.attr(i));
const auto& api_def_attr(api_def_.attr(i));
auto arg_list = attr_to_args_.find(attr.name());
if (arg_list != attr_to_args_.end()) {
if (attr.type() == "type") {
@ -565,14 +608,15 @@ void GenEagerPythonOp::AddEagerInferredAttrs() {
strings::StrAppend(
&conversion, ", ",
python_op_gen_internal::AttrValueToPython(
attr.type(), attr.default_value(), "_dtypes."));
attr.type(), api_def_attr.default_value(), "_dtypes."));
}
strings::StrAppend(&conversion, ")");
const string var_name = AttrVarName(attr.name(), &attr_expressions_);
if (output_sizes.size() == 1) {
// Avoid creating a temporary variable in the case where
// we can easily assign to the right value directly.
const string inputs_var = param_names_[arg_list->second.front()];
const string inputs_var =
param_names_[arg_list->second.front()].GetRenameTo();
if (output_sizes.front().empty()) {
strings::StrAppend(&result_, " ", var_name, ", (", inputs_var,
",) = ", conversion, "\n");
@ -589,7 +633,7 @@ void GenEagerPythonOp::AddEagerInferredAttrs() {
Unflatten(" ", output_sizes, inputs_var, &result_);
std::vector<string> p;
for (int j : arg_list->second) {
p.emplace_back(param_names_[j]);
p.emplace_back(param_names_[j].GetRenameTo());
}
strings::StrAppend(&result_, " ", VectorToTuple(p), " = ",
inputs_var, "\n");
@ -608,14 +652,14 @@ void GenEagerPythonOp::AddEagerInferredAttrs() {
std::vector<string> lists;
for (auto iter = arg_list->second.begin();
iter != arg_list->second.end(); ++iter) {
lists.push_back(param_names_[*iter]);
lists.push_back(param_names_[*iter].GetRenameTo());
}
inputs_var = VectorToTuple(lists);
conversion = "_execute.args_to_mixed_eager_tensors";
} else {
// For one list(tensor) argument, we just convert every
// element of the list to an eager tensor.
inputs_var = param_names_[arg_list->second.front()];
inputs_var = param_names_[arg_list->second.front()].GetRenameTo();
conversion = "_execute.convert_to_mixed_eager_tensors";
}
strings::StrAppend(&result_, " ", var_name, ", ", inputs_var, " = ",
@ -630,7 +674,7 @@ void GenEagerPythonOp::AddEagerInputCasts() {
for (int i = 0; i < op_def_.input_arg_size(); ++i) {
const auto& arg(op_def_.input_arg(i));
if (!arg.type_attr().empty() || !arg.type_list_attr().empty()) continue;
const string& param = param_names_[i];
const string& param = param_names_[i].GetRenameTo();
const string fn = arg.number_attr().empty() ? "" : "n_";
const string dtype =
python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes.");

View File

@ -21,34 +21,32 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/op_gen_lib.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
namespace tensorflow {
namespace {
constexpr char kBaseApiDef[] =
"tensorflow/core/api_def/base_api/*.pbtxt";
constexpr char kPythonApiDef[] =
"tensorflow/core/api_def/python_api/*.pbtxt";
constexpr bool kUseApiDef = false;
void PrintAllPythonOps(const std::vector<string>& hidden_ops) {
void PrintAllPythonOps(const std::vector<string>& hidden_ops,
const std::vector<string>& api_def_dirs) {
OpList ops;
OpRegistry::Global()->Export(false, &ops);
ApiDefMap api_def_map(ops);
if (kUseApiDef) {
if (!api_def_dirs.empty()) {
Env* env = Env::Default();
std::vector<string> base_api_files;
std::vector<string> python_api_files;
TF_CHECK_OK(env->GetMatchingPaths(kBaseApiDef, &base_api_files));
TF_CHECK_OK(env->GetMatchingPaths(kPythonApiDef, &python_api_files));
TF_CHECK_OK(api_def_map.LoadFileList(env, base_api_files));
TF_CHECK_OK(api_def_map.LoadFileList(env, python_api_files));
for (const auto& api_def_dir : api_def_dirs) {
std::vector<string> api_files;
TF_CHECK_OK(env->GetMatchingPaths(io::JoinPath(api_def_dir, "*.pbtxt"),
&api_files));
TF_CHECK_OK(api_def_map.LoadFileList(env, api_files));
}
api_def_map.UpdateDocs();
}
PrintEagerPythonOps(ops, api_def_map, hidden_ops, true /* require_shapes */);
}
@ -58,8 +56,15 @@ void PrintAllPythonOps(const std::vector<string>& hidden_ops) {
int main(int argc, char* argv[]) {
tensorflow::port::InitMain(argv[0], &argc, &argv);
// Usage:
// python_eager_op_gen_main api_def_dir1,api_def_dir2,...
if (argc == 1) {
tensorflow::PrintAllPythonOps({});
tensorflow::PrintAllPythonOps({}, {});
} else if (argc == 2) {
const std::vector<tensorflow::string> api_def_dirs =
tensorflow::str_util::Split(argv[1], ",",
tensorflow::str_util::SkipEmpty());
tensorflow::PrintAllPythonOps({}, api_def_dirs);
} else {
return -1;
}

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <stdio.h>
#include <sstream>
#include <unordered_map>
#include "tensorflow/core/framework/api_def.pb.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb_text.h"
@ -480,15 +481,15 @@ string GenPythonOp::Code() {
}
// This has all the input args followed by those attrs that don't have
// defaults.
std::vector<string> args_no_default;
std::vector<ParamNames> params_no_default;
// The parameters with defaults (these have to be listed after those without).
// No input args are included, just attrs.
std::vector<string> args_with_defaults;
std::vector<ParamNames> params_with_default;
for (int i = 0; i < api_def_.arg_order_size(); ++i) {
const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_);
const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_);
args_no_default.push_back(api_def_arg.rename_to());
params_no_default.emplace_back(api_def_arg.name(), api_def_arg.rename_to());
if (!arg.type_attr().empty()) {
gtl::InsertIfNotPresent(&inferred_attrs_, arg.type_attr(), arg.name());
} else if (!arg.type_list_attr().empty()) {
@ -504,9 +505,9 @@ string GenPythonOp::Code() {
// Do not add inferred attrs to the Python function signature.
if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
if (attr.has_default_value()) {
args_with_defaults.push_back(attr.rename_to());
params_with_default.emplace_back(attr.name(), attr.rename_to());
} else {
args_no_default.push_back(attr.rename_to());
params_no_default.emplace_back(attr.name(), attr.rename_to());
}
}
}
@ -515,27 +516,30 @@ string GenPythonOp::Code() {
// those with defaults go at the end.
// Get the attrs in the order we want by taking the attrs without defaults
// from the end of args_no_default, and adding args_no_default.
attrs_.reserve(args_no_default.size() - op_def_.input_arg_size() +
args_with_defaults.size());
attrs_.insert(attrs_.end(),
args_no_default.begin() + op_def_.input_arg_size(),
args_no_default.end());
attrs_.insert(attrs_.end(), args_with_defaults.begin(),
args_with_defaults.end());
attrs_.reserve(params_no_default.size() - op_def_.input_arg_size() +
params_with_default.size());
for (int i = op_def_.input_arg_size(); i < params_no_default.size(); ++i) {
attrs_.push_back(params_no_default[i].GetName());
}
for (int i = 0; i < params_with_default.size(); ++i) {
attrs_.push_back(params_with_default[i].GetName());
}
param_names_.reserve(args_no_default.size() + args_with_defaults.size());
string parameters;
for (const string& name : args_no_default) {
AddDelimiter(&parameters, ", ");
const string param = AvoidPythonReserved(name);
strings::StrAppend(&parameters, param);
param_names_.reserve(params_no_default.size() + params_with_default.size());
param_names_.insert(param_names_.begin(), params_no_default.begin(),
params_no_default.end());
for (const auto& param : params_with_default) {
param_names_.push_back(param);
}
for (const string& name : args_with_defaults) {
string parameters;
for (const auto& param : params_no_default) {
AddDelimiter(&parameters, ", ");
const string param = AvoidPythonReserved(name);
strings::StrAppend(&parameters, param, "=None");
param_names_.push_back(param);
strings::StrAppend(&parameters, param.GetRenameTo());
}
for (const auto& param_and_default : params_with_default) {
AddDelimiter(&parameters, ", ");
strings::StrAppend(&parameters, param_and_default.GetRenameTo(), "=None");
}
AddDelimiter(&parameters, ", ");
strings::StrAppend(&parameters, "name=None");
@ -557,10 +561,11 @@ string GenPythonOp::Code() {
}
void GenPythonOp::AddExport() {
if (api_def_.visibility() != api_def_.VISIBLE) {
if (api_def_.visibility() != ApiDef::VISIBLE) {
return;
}
strings::StrAppend(&result_, "tf_export(");
strings::StrAppend(&result_, "@tf_export(");
// Add all endpoint names to tf_export.
bool first_endpoint = true;
@ -603,9 +608,9 @@ void GenPythonOp::AddDocStringInputs() {
StringPiece description = api_def_arg.description();
string desc;
if (ConsumeEquals(&description)) { // Skip the generated type info.
desc = strings::StrCat(param_names_[i], ": ");
desc = strings::StrCat(param_names_[i].GetRenameTo(), ": ");
} else {
desc = strings::StrCat(param_names_[i], ": ",
desc = strings::StrCat(param_names_[i].GetRenameTo(), ": ",
ArgTypeName(op_def_, arg, inferred_attrs_, false));
}
if (!description.empty()) {
@ -750,7 +755,8 @@ void GenPythonOp::AddBody(const string& prefix) {
void GenPythonOp::AddBodyNoReturn(const string& apply_prefix) {
string args = strings::StrCat("\"", op_def_.name(), "\", ");
for (size_t i = 0; i < param_names_.size(); ++i) {
strings::StrAppend(&args, param_names_[i], "=", param_names_[i], ", ");
strings::StrAppend(&args, AvoidPythonReserved(param_names_[i].GetName()),
"=", param_names_[i].GetRenameTo(), ", ");
}
strings::StrAppend(&args, "name=name)");

View File

@ -41,6 +41,28 @@ void GenerateLowerCaseOpName(const string& str, string* result);
string DataTypeToPython(DataType dtype, const string& dtype_module);
// Names that corresponds to a single input parameter.
class ParamNames {
public:
// Create param based on Arg.
ParamNames(const string& name, const string& rename_to) : name_(name) {
rename_to_ = AvoidPythonReserved(rename_to);
}
// Get original parameter name.
string GetName() const { return name_; }
// Get the name to rename the parameter to. Note that AvoidPythonReserved
// has already been applied.
string GetRenameTo() const { return rename_to_; }
private:
// Original parameter name.
string name_;
// API name for this parameter.
string rename_to_;
};
class GenPythonOp {
public:
GenPythonOp(const OpDef& op_def, const ApiDef& api_def,
@ -84,7 +106,7 @@ class GenPythonOp {
// All parameters, including inputs & non-inferred attrs, required and those
// with defaults, except "name"
std::vector<string> param_names_;
std::vector<ParamNames> param_names_;
};
} // namespace python_op_gen_internal

View File

@ -34,12 +34,6 @@ limitations under the License.
namespace tensorflow {
namespace {
constexpr char kBaseApiDef[] =
"tensorflow/core/api_def/base_api/*.pbtxt";
constexpr char kPythonApiDef[] =
"tensorflow/core/api_def/python_api/*.pbtxt";
constexpr bool kUseApiDef = false;
Status ReadOpListFromFile(const string& filename,
std::vector<string>* op_list) {
std::unique_ptr<RandomAccessFile> file;
@ -110,22 +104,23 @@ string InferSourceFileName(const char* argv_zero) {
}
void PrintAllPythonOps(const std::vector<string>& op_list,
const std::vector<string>& api_def_dirs,
const string& source_file_name, bool require_shapes,
bool op_list_is_whitelist) {
OpList ops;
OpRegistry::Global()->Export(false, &ops);
ApiDefMap api_def_map(ops);
if (kUseApiDef) {
if (!api_def_dirs.empty()) {
Env* env = Env::Default();
std::vector<string> base_api_files;
std::vector<string> python_api_files;
TF_CHECK_OK(env->GetMatchingPaths(kBaseApiDef, &base_api_files));
TF_CHECK_OK(env->GetMatchingPaths(kPythonApiDef, &python_api_files));
TF_CHECK_OK(api_def_map.LoadFileList(env, base_api_files));
TF_CHECK_OK(api_def_map.LoadFileList(env, python_api_files));
for (const auto& api_def_dir : api_def_dirs) {
std::vector<string> api_files;
TF_CHECK_OK(env->GetMatchingPaths(io::JoinPath(api_def_dir, "*.pbtxt"),
&api_files));
TF_CHECK_OK(api_def_map.LoadFileList(env, api_files));
}
api_def_map.UpdateDocs();
}
if (op_list_is_whitelist) {
@ -154,23 +149,30 @@ int main(int argc, char* argv[]) {
tensorflow::InferSourceFileName(argv[0]);
// Usage:
// gen_main [ @FILENAME | OpName[,OpName]* ] (0 | 1) [0 | 1]
if (argc == 2) {
tensorflow::PrintAllPythonOps({}, source_file_name,
tensorflow::string(argv[1]) == "1",
false /* op_list_is_whitelist */);
} else if (argc == 3) {
std::vector<tensorflow::string> hidden_ops;
TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[1], &hidden_ops));
tensorflow::PrintAllPythonOps(hidden_ops, source_file_name,
// gen_main api_def_dir1,api_def_dir2,...
// [ @FILENAME | OpName[,OpName]* ] (0 | 1) [0 | 1]
if (argc < 3) {
return -1;
}
std::vector<tensorflow::string> api_def_dirs = tensorflow::str_util::Split(
argv[1], ",", tensorflow::str_util::SkipEmpty());
if (argc == 3) {
tensorflow::PrintAllPythonOps({}, api_def_dirs, source_file_name,
tensorflow::string(argv[2]) == "1",
false /* op_list_is_whitelist */);
} else if (argc == 4) {
std::vector<tensorflow::string> hidden_ops;
TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &hidden_ops));
tensorflow::PrintAllPythonOps(hidden_ops, api_def_dirs, source_file_name,
tensorflow::string(argv[3]) == "1",
false /* op_list_is_whitelist */);
} else if (argc == 5) {
std::vector<tensorflow::string> op_list;
TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[1], &op_list));
tensorflow::PrintAllPythonOps(op_list, source_file_name,
tensorflow::string(argv[2]) == "1",
tensorflow::string(argv[3]) == "1");
TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &op_list));
tensorflow::PrintAllPythonOps(op_list, api_def_dirs, source_file_name,
tensorflow::string(argv[3]) == "1",
tensorflow::string(argv[4]) == "1");
} else {
return -1;
}

View File

@ -1306,7 +1306,7 @@ def split(value, num_or_size_splits, axis=0, num=None, name="split"):
size_splits = ops.convert_to_tensor(num_or_size_splits)
if size_splits._rank() == 0 and size_splits.dtype.is_integer:
return gen_array_ops._split(
split_dim=axis, num_split=num_or_size_splits, value=value, name=name)
axis=axis, num_split=num_or_size_splits, value=value, name=name)
if num is None:
num = size_splits._shape_tuple()[0]
@ -1316,7 +1316,7 @@ def split(value, num_or_size_splits, axis=0, num=None, name="split"):
return gen_array_ops._split_v(
value=value,
size_splits=size_splits,
split_dim=axis,
axis=axis,
num_split=num,
name=name)
@ -2538,9 +2538,9 @@ def where(condition, x=None, y=None, name=None):
with ops.name_scope(name, "Where", [condition]) as name:
condition = ops.convert_to_tensor(
condition, preferred_dtype=dtypes.bool, name="condition")
return gen_array_ops.where(input=condition, name=name)
return gen_array_ops.where(condition=condition, name=name)
elif x is not None and y is not None:
return gen_math_ops._select(condition=condition, t=x, e=y, name=name)
return gen_math_ops._select(condition=condition, x=x, y=y, name=name)
else:
raise ValueError("x and y must both be non-None or both be None.")

View File

@ -334,6 +334,7 @@ def tf_gen_op_wrapper_cc(name,
" $$(dirname $$(echo $(locations " + api_def_src +
") | cut -d\" \" -f1))")
api_def_args_str = ",".join(api_def_args)
native.genrule(
name=name + "_genrule",
outs=[
@ -469,7 +470,8 @@ def tf_gen_op_wrapper_py(name,
hidden_file=None,
generated_target_name=None,
op_whitelist=[],
cc_linkopts=[]):
cc_linkopts=[],
api_def_srcs=[]):
if (hidden or hidden_file) and op_whitelist:
fail('Cannot pass specify both hidden and op_whitelist.')
@ -502,22 +504,39 @@ def tf_gen_op_wrapper_py(name,
op_list_arg = "''"
op_list_is_whitelist = False
# Prepare ApiDef directories to pass to the genrule.
if not api_def_srcs:
api_def_args_str = ","
else:
api_def_args = []
for api_def_src in api_def_srcs:
# Add directory of the first ApiDef source to args.
# We are assuming all ApiDefs in a single api_def_src are in the
# same directory.
api_def_args.append(
"$$(dirname $$(echo $(locations " + api_def_src +
") | cut -d\" \" -f1))")
api_def_args_str = ",".join(api_def_args)
if hidden_file:
# `hidden_file` is file containing a list of op names to be hidden in the
# generated module.
native.genrule(
name=name + "_pygenrule",
outs=[out],
srcs=[hidden_file],
srcs=api_def_srcs + [hidden_file],
tools=[tool_name] + tf_binary_additional_srcs(),
cmd=("$(location " + tool_name + ") @$(location " + hidden_file + ") " +
cmd=("$(location " + tool_name + ") " + api_def_args_str +
" @$(location " + hidden_file + ") " +
("1" if require_shape_functions else "0") + " > $@"))
else:
native.genrule(
name=name + "_pygenrule",
outs=[out],
srcs=api_def_srcs,
tools=[tool_name] + tf_binary_additional_srcs(),
cmd=("$(location " + tool_name + ") " + op_list_arg + " " +
cmd=("$(location " + tool_name + ") " + api_def_args_str + " " +
op_list_arg + " " +
("1" if require_shape_functions else "0") + " " +
("1" if op_list_is_whitelist else "0") + " > $@"))