op_def_registry.get now delegates to OpRegistry::Global()

This change removes a separate Python OpDef registry and therefore resolves
a possible inconsistency between Python/C++. op_def_registry is now a thin
wrapper around OpRegistry::Global() which is defined and updated in C++.

PiperOrigin-RevId: 271530180
This commit is contained in:
Sergei Lebedev 2019-09-27 02:18:07 -07:00 committed by TensorFlower Gardener
parent 18342dab3f
commit 3bd5fa5b9e
5 changed files with 91 additions and 52 deletions

View File

@ -1040,12 +1040,24 @@ py_library(
],
)
tf_python_pybind_extension(
name = "_op_def_registry",
srcs = ["framework/op_def_registry.cc"],
module_name = "_op_def_registry",
deps = [
":pybind11_status",
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core:protos_all_cc",
"@pybind11",
],
)
py_library(
name = "op_def_registry",
srcs = ["framework/op_def_registry.py"],
srcs_version = "PY2AND3",
deps = [
":pywrap_tensorflow",
":_op_def_registry",
"//tensorflow/core:protos_all_py",
],
)
@ -5140,6 +5152,7 @@ genrule(
"//tensorflow/stream_executor:stream_executor_pimpl", # stat_summarizer
"//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool", # graph_analyzer
"//tensorflow/core/profiler/internal:print_model_analysis", # tfprof
"//tensorflow/core:framework_internal_impl", # op_def_registry
],
outs = ["pybind_symbol_target_libs_file.txt"],
cmd = select({

View File

@ -0,0 +1,43 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "include/pybind11/pybind11.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/op_def_util.h"
#include "tensorflow/python/lib/core/pybind11_status.h"
namespace py = pybind11;
PYBIND11_MODULE(_op_def_registry, m) {
m.def("get", [](const std::string& name) {
const tensorflow::OpDef* op_def = nullptr;
auto status = tensorflow::OpRegistry::Global()->LookUpOpDef(name, &op_def);
if (!status.ok()) return py::reinterpret_borrow<py::object>(py::none());
tensorflow::OpDef stripped_op_def = *op_def;
tensorflow::RemoveNonDeprecationDescriptionsFromOpDef(&stripped_op_def);
tensorflow::MaybeRaiseFromStatus(status);
std::string serialized_op_def;
if (!stripped_op_def.SerializeToString(&serialized_op_def)) {
throw std::runtime_error("Failed to serialize OpDef to string");
}
// Explicitly convert to py::bytes because std::string is implicitly
// convertable to py::str by default.
return py::reinterpret_borrow<py::object>(py::bytes(serialized_op_def));
});
}

View File

@ -22,61 +22,39 @@ from __future__ import print_function
import threading
from tensorflow.core.framework import op_def_pb2
from tensorflow.python import pywrap_tensorflow as c_api
from tensorflow.python import _op_def_registry
_registered_ops = {}
_sync_lock = threading.Lock()
def _remove_non_deprecated_descriptions(op_def):
"""Remove docs from `op_def` but leave explanations of deprecations."""
for input_arg in op_def.input_arg:
input_arg.description = ""
for output_arg in op_def.output_arg:
output_arg.description = ""
for attr in op_def.attr:
attr.description = ""
op_def.summary = ""
op_def.description = ""
def register_op_list(op_list):
"""Register all the ops in an op_def_pb2.OpList."""
if not isinstance(op_list, op_def_pb2.OpList):
raise TypeError("%s is %s, not an op_def_pb2.OpList" %
(op_list, type(op_list)))
for op_def in op_list.op:
if op_def.name in _registered_ops:
if _registered_ops[op_def.name] != op_def:
raise ValueError(
"Registered op_def for %s (%s) not equal to op_def to register (%s)"
% (op_def.name, _registered_ops[op_def.name], op_def))
else:
_registered_ops[op_def.name] = op_def
# The cache amortizes ProtoBuf serialization/deserialization overhead
# on the language boundary. If an OpDef has been looked up, its Python
# representation is cached.
_cache = {}
_cache_lock = threading.Lock()
def get(name):
"""Returns an OpDef for a given `name` or None if the lookup fails."""
with _sync_lock:
return _registered_ops.get(name)
try:
return _cache[name]
except KeyError:
pass
with _cache_lock:
try:
# Return if another thread has already populated the cache.
return _cache[name]
except KeyError:
pass
serialized_op_def = _op_def_registry.get(name)
if serialized_op_def is None:
return None
op_def = op_def_pb2.OpDef()
op_def.ParseFromString(serialized_op_def)
_cache[name] = op_def
return op_def
# TODO(b/141354889): Remove once there are no callers.
def sync():
"""Synchronize the contents of the Python registry with C++."""
with _sync_lock:
p_buffer = c_api.TF_GetAllOpList()
cpp_op_list = op_def_pb2.OpList()
cpp_op_list.ParseFromString(c_api.TF_GetBuffer(p_buffer))
for op_def in cpp_op_list.op:
# If an OpList is registered from a gen_*_ops.py, it does not any
# descriptions. Strip them here as well to satisfy validation in
# register_op_list.
_remove_non_deprecated_descriptions(op_def)
_registered_ops[op_def.name] = op_def
def get_registered_ops():
"""Returns a dictionary mapping names to OpDefs."""
return _registered_ops
"""No-op. Used to synchronize the contents of the Python registry with C++."""

View File

@ -1073,7 +1073,6 @@ from tensorflow.python.util.tf_export import tf_export
result.append(R"(def _InitOpDefLibrary(op_list_proto_bytes):
op_list = _op_def_pb2.OpList()
op_list.ParseFromString(op_list_proto_bytes)
_op_def_registry.register_op_list(op_list)
op_def_lib = _op_def_library.OpDefLibrary()
op_def_lib.add_op_list(op_list)
return op_def_lib

View File

@ -54,3 +54,9 @@ tensorflow::EventsWriter::Close
[py_func_lib] # py_func
tensorflow::InitializePyTrampoline
[framework_internal_impl] # op_def_registry
tensorflow::OpRegistry::Global
tensorflow::OpRegistry::LookUpOpDef
tensorflow::RemoveNonDeprecationDescriptionsFromOpDef