Removed shape function registration from python_op_gen
Shape inference is done in C++ and the registered Python functions are never invoked. PiperOrigin-RevId: 269564601
This commit is contained in:
parent
21a7e525a1
commit
35f0302e2f
@ -26,14 +26,12 @@ tf_gen_op_libs(
|
|||||||
tf_gen_op_wrapper_py(
|
tf_gen_op_wrapper_py(
|
||||||
name = "gen_bigquery_reader_ops",
|
name = "gen_bigquery_reader_ops",
|
||||||
out = "python/ops/gen_bigquery_reader_ops.py",
|
out = "python/ops/gen_bigquery_reader_ops.py",
|
||||||
require_shape_functions = True,
|
|
||||||
deps = [":bigquery_reader_ops_op_lib"],
|
deps = [":bigquery_reader_ops_op_lib"],
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_gen_op_wrapper_py(
|
tf_gen_op_wrapper_py(
|
||||||
name = "gen_gcs_config_ops",
|
name = "gen_gcs_config_ops",
|
||||||
out = "python/ops/gen_gcs_config_ops.py",
|
out = "python/ops/gen_gcs_config_ops.py",
|
||||||
require_shape_functions = True,
|
|
||||||
visibility = ["//tensorflow:internal"],
|
visibility = ["//tensorflow:internal"],
|
||||||
deps = [":gcs_config_ops_op_lib"],
|
deps = [":gcs_config_ops_op_lib"],
|
||||||
)
|
)
|
||||||
|
@ -79,7 +79,6 @@ cc_library(
|
|||||||
|
|
||||||
tf_gen_op_wrapper_py(
|
tf_gen_op_wrapper_py(
|
||||||
name = "decode_audio_op_py",
|
name = "decode_audio_op_py",
|
||||||
require_shape_functions = True,
|
|
||||||
visibility = ["//visibility:private"],
|
visibility = ["//visibility:private"],
|
||||||
deps = [
|
deps = [
|
||||||
":decode_audio_op_cc",
|
":decode_audio_op_cc",
|
||||||
@ -88,7 +87,6 @@ tf_gen_op_wrapper_py(
|
|||||||
|
|
||||||
tf_gen_op_wrapper_py(
|
tf_gen_op_wrapper_py(
|
||||||
name = "encode_audio_op_py",
|
name = "encode_audio_op_py",
|
||||||
require_shape_functions = True,
|
|
||||||
visibility = ["//visibility:private"],
|
visibility = ["//visibility:private"],
|
||||||
deps = [
|
deps = [
|
||||||
":encode_audio_op_cc",
|
":encode_audio_op_cc",
|
||||||
@ -97,7 +95,6 @@ tf_gen_op_wrapper_py(
|
|||||||
|
|
||||||
tf_gen_op_wrapper_py(
|
tf_gen_op_wrapper_py(
|
||||||
name = "decode_video_op_py",
|
name = "decode_video_op_py",
|
||||||
require_shape_functions = True,
|
|
||||||
visibility = ["//visibility:private"],
|
visibility = ["//visibility:private"],
|
||||||
deps = [
|
deps = [
|
||||||
":decode_video_op_cc",
|
":decode_video_op_cc",
|
||||||
|
@ -2020,7 +2020,6 @@ tf_gen_op_wrapper_private_py(
|
|||||||
|
|
||||||
tf_gen_op_wrapper_private_py(
|
tf_gen_op_wrapper_private_py(
|
||||||
name = "bitwise_ops_gen",
|
name = "bitwise_ops_gen",
|
||||||
require_shape_functions = True,
|
|
||||||
visibility = [
|
visibility = [
|
||||||
"//learning/brain/python/ops:__pkg__",
|
"//learning/brain/python/ops:__pkg__",
|
||||||
"//tensorflow/compiler/tests:__pkg__",
|
"//tensorflow/compiler/tests:__pkg__",
|
||||||
@ -2053,7 +2052,6 @@ tf_gen_op_wrapper_private_py(
|
|||||||
|
|
||||||
tf_gen_op_wrapper_private_py(
|
tf_gen_op_wrapper_private_py(
|
||||||
name = "audio_ops_gen",
|
name = "audio_ops_gen",
|
||||||
require_shape_functions = True,
|
|
||||||
visibility = [
|
visibility = [
|
||||||
"//learning/brain/python/ops:__pkg__",
|
"//learning/brain/python/ops:__pkg__",
|
||||||
"//tensorflow/contrib/framework:__pkg__",
|
"//tensorflow/contrib/framework:__pkg__",
|
||||||
@ -2319,7 +2317,6 @@ tf_gen_op_wrapper_private_py(
|
|||||||
|
|
||||||
tf_gen_op_wrapper_private_py(
|
tf_gen_op_wrapper_private_py(
|
||||||
name = "user_ops_gen",
|
name = "user_ops_gen",
|
||||||
require_shape_functions = False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_gen_op_wrapper_private_py(
|
tf_gen_op_wrapper_private_py(
|
||||||
|
@ -14,7 +14,7 @@ def tf_gen_op_wrapper_private_py(
|
|||||||
name,
|
name,
|
||||||
out = None,
|
out = None,
|
||||||
deps = [],
|
deps = [],
|
||||||
require_shape_functions = True,
|
require_shape_functions = False,
|
||||||
visibility = []):
|
visibility = []):
|
||||||
if not name.endswith("_gen"):
|
if not name.endswith("_gen"):
|
||||||
fail("name must end in _gen")
|
fail("name must end in _gen")
|
||||||
|
@ -979,7 +979,7 @@ void GenEagerPythonOp::AddRawOpExport(const string& parameters) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs,
|
string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs,
|
||||||
const std::vector<string>& hidden_ops, bool require_shapes,
|
const std::vector<string>& hidden_ops,
|
||||||
const string& source_file_name = "") {
|
const string& source_file_name = "") {
|
||||||
string result;
|
string result;
|
||||||
// Header
|
// Header
|
||||||
@ -1008,8 +1008,6 @@ from tensorflow.python.eager import execute as _execute
|
|||||||
from tensorflow.python.framework import dtypes as _dtypes
|
from tensorflow.python.framework import dtypes as _dtypes
|
||||||
|
|
||||||
from tensorflow.core.framework import op_def_pb2 as _op_def_pb2
|
from tensorflow.core.framework import op_def_pb2 as _op_def_pb2
|
||||||
# Needed to trigger the call to _set_call_cpp_shape_fn.
|
|
||||||
from tensorflow.python.framework import common_shapes as _common_shapes
|
|
||||||
from tensorflow.python.framework import op_def_registry as _op_def_registry
|
from tensorflow.python.framework import op_def_registry as _op_def_registry
|
||||||
from tensorflow.python.framework import ops as _ops
|
from tensorflow.python.framework import ops as _ops
|
||||||
from tensorflow.python.framework import op_def_library as _op_def_library
|
from tensorflow.python.framework import op_def_library as _op_def_library
|
||||||
@ -1067,11 +1065,6 @@ from tensorflow.python.util.tf_export import tf_export
|
|||||||
strings::StrAppend(&result,
|
strings::StrAppend(&result,
|
||||||
GetEagerPythonOp(op_def, *api_def, function_name));
|
GetEagerPythonOp(op_def, *api_def, function_name));
|
||||||
|
|
||||||
if (!require_shapes) {
|
|
||||||
strings::StrAppend(&result, "_ops.RegisterShape(\"", op_def.name(),
|
|
||||||
"\")(None)\n\n");
|
|
||||||
}
|
|
||||||
|
|
||||||
auto added = out->Add();
|
auto added = out->Add();
|
||||||
*added = op_def;
|
*added = op_def;
|
||||||
RemoveNonDeprecationDescriptionsFromOpDef(added);
|
RemoveNonDeprecationDescriptionsFromOpDef(added);
|
||||||
@ -1094,11 +1087,10 @@ from tensorflow.python.util.tf_export import tf_export
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs,
|
void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs,
|
||||||
const std::vector<string>& hidden_ops, bool require_shapes,
|
const std::vector<string>& hidden_ops,
|
||||||
const string& source_file_name) {
|
const string& source_file_name) {
|
||||||
printf("%s", GetPythonOps(ops, api_defs, hidden_ops, require_shapes,
|
printf("%s",
|
||||||
source_file_name)
|
GetPythonOps(ops, api_defs, hidden_ops, source_file_name).c_str());
|
||||||
.c_str());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) {
|
string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) {
|
||||||
@ -1107,7 +1099,7 @@ string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) {
|
|||||||
ops.ParseFromString(op_list_str);
|
ops.ParseFromString(op_list_str);
|
||||||
|
|
||||||
ApiDefMap api_def_map(ops);
|
ApiDefMap api_def_map(ops);
|
||||||
return GetPythonOps(ops, api_def_map, {}, false);
|
return GetPythonOps(ops, api_def_map, {});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -28,8 +28,8 @@ namespace tensorflow {
|
|||||||
// Optional fourth argument is the name of the original C++ source file
|
// Optional fourth argument is the name of the original C++ source file
|
||||||
// where the ops' REGISTER_OP() calls reside.
|
// where the ops' REGISTER_OP() calls reside.
|
||||||
void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs,
|
void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs,
|
||||||
const std::vector<string>& hidden_ops, bool require_shapes,
|
const std::vector<string>& hidden_ops,
|
||||||
const string& source_file_name = "");
|
const string& source_file_name);
|
||||||
|
|
||||||
// Get the python wrappers for a list of ops in a OpList.
|
// Get the python wrappers for a list of ops in a OpList.
|
||||||
// `op_list_buf` should be a pointer to a buffer containing
|
// `op_list_buf` should be a pointer to a buffer containing
|
||||||
|
@ -107,7 +107,7 @@ string InferSourceFileName(const char* argv_zero) {
|
|||||||
|
|
||||||
void PrintAllPythonOps(const std::vector<string>& op_list,
|
void PrintAllPythonOps(const std::vector<string>& op_list,
|
||||||
const std::vector<string>& api_def_dirs,
|
const std::vector<string>& api_def_dirs,
|
||||||
const string& source_file_name, bool require_shapes,
|
const string& source_file_name,
|
||||||
bool op_list_is_whitelist) {
|
bool op_list_is_whitelist) {
|
||||||
OpList ops;
|
OpList ops;
|
||||||
OpRegistry::Global()->Export(false, &ops);
|
OpRegistry::Global()->Export(false, &ops);
|
||||||
@ -133,10 +133,9 @@ void PrintAllPythonOps(const std::vector<string>& op_list,
|
|||||||
*pruned_ops.mutable_op()->Add() = op_def;
|
*pruned_ops.mutable_op()->Add() = op_def;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
PrintPythonOps(pruned_ops, api_def_map, {}, require_shapes,
|
PrintPythonOps(pruned_ops, api_def_map, {}, source_file_name);
|
||||||
source_file_name);
|
|
||||||
} else {
|
} else {
|
||||||
PrintPythonOps(ops, api_def_map, op_list, require_shapes, source_file_name);
|
PrintPythonOps(ops, api_def_map, op_list, source_file_name);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -151,29 +150,26 @@ int main(int argc, char* argv[]) {
|
|||||||
|
|
||||||
// Usage:
|
// Usage:
|
||||||
// gen_main api_def_dir1,api_def_dir2,...
|
// gen_main api_def_dir1,api_def_dir2,...
|
||||||
// [ @FILENAME | OpName[,OpName]* ] (0 | 1) [0 | 1]
|
// [ @FILENAME | OpName[,OpName]* ] [0 | 1]
|
||||||
if (argc < 3) {
|
if (argc < 2) {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
std::vector<tensorflow::string> api_def_dirs = tensorflow::str_util::Split(
|
std::vector<tensorflow::string> api_def_dirs = tensorflow::str_util::Split(
|
||||||
argv[1], ",", tensorflow::str_util::SkipEmpty());
|
argv[1], ",", tensorflow::str_util::SkipEmpty());
|
||||||
|
|
||||||
if (argc == 3) {
|
if (argc == 2) {
|
||||||
tensorflow::PrintAllPythonOps({}, api_def_dirs, source_file_name,
|
tensorflow::PrintAllPythonOps({}, api_def_dirs, source_file_name,
|
||||||
tensorflow::string(argv[2]) == "1",
|
|
||||||
false /* op_list_is_whitelist */);
|
false /* op_list_is_whitelist */);
|
||||||
} else if (argc == 4) {
|
} else if (argc == 3) {
|
||||||
std::vector<tensorflow::string> hidden_ops;
|
std::vector<tensorflow::string> hidden_ops;
|
||||||
TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &hidden_ops));
|
TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &hidden_ops));
|
||||||
tensorflow::PrintAllPythonOps(hidden_ops, api_def_dirs, source_file_name,
|
tensorflow::PrintAllPythonOps(hidden_ops, api_def_dirs, source_file_name,
|
||||||
tensorflow::string(argv[3]) == "1",
|
|
||||||
false /* op_list_is_whitelist */);
|
false /* op_list_is_whitelist */);
|
||||||
} else if (argc == 5) {
|
} else if (argc == 4) {
|
||||||
std::vector<tensorflow::string> op_list;
|
std::vector<tensorflow::string> op_list;
|
||||||
TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &op_list));
|
TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &op_list));
|
||||||
tensorflow::PrintAllPythonOps(op_list, api_def_dirs, source_file_name,
|
tensorflow::PrintAllPythonOps(op_list, api_def_dirs, source_file_name,
|
||||||
tensorflow::string(argv[3]) == "1",
|
tensorflow::string(argv[3]) == "1");
|
||||||
tensorflow::string(argv[4]) == "1");
|
|
||||||
} else {
|
} else {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
@ -839,7 +839,7 @@ def tf_gen_op_wrappers_cc(
|
|||||||
# deps: list of dependencies for the intermediate tool used to generate the
|
# deps: list of dependencies for the intermediate tool used to generate the
|
||||||
# python target. NOTE these `deps` are not applied to the final python
|
# python target. NOTE these `deps` are not applied to the final python
|
||||||
# library target itself.
|
# library target itself.
|
||||||
# require_shape_functions: leave this as False.
|
# require_shape_functions: Unused. Leave this as False.
|
||||||
# hidden_file: optional file that contains a list of op names to make private
|
# hidden_file: optional file that contains a list of op names to make private
|
||||||
# in the generated Python module. Each op name should be on a line by
|
# in the generated Python module. Each op name should be on a line by
|
||||||
# itself. Lines that start with characters that are invalid op name
|
# itself. Lines that start with characters that are invalid op name
|
||||||
@ -863,6 +863,8 @@ def tf_gen_op_wrapper_py(
|
|||||||
op_whitelist = [],
|
op_whitelist = [],
|
||||||
cc_linkopts = [],
|
cc_linkopts = [],
|
||||||
api_def_srcs = []):
|
api_def_srcs = []):
|
||||||
|
_ = require_shape_functions # Unused.
|
||||||
|
|
||||||
if (hidden or hidden_file) and op_whitelist:
|
if (hidden or hidden_file) and op_whitelist:
|
||||||
fail("Cannot pass specify both hidden and op_whitelist.")
|
fail("Cannot pass specify both hidden and op_whitelist.")
|
||||||
|
|
||||||
@ -920,8 +922,7 @@ def tf_gen_op_wrapper_py(
|
|||||||
srcs = api_def_srcs + [hidden_file],
|
srcs = api_def_srcs + [hidden_file],
|
||||||
tools = [tool_name] + tf_binary_additional_srcs(),
|
tools = [tool_name] + tf_binary_additional_srcs(),
|
||||||
cmd = ("$(location " + tool_name + ") " + api_def_args_str +
|
cmd = ("$(location " + tool_name + ") " + api_def_args_str +
|
||||||
" @$(location " + hidden_file + ") " +
|
" @$(location " + hidden_file + ") > $@"),
|
||||||
("1" if require_shape_functions else "0") + " > $@"),
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
native.genrule(
|
native.genrule(
|
||||||
@ -931,7 +932,6 @@ def tf_gen_op_wrapper_py(
|
|||||||
tools = [tool_name] + tf_binary_additional_srcs(),
|
tools = [tool_name] + tf_binary_additional_srcs(),
|
||||||
cmd = ("$(location " + tool_name + ") " + api_def_args_str + " " +
|
cmd = ("$(location " + tool_name + ") " + api_def_args_str + " " +
|
||||||
op_list_arg + " " +
|
op_list_arg + " " +
|
||||||
("1" if require_shape_functions else "0") + " " +
|
|
||||||
("1" if op_list_is_whitelist else "0") + " > $@"),
|
("1" if op_list_is_whitelist else "0") + " > $@"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user