Removed shape function registration from python_op_gen

Shape inference is done in C++ and the registered Python functions are
never invoked.

PiperOrigin-RevId: 269564601
This commit is contained in:
Sergei Lebedev 2019-09-17 07:30:06 -07:00 committed by TensorFlower Gardener
parent 21a7e525a1
commit 35f0302e2f
8 changed files with 21 additions and 41 deletions

View File

@ -26,14 +26,12 @@ tf_gen_op_libs(
tf_gen_op_wrapper_py(
name = "gen_bigquery_reader_ops",
out = "python/ops/gen_bigquery_reader_ops.py",
require_shape_functions = True,
deps = [":bigquery_reader_ops_op_lib"],
)
tf_gen_op_wrapper_py(
name = "gen_gcs_config_ops",
out = "python/ops/gen_gcs_config_ops.py",
require_shape_functions = True,
visibility = ["//tensorflow:internal"],
deps = [":gcs_config_ops_op_lib"],
)

View File

@ -79,7 +79,6 @@ cc_library(
tf_gen_op_wrapper_py(
name = "decode_audio_op_py",
require_shape_functions = True,
visibility = ["//visibility:private"],
deps = [
":decode_audio_op_cc",
@ -88,7 +87,6 @@ tf_gen_op_wrapper_py(
tf_gen_op_wrapper_py(
name = "encode_audio_op_py",
require_shape_functions = True,
visibility = ["//visibility:private"],
deps = [
":encode_audio_op_cc",
@ -97,7 +95,6 @@ tf_gen_op_wrapper_py(
tf_gen_op_wrapper_py(
name = "decode_video_op_py",
require_shape_functions = True,
visibility = ["//visibility:private"],
deps = [
":decode_video_op_cc",

View File

@ -2020,7 +2020,6 @@ tf_gen_op_wrapper_private_py(
tf_gen_op_wrapper_private_py(
name = "bitwise_ops_gen",
require_shape_functions = True,
visibility = [
"//learning/brain/python/ops:__pkg__",
"//tensorflow/compiler/tests:__pkg__",
@ -2053,7 +2052,6 @@ tf_gen_op_wrapper_private_py(
tf_gen_op_wrapper_private_py(
name = "audio_ops_gen",
require_shape_functions = True,
visibility = [
"//learning/brain/python/ops:__pkg__",
"//tensorflow/contrib/framework:__pkg__",
@ -2319,7 +2317,6 @@ tf_gen_op_wrapper_private_py(
tf_gen_op_wrapper_private_py(
name = "user_ops_gen",
require_shape_functions = False,
)
tf_gen_op_wrapper_private_py(

View File

@ -14,7 +14,7 @@ def tf_gen_op_wrapper_private_py(
name,
out = None,
deps = [],
require_shape_functions = True,
require_shape_functions = False,
visibility = []):
if not name.endswith("_gen"):
fail("name must end in _gen")

View File

@ -979,7 +979,7 @@ void GenEagerPythonOp::AddRawOpExport(const string& parameters) {
}
string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs,
const std::vector<string>& hidden_ops, bool require_shapes,
const std::vector<string>& hidden_ops,
const string& source_file_name = "") {
string result;
// Header
@ -1008,8 +1008,6 @@ from tensorflow.python.eager import execute as _execute
from tensorflow.python.framework import dtypes as _dtypes
from tensorflow.core.framework import op_def_pb2 as _op_def_pb2
# Needed to trigger the call to _set_call_cpp_shape_fn.
from tensorflow.python.framework import common_shapes as _common_shapes
from tensorflow.python.framework import op_def_registry as _op_def_registry
from tensorflow.python.framework import ops as _ops
from tensorflow.python.framework import op_def_library as _op_def_library
@ -1067,11 +1065,6 @@ from tensorflow.python.util.tf_export import tf_export
strings::StrAppend(&result,
GetEagerPythonOp(op_def, *api_def, function_name));
if (!require_shapes) {
strings::StrAppend(&result, "_ops.RegisterShape(\"", op_def.name(),
"\")(None)\n\n");
}
auto added = out->Add();
*added = op_def;
RemoveNonDeprecationDescriptionsFromOpDef(added);
@ -1094,11 +1087,10 @@ from tensorflow.python.util.tf_export import tf_export
} // namespace
void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs,
const std::vector<string>& hidden_ops, bool require_shapes,
const std::vector<string>& hidden_ops,
const string& source_file_name) {
printf("%s", GetPythonOps(ops, api_defs, hidden_ops, require_shapes,
source_file_name)
.c_str());
printf("%s",
GetPythonOps(ops, api_defs, hidden_ops, source_file_name).c_str());
}
string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) {
@ -1107,7 +1099,7 @@ string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) {
ops.ParseFromString(op_list_str);
ApiDefMap api_def_map(ops);
return GetPythonOps(ops, api_def_map, {}, false);
return GetPythonOps(ops, api_def_map, {});
}
} // namespace tensorflow

View File

@ -28,8 +28,8 @@ namespace tensorflow {
// Optional fourth argument is the name of the original C++ source file
// where the ops' REGISTER_OP() calls reside.
void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs,
const std::vector<string>& hidden_ops, bool require_shapes,
const string& source_file_name = "");
const std::vector<string>& hidden_ops,
const string& source_file_name);
// Get the python wrappers for a list of ops in a OpList.
// `op_list_buf` should be a pointer to a buffer containing

View File

@ -107,7 +107,7 @@ string InferSourceFileName(const char* argv_zero) {
void PrintAllPythonOps(const std::vector<string>& op_list,
const std::vector<string>& api_def_dirs,
const string& source_file_name, bool require_shapes,
const string& source_file_name,
bool op_list_is_whitelist) {
OpList ops;
OpRegistry::Global()->Export(false, &ops);
@ -133,10 +133,9 @@ void PrintAllPythonOps(const std::vector<string>& op_list,
*pruned_ops.mutable_op()->Add() = op_def;
}
}
PrintPythonOps(pruned_ops, api_def_map, {}, require_shapes,
source_file_name);
PrintPythonOps(pruned_ops, api_def_map, {}, source_file_name);
} else {
PrintPythonOps(ops, api_def_map, op_list, require_shapes, source_file_name);
PrintPythonOps(ops, api_def_map, op_list, source_file_name);
}
}
@ -151,29 +150,26 @@ int main(int argc, char* argv[]) {
// Usage:
// gen_main api_def_dir1,api_def_dir2,...
// [ @FILENAME | OpName[,OpName]* ] (0 | 1) [0 | 1]
if (argc < 3) {
// [ @FILENAME | OpName[,OpName]* ] [0 | 1]
if (argc < 2) {
return -1;
}
std::vector<tensorflow::string> api_def_dirs = tensorflow::str_util::Split(
argv[1], ",", tensorflow::str_util::SkipEmpty());
if (argc == 3) {
if (argc == 2) {
tensorflow::PrintAllPythonOps({}, api_def_dirs, source_file_name,
tensorflow::string(argv[2]) == "1",
false /* op_list_is_whitelist */);
} else if (argc == 4) {
} else if (argc == 3) {
std::vector<tensorflow::string> hidden_ops;
TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &hidden_ops));
tensorflow::PrintAllPythonOps(hidden_ops, api_def_dirs, source_file_name,
tensorflow::string(argv[3]) == "1",
false /* op_list_is_whitelist */);
} else if (argc == 5) {
} else if (argc == 4) {
std::vector<tensorflow::string> op_list;
TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &op_list));
tensorflow::PrintAllPythonOps(op_list, api_def_dirs, source_file_name,
tensorflow::string(argv[3]) == "1",
tensorflow::string(argv[4]) == "1");
tensorflow::string(argv[3]) == "1");
} else {
return -1;
}

View File

@ -839,7 +839,7 @@ def tf_gen_op_wrappers_cc(
# deps: list of dependencies for the intermediate tool used to generate the
# python target. NOTE these `deps` are not applied to the final python
# library target itself.
# require_shape_functions: leave this as False.
# require_shape_functions: Unused. Leave this as False.
# hidden_file: optional file that contains a list of op names to make private
# in the generated Python module. Each op name should be on a line by
# itself. Lines that start with characters that are invalid op name
@ -863,6 +863,8 @@ def tf_gen_op_wrapper_py(
op_whitelist = [],
cc_linkopts = [],
api_def_srcs = []):
_ = require_shape_functions # Unused.
if (hidden or hidden_file) and op_whitelist:
fail("Cannot pass specify both hidden and op_whitelist.")
@ -920,8 +922,7 @@ def tf_gen_op_wrapper_py(
srcs = api_def_srcs + [hidden_file],
tools = [tool_name] + tf_binary_additional_srcs(),
cmd = ("$(location " + tool_name + ") " + api_def_args_str +
" @$(location " + hidden_file + ") " +
("1" if require_shape_functions else "0") + " > $@"),
" @$(location " + hidden_file + ") > $@"),
)
else:
native.genrule(
@ -931,7 +932,6 @@ def tf_gen_op_wrapper_py(
tools = [tool_name] + tf_binary_additional_srcs(),
cmd = ("$(location " + tool_name + ") " + api_def_args_str + " " +
op_list_arg + " " +
("1" if require_shape_functions else "0") + " " +
("1" if op_list_is_whitelist else "0") + " > $@"),
)