diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index b32f574628c..fe85f8ee0ed 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -1496,7 +1496,8 @@ TF_CAPI_EXPORT extern int TF_DeviceListCount(const TF_DeviceList* list); // If index is out of bounds, an error code will be set in the status object, // and a null pointer will be returned. TF_CAPI_EXPORT extern const char* TF_DeviceListName(const TF_DeviceList* list, - int index, TF_Status*); + int index, + TF_Status* status); // Retrieves the type of the device at the given index. // @@ -1506,14 +1507,15 @@ TF_CAPI_EXPORT extern const char* TF_DeviceListName(const TF_DeviceList* list, // If index is out of bounds, an error code will be set in the status object, // and a null pointer will be returned. TF_CAPI_EXPORT extern const char* TF_DeviceListType(const TF_DeviceList* list, - int index, TF_Status*); + int index, + TF_Status* status); // Retrieve the amount of memory associated with a given device. // // If index is out of bounds, an error code will be set in the status object, // and -1 will be returned. TF_CAPI_EXPORT extern int64_t TF_DeviceListMemoryBytes( - const TF_DeviceList* list, int index, TF_Status*); + const TF_DeviceList* list, int index, TF_Status* status); // -------------------------------------------------------------------------- // Load plugins containing custom ops and kernels diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index b7763079242..fae45ead5ca 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -474,6 +474,8 @@ set (pywrap_tensorflow_internal_src "${tensorflow_source_dir}/tensorflow/python/lib/core/ndarray_tensor_bridge.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_func.h" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_func.cc" + "${tensorflow_source_dir}/tensorflow/python/lib/core/py_exception_registry.h" + "${tensorflow_source_dir}/tensorflow/python/lib/core/py_exception_registry.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.h" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_util.h" diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index aa0acd243c9..c502a3a42b1 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -283,6 +283,17 @@ cc_library( ], ) +cc_library( + name = "py_exception_registry", + srcs = ["lib/core/py_exception_registry.cc"], + hdrs = ["lib/core/py_exception_registry.h"], + deps = [ + "//tensorflow/c:c_api", + "//tensorflow/core:lib", + "//util/python:python_headers", + ], +) + cc_library( name = "kernel_registry", srcs = ["util/kernel_registry.cc"], @@ -3313,6 +3324,7 @@ tf_py_wrap_cc( "grappler/model_analyzer.i", "grappler/tf_optimizer.i", "lib/core/bfloat16.i", + "lib/core/py_exception_registry.i", "lib/core/py_func.i", "lib/core/strings.i", "lib/io/file_io.i", @@ -3344,6 +3356,7 @@ tf_py_wrap_cc( ":kernel_registry", ":numpy_lib", ":safe_ptr", + ":py_exception_registry", ":py_func_lib", ":py_record_reader_lib", ":py_record_writer_lib", diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index 5c9ed9ccafd..4c84d78f2e1 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -27,7 +27,6 @@ import numpy as np from tensorflow.core.protobuf import config_pb2 from tensorflow.python import pywrap_tensorflow as tf_session -from tensorflow.python.framework import c_api_util from tensorflow.python.framework import device from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -629,14 +628,12 @@ class BaseSession(SessionInterface): self._session = None opts = tf_session.TF_NewSessionOptions(target=self._target, config=config) try: - with errors.raise_exception_on_not_ok_status() as status: - if self._created_with_new_api: - # pylint: disable=protected-access - self._session = tf_session.TF_NewSession(self._graph._c_graph, opts, - status) - # pylint: enable=protected-access - else: - self._session = tf_session.TF_NewDeprecatedSession(opts, status) + if self._created_with_new_api: + # pylint: disable=protected-access + self._session = tf_session.TF_NewSession(self._graph._c_graph, opts) + # pylint: enable=protected-access + else: + self._session = tf_session.TF_NewDeprecatedSession(opts) finally: tf_session.TF_DeleteSessionOptions(opts) @@ -663,22 +660,20 @@ class BaseSession(SessionInterface): Returns: A list of devices in the session. """ - with errors.raise_exception_on_not_ok_status() as status: - if self._created_with_new_api: - raw_device_list = tf_session.TF_SessionListDevices( - self._session, status) - else: - raw_device_list = tf_session.TF_DeprecatedSessionListDevices( - self._session, status) - device_list = [] - size = tf_session.TF_DeviceListCount(raw_device_list) - for i in range(size): - name = tf_session.TF_DeviceListName(raw_device_list, i, status) - device_type = tf_session.TF_DeviceListType(raw_device_list, i, status) - memory = tf_session.TF_DeviceListMemoryBytes(raw_device_list, i, status) - device_list.append(_DeviceAttributes(name, device_type, memory)) - tf_session.TF_DeleteDeviceList(raw_device_list) - return device_list + if self._created_with_new_api: + raw_device_list = tf_session.TF_SessionListDevices(self._session) + else: + raw_device_list = tf_session.TF_DeprecatedSessionListDevices( + self._session) + device_list = [] + size = tf_session.TF_DeviceListCount(raw_device_list) + for i in range(size): + name = tf_session.TF_DeviceListName(raw_device_list, i) + device_type = tf_session.TF_DeviceListType(raw_device_list, i) + memory = tf_session.TF_DeviceListMemoryBytes(raw_device_list, i) + device_list.append(_DeviceAttributes(name, device_type, memory)) + tf_session.TF_DeleteDeviceList(raw_device_list) + return device_list def close(self): """Closes this session. @@ -692,15 +687,13 @@ class BaseSession(SessionInterface): if self._created_with_new_api: if self._session and not self._closed: self._closed = True - with errors.raise_exception_on_not_ok_status() as status: - tf_session.TF_CloseSession(self._session, status) + tf_session.TF_CloseSession(self._session) else: with self._extend_lock: if self._opened and not self._closed: self._closed = True - with errors.raise_exception_on_not_ok_status() as status: - tf_session.TF_CloseDeprecatedSession(self._session, status) + tf_session.TF_CloseDeprecatedSession(self._session) def __del__(self): # cleanly ignore all exceptions @@ -710,11 +703,10 @@ class BaseSession(SessionInterface): pass if self._session is not None: try: - status = c_api_util.ScopedTFStatus() if self._created_with_new_api: - tf_session.TF_DeleteSession(self._session, status) + tf_session.TF_DeleteSession(self._session) else: - tf_session.TF_DeleteDeprecatedSession(self._session, status) + tf_session.TF_DeleteDeprecatedSession(self._session) except AttributeError: # At shutdown, `c_api_util` or `tf_session` may have been garbage # collected, causing the above method calls to fail. In this case, @@ -1031,11 +1023,11 @@ class BaseSession(SessionInterface): # Set up a graph with feeds and fetches for partial run. def _setup_fn(session, feed_list, fetch_list, target_list): self._extend_graph() - with errors.raise_exception_on_not_ok_status() as status: - if self._created_with_new_api: - return tf_session.TF_SessionPRunSetup_wrapper( - session, feed_list, fetch_list, target_list, status) - else: + if self._created_with_new_api: + return tf_session.TF_SessionPRunSetup_wrapper( + session, feed_list, fetch_list, target_list) + else: + with errors.raise_exception_on_not_ok_status() as status: return tf_session.TF_PRunSetup(session, feed_list, fetch_list, target_list, status) @@ -1345,8 +1337,7 @@ class BaseSession(SessionInterface): def _extend_graph(self): if self._created_with_new_api: with self._graph._lock: # pylint: disable=protected-access - with errors.raise_exception_on_not_ok_status() as status: - tf_session.ExtendSession(self._session, status) + tf_session.ExtendSession(self._session) else: # Ensure any changes to the graph are reflected in the runtime. with self._extend_lock: @@ -1412,22 +1403,22 @@ class BaseSession(SessionInterface): def _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list, run_metadata): - with errors.raise_exception_on_not_ok_status() as status: - if self._created_with_new_api: - return tf_session.TF_SessionRun_wrapper( - self._session, options, feed_dict, fetch_list, target_list, - run_metadata, status) - else: + if self._created_with_new_api: + return tf_session.TF_SessionRun_wrapper( + self._session, options, feed_dict, fetch_list, target_list, + run_metadata) + else: + with errors.raise_exception_on_not_ok_status() as status: return tf_session.TF_Run( self._session, options, feed_dict, fetch_list, target_list, status, run_metadata) def _call_tf_sessionprun(self, handle, feed_dict, fetch_list): - with errors.raise_exception_on_not_ok_status() as status: - if self._created_with_new_api: - return tf_session.TF_SessionPRun_wrapper( - self._session, handle, feed_dict, fetch_list, status) - else: + if self._created_with_new_api: + return tf_session.TF_SessionPRun_wrapper( + self._session, handle, feed_dict, fetch_list) + else: + with errors.raise_exception_on_not_ok_status() as status: return tf_session.TF_PRun( self._session, handle, feed_dict, fetch_list, status) diff --git a/tensorflow/python/client/session_list_devices_test.py b/tensorflow/python/client/session_list_devices_test.py index 5a7413c12e9..38a3acb2dc3 100644 --- a/tensorflow/python/client/session_list_devices_test.py +++ b/tensorflow/python/client/session_list_devices_test.py @@ -23,7 +23,6 @@ from tensorflow.core.protobuf import cluster_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.python import pywrap_tensorflow as tf_session from tensorflow.python.client import session -from tensorflow.python.framework import c_api_util from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util @@ -42,21 +41,13 @@ class SessionListDevicesTestMethods(object): def testInvalidDeviceNumber(self): opts = tf_session.TF_NewSessionOptions() - with errors.raise_exception_on_not_ok_status() as status: - c_session = tf_session.TF_NewSession( - ops.get_default_graph()._c_graph, opts, status) - raw_device_list = tf_session.TF_SessionListDevices( - c_session, status) + c_session = tf_session.TF_NewSession(ops.get_default_graph()._c_graph, opts) + raw_device_list = tf_session.TF_SessionListDevices(c_session) size = tf_session.TF_DeviceListCount(raw_device_list) - # Test that invalid device numbers return -1 rather than a Swig-wrapped - # pointer. - status_no_exception = c_api_util.ScopedTFStatus() - memory = tf_session.TF_DeviceListMemoryBytes( - raw_device_list, size, status_no_exception) - self.assertEqual(memory, -1) + with self.assertRaises(errors.InvalidArgumentError): + tf_session.TF_DeviceListMemoryBytes(raw_device_list, size) tf_session.TF_DeleteDeviceList(raw_device_list) - with errors.raise_exception_on_not_ok_status() as status: - tf_session.TF_CloseSession(c_session, status) + tf_session.TF_CloseSession(c_session) def testListDevicesGrpcSession(self): server = server_lib.Server.create_local_server() diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i index 77ce9195eef..5dcd0c192e4 100644 --- a/tensorflow/python/client/tf_session.i +++ b/tensorflow/python/client/tf_session.i @@ -18,11 +18,12 @@ limitations under the License. %{ #include "tensorflow/c/python_api.h" -#include "tensorflow/python/client/tf_session_helper.h" #include "tensorflow/core/framework/session_state.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/public/version.h" +#include "tensorflow/python/client/tf_session_helper.h" +#include "tensorflow/python/lib/core/py_exception_registry.h" // Helper function to convert a Python list of Tensors to a C++ vector of // TF_Outputs. @@ -352,6 +353,27 @@ TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper{ reinterpret_cast($1.data), $1.length); } +// Typemaps to automatically raise a Python exception from bad output TF_Status. +// TODO(b/77295559): expand this to all TF_Status* output params and deprecate +// raise_exception_on_not_ok_status (currently it only affects the C API). +%typemap(in, numinputs=0) TF_Status* status (TF_Status* status) { + status = TF_NewStatus(); + $1 = status; +} + +%typemap(argout) TF_Status* status { + TF_Code code = TF_GetCode($1); + if (code != TF_OK) { + PyObject* exc = tensorflow::PyExceptionRegistry::Lookup(code); + // Arguments to OpError. + PyObject* exc_args = Py_BuildValue("sss", nullptr, nullptr, TF_Message($1)); + TF_DeleteStatus($1); + SWIG_SetErrorObj(exc, exc_args); + SWIG_fail; + } + TF_DeleteStatus($1); +} + // Converts input Python list of wrapped TF_Outputs into a single array %typemap(in) (const TF_Output* inputs, int num_inputs) (std::vector inputs) { @@ -499,9 +521,8 @@ TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper{ _TF_SetTarget(opts, target) if config is not None: from tensorflow.python.framework import errors - with errors.raise_exception_on_not_ok_status() as status: - config_str = config.SerializeToString() - _TF_SetConfig(opts, config_str, status) + config_str = config.SerializeToString() + _TF_SetConfig(opts, config_str) return opts %} @@ -758,3 +779,7 @@ def TF_Reset(target, containers=None, config=None): %include "tensorflow/python/client/tf_session_helper.h" %unignoreall + +// Clear "TF_Status* status" typemap so it doesn't affect other modules and +// unexpectedly remove the TF_Status* argument from wrappers. +%clear TF_Status* status; diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h index 603d03e3152..5416d41376f 100644 --- a/tensorflow/python/client/tf_session_helper.h +++ b/tensorflow/python/client/tf_session_helper.h @@ -136,8 +136,7 @@ string EqualAttrValueWrapper(const string& actual, const string& expected); // // If shape is unknown, sets unknown_shape to true. tensorflow::gtl::InlinedVector TF_GraphGetTensorShapeHelper( - TF_Graph* graph, TF_Output output, TF_Status* out_status, - bool* unknown_shape); + TF_Graph* graph, TF_Output output, TF_Status* status, bool* unknown_shape); // Runs the graph associated with the session starting with the supplied inputs. // On success, `py_outputs` is populated with a numpy ndarray for each output @@ -149,7 +148,7 @@ void TF_SessionRun_wrapper(TF_Session* session, const TF_Buffer* run_options, const std::vector& input_ndarrays, const std::vector& outputs, const std::vector& targets, - TF_Buffer* run_metadata, TF_Status* out_status, + TF_Buffer* run_metadata, TF_Status* status, std::vector* py_outputs); // Set up the graph with the intended feeds (inputs) and fetches (output) for @@ -165,8 +164,7 @@ void TF_SessionPRunSetup_wrapper(TF_Session* session, const std::vector& inputs, const std::vector& outputs, const std::vector& targets, - const char** out_handle, - TF_Status* out_status); + const char** out_handle, TF_Status* status); // Continue to run the graph with additional feeds and fetches. The // execution state is uniquely identified by the handle. @@ -182,7 +180,7 @@ void TF_SessionPRun_wrapper(TF_Session* session, const char* handle, const std::vector& inputs, const std::vector& input_ndarrays, const std::vector& outputs, - TF_Status* out_status, + TF_Status* status, std::vector* py_outputs); // Retrieves the inputs of this operation. @@ -204,7 +202,7 @@ TF_Function* TF_GraphToFunction_wrapper( const std::vector* opers, const std::vector& inputs, const std::vector& outputs, const NameVector& output_names, const TF_FunctionOptions* opts, - const char* description, TF_Status* out_status); + const char* description, TF_Status* status); // Set the shapes and types for the output's handle. // diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 8c1bb06bc31..6ad9e0d88fb 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -244,13 +244,9 @@ class Context(object): try: self._num_gpus = 0 for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)): - with errors.raise_exception_on_not_ok_status() as status: - dev_name = pywrap_tensorflow.TF_DeviceListName( - device_list, i, status) + dev_name = pywrap_tensorflow.TF_DeviceListName(device_list, i) self._context_devices.append(pydev.canonical_name(dev_name)) - with errors.raise_exception_on_not_ok_status() as status: - dev_type = pywrap_tensorflow.TF_DeviceListType( - device_list, i, status) + dev_type = pywrap_tensorflow.TF_DeviceListType(device_list, i) if dev_type == "GPU": self._num_gpus += 1 diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 343012e5525..711eddcec1d 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -34,7 +34,6 @@ from tensorflow.python.eager.graph_only_ops import graph_placeholder from tensorflow.python.framework import c_api_util from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes as dtypes_module -from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -79,14 +78,10 @@ def capture_value(tensor_map, value, dtype, name): ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes] shapes = [[d.size for d in s.dim] if not s.unknown_rank else None for s in shapes] - with errors.raise_exception_on_not_ok_status() as status: - pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper( - captured_value._op._graph._c_graph, # pylint: disable=protected-access - captured_value._as_tf_output(), # pylint: disable=protected-access - shapes, - ranks, - types, - status) + pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper( + captured_value._op._graph._c_graph, # pylint: disable=protected-access + captured_value._as_tf_output(), # pylint: disable=protected-access + shapes, ranks, types) tensor_map[ops.tensor_id(value)] = (value, captured_value) else: @@ -275,23 +270,20 @@ class _EagerDefinedFunction(object): inputs: the tensors in the graph to be used as inputs to the function outputs: the tensors in the graph which will be outputs to the function """ - with errors.raise_exception_on_not_ok_status() as status: - fn = pywrap_tensorflow.TF_GraphToFunction_wrapper( - graph._c_graph, # pylint: disable=protected-access - compat.as_str(name), - False, - [o._c_op for o in operations], # pylint: disable=protected-access - [t._as_tf_output() for t in inputs], # pylint: disable=protected-access - [t._as_tf_output() for t in outputs], # pylint: disable=protected-access - [], - None, - compat.as_str(""), - status) + fn = pywrap_tensorflow.TF_GraphToFunction_wrapper( + graph._c_graph, # pylint: disable=protected-access + compat.as_str(name), + False, + [o._c_op for o in operations], # pylint: disable=protected-access + [t._as_tf_output() for t in inputs], # pylint: disable=protected-access + [t._as_tf_output() for t in outputs], # pylint: disable=protected-access + [], + None, + compat.as_str("")) # TODO(apassos) avoid creating a FunctionDef (specially to grab the # signature, but also in general it's nice not to depend on it. with c_api_util.tf_buffer() as buffer_: - with errors.raise_exception_on_not_ok_status() as status: - pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_, status) + pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_) proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) function_def = function_pb2.FunctionDef() function_def.ParseFromString(compat.as_bytes(proto_data)) diff --git a/tensorflow/python/framework/errors_impl.py b/tensorflow/python/framework/errors_impl.py index 2a40316d51c..84106c32c67 100644 --- a/tensorflow/python/framework/errors_impl.py +++ b/tensorflow/python/framework/errors_impl.py @@ -473,6 +473,8 @@ _CODE_TO_EXCEPTION_CLASS = { DATA_LOSS: DataLossError, } +c_api.PyExceptionRegistry_Init(_CODE_TO_EXCEPTION_CLASS) + _EXCEPTION_CLASS_TO_CODE = dict(( (class_, code) for (code, class_) in _CODE_TO_EXCEPTION_CLASS.items())) @@ -499,6 +501,7 @@ def _make_specific_exception(node_def, op, message, error_code): # Named like a function for backwards compatibility with the # @tf_contextlib.contextmanager version, which was switched to a class to avoid # some object creation overhead. +# TODO(b/77295559): expand use of TF_Status* SWIG typemap and deprecate this. @tf_export("errors.raise_exception_on_not_ok_status") # pylint: disable=invalid-name class raise_exception_on_not_ok_status(object): """Context manager to check for C API status.""" diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index 82dd2a3356f..c5caf9ebc06 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -30,7 +30,6 @@ from tensorflow.python import pywrap_tensorflow as c_api from tensorflow.python.eager import context from tensorflow.python.framework import c_api_util from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors from tensorflow.python.framework import graph_to_function_def from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -275,8 +274,7 @@ class _DefinedFunction(object): self._create_definition_if_needed() if self._c_func: with c_api_util.tf_buffer() as buf: - with errors.raise_exception_on_not_ok_status() as status: - c_api.TF_FunctionToFunctionDef(self._c_func, buf, status) + c_api.TF_FunctionToFunctionDef(self._c_func, buf) fdef = function_pb2.FunctionDef() proto_data = c_api.TF_GetBuffer(buf) fdef.ParseFromString(compat.as_bytes(proto_data)) @@ -399,18 +397,16 @@ class _DefinedFunction(object): if self._out_names else []) description = self._func.__doc__ or None # pylint: disable=protected-access - with errors.raise_exception_on_not_ok_status() as status: - self._c_func = c_api.TF_GraphToFunction_wrapper( - temp_graph._c_graph, - base_func_name, - self._func_name is None, # append_hash_to_fn_name - None, # opers - [t._as_tf_output() for t in inputs], - [t._as_tf_output() for t in outputs], - output_names, - None, # opts - description, - status) + self._c_func = c_api.TF_GraphToFunction_wrapper( + temp_graph._c_graph, + base_func_name, + self._func_name is None, # append_hash_to_fn_name + None, # opers + [t._as_tf_output() for t in inputs], + [t._as_tf_output() for t in outputs], + output_names, + None, # opts + description) # pylint: enable=protected-access self._set_c_attrs(kwargs_attr) @@ -433,9 +429,8 @@ class _DefinedFunction(object): serialized = attr_value.SerializeToString() # TODO(skyewm): this creates and deletes a new TF_Status for every attr. # It might be worth creating a convenient way to re-use the same status. - with errors.raise_exception_on_not_ok_status() as status: - c_api.TF_FunctionSetAttrValueProto(self._c_func, compat.as_str(name), - serialized, status) + c_api.TF_FunctionSetAttrValueProto(self._c_func, compat.as_str(name), + serialized) def _create_hash_str(self, input_arg, output_arg, node_def): """Creates an 8-character string unique to this input. @@ -830,8 +825,7 @@ def _from_definition(fdef, grad_func=None): # pylint: disable=protected-access if ops._USE_C_API: serialized = fdef.SerializeToString() - with errors.raise_exception_on_not_ok_status() as status: - result._c_func = c_api.TF_FunctionImportFunctionDef(serialized, status) + result._c_func = c_api.TF_FunctionImportFunctionDef(serialized) result._extra_inputs = [] else: result._definition = fdef diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py index 4ea34d7bb28..23f529b9885 100644 --- a/tensorflow/python/framework/importer.py +++ b/tensorflow/python/framework/importer.py @@ -485,9 +485,8 @@ def import_graph_def(graph_def, with graph._lock: # pylint: disable=protected-access with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized: try: - with errors.raise_exception_on_not_ok_status() as status: - results = c_api.TF_GraphImportGraphDefWithResults( - graph._c_graph, serialized, options, status) # pylint: disable=protected-access + results = c_api.TF_GraphImportGraphDefWithResults( + graph._c_graph, serialized, options) # pylint: disable=protected-access except errors.InvalidArgumentError as e: # Convert to ValueError for backwards compatibility. raise ValueError(str(e)) diff --git a/tensorflow/python/framework/load_library.py b/tensorflow/python/framework/load_library.py index 1f2aa264c11..535c6017f5f 100644 --- a/tensorflow/python/framework/load_library.py +++ b/tensorflow/python/framework/load_library.py @@ -26,7 +26,6 @@ import threading # pylint: disable=unused-import from tensorflow.core.framework import op_def_pb2 from tensorflow.core.lib.core import error_codes_pb2 # pylint: disable=unused-import from tensorflow.python import pywrap_tensorflow as py_tf -from tensorflow.python.framework import errors_impl from tensorflow.python.util import compat from tensorflow.python.util.tf_export import tf_export @@ -54,8 +53,7 @@ def load_op_library(library_filename): Raises: RuntimeError: when unable to load the library or get the python wrappers. """ - with errors_impl.raise_exception_on_not_ok_status() as status: - lib_handle = py_tf.TF_LoadLibrary(library_filename, status) + lib_handle = py_tf.TF_LoadLibrary(library_filename) op_list_str = py_tf.TF_GetOpList(lib_handle) op_list = op_def_pb2.OpList() @@ -99,5 +97,4 @@ def load_file_system_library(library_filename): Raises: RuntimeError: when unable to load the library. """ - with errors_impl.raise_exception_on_not_ok_status() as status: - lib_handle = py_tf.TF_LoadLibrary(library_filename, status) + py_tf.TF_LoadLibrary(library_filename) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 6930737a0c3..7ca0b836dd4 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -373,15 +373,12 @@ class Tensor(_TensorLike): """ graph = self._op._graph._c_graph # pylint: disable=protected-access if graph and _USE_C_SHAPES: - with errors.raise_exception_on_not_ok_status() as status: - num_dims = c_api.TF_GraphGetTensorNumDims(graph, self._as_tf_output(), - status) + num_dims = c_api.TF_GraphGetTensorNumDims(graph, self._as_tf_output()) if num_dims == -1: dim_list = None else: - with errors.raise_exception_on_not_ok_status() as status: - dim_list = c_api.TF_GraphGetTensorShape_wrapper( - graph, self._as_tf_output(), num_dims, status) + dim_list = c_api.TF_GraphGetTensorShape_wrapper( + graph, self._as_tf_output(), num_dims) dim_list = [None if i == -1 else i for i in dim_list] return tensor_shape.TensorShape(dim_list) return self._shape_val @@ -489,13 +486,11 @@ class Tensor(_TensorLike): else: dim_list.append(dim.value) try: - with errors.raise_exception_on_not_ok_status() as status: - c_api.TF_GraphSetTensorShape_wrapper( - self._op._graph._c_graph, # pylint: disable=protected-access - self._as_tf_output(), - dim_list, - unknown_shape, - status) + c_api.TF_GraphSetTensorShape_wrapper( + self._op._graph._c_graph, # pylint: disable=protected-access + self._as_tf_output(), + dim_list, + unknown_shape) except errors.InvalidArgumentError as e: # Convert to ValueError for backwards compatibility. raise ValueError(str(e)) @@ -1514,13 +1509,10 @@ def _create_c_op(graph, node_def, inputs, control_inputs): serialized = attr_value.SerializeToString() # TODO(skyewm): this creates and deletes a new TF_Status for every attr. # It might be worth creating a convenient way to re-use the same status. - with errors.raise_exception_on_not_ok_status() as status: - c_api.TF_SetAttrValueProto(op_desc, - compat.as_str(name), serialized, status) + c_api.TF_SetAttrValueProto(op_desc, compat.as_str(name), serialized) try: - with errors.raise_exception_on_not_ok_status() as status: - c_op = c_api.TF_FinishOperation(op_desc, status) + c_op = c_api.TF_FinishOperation(op_desc) except errors.InvalidArgumentError as e: # Convert to ValueError for backwards compatibility. raise ValueError(str(e)) @@ -1943,12 +1935,10 @@ class Operation(object): if self._c_op: # Reset cached inputs. self._inputs_val = None - with errors.raise_exception_on_not_ok_status() as status: - c_api.UpdateEdge( - self._graph._c_graph, # pylint: disable=protected-access - tensor._as_tf_output(), # pylint: disable=protected-access - self._tf_input(index), - status) + c_api.UpdateEdge( + self._graph._c_graph, # pylint: disable=protected-access + tensor._as_tf_output(), # pylint: disable=protected-access + self._tf_input(index)) else: self._inputs_val[index].consumers().remove(self) self._inputs_val[index] = tensor @@ -2169,8 +2159,7 @@ class Operation(object): # pylint: enable=line-too-long if self._c_op: with c_api_util.tf_buffer() as buf: - with errors.raise_exception_on_not_ok_status() as status: - c_api.TF_OperationToNodeDef(self._c_op, buf, status) + c_api.TF_OperationToNodeDef(self._c_op, buf) data = c_api.TF_GetBuffer(buf) node_def = node_def_pb2.NodeDef() node_def.ParseFromString(compat.as_bytes(data)) @@ -2228,11 +2217,9 @@ class Operation(object): buf = c_api.TF_NewBufferFromString( compat.as_bytes(attr_value.SerializeToString())) try: - with errors.raise_exception_on_not_ok_status() as status: - # pylint: disable=protected-access - c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf, - status) - # pylint: enable=protected-access + # pylint: disable=protected-access + c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf) + # pylint: enable=protected-access finally: c_api.TF_DeleteBuffer(buf) else: @@ -2254,8 +2241,7 @@ class Operation(object): if self._c_op: try: with c_api_util.tf_buffer() as buf: - with errors.raise_exception_on_not_ok_status() as status: - c_api.TF_OperationGetAttrValueProto(self._c_op, name, buf, status) + c_api.TF_OperationGetAttrValueProto(self._c_op, name, buf) data = c_api.TF_GetBuffer(buf) except errors.InvalidArgumentError as e: # Convert to ValueError for backwards compatibility. @@ -2469,11 +2455,10 @@ def _set_shapes_for_outputs_c_api(op): # The C API computes the shapes when the TF_Operation is created. Fetch the # output shapes from the C object. for output in op.outputs: - with errors.raise_exception_on_not_ok_status() as status: - # pylint: disable=protected-access - shape_vector, unknown_shape = c_api.TF_GraphGetTensorShapeHelper( - op._graph._c_graph, output._as_tf_output(), status) - # pylint: enable=protected-access + # pylint: disable=protected-access + shape_vector, unknown_shape = c_api.TF_GraphGetTensorShapeHelper( + op._graph._c_graph, output._as_tf_output()) + # pylint: enable=protected-access if unknown_shape: output.set_shape(tensor_shape.unknown_shape()) elif not shape_vector: @@ -2994,8 +2979,7 @@ class Graph(object): # pylint: enable=line-too-long if self._c_graph: with c_api_util.tf_buffer() as buf: - with errors.raise_exception_on_not_ok_status() as status: - c_api.TF_GraphVersions(self._c_graph, buf, status) + c_api.TF_GraphVersions(self._c_graph, buf) data = c_api.TF_GetBuffer(buf) version_def = versions_pb2.VersionDef() version_def.ParseFromString(compat.as_bytes(data)) @@ -3098,8 +3082,7 @@ class Graph(object): if self._c_graph: with self._lock: with c_api_util.tf_buffer() as buf: - with errors.raise_exception_on_not_ok_status() as status: - c_api.TF_GraphToGraphDef(self._c_graph, buf, status) + c_api.TF_GraphToGraphDef(self._c_graph, buf) data = c_api.TF_GetBuffer(buf) graph = graph_pb2.GraphDef() graph.ParseFromString(compat.as_bytes(data)) @@ -3208,14 +3191,10 @@ class Graph(object): # remove this when all functions are generated using the C API by default # as this will be unnecessary. if not function._c_func: - with errors.raise_exception_on_not_ok_status() as status: - serialized = function.definition.SerializeToString() - function._c_func = c_api.TF_FunctionImportFunctionDef( - serialized, status) - with errors.raise_exception_on_not_ok_status() as status: - gradient = function._grad_func._c_func if function._grad_func else None - c_api.TF_GraphCopyFunction(self._c_graph, function._c_func, gradient, - status) + serialized = function.definition.SerializeToString() + function._c_func = c_api.TF_FunctionImportFunctionDef(serialized) + gradient = function._grad_func._c_func if function._grad_func else None + c_api.TF_GraphCopyFunction(self._c_graph, function._c_func, gradient) else: # If there is already a function with the same name, raise an error # if bodies are different. Else, do nothing. The C API version above @@ -3732,11 +3711,9 @@ class Graph(object): """Returns the `OpDef` proto for `type`. `type` is a string.""" if self._c_graph: with c_api_util.tf_buffer() as buf: - with errors.raise_exception_on_not_ok_status() as status: - # pylint: disable=protected-access - c_api.TF_GraphGetOpDef(self._c_graph, - compat.as_bytes(type), buf, status) - # pylint: enable=protected-access + # pylint: disable=protected-access + c_api.TF_GraphGetOpDef(self._c_graph, compat.as_bytes(type), buf) + # pylint: enable=protected-access data = c_api.TF_GetBuffer(buf) op_def = op_def_pb2.OpDef() op_def.ParseFromString(compat.as_bytes(data)) diff --git a/tensorflow/python/framework/smart_cond.py b/tensorflow/python/framework/smart_cond.py index c7ff23e4ff8..48a834392b4 100644 --- a/tensorflow/python/framework/smart_cond.py +++ b/tensorflow/python/framework/smart_cond.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function from tensorflow.python import pywrap_tensorflow as c_api -from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import control_flow_ops @@ -83,9 +82,8 @@ def smart_constant_value(pred): # wanted to limit the change hidden behind _USE_C_API). # pylint: disable=protected-access if pred_value is None and ops._USE_C_API: - with errors.raise_exception_on_not_ok_status() as status: - pred_value = c_api.TF_TryEvaluateConstant_wrapper( - pred.graph._c_graph, pred._as_tf_output(), status) + pred_value = c_api.TF_TryEvaluateConstant_wrapper(pred.graph._c_graph, + pred._as_tf_output()) # pylint: enable=protected-access else: diff --git a/tensorflow/python/lib/core/py_exception_registry.cc b/tensorflow/python/lib/core/py_exception_registry.cc new file mode 100644 index 00000000000..6637de632b4 --- /dev/null +++ b/tensorflow/python/lib/core/py_exception_registry.cc @@ -0,0 +1,50 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/python/lib/core/py_exception_registry.h" + +#include + +namespace tensorflow { + +PyExceptionRegistry* PyExceptionRegistry::singleton_ = nullptr; + +void PyExceptionRegistry::Init(PyObject* code_to_exc_type_map) { + DCHECK(singleton_ == nullptr) << "PyExceptionRegistry::Init() already called"; + singleton_ = new PyExceptionRegistry; + + DCHECK(PyDict_Check(code_to_exc_type_map)); + PyObject* key; + PyObject* value; + Py_ssize_t pos = 0; + while (PyDict_Next(code_to_exc_type_map, &pos, &key, &value)) { + TF_Code code = static_cast(PyLong_AsLong(key)); + singleton_->exc_types_[code] = value; + // The exception classes should also have the lifetime of the process, but + // incref just in case. + Py_INCREF(value); + } +} + +PyObject* PyExceptionRegistry::Lookup(TF_Code code) { + DCHECK(singleton_ != nullptr) << "Must call PyExceptionRegistry::Init() " + "before PyExceptionRegistry::Lookup()"; + DCHECK_NE(code, TF_OK); + DCHECK(singleton_->exc_types_.find(code) != singleton_->exc_types_.end()) + << "Unknown error code passed to PyExceptionRegistry::Lookup: " << code; + return singleton_->exc_types_[code]; +} + +} // namespace tensorflow diff --git a/tensorflow/python/lib/core/py_exception_registry.h b/tensorflow/python/lib/core/py_exception_registry.h new file mode 100644 index 00000000000..2b0f23b548c --- /dev/null +++ b/tensorflow/python/lib/core/py_exception_registry.h @@ -0,0 +1,73 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_PYTHON_LIB_CORE_PY_EXCEPTION_REGISTRY_H_ +#define TENSORFLOW_PYTHON_LIB_CORE_PY_EXCEPTION_REGISTRY_H_ + +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/core/platform/logging.h" + +#ifndef PyObject_HEAD +struct _object; +typedef _object PyObject; +#endif + +namespace tensorflow { + +// Global registry mapping C API error codes to the corresponding custom Python +// exception type. This is used to expose the exception types to C extension +// code (i.e. so we can raise custom exceptions via SWIG). +// +// Init() must be called exactly once at the beginning of the process before +// Lookup() can be used. +// +// Example usage: +// TF_Status* status = TF_NewStatus(); +// TF_Foo(..., status); +// +// if (TF_GetCode(status) != TF_OK) { +// PyObject* exc_type = PyExceptionRegistry::Lookup(TF_GetCode(status)); +// // Arguments to OpError base class. Set `node_def` and `op` to None. +// PyObject* args = +// Py_BuildValue("sss", nullptr, nullptr, TF_Message(status)); +// PyErr_SetObject(exc_type, args); +// Py_DECREF(args); +// TF_DeleteStatus(status); +// return NULL; +// } +class PyExceptionRegistry { + public: + // Initializes the process-wide registry. Should be called exactly once near + // the beginning of the process. The arguments are the various Python + // exception types (e.g. `cancelled_exc` corresponds to + // errors.CancelledError). + static void Init(PyObject* code_to_exc_type_map); + + // Returns the Python exception type corresponding to `code`. Init() must be + // called before using this function. `code` should not be TF_OK. + static PyObject* Lookup(TF_Code code); + + private: + static PyExceptionRegistry* singleton_; + PyExceptionRegistry() = default; + + // Maps error codes to the corresponding Python exception type. + std::map exc_types_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_PYTHON_LIB_CORE_PY_EXCEPTION_REGISTRY_H_ diff --git a/tensorflow/python/lib/core/py_exception_registry.i b/tensorflow/python/lib/core/py_exception_registry.i new file mode 100644 index 00000000000..e872b74985e --- /dev/null +++ b/tensorflow/python/lib/core/py_exception_registry.i @@ -0,0 +1,28 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +%include "tensorflow/python/platform/base.i" + +%{ +#include "tensorflow/python/lib/core/py_exception_registry.h" +%} + +%ignoreall + +%unignore tensorflow::PyExceptionRegistry; +%unignore tensorflow::PyExceptionRegistry::Init; + +%include "tensorflow/python/lib/core/py_exception_registry.h" +%unignoreall diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i index 82b908ac0e9..26e8acd8977 100644 --- a/tensorflow/python/tensorflow.i +++ b/tensorflow/python/tensorflow.i @@ -25,6 +25,7 @@ limitations under the License. %include "tensorflow/python/util/tfprof.i" %include "tensorflow/python/lib/core/py_func.i" +%include "tensorflow/python/lib/core/py_exception_registry.i" %include "tensorflow/python/lib/io/py_record_reader.i" %include "tensorflow/python/lib/io/py_record_writer.i" @@ -54,4 +55,3 @@ limitations under the License. %include "tensorflow/python/grappler/tf_optimizer.i" %include "tensorflow/python/grappler/cost_analyzer.i" %include "tensorflow/python/grappler/model_analyzer.i" -