diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 59b257e6531..3195c49bfef 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -99,6 +99,7 @@ py_library( ], deps = [ ":_pywrap_events_writer", + ":_pywrap_py_exception_registry", ":_pywrap_stat_summarizer", ":_pywrap_tfprof", ":_pywrap_util_port", @@ -509,6 +510,31 @@ tf_python_pybind_extension( ], ) +filegroup( + name = "py_exception_registry_hdr", + srcs = [ + "lib/core/py_exception_registry.h", + ], + visibility = ["//visibility:public"], +) + +tf_python_pybind_extension( + name = "_pywrap_py_exception_registry", + srcs = ["lib/core/py_exception_registry_wrapper.cc"], + hdrs = [ + ":py_exception_registry_hdr", + "//tensorflow/c:headers", + "//tensorflow/c/eager:headers", + ], + module_name = "_pywrap_py_exception_registry", + deps = [ + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//third_party/python_runtime:headers", + "@pybind11", + ], +) + cc_library( name = "cpp_python_util", srcs = ["util/util.cc"], @@ -838,6 +864,7 @@ py_library( deps = [ ":_pywrap_debug_events_writer", ":_pywrap_events_writer", + ":_pywrap_py_exception_registry", ":_pywrap_py_func", # TODO(b/142001480): remove once the bug is fixed. ":_pywrap_stat_summarizer", ":_pywrap_tfprof", @@ -931,6 +958,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":_pywrap_py_exception_registry", ":c_api_util", ":error_interpolation", ":util", @@ -5086,7 +5114,6 @@ tf_py_wrap_cc( "grappler/model_analyzer.i", "grappler/tf_optimizer.i", "lib/core/bfloat16.i", - "lib/core/py_exception_registry.i", "lib/core/strings.i", "lib/io/file_io.i", "lib/io/py_record_reader.i", @@ -5203,6 +5230,7 @@ genrule( "//tensorflow/core:framework_internal_impl", # op_def_registry "//tensorflow/core:lib_internal_impl", # device_lib "//tensorflow/core:core_cpu_impl", # device_lib + ":py_exception_registry", # py_exception_registry ], outs = ["pybind_symbol_target_libs_file.txt"], cmd = select({ diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index f7ac57f37bb..fbd554e99bf 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -52,6 +52,7 @@ from tensorflow.python import _pywrap_tfprof from tensorflow.python import _pywrap_events_writer from tensorflow.python import _pywrap_util_port from tensorflow.python import _pywrap_stat_summarizer +from tensorflow.python import _pywrap_py_exception_registry # Protocol buffers from tensorflow.core.framework.graph_pb2 import * diff --git a/tensorflow/python/framework/errors_impl.py b/tensorflow/python/framework/errors_impl.py index 233803bb242..209d0e2eb5b 100644 --- a/tensorflow/python/framework/errors_impl.py +++ b/tensorflow/python/framework/errors_impl.py @@ -22,6 +22,7 @@ import traceback import warnings from tensorflow.core.lib.core import error_codes_pb2 +from tensorflow.python import _pywrap_py_exception_registry from tensorflow.python import pywrap_tensorflow as c_api from tensorflow.python.framework import c_api_util from tensorflow.python.framework import error_interpolation @@ -503,7 +504,7 @@ _CODE_TO_EXCEPTION_CLASS = { DATA_LOSS: DataLossError, } -c_api.PyExceptionRegistry_Init(_CODE_TO_EXCEPTION_CLASS) +_pywrap_py_exception_registry.PyExceptionRegistry_Init(_CODE_TO_EXCEPTION_CLASS) _EXCEPTION_CLASS_TO_CODE = { class_: code for code, class_ in _CODE_TO_EXCEPTION_CLASS.items()} diff --git a/tensorflow/python/lib/core/py_exception_registry_wrapper.cc b/tensorflow/python/lib/core/py_exception_registry_wrapper.cc new file mode 100644 index 00000000000..2ae56c3f671 --- /dev/null +++ b/tensorflow/python/lib/core/py_exception_registry_wrapper.cc @@ -0,0 +1,32 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include + +#include "include/pybind11/pybind11.h" +#include "include/pybind11/pytypes.h" +#include "tensorflow/python/lib/core/py_exception_registry.h" + +namespace py = pybind11; + +PYBIND11_MODULE(_pywrap_py_exception_registry, m) { + m.def("PyExceptionRegistry_Init", [](py::object& code_to_exc_type_map) { + tensorflow::PyExceptionRegistry::Init(code_to_exc_type_map.ptr()); + }); + m.def("PyExceptionRegistry_Lookup", + [](TF_Code code) { tensorflow::PyExceptionRegistry::Lookup(code); }); +}; diff --git a/tensorflow/python/platform/base.i b/tensorflow/python/platform/base.i index 42a0ff57568..25fffcfb2d2 100644 --- a/tensorflow/python/platform/base.i +++ b/tensorflow/python/platform/base.i @@ -20,6 +20,8 @@ limitations under the License. #include #include "tensorflow/c/tf_status.h" #include "tensorflow/core/platform/types.h" + #include "tensorflow/python/lib/core/py_exception_registry.h" + using tensorflow::uint64; using tensorflow::string; diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i index 525e95411f7..070323f36cd 100644 --- a/tensorflow/python/tensorflow.i +++ b/tensorflow/python/tensorflow.i @@ -21,7 +21,6 @@ limitations under the License. %include "tensorflow/python/pywrap_tfe.i" -%include "tensorflow/python/lib/core/py_exception_registry.i" %include "tensorflow/python/lib/io/py_record_reader.i" %include "tensorflow/python/lib/io/py_record_writer.i" diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt index da67eec55ff..216e88c6eab 100644 --- a/tensorflow/tools/def_file_filter/symbols_pybind.txt +++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt @@ -74,3 +74,7 @@ tensorflow::SessionOptions::SessionOptions tensorflow::ConfigProto::ConfigProto tensorflow::ConfigProto::ParseFromString tensorflow::DeviceAttributes::SerializeToString + +[py_exception_registry] # py_exception_registry +tensorflow::PyExceptionRegistry::Init +tensorflow::PyExceptionRegistry::Lookup