diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 16431d35616..a5d2301d706 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -1040,12 +1040,24 @@ py_library( ], ) +tf_python_pybind_extension( + name = "_op_def_registry", + srcs = ["framework/op_def_registry.cc"], + module_name = "_op_def_registry", + deps = [ + ":pybind11_status", + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:protos_all_cc", + "@pybind11", + ], +) + py_library( name = "op_def_registry", srcs = ["framework/op_def_registry.py"], srcs_version = "PY2AND3", deps = [ - ":pywrap_tensorflow", + ":_op_def_registry", "//tensorflow/core:protos_all_py", ], ) @@ -5140,6 +5152,7 @@ genrule( "//tensorflow/stream_executor:stream_executor_pimpl", # stat_summarizer "//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool", # graph_analyzer "//tensorflow/core/profiler/internal:print_model_analysis", # tfprof + "//tensorflow/core:framework_internal_impl", # op_def_registry ], outs = ["pybind_symbol_target_libs_file.txt"], cmd = select({ diff --git a/tensorflow/python/framework/op_def_registry.cc b/tensorflow/python/framework/op_def_registry.cc new file mode 100644 index 00000000000..0de2ce01b96 --- /dev/null +++ b/tensorflow/python/framework/op_def_registry.cc @@ -0,0 +1,43 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "include/pybind11/pybind11.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_def_util.h" +#include "tensorflow/python/lib/core/pybind11_status.h" + +namespace py = pybind11; + +PYBIND11_MODULE(_op_def_registry, m) { + m.def("get", [](const std::string& name) { + const tensorflow::OpDef* op_def = nullptr; + auto status = tensorflow::OpRegistry::Global()->LookUpOpDef(name, &op_def); + if (!status.ok()) return py::reinterpret_borrow(py::none()); + + tensorflow::OpDef stripped_op_def = *op_def; + tensorflow::RemoveNonDeprecationDescriptionsFromOpDef(&stripped_op_def); + + tensorflow::MaybeRaiseFromStatus(status); + std::string serialized_op_def; + if (!stripped_op_def.SerializeToString(&serialized_op_def)) { + throw std::runtime_error("Failed to serialize OpDef to string"); + } + + // Explicitly convert to py::bytes because std::string is implicitly + // convertable to py::str by default. + return py::reinterpret_borrow(py::bytes(serialized_op_def)); + }); +} diff --git a/tensorflow/python/framework/op_def_registry.py b/tensorflow/python/framework/op_def_registry.py index 2ef386879ad..5949db537cf 100644 --- a/tensorflow/python/framework/op_def_registry.py +++ b/tensorflow/python/framework/op_def_registry.py @@ -22,61 +22,39 @@ from __future__ import print_function import threading from tensorflow.core.framework import op_def_pb2 -from tensorflow.python import pywrap_tensorflow as c_api +from tensorflow.python import _op_def_registry - -_registered_ops = {} -_sync_lock = threading.Lock() - - -def _remove_non_deprecated_descriptions(op_def): - """Remove docs from `op_def` but leave explanations of deprecations.""" - for input_arg in op_def.input_arg: - input_arg.description = "" - for output_arg in op_def.output_arg: - output_arg.description = "" - for attr in op_def.attr: - attr.description = "" - - op_def.summary = "" - op_def.description = "" - - -def register_op_list(op_list): - """Register all the ops in an op_def_pb2.OpList.""" - if not isinstance(op_list, op_def_pb2.OpList): - raise TypeError("%s is %s, not an op_def_pb2.OpList" % - (op_list, type(op_list))) - for op_def in op_list.op: - if op_def.name in _registered_ops: - if _registered_ops[op_def.name] != op_def: - raise ValueError( - "Registered op_def for %s (%s) not equal to op_def to register (%s)" - % (op_def.name, _registered_ops[op_def.name], op_def)) - else: - _registered_ops[op_def.name] = op_def +# The cache amortizes ProtoBuf serialization/deserialization overhead +# on the language boundary. If an OpDef has been looked up, its Python +# representation is cached. +_cache = {} +_cache_lock = threading.Lock() def get(name): """Returns an OpDef for a given `name` or None if the lookup fails.""" - with _sync_lock: - return _registered_ops.get(name) + try: + return _cache[name] + except KeyError: + pass + + with _cache_lock: + try: + # Return if another thread has already populated the cache. + return _cache[name] + except KeyError: + pass + + serialized_op_def = _op_def_registry.get(name) + if serialized_op_def is None: + return None + + op_def = op_def_pb2.OpDef() + op_def.ParseFromString(serialized_op_def) + _cache[name] = op_def + return op_def +# TODO(b/141354889): Remove once there are no callers. def sync(): - """Synchronize the contents of the Python registry with C++.""" - with _sync_lock: - p_buffer = c_api.TF_GetAllOpList() - cpp_op_list = op_def_pb2.OpList() - cpp_op_list.ParseFromString(c_api.TF_GetBuffer(p_buffer)) - for op_def in cpp_op_list.op: - # If an OpList is registered from a gen_*_ops.py, it does not any - # descriptions. Strip them here as well to satisfy validation in - # register_op_list. - _remove_non_deprecated_descriptions(op_def) - _registered_ops[op_def.name] = op_def - - -def get_registered_ops(): - """Returns a dictionary mapping names to OpDefs.""" - return _registered_ops + """No-op. Used to synchronize the contents of the Python registry with C++.""" diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc index cccadadb6d4..7db4a0c5133 100644 --- a/tensorflow/python/framework/python_op_gen.cc +++ b/tensorflow/python/framework/python_op_gen.cc @@ -1073,7 +1073,6 @@ from tensorflow.python.util.tf_export import tf_export result.append(R"(def _InitOpDefLibrary(op_list_proto_bytes): op_list = _op_def_pb2.OpList() op_list.ParseFromString(op_list_proto_bytes) - _op_def_registry.register_op_list(op_list) op_def_lib = _op_def_library.OpDefLibrary() op_def_lib.add_op_list(op_list) return op_def_lib diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt index 0c173b2f96e..72ff2b2d052 100644 --- a/tensorflow/tools/def_file_filter/symbols_pybind.txt +++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt @@ -54,3 +54,9 @@ tensorflow::EventsWriter::Close [py_func_lib] # py_func tensorflow::InitializePyTrampoline + +[framework_internal_impl] # op_def_registry +tensorflow::OpRegistry::Global +tensorflow::OpRegistry::LookUpOpDef +tensorflow::RemoveNonDeprecationDescriptionsFromOpDef +