op_def_registry.get now delegates to OpRegistry::Global()
This change removes a separate Python OpDef registry and therefore resolves a possible inconsistency between Python/C++. op_def_registry is now a thin wrapper around OpRegistry::Global() which is defined and updated in C++. PiperOrigin-RevId: 271530180
This commit is contained in:
parent
18342dab3f
commit
3bd5fa5b9e
@ -1040,12 +1040,24 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "_op_def_registry",
|
||||
srcs = ["framework/op_def_registry.cc"],
|
||||
module_name = "_op_def_registry",
|
||||
deps = [
|
||||
":pybind11_status",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "op_def_registry",
|
||||
srcs = ["framework/op_def_registry.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":pywrap_tensorflow",
|
||||
":_op_def_registry",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
],
|
||||
)
|
||||
@ -5140,6 +5152,7 @@ genrule(
|
||||
"//tensorflow/stream_executor:stream_executor_pimpl", # stat_summarizer
|
||||
"//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool", # graph_analyzer
|
||||
"//tensorflow/core/profiler/internal:print_model_analysis", # tfprof
|
||||
"//tensorflow/core:framework_internal_impl", # op_def_registry
|
||||
],
|
||||
outs = ["pybind_symbol_target_libs_file.txt"],
|
||||
cmd = select({
|
||||
|
||||
43
tensorflow/python/framework/op_def_registry.cc
Normal file
43
tensorflow/python/framework/op_def_registry.cc
Normal file
@ -0,0 +1,43 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_def_util.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
PYBIND11_MODULE(_op_def_registry, m) {
|
||||
m.def("get", [](const std::string& name) {
|
||||
const tensorflow::OpDef* op_def = nullptr;
|
||||
auto status = tensorflow::OpRegistry::Global()->LookUpOpDef(name, &op_def);
|
||||
if (!status.ok()) return py::reinterpret_borrow<py::object>(py::none());
|
||||
|
||||
tensorflow::OpDef stripped_op_def = *op_def;
|
||||
tensorflow::RemoveNonDeprecationDescriptionsFromOpDef(&stripped_op_def);
|
||||
|
||||
tensorflow::MaybeRaiseFromStatus(status);
|
||||
std::string serialized_op_def;
|
||||
if (!stripped_op_def.SerializeToString(&serialized_op_def)) {
|
||||
throw std::runtime_error("Failed to serialize OpDef to string");
|
||||
}
|
||||
|
||||
// Explicitly convert to py::bytes because std::string is implicitly
|
||||
// convertable to py::str by default.
|
||||
return py::reinterpret_borrow<py::object>(py::bytes(serialized_op_def));
|
||||
});
|
||||
}
|
||||
@ -22,61 +22,39 @@ from __future__ import print_function
|
||||
import threading
|
||||
|
||||
from tensorflow.core.framework import op_def_pb2
|
||||
from tensorflow.python import pywrap_tensorflow as c_api
|
||||
from tensorflow.python import _op_def_registry
|
||||
|
||||
|
||||
_registered_ops = {}
|
||||
_sync_lock = threading.Lock()
|
||||
|
||||
|
||||
def _remove_non_deprecated_descriptions(op_def):
|
||||
"""Remove docs from `op_def` but leave explanations of deprecations."""
|
||||
for input_arg in op_def.input_arg:
|
||||
input_arg.description = ""
|
||||
for output_arg in op_def.output_arg:
|
||||
output_arg.description = ""
|
||||
for attr in op_def.attr:
|
||||
attr.description = ""
|
||||
|
||||
op_def.summary = ""
|
||||
op_def.description = ""
|
||||
|
||||
|
||||
def register_op_list(op_list):
|
||||
"""Register all the ops in an op_def_pb2.OpList."""
|
||||
if not isinstance(op_list, op_def_pb2.OpList):
|
||||
raise TypeError("%s is %s, not an op_def_pb2.OpList" %
|
||||
(op_list, type(op_list)))
|
||||
for op_def in op_list.op:
|
||||
if op_def.name in _registered_ops:
|
||||
if _registered_ops[op_def.name] != op_def:
|
||||
raise ValueError(
|
||||
"Registered op_def for %s (%s) not equal to op_def to register (%s)"
|
||||
% (op_def.name, _registered_ops[op_def.name], op_def))
|
||||
else:
|
||||
_registered_ops[op_def.name] = op_def
|
||||
# The cache amortizes ProtoBuf serialization/deserialization overhead
|
||||
# on the language boundary. If an OpDef has been looked up, its Python
|
||||
# representation is cached.
|
||||
_cache = {}
|
||||
_cache_lock = threading.Lock()
|
||||
|
||||
|
||||
def get(name):
|
||||
"""Returns an OpDef for a given `name` or None if the lookup fails."""
|
||||
with _sync_lock:
|
||||
return _registered_ops.get(name)
|
||||
try:
|
||||
return _cache[name]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
with _cache_lock:
|
||||
try:
|
||||
# Return if another thread has already populated the cache.
|
||||
return _cache[name]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
serialized_op_def = _op_def_registry.get(name)
|
||||
if serialized_op_def is None:
|
||||
return None
|
||||
|
||||
op_def = op_def_pb2.OpDef()
|
||||
op_def.ParseFromString(serialized_op_def)
|
||||
_cache[name] = op_def
|
||||
return op_def
|
||||
|
||||
|
||||
# TODO(b/141354889): Remove once there are no callers.
|
||||
def sync():
|
||||
"""Synchronize the contents of the Python registry with C++."""
|
||||
with _sync_lock:
|
||||
p_buffer = c_api.TF_GetAllOpList()
|
||||
cpp_op_list = op_def_pb2.OpList()
|
||||
cpp_op_list.ParseFromString(c_api.TF_GetBuffer(p_buffer))
|
||||
for op_def in cpp_op_list.op:
|
||||
# If an OpList is registered from a gen_*_ops.py, it does not any
|
||||
# descriptions. Strip them here as well to satisfy validation in
|
||||
# register_op_list.
|
||||
_remove_non_deprecated_descriptions(op_def)
|
||||
_registered_ops[op_def.name] = op_def
|
||||
|
||||
|
||||
def get_registered_ops():
|
||||
"""Returns a dictionary mapping names to OpDefs."""
|
||||
return _registered_ops
|
||||
"""No-op. Used to synchronize the contents of the Python registry with C++."""
|
||||
|
||||
@ -1073,7 +1073,6 @@ from tensorflow.python.util.tf_export import tf_export
|
||||
result.append(R"(def _InitOpDefLibrary(op_list_proto_bytes):
|
||||
op_list = _op_def_pb2.OpList()
|
||||
op_list.ParseFromString(op_list_proto_bytes)
|
||||
_op_def_registry.register_op_list(op_list)
|
||||
op_def_lib = _op_def_library.OpDefLibrary()
|
||||
op_def_lib.add_op_list(op_list)
|
||||
return op_def_lib
|
||||
|
||||
@ -54,3 +54,9 @@ tensorflow::EventsWriter::Close
|
||||
|
||||
[py_func_lib] # py_func
|
||||
tensorflow::InitializePyTrampoline
|
||||
|
||||
[framework_internal_impl] # op_def_registry
|
||||
tensorflow::OpRegistry::Global
|
||||
tensorflow::OpRegistry::LookUpOpDef
|
||||
tensorflow::RemoveNonDeprecationDescriptionsFromOpDef
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user