Raise exception in SWIG on bad TF_Status from C API.

This change provides an alternative mechanism to
tf.raise_exception_on_not_ok_status(), which is inefficient and
error-prone (people often use the status multiple times in the with
block, but it's only checked when the context manager exits). Instead,
it uses SWIG to automatically raise an exception when a C API method
fails. Note that this removes the status argument from affected
methods.

For now, I've only applied this typemap to C API methods. It would be
good to expand this to all uses of raise_exception_on_not_ok_status.

PiperOrigin-RevId: 191121016
This commit is contained in:
Skye Wanderman-Milne 2018-03-30 14:56:08 -07:00 committed by TensorFlower Gardener
parent 15c10899c9
commit 97731cb122
19 changed files with 324 additions and 195 deletions

View File

@ -1496,7 +1496,8 @@ TF_CAPI_EXPORT extern int TF_DeviceListCount(const TF_DeviceList* list);
// If index is out of bounds, an error code will be set in the status object, // If index is out of bounds, an error code will be set in the status object,
// and a null pointer will be returned. // and a null pointer will be returned.
TF_CAPI_EXPORT extern const char* TF_DeviceListName(const TF_DeviceList* list, TF_CAPI_EXPORT extern const char* TF_DeviceListName(const TF_DeviceList* list,
int index, TF_Status*); int index,
TF_Status* status);
// Retrieves the type of the device at the given index. // Retrieves the type of the device at the given index.
// //
@ -1506,14 +1507,15 @@ TF_CAPI_EXPORT extern const char* TF_DeviceListName(const TF_DeviceList* list,
// If index is out of bounds, an error code will be set in the status object, // If index is out of bounds, an error code will be set in the status object,
// and a null pointer will be returned. // and a null pointer will be returned.
TF_CAPI_EXPORT extern const char* TF_DeviceListType(const TF_DeviceList* list, TF_CAPI_EXPORT extern const char* TF_DeviceListType(const TF_DeviceList* list,
int index, TF_Status*); int index,
TF_Status* status);
// Retrieve the amount of memory associated with a given device. // Retrieve the amount of memory associated with a given device.
// //
// If index is out of bounds, an error code will be set in the status object, // If index is out of bounds, an error code will be set in the status object,
// and -1 will be returned. // and -1 will be returned.
TF_CAPI_EXPORT extern int64_t TF_DeviceListMemoryBytes( TF_CAPI_EXPORT extern int64_t TF_DeviceListMemoryBytes(
const TF_DeviceList* list, int index, TF_Status*); const TF_DeviceList* list, int index, TF_Status* status);
// -------------------------------------------------------------------------- // --------------------------------------------------------------------------
// Load plugins containing custom ops and kernels // Load plugins containing custom ops and kernels

View File

@ -474,6 +474,8 @@ set (pywrap_tensorflow_internal_src
"${tensorflow_source_dir}/tensorflow/python/lib/core/ndarray_tensor_bridge.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/ndarray_tensor_bridge.cc"
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_func.h" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_func.h"
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_func.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_func.cc"
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_exception_registry.h"
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_exception_registry.cc"
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.h" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.h"
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.cc"
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_util.h" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_util.h"

View File

@ -283,6 +283,17 @@ cc_library(
], ],
) )
cc_library(
name = "py_exception_registry",
srcs = ["lib/core/py_exception_registry.cc"],
hdrs = ["lib/core/py_exception_registry.h"],
deps = [
"//tensorflow/c:c_api",
"//tensorflow/core:lib",
"//util/python:python_headers",
],
)
cc_library( cc_library(
name = "kernel_registry", name = "kernel_registry",
srcs = ["util/kernel_registry.cc"], srcs = ["util/kernel_registry.cc"],
@ -3313,6 +3324,7 @@ tf_py_wrap_cc(
"grappler/model_analyzer.i", "grappler/model_analyzer.i",
"grappler/tf_optimizer.i", "grappler/tf_optimizer.i",
"lib/core/bfloat16.i", "lib/core/bfloat16.i",
"lib/core/py_exception_registry.i",
"lib/core/py_func.i", "lib/core/py_func.i",
"lib/core/strings.i", "lib/core/strings.i",
"lib/io/file_io.i", "lib/io/file_io.i",
@ -3344,6 +3356,7 @@ tf_py_wrap_cc(
":kernel_registry", ":kernel_registry",
":numpy_lib", ":numpy_lib",
":safe_ptr", ":safe_ptr",
":py_exception_registry",
":py_func_lib", ":py_func_lib",
":py_record_reader_lib", ":py_record_reader_lib",
":py_record_writer_lib", ":py_record_writer_lib",

View File

@ -27,7 +27,6 @@ import numpy as np
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python import pywrap_tensorflow as tf_session from tensorflow.python import pywrap_tensorflow as tf_session
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import device from tensorflow.python.framework import device
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
@ -629,14 +628,12 @@ class BaseSession(SessionInterface):
self._session = None self._session = None
opts = tf_session.TF_NewSessionOptions(target=self._target, config=config) opts = tf_session.TF_NewSessionOptions(target=self._target, config=config)
try: try:
with errors.raise_exception_on_not_ok_status() as status: if self._created_with_new_api:
if self._created_with_new_api: # pylint: disable=protected-access
# pylint: disable=protected-access self._session = tf_session.TF_NewSession(self._graph._c_graph, opts)
self._session = tf_session.TF_NewSession(self._graph._c_graph, opts, # pylint: enable=protected-access
status) else:
# pylint: enable=protected-access self._session = tf_session.TF_NewDeprecatedSession(opts)
else:
self._session = tf_session.TF_NewDeprecatedSession(opts, status)
finally: finally:
tf_session.TF_DeleteSessionOptions(opts) tf_session.TF_DeleteSessionOptions(opts)
@ -663,22 +660,20 @@ class BaseSession(SessionInterface):
Returns: Returns:
A list of devices in the session. A list of devices in the session.
""" """
with errors.raise_exception_on_not_ok_status() as status: if self._created_with_new_api:
if self._created_with_new_api: raw_device_list = tf_session.TF_SessionListDevices(self._session)
raw_device_list = tf_session.TF_SessionListDevices( else:
self._session, status) raw_device_list = tf_session.TF_DeprecatedSessionListDevices(
else: self._session)
raw_device_list = tf_session.TF_DeprecatedSessionListDevices( device_list = []
self._session, status) size = tf_session.TF_DeviceListCount(raw_device_list)
device_list = [] for i in range(size):
size = tf_session.TF_DeviceListCount(raw_device_list) name = tf_session.TF_DeviceListName(raw_device_list, i)
for i in range(size): device_type = tf_session.TF_DeviceListType(raw_device_list, i)
name = tf_session.TF_DeviceListName(raw_device_list, i, status) memory = tf_session.TF_DeviceListMemoryBytes(raw_device_list, i)
device_type = tf_session.TF_DeviceListType(raw_device_list, i, status) device_list.append(_DeviceAttributes(name, device_type, memory))
memory = tf_session.TF_DeviceListMemoryBytes(raw_device_list, i, status) tf_session.TF_DeleteDeviceList(raw_device_list)
device_list.append(_DeviceAttributes(name, device_type, memory)) return device_list
tf_session.TF_DeleteDeviceList(raw_device_list)
return device_list
def close(self): def close(self):
"""Closes this session. """Closes this session.
@ -692,15 +687,13 @@ class BaseSession(SessionInterface):
if self._created_with_new_api: if self._created_with_new_api:
if self._session and not self._closed: if self._session and not self._closed:
self._closed = True self._closed = True
with errors.raise_exception_on_not_ok_status() as status: tf_session.TF_CloseSession(self._session)
tf_session.TF_CloseSession(self._session, status)
else: else:
with self._extend_lock: with self._extend_lock:
if self._opened and not self._closed: if self._opened and not self._closed:
self._closed = True self._closed = True
with errors.raise_exception_on_not_ok_status() as status: tf_session.TF_CloseDeprecatedSession(self._session)
tf_session.TF_CloseDeprecatedSession(self._session, status)
def __del__(self): def __del__(self):
# cleanly ignore all exceptions # cleanly ignore all exceptions
@ -710,11 +703,10 @@ class BaseSession(SessionInterface):
pass pass
if self._session is not None: if self._session is not None:
try: try:
status = c_api_util.ScopedTFStatus()
if self._created_with_new_api: if self._created_with_new_api:
tf_session.TF_DeleteSession(self._session, status) tf_session.TF_DeleteSession(self._session)
else: else:
tf_session.TF_DeleteDeprecatedSession(self._session, status) tf_session.TF_DeleteDeprecatedSession(self._session)
except AttributeError: except AttributeError:
# At shutdown, `c_api_util` or `tf_session` may have been garbage # At shutdown, `c_api_util` or `tf_session` may have been garbage
# collected, causing the above method calls to fail. In this case, # collected, causing the above method calls to fail. In this case,
@ -1031,11 +1023,11 @@ class BaseSession(SessionInterface):
# Set up a graph with feeds and fetches for partial run. # Set up a graph with feeds and fetches for partial run.
def _setup_fn(session, feed_list, fetch_list, target_list): def _setup_fn(session, feed_list, fetch_list, target_list):
self._extend_graph() self._extend_graph()
with errors.raise_exception_on_not_ok_status() as status: if self._created_with_new_api:
if self._created_with_new_api: return tf_session.TF_SessionPRunSetup_wrapper(
return tf_session.TF_SessionPRunSetup_wrapper( session, feed_list, fetch_list, target_list)
session, feed_list, fetch_list, target_list, status) else:
else: with errors.raise_exception_on_not_ok_status() as status:
return tf_session.TF_PRunSetup(session, feed_list, fetch_list, return tf_session.TF_PRunSetup(session, feed_list, fetch_list,
target_list, status) target_list, status)
@ -1345,8 +1337,7 @@ class BaseSession(SessionInterface):
def _extend_graph(self): def _extend_graph(self):
if self._created_with_new_api: if self._created_with_new_api:
with self._graph._lock: # pylint: disable=protected-access with self._graph._lock: # pylint: disable=protected-access
with errors.raise_exception_on_not_ok_status() as status: tf_session.ExtendSession(self._session)
tf_session.ExtendSession(self._session, status)
else: else:
# Ensure any changes to the graph are reflected in the runtime. # Ensure any changes to the graph are reflected in the runtime.
with self._extend_lock: with self._extend_lock:
@ -1412,22 +1403,22 @@ class BaseSession(SessionInterface):
def _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list, def _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list,
run_metadata): run_metadata):
with errors.raise_exception_on_not_ok_status() as status: if self._created_with_new_api:
if self._created_with_new_api: return tf_session.TF_SessionRun_wrapper(
return tf_session.TF_SessionRun_wrapper( self._session, options, feed_dict, fetch_list, target_list,
self._session, options, feed_dict, fetch_list, target_list, run_metadata)
run_metadata, status) else:
else: with errors.raise_exception_on_not_ok_status() as status:
return tf_session.TF_Run( return tf_session.TF_Run(
self._session, options, feed_dict, fetch_list, target_list, self._session, options, feed_dict, fetch_list, target_list,
status, run_metadata) status, run_metadata)
def _call_tf_sessionprun(self, handle, feed_dict, fetch_list): def _call_tf_sessionprun(self, handle, feed_dict, fetch_list):
with errors.raise_exception_on_not_ok_status() as status: if self._created_with_new_api:
if self._created_with_new_api: return tf_session.TF_SessionPRun_wrapper(
return tf_session.TF_SessionPRun_wrapper( self._session, handle, feed_dict, fetch_list)
self._session, handle, feed_dict, fetch_list, status) else:
else: with errors.raise_exception_on_not_ok_status() as status:
return tf_session.TF_PRun( return tf_session.TF_PRun(
self._session, handle, feed_dict, fetch_list, status) self._session, handle, feed_dict, fetch_list, status)

View File

@ -23,7 +23,6 @@ from tensorflow.core.protobuf import cluster_pb2
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python import pywrap_tensorflow as tf_session from tensorflow.python import pywrap_tensorflow as tf_session
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
@ -42,21 +41,13 @@ class SessionListDevicesTestMethods(object):
def testInvalidDeviceNumber(self): def testInvalidDeviceNumber(self):
opts = tf_session.TF_NewSessionOptions() opts = tf_session.TF_NewSessionOptions()
with errors.raise_exception_on_not_ok_status() as status: c_session = tf_session.TF_NewSession(ops.get_default_graph()._c_graph, opts)
c_session = tf_session.TF_NewSession( raw_device_list = tf_session.TF_SessionListDevices(c_session)
ops.get_default_graph()._c_graph, opts, status)
raw_device_list = tf_session.TF_SessionListDevices(
c_session, status)
size = tf_session.TF_DeviceListCount(raw_device_list) size = tf_session.TF_DeviceListCount(raw_device_list)
# Test that invalid device numbers return -1 rather than a Swig-wrapped with self.assertRaises(errors.InvalidArgumentError):
# pointer. tf_session.TF_DeviceListMemoryBytes(raw_device_list, size)
status_no_exception = c_api_util.ScopedTFStatus()
memory = tf_session.TF_DeviceListMemoryBytes(
raw_device_list, size, status_no_exception)
self.assertEqual(memory, -1)
tf_session.TF_DeleteDeviceList(raw_device_list) tf_session.TF_DeleteDeviceList(raw_device_list)
with errors.raise_exception_on_not_ok_status() as status: tf_session.TF_CloseSession(c_session)
tf_session.TF_CloseSession(c_session, status)
def testListDevicesGrpcSession(self): def testListDevicesGrpcSession(self):
server = server_lib.Server.create_local_server() server = server_lib.Server.create_local_server()

View File

@ -18,11 +18,12 @@ limitations under the License.
%{ %{
#include "tensorflow/c/python_api.h" #include "tensorflow/c/python_api.h"
#include "tensorflow/python/client/tf_session_helper.h"
#include "tensorflow/core/framework/session_state.h" #include "tensorflow/core/framework/session_state.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/public/version.h" #include "tensorflow/core/public/version.h"
#include "tensorflow/python/client/tf_session_helper.h"
#include "tensorflow/python/lib/core/py_exception_registry.h"
// Helper function to convert a Python list of Tensors to a C++ vector of // Helper function to convert a Python list of Tensors to a C++ vector of
// TF_Outputs. // TF_Outputs.
@ -352,6 +353,27 @@ TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper{
reinterpret_cast<const char*>($1.data), $1.length); reinterpret_cast<const char*>($1.data), $1.length);
} }
// Typemaps to automatically raise a Python exception from bad output TF_Status.
// TODO(b/77295559): expand this to all TF_Status* output params and deprecate
// raise_exception_on_not_ok_status (currently it only affects the C API).
%typemap(in, numinputs=0) TF_Status* status (TF_Status* status) {
status = TF_NewStatus();
$1 = status;
}
%typemap(argout) TF_Status* status {
TF_Code code = TF_GetCode($1);
if (code != TF_OK) {
PyObject* exc = tensorflow::PyExceptionRegistry::Lookup(code);
// Arguments to OpError.
PyObject* exc_args = Py_BuildValue("sss", nullptr, nullptr, TF_Message($1));
TF_DeleteStatus($1);
SWIG_SetErrorObj(exc, exc_args);
SWIG_fail;
}
TF_DeleteStatus($1);
}
// Converts input Python list of wrapped TF_Outputs into a single array // Converts input Python list of wrapped TF_Outputs into a single array
%typemap(in) (const TF_Output* inputs, int num_inputs) %typemap(in) (const TF_Output* inputs, int num_inputs)
(std::vector<TF_Output> inputs) { (std::vector<TF_Output> inputs) {
@ -499,9 +521,8 @@ TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper{
_TF_SetTarget(opts, target) _TF_SetTarget(opts, target)
if config is not None: if config is not None:
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
with errors.raise_exception_on_not_ok_status() as status: config_str = config.SerializeToString()
config_str = config.SerializeToString() _TF_SetConfig(opts, config_str)
_TF_SetConfig(opts, config_str, status)
return opts return opts
%} %}
@ -758,3 +779,7 @@ def TF_Reset(target, containers=None, config=None):
%include "tensorflow/python/client/tf_session_helper.h" %include "tensorflow/python/client/tf_session_helper.h"
%unignoreall %unignoreall
// Clear "TF_Status* status" typemap so it doesn't affect other modules and
// unexpectedly remove the TF_Status* argument from wrappers.
%clear TF_Status* status;

View File

@ -136,8 +136,7 @@ string EqualAttrValueWrapper(const string& actual, const string& expected);
// //
// If shape is unknown, sets unknown_shape to true. // If shape is unknown, sets unknown_shape to true.
tensorflow::gtl::InlinedVector<int64_t, 6> TF_GraphGetTensorShapeHelper( tensorflow::gtl::InlinedVector<int64_t, 6> TF_GraphGetTensorShapeHelper(
TF_Graph* graph, TF_Output output, TF_Status* out_status, TF_Graph* graph, TF_Output output, TF_Status* status, bool* unknown_shape);
bool* unknown_shape);
// Runs the graph associated with the session starting with the supplied inputs. // Runs the graph associated with the session starting with the supplied inputs.
// On success, `py_outputs` is populated with a numpy ndarray for each output // On success, `py_outputs` is populated with a numpy ndarray for each output
@ -149,7 +148,7 @@ void TF_SessionRun_wrapper(TF_Session* session, const TF_Buffer* run_options,
const std::vector<PyObject*>& input_ndarrays, const std::vector<PyObject*>& input_ndarrays,
const std::vector<TF_Output>& outputs, const std::vector<TF_Output>& outputs,
const std::vector<TF_Operation*>& targets, const std::vector<TF_Operation*>& targets,
TF_Buffer* run_metadata, TF_Status* out_status, TF_Buffer* run_metadata, TF_Status* status,
std::vector<PyObject*>* py_outputs); std::vector<PyObject*>* py_outputs);
// Set up the graph with the intended feeds (inputs) and fetches (output) for // Set up the graph with the intended feeds (inputs) and fetches (output) for
@ -165,8 +164,7 @@ void TF_SessionPRunSetup_wrapper(TF_Session* session,
const std::vector<TF_Output>& inputs, const std::vector<TF_Output>& inputs,
const std::vector<TF_Output>& outputs, const std::vector<TF_Output>& outputs,
const std::vector<TF_Operation*>& targets, const std::vector<TF_Operation*>& targets,
const char** out_handle, const char** out_handle, TF_Status* status);
TF_Status* out_status);
// Continue to run the graph with additional feeds and fetches. The // Continue to run the graph with additional feeds and fetches. The
// execution state is uniquely identified by the handle. // execution state is uniquely identified by the handle.
@ -182,7 +180,7 @@ void TF_SessionPRun_wrapper(TF_Session* session, const char* handle,
const std::vector<TF_Output>& inputs, const std::vector<TF_Output>& inputs,
const std::vector<PyObject*>& input_ndarrays, const std::vector<PyObject*>& input_ndarrays,
const std::vector<TF_Output>& outputs, const std::vector<TF_Output>& outputs,
TF_Status* out_status, TF_Status* status,
std::vector<PyObject*>* py_outputs); std::vector<PyObject*>* py_outputs);
// Retrieves the inputs of this operation. // Retrieves the inputs of this operation.
@ -204,7 +202,7 @@ TF_Function* TF_GraphToFunction_wrapper(
const std::vector<TF_Operation*>* opers, const std::vector<TF_Operation*>* opers,
const std::vector<TF_Output>& inputs, const std::vector<TF_Output>& outputs, const std::vector<TF_Output>& inputs, const std::vector<TF_Output>& outputs,
const NameVector& output_names, const TF_FunctionOptions* opts, const NameVector& output_names, const TF_FunctionOptions* opts,
const char* description, TF_Status* out_status); const char* description, TF_Status* status);
// Set the shapes and types for the output's handle. // Set the shapes and types for the output's handle.
// //

