Removed shape function registration from python_op_gen
Shape inference is done in C++ and the registered Python functions are never invoked. PiperOrigin-RevId: 269564601
This commit is contained in:
parent
21a7e525a1
commit
35f0302e2f
@ -26,14 +26,12 @@ tf_gen_op_libs(
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "gen_bigquery_reader_ops",
|
||||
out = "python/ops/gen_bigquery_reader_ops.py",
|
||||
require_shape_functions = True,
|
||||
deps = [":bigquery_reader_ops_op_lib"],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "gen_gcs_config_ops",
|
||||
out = "python/ops/gen_gcs_config_ops.py",
|
||||
require_shape_functions = True,
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [":gcs_config_ops_op_lib"],
|
||||
)
|
||||
|
@ -79,7 +79,6 @@ cc_library(
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "decode_audio_op_py",
|
||||
require_shape_functions = True,
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
":decode_audio_op_cc",
|
||||
@ -88,7 +87,6 @@ tf_gen_op_wrapper_py(
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "encode_audio_op_py",
|
||||
require_shape_functions = True,
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
":encode_audio_op_cc",
|
||||
@ -97,7 +95,6 @@ tf_gen_op_wrapper_py(
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "decode_video_op_py",
|
||||
require_shape_functions = True,
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
":decode_video_op_cc",
|
||||
|
@ -2020,7 +2020,6 @@ tf_gen_op_wrapper_private_py(
|
||||
|
||||
tf_gen_op_wrapper_private_py(
|
||||
name = "bitwise_ops_gen",
|
||||
require_shape_functions = True,
|
||||
visibility = [
|
||||
"//learning/brain/python/ops:__pkg__",
|
||||
"//tensorflow/compiler/tests:__pkg__",
|
||||
@ -2053,7 +2052,6 @@ tf_gen_op_wrapper_private_py(
|
||||
|
||||
tf_gen_op_wrapper_private_py(
|
||||
name = "audio_ops_gen",
|
||||
require_shape_functions = True,
|
||||
visibility = [
|
||||
"//learning/brain/python/ops:__pkg__",
|
||||
"//tensorflow/contrib/framework:__pkg__",
|
||||
@ -2319,7 +2317,6 @@ tf_gen_op_wrapper_private_py(
|
||||
|
||||
tf_gen_op_wrapper_private_py(
|
||||
name = "user_ops_gen",
|
||||
require_shape_functions = False,
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_private_py(
|
||||
|
@ -14,7 +14,7 @@ def tf_gen_op_wrapper_private_py(
|
||||
name,
|
||||
out = None,
|
||||
deps = [],
|
||||
require_shape_functions = True,
|
||||
require_shape_functions = False,
|
||||
visibility = []):
|
||||
if not name.endswith("_gen"):
|
||||
fail("name must end in _gen")
|
||||
|
@ -979,7 +979,7 @@ void GenEagerPythonOp::AddRawOpExport(const string& parameters) {
|
||||
}
|
||||
|
||||
string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs,
|
||||
const std::vector<string>& hidden_ops, bool require_shapes,
|
||||
const std::vector<string>& hidden_ops,
|
||||
const string& source_file_name = "") {
|
||||
string result;
|
||||
// Header
|
||||
@ -1008,8 +1008,6 @@ from tensorflow.python.eager import execute as _execute
|
||||
from tensorflow.python.framework import dtypes as _dtypes
|
||||
|
||||
from tensorflow.core.framework import op_def_pb2 as _op_def_pb2
|
||||
# Needed to trigger the call to _set_call_cpp_shape_fn.
|
||||
from tensorflow.python.framework import common_shapes as _common_shapes
|
||||
from tensorflow.python.framework import op_def_registry as _op_def_registry
|
||||
from tensorflow.python.framework import ops as _ops
|
||||
from tensorflow.python.framework import op_def_library as _op_def_library
|
||||
@ -1067,11 +1065,6 @@ from tensorflow.python.util.tf_export import tf_export
|
||||
strings::StrAppend(&result,
|
||||
GetEagerPythonOp(op_def, *api_def, function_name));
|
||||
|
||||
if (!require_shapes) {
|
||||
strings::StrAppend(&result, "_ops.RegisterShape(\"", op_def.name(),
|
||||
"\")(None)\n\n");
|
||||
}
|
||||
|
||||
auto added = out->Add();
|
||||
*added = op_def;
|
||||
RemoveNonDeprecationDescriptionsFromOpDef(added);
|
||||
@ -1094,11 +1087,10 @@ from tensorflow.python.util.tf_export import tf_export
|
||||
} // namespace
|
||||
|
||||
void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs,
|
||||
const std::vector<string>& hidden_ops, bool require_shapes,
|
||||
const std::vector<string>& hidden_ops,
|
||||
const string& source_file_name) {
|
||||
printf("%s", GetPythonOps(ops, api_defs, hidden_ops, require_shapes,
|
||||
source_file_name)
|
||||
.c_str());
|
||||
printf("%s",
|
||||
GetPythonOps(ops, api_defs, hidden_ops, source_file_name).c_str());
|
||||
}
|
||||
|
||||
string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) {
|
||||
@ -1107,7 +1099,7 @@ string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) {
|
||||
ops.ParseFromString(op_list_str);
|
||||
|
||||
ApiDefMap api_def_map(ops);
|
||||
return GetPythonOps(ops, api_def_map, {}, false);
|
||||
return GetPythonOps(ops, api_def_map, {});
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -28,8 +28,8 @@ namespace tensorflow {
|
||||
// Optional fourth argument is the name of the original C++ source file
|
||||
// where the ops' REGISTER_OP() calls reside.
|
||||
void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs,
|
||||
const std::vector<string>& hidden_ops, bool require_shapes,
|
||||
const string& source_file_name = "");
|
||||
const std::vector<string>& hidden_ops,
|
||||
const string& source_file_name);
|
||||
|
||||
// Get the python wrappers for a list of ops in a OpList.
|
||||
// `op_list_buf` should be a pointer to a buffer containing
|
||||
|
@ -107,7 +107,7 @@ string InferSourceFileName(const char* argv_zero) {
|
||||
|
||||
void PrintAllPythonOps(const std::vector<string>& op_list,
|
||||
const std::vector<string>& api_def_dirs,
|
||||
const string& source_file_name, bool require_shapes,
|
||||
const string& source_file_name,
|
||||
bool op_list_is_whitelist) {
|
||||
OpList ops;
|
||||
OpRegistry::Global()->Export(false, &ops);
|
||||
@ -133,10 +133,9 @@ void PrintAllPythonOps(const std::vector<string>& op_list,
|
||||
*pruned_ops.mutable_op()->Add() = op_def;
|
||||
}
|
||||
}
|
||||
PrintPythonOps(pruned_ops, api_def_map, {}, require_shapes,
|
||||
source_file_name);
|
||||
PrintPythonOps(pruned_ops, api_def_map, {}, source_file_name);
|
||||
} else {
|
||||
PrintPythonOps(ops, api_def_map, op_list, require_shapes, source_file_name);
|
||||
PrintPythonOps(ops, api_def_map, op_list, source_file_name);
|
||||
}
|
||||
}
|
||||
|
||||
@ -151,29 +150,26 @@ int main(int argc, char* argv[]) {
|
||||
|
||||
// Usage:
|
||||
// gen_main api_def_dir1,api_def_dir2,...
|
||||
// [ @FILENAME | OpName[,OpName]* ] (0 | 1) [0 | 1]
|
||||
if (argc < 3) {
|
||||
// [ @FILENAME | OpName[,OpName]* ] [0 | 1]
|
||||
if (argc < 2) {
|
||||
return -1;
|
||||
}
|
||||
std::vector<tensorflow::string> api_def_dirs = tensorflow::str_util::Split(
|
||||
argv[1], ",", tensorflow::str_util::SkipEmpty());
|
||||
|
||||
if (argc == 3) {
|
||||
if (argc == 2) {
|
||||
tensorflow::PrintAllPythonOps({}, api_def_dirs, source_file_name,
|
||||
tensorflow::string(argv[2]) == "1",
|
||||
false /* op_list_is_whitelist */);
|
||||
} else if (argc == 4) {
|
||||
} else if (argc == 3) {
|
||||
std::vector<tensorflow::string> hidden_ops;
|
||||
TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &hidden_ops));
|
||||
tensorflow::PrintAllPythonOps(hidden_ops, api_def_dirs, source_file_name,
|
||||
tensorflow::string(argv[3]) == "1",
|
||||
false /* op_list_is_whitelist */);
|
||||
} else if (argc == 5) {
|
||||
} else if (argc == 4) {
|
||||
std::vector<tensorflow::string> op_list;
|
||||
TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &op_list));
|
||||
tensorflow::PrintAllPythonOps(op_list, api_def_dirs, source_file_name,
|
||||
tensorflow::string(argv[3]) == "1",
|
||||
tensorflow::string(argv[4]) == "1");
|
||||
tensorflow::string(argv[3]) == "1");
|
||||
} else {
|
||||
return -1;
|
||||
}
|
||||
|
@ -839,7 +839,7 @@ def tf_gen_op_wrappers_cc(
|
||||
# deps: list of dependencies for the intermediate tool used to generate the
|
||||
# python target. NOTE these `deps` are not applied to the final python
|
||||
# library target itself.
|
||||
# require_shape_functions: leave this as False.
|
||||
# require_shape_functions: Unused. Leave this as False.
|
||||
# hidden_file: optional file that contains a list of op names to make private
|
||||
# in the generated Python module. Each op name should be on a line by
|
||||
# itself. Lines that start with characters that are invalid op name
|
||||
@ -863,6 +863,8 @@ def tf_gen_op_wrapper_py(
|
||||
op_whitelist = [],
|
||||
cc_linkopts = [],
|
||||
api_def_srcs = []):
|
||||
_ = require_shape_functions # Unused.
|
||||
|
||||
if (hidden or hidden_file) and op_whitelist:
|
||||
fail("Cannot pass specify both hidden and op_whitelist.")
|
||||
|
||||
@ -920,8 +922,7 @@ def tf_gen_op_wrapper_py(
|
||||
srcs = api_def_srcs + [hidden_file],
|
||||
tools = [tool_name] + tf_binary_additional_srcs(),
|
||||
cmd = ("$(location " + tool_name + ") " + api_def_args_str +
|
||||
" @$(location " + hidden_file + ") " +
|
||||
("1" if require_shape_functions else "0") + " > $@"),
|
||||
" @$(location " + hidden_file + ") > $@"),
|
||||
)
|
||||
else:
|
||||
native.genrule(
|
||||
@ -931,7 +932,6 @@ def tf_gen_op_wrapper_py(
|
||||
tools = [tool_name] + tf_binary_additional_srcs(),
|
||||
cmd = ("$(location " + tool_name + ") " + api_def_args_str + " " +
|
||||
op_list_arg + " " +
|
||||
("1" if require_shape_functions else "0") + " " +
|
||||
("1" if op_list_is_whitelist else "0") + " > $@"),
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user