View File

@ -244,13 +244,9 @@ class Context(object):
try: try:
self._num_gpus = 0 self._num_gpus = 0
for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)): for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)):
with errors.raise_exception_on_not_ok_status() as status: dev_name = pywrap_tensorflow.TF_DeviceListName(device_list, i)
dev_name = pywrap_tensorflow.TF_DeviceListName(
device_list, i, status)
self._context_devices.append(pydev.canonical_name(dev_name)) self._context_devices.append(pydev.canonical_name(dev_name))
with errors.raise_exception_on_not_ok_status() as status: dev_type = pywrap_tensorflow.TF_DeviceListType(device_list, i)
dev_type = pywrap_tensorflow.TF_DeviceListType(
device_list, i, status)
if dev_type == "GPU": if dev_type == "GPU":
self._num_gpus += 1 self._num_gpus += 1

View File

@ -34,7 +34,6 @@ from tensorflow.python.eager.graph_only_ops import graph_placeholder
from tensorflow.python.framework import c_api_util from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes as dtypes_module from tensorflow.python.framework import dtypes as dtypes_module
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
@ -79,14 +78,10 @@ def capture_value(tensor_map, value, dtype, name):
ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes] ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
shapes = [[d.size for d in s.dim] shapes = [[d.size for d in s.dim]
if not s.unknown_rank else None for s in shapes] if not s.unknown_rank else None for s in shapes]
with errors.raise_exception_on_not_ok_status() as status: pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper( captured_value._op._graph._c_graph, # pylint: disable=protected-access
captured_value._op._graph._c_graph, # pylint: disable=protected-access captured_value._as_tf_output(), # pylint: disable=protected-access
captured_value._as_tf_output(), # pylint: disable=protected-access shapes, ranks, types)
shapes,
ranks,
types,
status)
tensor_map[ops.tensor_id(value)] = (value, captured_value) tensor_map[ops.tensor_id(value)] = (value, captured_value)
else: else:
@ -275,23 +270,20 @@ class _EagerDefinedFunction(object):
inputs: the tensors in the graph to be used as inputs to the function inputs: the tensors in the graph to be used as inputs to the function
outputs: the tensors in the graph which will be outputs to the function outputs: the tensors in the graph which will be outputs to the function
""" """
with errors.raise_exception_on_not_ok_status() as status: fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
fn = pywrap_tensorflow.TF_GraphToFunction_wrapper( graph._c_graph, # pylint: disable=protected-access
graph._c_graph, # pylint: disable=protected-access compat.as_str(name),
compat.as_str(name), False,
False, [o._c_op for o in operations], # pylint: disable=protected-access
[o._c_op for o in operations], # pylint: disable=protected-access [t._as_tf_output() for t in inputs], # pylint: disable=protected-access
[t._as_tf_output() for t in inputs], # pylint: disable=protected-access [t._as_tf_output() for t in outputs], # pylint: disable=protected-access
[t._as_tf_output() for t in outputs], # pylint: disable=protected-access [],
[], None,
None, compat.as_str(""))
compat.as_str(""),
status)
# TODO(apassos) avoid creating a FunctionDef (specially to grab the # TODO(apassos) avoid creating a FunctionDef (specially to grab the
# signature, but also in general it's nice not to depend on it. # signature, but also in general it's nice not to depend on it.
with c_api_util.tf_buffer() as buffer_: with c_api_util.tf_buffer() as buffer_:
with errors.raise_exception_on_not_ok_status() as status: pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_)
pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_, status)
proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
function_def = function_pb2.FunctionDef() function_def = function_pb2.FunctionDef()
function_def.ParseFromString(compat.as_bytes(proto_data)) function_def.ParseFromString(compat.as_bytes(proto_data))

View File

@ -473,6 +473,8 @@ _CODE_TO_EXCEPTION_CLASS = {
DATA_LOSS: DataLossError, DATA_LOSS: DataLossError,
} }
c_api.PyExceptionRegistry_Init(_CODE_TO_EXCEPTION_CLASS)
_EXCEPTION_CLASS_TO_CODE = dict(( _EXCEPTION_CLASS_TO_CODE = dict((
(class_, code) for (code, class_) in _CODE_TO_EXCEPTION_CLASS.items())) (class_, code) for (code, class_) in _CODE_TO_EXCEPTION_CLASS.items()))
@ -499,6 +501,7 @@ def _make_specific_exception(node_def, op, message, error_code):
# Named like a function for backwards compatibility with the # Named like a function for backwards compatibility with the
# @tf_contextlib.contextmanager version, which was switched to a class to avoid # @tf_contextlib.contextmanager version, which was switched to a class to avoid
# some object creation overhead. # some object creation overhead.
# TODO(b/77295559): expand use of TF_Status* SWIG typemap and deprecate this.
@tf_export("errors.raise_exception_on_not_ok_status") # pylint: disable=invalid-name @tf_export("errors.raise_exception_on_not_ok_status") # pylint: disable=invalid-name
class raise_exception_on_not_ok_status(object): class raise_exception_on_not_ok_status(object):
"""Context manager to check for C API status.""" """Context manager to check for C API status."""

View File

@ -30,7 +30,6 @@ from tensorflow.python import pywrap_tensorflow as c_api
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import c_api_util from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import graph_to_function_def from tensorflow.python.framework import graph_to_function_def
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
@ -275,8 +274,7 @@ class _DefinedFunction(object):
self._create_definition_if_needed() self._create_definition_if_needed()
if self._c_func: if self._c_func:
with c_api_util.tf_buffer() as buf: with c_api_util.tf_buffer() as buf:
with errors.raise_exception_on_not_ok_status() as status: c_api.TF_FunctionToFunctionDef(self._c_func, buf)
c_api.TF_FunctionToFunctionDef(self._c_func, buf, status)
fdef = function_pb2.FunctionDef() fdef = function_pb2.FunctionDef()
proto_data = c_api.TF_GetBuffer(buf) proto_data = c_api.TF_GetBuffer(buf)
fdef.ParseFromString(compat.as_bytes(proto_data)) fdef.ParseFromString(compat.as_bytes(proto_data))
@ -399,18 +397,16 @@ class _DefinedFunction(object):
if self._out_names else []) if self._out_names else [])
description = self._func.__doc__ or None description = self._func.__doc__ or None
# pylint: disable=protected-access # pylint: disable=protected-access
with errors.raise_exception_on_not_ok_status() as status: self._c_func = c_api.TF_GraphToFunction_wrapper(
self._c_func = c_api.TF_GraphToFunction_wrapper( temp_graph._c_graph,
temp_graph._c_graph, base_func_name,
base_func_name, self._func_name is None, # append_hash_to_fn_name
self._func_name is None, # append_hash_to_fn_name None, # opers
None, # opers [t._as_tf_output() for t in inputs],
[t._as_tf_output() for t in inputs], [t._as_tf_output() for t in outputs],
[t._as_tf_output() for t in outputs], output_names,
output_names, None, # opts
None, # opts description)
description,
status)
# pylint: enable=protected-access # pylint: enable=protected-access
self._set_c_attrs(kwargs_attr) self._set_c_attrs(kwargs_attr)
@ -433,9 +429,8 @@ class _DefinedFunction(object):
serialized = attr_value.SerializeToString() serialized = attr_value.SerializeToString()
# TODO(skyewm): this creates and deletes a new TF_Status for every attr. # TODO(skyewm): this creates and deletes a new TF_Status for every attr.
# It might be worth creating a convenient way to re-use the same status. # It might be worth creating a convenient way to re-use the same status.
with errors.raise_exception_on_not_ok_status() as status: c_api.TF_FunctionSetAttrValueProto(self._c_func, compat.as_str(name),
c_api.TF_FunctionSetAttrValueProto(self._c_func, compat.as_str(name), serialized)
serialized, status)
def _create_hash_str(self, input_arg, output_arg, node_def): def _create_hash_str(self, input_arg, output_arg, node_def):
"""Creates an 8-character string unique to this input. """Creates an 8-character string unique to this input.
@ -830,8 +825,7 @@ def _from_definition(fdef, grad_func=None):
# pylint: disable=protected-access # pylint: disable=protected-access
if ops._USE_C_API: if ops._USE_C_API:
serialized = fdef.SerializeToString() serialized = fdef.SerializeToString()
with errors.raise_exception_on_not_ok_status() as status: result._c_func = c_api.TF_FunctionImportFunctionDef(serialized)
result._c_func = c_api.TF_FunctionImportFunctionDef(serialized, status)
result._extra_inputs = [] result._extra_inputs = []
else: else:
result._definition = fdef result._definition = fdef

View File

@ -485,9 +485,8 @@ def import_graph_def(graph_def,
with graph._lock: # pylint: disable=protected-access with graph._lock: # pylint: disable=protected-access
with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized: with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
try: try:
with errors.raise_exception_on_not_ok_status() as status: results = c_api.TF_GraphImportGraphDefWithResults(
results = c_api.TF_GraphImportGraphDefWithResults( graph._c_graph, serialized, options) # pylint: disable=protected-access
graph._c_graph, serialized, options, status) # pylint: disable=protected-access
except errors.InvalidArgumentError as e: except errors.InvalidArgumentError as e:
# Convert to ValueError for backwards compatibility. # Convert to ValueError for backwards compatibility.
raise ValueError(str(e)) raise ValueError(str(e))

View File

@ -26,7 +26,6 @@ import threading # pylint: disable=unused-import
from tensorflow.core.framework import op_def_pb2 from tensorflow.core.framework import op_def_pb2
from tensorflow.core.lib.core import error_codes_pb2 # pylint: disable=unused-import from tensorflow.core.lib.core import error_codes_pb2 # pylint: disable=unused-import
from tensorflow.python import pywrap_tensorflow as py_tf from tensorflow.python import pywrap_tensorflow as py_tf
from tensorflow.python.framework import errors_impl
from tensorflow.python.util import compat from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -54,8 +53,7 @@ def load_op_library(library_filename):
Raises: Raises:
RuntimeError: when unable to load the library or get the python wrappers. RuntimeError: when unable to load the library or get the python wrappers.
""" """
with errors_impl.raise_exception_on_not_ok_status() as status: lib_handle = py_tf.TF_LoadLibrary(library_filename)
lib_handle = py_tf.TF_LoadLibrary(library_filename, status)
op_list_str = py_tf.TF_GetOpList(lib_handle) op_list_str = py_tf.TF_GetOpList(lib_handle)
op_list = op_def_pb2.OpList() op_list = op_def_pb2.OpList()
@ -99,5 +97,4 @@ def load_file_system_library(library_filename):
Raises: Raises:
RuntimeError: when unable to load the library. RuntimeError: when unable to load the library.
""" """
with errors_impl.raise_exception_on_not_ok_status() as status: py_tf.TF_LoadLibrary(library_filename)
lib_handle = py_tf.TF_LoadLibrary(library_filename, status)

View File

@ -373,15 +373,12 @@ class Tensor(_TensorLike):
""" """
graph = self._op._graph._c_graph # pylint: disable=protected-access graph = self._op._graph._c_graph # pylint: disable=protected-access
if graph and _USE_C_SHAPES: if graph and _USE_C_SHAPES:
with errors.raise_exception_on_not_ok_status() as status: num_dims = c_api.TF_GraphGetTensorNumDims(graph, self._as_tf_output())
num_dims = c_api.TF_GraphGetTensorNumDims(graph, self._as_tf_output(),
status)
if num_dims == -1: if num_dims == -1:
dim_list = None dim_list = None
else: else:
with errors.raise_exception_on_not_ok_status() as status: dim_list = c_api.TF_GraphGetTensorShape_wrapper(
dim_list = c_api.TF_GraphGetTensorShape_wrapper( graph, self._as_tf_output(), num_dims)
graph, self._as_tf_output(), num_dims, status)
dim_list = [None if i == -1 else i for i in dim_list] dim_list = [None if i == -1 else i for i in dim_list]
return tensor_shape.TensorShape(dim_list) return tensor_shape.TensorShape(dim_list)
return self._shape_val return self._shape_val
@ -489,13 +486,11 @@ class Tensor(_TensorLike):
else: else:
dim_list.append(dim.value) dim_list.append(dim.value)
try: try:
with errors.raise_exception_on_not_ok_status() as status: c_api.TF_GraphSetTensorShape_wrapper(
c_api.TF_GraphSetTensorShape_wrapper( self._op._graph._c_graph, # pylint: disable=protected-access
self._op._graph._c_graph, # pylint: disable=protected-access self._as_tf_output(),
self._as_tf_output(), dim_list,
dim_list, unknown_shape)
unknown_shape,
status)
except errors.InvalidArgumentError as e: except errors.InvalidArgumentError as e:
# Convert to ValueError for backwards compatibility. # Convert to ValueError for backwards compatibility.
raise ValueError(str(e)) raise ValueError(str(e))
@ -1514,13 +1509,10 @@ def _create_c_op(graph, node_def, inputs, control_inputs):
serialized = attr_value.SerializeToString() serialized = attr_value.SerializeToString()
# TODO(skyewm): this creates and deletes a new TF_Status for every attr. # TODO(skyewm): this creates and deletes a new TF_Status for every attr.
# It might be worth creating a convenient way to re-use the same status. # It might be worth creating a convenient way to re-use the same status.
with errors.raise_exception_on_not_ok_status() as status: c_api.TF_SetAttrValueProto(op_desc, compat.as_str(name), serialized)
c_api.TF_SetAttrValueProto(op_desc,
compat.as_str(name), serialized, status)
try: try:
with errors.raise_exception_on_not_ok_status() as status: c_op = c_api.TF_FinishOperation(op_desc)
c_op = c_api.TF_FinishOperation(op_desc, status)
except errors.InvalidArgumentError as e: except errors.InvalidArgumentError as e:
# Convert to ValueError for backwards compatibility. # Convert to ValueError for backwards compatibility.
raise ValueError(str(e)) raise ValueError(str(e))
@ -1943,12 +1935,10 @@ class Operation(object):
if self._c_op: if self._c_op:
# Reset cached inputs. # Reset cached inputs.
self._inputs_val = None self._inputs_val = None
with errors.raise_exception_on_not_ok_status() as status: c_api.UpdateEdge(
c_api.UpdateEdge( self._graph._c_graph, # pylint: disable=protected-access
self._graph._c_graph, # pylint: disable=protected-access tensor._as_tf_output(), # pylint: disable=protected-access
tensor._as_tf_output(), # pylint: disable=protected-access self._tf_input(index))
self._tf_input(index),
status)
else: else:
self._inputs_val[index].consumers().remove(self) self._inputs_val[index].consumers().remove(self)
self._inputs_val[index] = tensor self._inputs_val[index] = tensor
@ -2169,8 +2159,7 @@ class Operation(object):
# pylint: enable=line-too-long # pylint: enable=line-too-long
if self._c_op: if self._c_op:
with c_api_util.tf_buffer() as buf: with c_api_util.tf_buffer() as buf:
with errors.raise_exception_on_not_ok_status() as status: c_api.TF_OperationToNodeDef(self._c_op, buf)
c_api.TF_OperationToNodeDef(self._c_op, buf, status)
data = c_api.TF_GetBuffer(buf) data = c_api.TF_GetBuffer(buf)
node_def = node_def_pb2.NodeDef() node_def = node_def_pb2.NodeDef()
node_def.ParseFromString(compat.as_bytes(data)) node_def.ParseFromString(compat.as_bytes(data))
@ -2228,11 +2217,9 @@ class Operation(object):
buf = c_api.TF_NewBufferFromString( buf = c_api.TF_NewBufferFromString(
compat.as_bytes(attr_value.SerializeToString())) compat.as_bytes(attr_value.SerializeToString()))
try: try:
with errors.raise_exception_on_not_ok_status() as status: # pylint: disable=protected-access
# pylint: disable=protected-access c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf)
c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf, # pylint: enable=protected-access
status)
# pylint: enable=protected-access
finally: finally:
c_api.TF_DeleteBuffer(buf) c_api.TF_DeleteBuffer(buf)
else: else:
@ -2254,8 +2241,7 @@ class Operation(object):
if self._c_op: if self._c_op:
try: try:
with c_api_util.tf_buffer() as buf: with c_api_util.tf_buffer() as buf:
with errors.raise_exception_on_not_ok_status() as status: c_api.TF_OperationGetAttrValueProto(self._c_op, name, buf)
c_api.TF_OperationGetAttrValueProto(self._c_op, name, buf, status)
data = c_api.TF_GetBuffer(buf) data = c_api.TF_GetBuffer(buf)
except errors.InvalidArgumentError as e: except errors.InvalidArgumentError as e:
# Convert to ValueError for backwards compatibility. # Convert to ValueError for backwards compatibility.
@ -2469,11 +2455,10 @@ def _set_shapes_for_outputs_c_api(op):
# The C API computes the shapes when the TF_Operation is created. Fetch the # The C API computes the shapes when the TF_Operation is created. Fetch the
# output shapes from the C object. # output shapes from the C object.
for output in op.outputs: for output in op.outputs:
with errors.raise_exception_on_not_ok_status() as status: # pylint: disable=protected-access
# pylint: disable=protected-access shape_vector, unknown_shape = c_api.TF_GraphGetTensorShapeHelper(
shape_vector, unknown_shape = c_api.TF_GraphGetTensorShapeHelper( op._graph._c_graph, output._as_tf_output())
op._graph._c_graph, output._as_tf_output(), status) # pylint: enable=protected-access
# pylint: enable=protected-access
if unknown_shape: if unknown_shape:
output.set_shape(tensor_shape.unknown_shape()) output.set_shape(tensor_shape.unknown_shape())
elif not shape_vector: elif not shape_vector:
@ -2994,8 +2979,7 @@ class Graph(object):
# pylint: enable=line-too-long # pylint: enable=line-too-long
if self._c_graph: if self._c_graph:
with c_api_util.tf_buffer() as buf: with c_api_util.tf_buffer() as buf:
with errors.raise_exception_on_not_ok_status() as status: c_api.TF_GraphVersions(self._c_graph, buf)
c_api.TF_GraphVersions(self._c_graph, buf, status)
data = c_api.TF_GetBuffer(buf) data = c_api.TF_GetBuffer(buf)
version_def = versions_pb2.VersionDef() version_def = versions_pb2.VersionDef()
version_def.ParseFromString(compat.as_bytes(data)) version_def.ParseFromString(compat.as_bytes(data))
@ -3098,8 +3082,7 @@ class Graph(object):
if self._c_graph: if self._c_graph:
with self._lock: with self._lock:
with c_api_util.tf_buffer() as buf: with c_api_util.tf_buffer() as buf:
with errors.raise_exception_on_not_ok_status() as status: c_api.TF_GraphToGraphDef(self._c_graph, buf)
c_api.TF_GraphToGraphDef(self._c_graph, buf, status)
data = c_api.TF_GetBuffer(buf) data = c_api.TF_GetBuffer(buf)
graph = graph_pb2.GraphDef() graph = graph_pb2.GraphDef()
graph.ParseFromString(compat.as_bytes(data)) graph.ParseFromString(compat.as_bytes(data))
@ -3208,14 +3191,10 @@ class Graph(object):
# remove this when all functions are generated using the C API by default # remove this when all functions are generated using the C API by default
# as this will be unnecessary. # as this will be unnecessary.
if not function._c_func: if not function._c_func:
with errors.raise_exception_on_not_ok_status() as status: serialized = function.definition.SerializeToString()
serialized = function.definition.SerializeToString() function._c_func = c_api.TF_FunctionImportFunctionDef(serialized)
function._c_func = c_api.TF_FunctionImportFunctionDef( gradient = function._grad_func._c_func if function._grad_func else None
serialized, status) c_api.TF_GraphCopyFunction(self._c_graph, function._c_func, gradient)
with errors.raise_exception_on_not_ok_status() as status:
gradient = function._grad_func._c_func if function._grad_func else None
c_api.TF_GraphCopyFunction(self._c_graph, function._c_func, gradient,
status)
else: else:
# If there is already a function with the same name, raise an error # If there is already a function with the same name, raise an error
# if bodies are different. Else, do nothing. The C API version above # if bodies are different. Else, do nothing. The C API version above
@ -3732,11 +3711,9 @@ class Graph(object):
"""Returns the `OpDef` proto for `type`. `type` is a string.""" """Returns the `OpDef` proto for `type`. `type` is a string."""
if self._c_graph: if self._c_graph:
with c_api_util.tf_buffer() as buf: with c_api_util.tf_buffer() as buf:
with errors.raise_exception_on_not_ok_status() as status: # pylint: disable=protected-access
# pylint: disable=protected-access c_api.TF_GraphGetOpDef(self._c_graph, compat.as_bytes(type), buf)
c_api.TF_GraphGetOpDef(self._c_graph, # pylint: enable=protected-access
compat.as_bytes(type), buf, status)
# pylint: enable=protected-access
data = c_api.TF_GetBuffer(buf) data = c_api.TF_GetBuffer(buf)
op_def = op_def_pb2.OpDef() op_def = op_def_pb2.OpDef()
op_def.ParseFromString(compat.as_bytes(data)) op_def.ParseFromString(compat.as_bytes(data))

View File

@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python import pywrap_tensorflow as c_api from tensorflow.python import pywrap_tensorflow as c_api
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
@ -83,9 +82,8 @@ def smart_constant_value(pred):
# wanted to limit the change hidden behind _USE_C_API). # wanted to limit the change hidden behind _USE_C_API).
# pylint: disable=protected-access # pylint: disable=protected-access
if pred_value is None and ops._USE_C_API: if pred_value is None and ops._USE_C_API:
with errors.raise_exception_on_not_ok_status() as status: pred_value = c_api.TF_TryEvaluateConstant_wrapper(pred.graph._c_graph,
pred_value = c_api.TF_TryEvaluateConstant_wrapper( pred._as_tf_output())
pred.graph._c_graph, pred._as_tf_output(), status)
# pylint: enable=protected-access # pylint: enable=protected-access
else: else:

View File

@ -0,0 +1,50 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/python/lib/core/py_exception_registry.h"
#include <Python.h>
namespace tensorflow {
PyExceptionRegistry* PyExceptionRegistry::singleton_ = nullptr;
void PyExceptionRegistry::Init(PyObject* code_to_exc_type_map) {
DCHECK(singleton_ == nullptr) << "PyExceptionRegistry::Init() already called";
singleton_ = new PyExceptionRegistry;
DCHECK(PyDict_Check(code_to_exc_type_map));
PyObject* key;
PyObject* value;
Py_ssize_t pos = 0;
while (PyDict_Next(code_to_exc_type_map, &pos, &key, &value)) {
TF_Code code = static_cast<TF_Code>(PyLong_AsLong(key));
singleton_->exc_types_[code] = value;
// The exception classes should also have the lifetime of the process, but
// incref just in case.
Py_INCREF(value);
}
}
PyObject* PyExceptionRegistry::Lookup(TF_Code code) {
DCHECK(singleton_ != nullptr) << "Must call PyExceptionRegistry::Init() "
"before PyExceptionRegistry::Lookup()";
DCHECK_NE(code, TF_OK);
DCHECK(singleton_->exc_types_.find(code) != singleton_->exc_types_.end())
<< "Unknown error code passed to PyExceptionRegistry::Lookup: " << code;
return singleton_->exc_types_[code];
}
} // namespace tensorflow

View File

@ -0,0 +1,73 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_PYTHON_LIB_CORE_PY_EXCEPTION_REGISTRY_H_
#define TENSORFLOW_PYTHON_LIB_CORE_PY_EXCEPTION_REGISTRY_H_
#include <map>
#include "tensorflow/c/c_api.h"
#include "tensorflow/core/platform/logging.h"
#ifndef PyObject_HEAD
struct _object;
typedef _object PyObject;
#endif
namespace tensorflow {
// Global registry mapping C API error codes to the corresponding custom Python
// exception type. This is used to expose the exception types to C extension
// code (i.e. so we can raise custom exceptions via SWIG).
//
// Init() must be called exactly once at the beginning of the process before
// Lookup() can be used.
//
// Example usage:
// TF_Status* status = TF_NewStatus();
// TF_Foo(..., status);
//
// if (TF_GetCode(status) != TF_OK) {
// PyObject* exc_type = PyExceptionRegistry::Lookup(TF_GetCode(status));
// // Arguments to OpError base class. Set `node_def` and `op` to None.
// PyObject* args =
// Py_BuildValue("sss", nullptr, nullptr, TF_Message(status));
// PyErr_SetObject(exc_type, args);
// Py_DECREF(args);
// TF_DeleteStatus(status);
// return NULL;
// }
class PyExceptionRegistry {
public:
// Initializes the process-wide registry. Should be called exactly once near
// the beginning of the process. The arguments are the various Python
// exception types (e.g. `cancelled_exc` corresponds to
// errors.CancelledError).
static void Init(PyObject* code_to_exc_type_map);
// Returns the Python exception type corresponding to `code`. Init() must be
// called before using this function. `code` should not be TF_OK.
static PyObject* Lookup(TF_Code code);
private:
static PyExceptionRegistry* singleton_;
PyExceptionRegistry() = default;
// Maps error codes to the corresponding Python exception type.
std::map<TF_Code, PyObject*> exc_types_;
};
} // namespace tensorflow
#endif // TENSORFLOW_PYTHON_LIB_CORE_PY_EXCEPTION_REGISTRY_H_

View File

@ -0,0 +1,28 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
%include "tensorflow/python/platform/base.i"
%{
#include "tensorflow/python/lib/core/py_exception_registry.h"
%}
%ignoreall
%unignore tensorflow::PyExceptionRegistry;
%unignore tensorflow::PyExceptionRegistry::Init;
%include "tensorflow/python/lib/core/py_exception_registry.h"
%unignoreall

View File

@ -25,6 +25,7 @@ limitations under the License.
%include "tensorflow/python/util/tfprof.i" %include "tensorflow/python/util/tfprof.i"
%include "tensorflow/python/lib/core/py_func.i" %include "tensorflow/python/lib/core/py_func.i"
%include "tensorflow/python/lib/core/py_exception_registry.i"
%include "tensorflow/python/lib/io/py_record_reader.i" %include "tensorflow/python/lib/io/py_record_reader.i"
%include "tensorflow/python/lib/io/py_record_writer.i" %include "tensorflow/python/lib/io/py_record_writer.i"
@ -54,4 +55,3 @@ limitations under the License.
%include "tensorflow/python/grappler/tf_optimizer.i" %include "tensorflow/python/grappler/tf_optimizer.i"
%include "tensorflow/python/grappler/cost_analyzer.i" %include "tensorflow/python/grappler/cost_analyzer.i"
%include "tensorflow/python/grappler/model_analyzer.i" %include "tensorflow/python/grappler/model_analyzer.i"