Raise exception in SWIG on bad TF_Status from C API.
This change provides an alternative mechanism to tf.raise_exception_on_not_ok_status(), which is inefficient and error-prone (people often use the status multiple times in the with block, but it's only checked when the context manager exits). Instead, it uses SWIG to automatically raise an exception when a C API method fails. Note that this removes the status argument from affected methods. For now, I've only applied this typemap to C API methods. It would be good to expand this to all uses of raise_exception_on_not_ok_status. PiperOrigin-RevId: 191121016
This commit is contained in:
parent
15c10899c9
commit
97731cb122
@ -1496,7 +1496,8 @@ TF_CAPI_EXPORT extern int TF_DeviceListCount(const TF_DeviceList* list);
|
||||
// If index is out of bounds, an error code will be set in the status object,
|
||||
// and a null pointer will be returned.
|
||||
TF_CAPI_EXPORT extern const char* TF_DeviceListName(const TF_DeviceList* list,
|
||||
int index, TF_Status*);
|
||||
int index,
|
||||
TF_Status* status);
|
||||
|
||||
// Retrieves the type of the device at the given index.
|
||||
//
|
||||
@ -1506,14 +1507,15 @@ TF_CAPI_EXPORT extern const char* TF_DeviceListName(const TF_DeviceList* list,
|
||||
// If index is out of bounds, an error code will be set in the status object,
|
||||
// and a null pointer will be returned.
|
||||
TF_CAPI_EXPORT extern const char* TF_DeviceListType(const TF_DeviceList* list,
|
||||
int index, TF_Status*);
|
||||
int index,
|
||||
TF_Status* status);
|
||||
|
||||
// Retrieve the amount of memory associated with a given device.
|
||||
//
|
||||
// If index is out of bounds, an error code will be set in the status object,
|
||||
// and -1 will be returned.
|
||||
TF_CAPI_EXPORT extern int64_t TF_DeviceListMemoryBytes(
|
||||
const TF_DeviceList* list, int index, TF_Status*);
|
||||
const TF_DeviceList* list, int index, TF_Status* status);
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Load plugins containing custom ops and kernels
|
||||
|
@ -474,6 +474,8 @@ set (pywrap_tensorflow_internal_src
|
||||
"${tensorflow_source_dir}/tensorflow/python/lib/core/ndarray_tensor_bridge.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_func.h"
|
||||
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_func.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_exception_registry.h"
|
||||
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_exception_registry.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.h"
|
||||
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_util.h"
|
||||
|
@ -283,6 +283,17 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "py_exception_registry",
|
||||
srcs = ["lib/core/py_exception_registry.cc"],
|
||||
hdrs = ["lib/core/py_exception_registry.h"],
|
||||
deps = [
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/core:lib",
|
||||
"//util/python:python_headers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "kernel_registry",
|
||||
srcs = ["util/kernel_registry.cc"],
|
||||
@ -3313,6 +3324,7 @@ tf_py_wrap_cc(
|
||||
"grappler/model_analyzer.i",
|
||||
"grappler/tf_optimizer.i",
|
||||
"lib/core/bfloat16.i",
|
||||
"lib/core/py_exception_registry.i",
|
||||
"lib/core/py_func.i",
|
||||
"lib/core/strings.i",
|
||||
"lib/io/file_io.i",
|
||||
@ -3344,6 +3356,7 @@ tf_py_wrap_cc(
|
||||
":kernel_registry",
|
||||
":numpy_lib",
|
||||
":safe_ptr",
|
||||
":py_exception_registry",
|
||||
":py_func_lib",
|
||||
":py_record_reader_lib",
|
||||
":py_record_writer_lib",
|
||||
|
@ -27,7 +27,6 @@ import numpy as np
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python import pywrap_tensorflow as tf_session
|
||||
from tensorflow.python.framework import c_api_util
|
||||
from tensorflow.python.framework import device
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
@ -629,14 +628,12 @@ class BaseSession(SessionInterface):
|
||||
self._session = None
|
||||
opts = tf_session.TF_NewSessionOptions(target=self._target, config=config)
|
||||
try:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
if self._created_with_new_api:
|
||||
# pylint: disable=protected-access
|
||||
self._session = tf_session.TF_NewSession(self._graph._c_graph, opts,
|
||||
status)
|
||||
# pylint: enable=protected-access
|
||||
else:
|
||||
self._session = tf_session.TF_NewDeprecatedSession(opts, status)
|
||||
if self._created_with_new_api:
|
||||
# pylint: disable=protected-access
|
||||
self._session = tf_session.TF_NewSession(self._graph._c_graph, opts)
|
||||
# pylint: enable=protected-access
|
||||
else:
|
||||
self._session = tf_session.TF_NewDeprecatedSession(opts)
|
||||
finally:
|
||||
tf_session.TF_DeleteSessionOptions(opts)
|
||||
|
||||
@ -663,22 +660,20 @@ class BaseSession(SessionInterface):
|
||||
Returns:
|
||||
A list of devices in the session.
|
||||
"""
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
if self._created_with_new_api:
|
||||
raw_device_list = tf_session.TF_SessionListDevices(
|
||||
self._session, status)
|
||||
else:
|
||||
raw_device_list = tf_session.TF_DeprecatedSessionListDevices(
|
||||
self._session, status)
|
||||
device_list = []
|
||||
size = tf_session.TF_DeviceListCount(raw_device_list)
|
||||
for i in range(size):
|
||||
name = tf_session.TF_DeviceListName(raw_device_list, i, status)
|
||||
device_type = tf_session.TF_DeviceListType(raw_device_list, i, status)
|
||||
memory = tf_session.TF_DeviceListMemoryBytes(raw_device_list, i, status)
|
||||
device_list.append(_DeviceAttributes(name, device_type, memory))
|
||||
tf_session.TF_DeleteDeviceList(raw_device_list)
|
||||
return device_list
|
||||
if self._created_with_new_api:
|
||||
raw_device_list = tf_session.TF_SessionListDevices(self._session)
|
||||
else:
|
||||
raw_device_list = tf_session.TF_DeprecatedSessionListDevices(
|
||||
self._session)
|
||||
device_list = []
|
||||
size = tf_session.TF_DeviceListCount(raw_device_list)
|
||||
for i in range(size):
|
||||
name = tf_session.TF_DeviceListName(raw_device_list, i)
|
||||
device_type = tf_session.TF_DeviceListType(raw_device_list, i)
|
||||
memory = tf_session.TF_DeviceListMemoryBytes(raw_device_list, i)
|
||||
device_list.append(_DeviceAttributes(name, device_type, memory))
|
||||
tf_session.TF_DeleteDeviceList(raw_device_list)
|
||||
return device_list
|
||||
|
||||
def close(self):
|
||||
"""Closes this session.
|
||||
@ -692,15 +687,13 @@ class BaseSession(SessionInterface):
|
||||
if self._created_with_new_api:
|
||||
if self._session and not self._closed:
|
||||
self._closed = True
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
tf_session.TF_CloseSession(self._session, status)
|
||||
tf_session.TF_CloseSession(self._session)
|
||||
|
||||
else:
|
||||
with self._extend_lock:
|
||||
if self._opened and not self._closed:
|
||||
self._closed = True
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
tf_session.TF_CloseDeprecatedSession(self._session, status)
|
||||
tf_session.TF_CloseDeprecatedSession(self._session)
|
||||
|
||||
def __del__(self):
|
||||
# cleanly ignore all exceptions
|
||||
@ -710,11 +703,10 @@ class BaseSession(SessionInterface):
|
||||
pass
|
||||
if self._session is not None:
|
||||
try:
|
||||
status = c_api_util.ScopedTFStatus()
|
||||
if self._created_with_new_api:
|
||||
tf_session.TF_DeleteSession(self._session, status)
|
||||
tf_session.TF_DeleteSession(self._session)
|
||||
else:
|
||||
tf_session.TF_DeleteDeprecatedSession(self._session, status)
|
||||
tf_session.TF_DeleteDeprecatedSession(self._session)
|
||||
except AttributeError:
|
||||
# At shutdown, `c_api_util` or `tf_session` may have been garbage
|
||||
# collected, causing the above method calls to fail. In this case,
|
||||
@ -1031,11 +1023,11 @@ class BaseSession(SessionInterface):
|
||||
# Set up a graph with feeds and fetches for partial run.
|
||||
def _setup_fn(session, feed_list, fetch_list, target_list):
|
||||
self._extend_graph()
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
if self._created_with_new_api:
|
||||
return tf_session.TF_SessionPRunSetup_wrapper(
|
||||
session, feed_list, fetch_list, target_list, status)
|
||||
else:
|
||||
if self._created_with_new_api:
|
||||
return tf_session.TF_SessionPRunSetup_wrapper(
|
||||
session, feed_list, fetch_list, target_list)
|
||||
else:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
return tf_session.TF_PRunSetup(session, feed_list, fetch_list,
|
||||
target_list, status)
|
||||
|
||||
@ -1345,8 +1337,7 @@ class BaseSession(SessionInterface):
|
||||
def _extend_graph(self):
|
||||
if self._created_with_new_api:
|
||||
with self._graph._lock: # pylint: disable=protected-access
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
tf_session.ExtendSession(self._session, status)
|
||||
tf_session.ExtendSession(self._session)
|
||||
else:
|
||||
# Ensure any changes to the graph are reflected in the runtime.
|
||||
with self._extend_lock:
|
||||
@ -1412,22 +1403,22 @@ class BaseSession(SessionInterface):
|
||||
|
||||
def _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list,
|
||||
run_metadata):
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
if self._created_with_new_api:
|
||||
return tf_session.TF_SessionRun_wrapper(
|
||||
self._session, options, feed_dict, fetch_list, target_list,
|
||||
run_metadata, status)
|
||||
else:
|
||||
if self._created_with_new_api:
|
||||
return tf_session.TF_SessionRun_wrapper(
|
||||
self._session, options, feed_dict, fetch_list, target_list,
|
||||
run_metadata)
|
||||
else:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
return tf_session.TF_Run(
|
||||
self._session, options, feed_dict, fetch_list, target_list,
|
||||
status, run_metadata)
|
||||
|
||||
def _call_tf_sessionprun(self, handle, feed_dict, fetch_list):
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
if self._created_with_new_api:
|
||||
return tf_session.TF_SessionPRun_wrapper(
|
||||
self._session, handle, feed_dict, fetch_list, status)
|
||||
else:
|
||||
if self._created_with_new_api:
|
||||
return tf_session.TF_SessionPRun_wrapper(
|
||||
self._session, handle, feed_dict, fetch_list)
|
||||
else:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
return tf_session.TF_PRun(
|
||||
self._session, handle, feed_dict, fetch_list, status)
|
||||
|
||||
|
@ -23,7 +23,6 @@ from tensorflow.core.protobuf import cluster_pb2
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python import pywrap_tensorflow as tf_session
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import c_api_util
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
@ -42,21 +41,13 @@ class SessionListDevicesTestMethods(object):
|
||||
|
||||
def testInvalidDeviceNumber(self):
|
||||
opts = tf_session.TF_NewSessionOptions()
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
c_session = tf_session.TF_NewSession(
|
||||
ops.get_default_graph()._c_graph, opts, status)
|
||||
raw_device_list = tf_session.TF_SessionListDevices(
|
||||
c_session, status)
|
||||
c_session = tf_session.TF_NewSession(ops.get_default_graph()._c_graph, opts)
|
||||
raw_device_list = tf_session.TF_SessionListDevices(c_session)
|
||||
size = tf_session.TF_DeviceListCount(raw_device_list)
|
||||
# Test that invalid device numbers return -1 rather than a Swig-wrapped
|
||||
# pointer.
|
||||
status_no_exception = c_api_util.ScopedTFStatus()
|
||||
memory = tf_session.TF_DeviceListMemoryBytes(
|
||||
raw_device_list, size, status_no_exception)
|
||||
self.assertEqual(memory, -1)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
tf_session.TF_DeviceListMemoryBytes(raw_device_list, size)
|
||||
tf_session.TF_DeleteDeviceList(raw_device_list)
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
tf_session.TF_CloseSession(c_session, status)
|
||||
tf_session.TF_CloseSession(c_session)
|
||||
|
||||
def testListDevicesGrpcSession(self):
|
||||
server = server_lib.Server.create_local_server()
|
||||
|
@ -18,11 +18,12 @@ limitations under the License.
|
||||
%{
|
||||
|
||||
#include "tensorflow/c/python_api.h"
|
||||
#include "tensorflow/python/client/tf_session_helper.h"
|
||||
#include "tensorflow/core/framework/session_state.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
#include "tensorflow/python/client/tf_session_helper.h"
|
||||
#include "tensorflow/python/lib/core/py_exception_registry.h"
|
||||
|
||||
// Helper function to convert a Python list of Tensors to a C++ vector of
|
||||
// TF_Outputs.
|
||||
@ -352,6 +353,27 @@ TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper{
|
||||
reinterpret_cast<const char*>($1.data), $1.length);
|
||||
}
|
||||
|
||||
// Typemaps to automatically raise a Python exception from bad output TF_Status.
|
||||
// TODO(b/77295559): expand this to all TF_Status* output params and deprecate
|
||||
// raise_exception_on_not_ok_status (currently it only affects the C API).
|
||||
%typemap(in, numinputs=0) TF_Status* status (TF_Status* status) {
|
||||
status = TF_NewStatus();
|
||||
$1 = status;
|
||||
}
|
||||
|
||||
%typemap(argout) TF_Status* status {
|
||||
TF_Code code = TF_GetCode($1);
|
||||
if (code != TF_OK) {
|
||||
PyObject* exc = tensorflow::PyExceptionRegistry::Lookup(code);
|
||||
// Arguments to OpError.
|
||||
PyObject* exc_args = Py_BuildValue("sss", nullptr, nullptr, TF_Message($1));
|
||||
TF_DeleteStatus($1);
|
||||
SWIG_SetErrorObj(exc, exc_args);
|
||||
SWIG_fail;
|
||||
}
|
||||
TF_DeleteStatus($1);
|
||||
}
|
||||
|
||||
// Converts input Python list of wrapped TF_Outputs into a single array
|
||||
%typemap(in) (const TF_Output* inputs, int num_inputs)
|
||||
(std::vector<TF_Output> inputs) {
|
||||
@ -499,9 +521,8 @@ TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper{
|
||||
_TF_SetTarget(opts, target)
|
||||
if config is not None:
|
||||
from tensorflow.python.framework import errors
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
config_str = config.SerializeToString()
|
||||
_TF_SetConfig(opts, config_str, status)
|
||||
config_str = config.SerializeToString()
|
||||
_TF_SetConfig(opts, config_str)
|
||||
return opts
|
||||
%}
|
||||
|
||||
@ -758,3 +779,7 @@ def TF_Reset(target, containers=None, config=None):
|
||||
%include "tensorflow/python/client/tf_session_helper.h"
|
||||
|
||||
%unignoreall
|
||||
|
||||
// Clear "TF_Status* status" typemap so it doesn't affect other modules and
|
||||
// unexpectedly remove the TF_Status* argument from wrappers.
|
||||
%clear TF_Status* status;
|
||||
|
@ -136,8 +136,7 @@ string EqualAttrValueWrapper(const string& actual, const string& expected);
|
||||
//
|
||||
// If shape is unknown, sets unknown_shape to true.
|
||||
tensorflow::gtl::InlinedVector<int64_t, 6> TF_GraphGetTensorShapeHelper(
|
||||
TF_Graph* graph, TF_Output output, TF_Status* out_status,
|
||||
bool* unknown_shape);
|
||||
TF_Graph* graph, TF_Output output, TF_Status* status, bool* unknown_shape);
|
||||
|
||||
// Runs the graph associated with the session starting with the supplied inputs.
|
||||
// On success, `py_outputs` is populated with a numpy ndarray for each output
|
||||
@ -149,7 +148,7 @@ void TF_SessionRun_wrapper(TF_Session* session, const TF_Buffer* run_options,
|
||||
const std::vector<PyObject*>& input_ndarrays,
|
||||
const std::vector<TF_Output>& outputs,
|
||||
const std::vector<TF_Operation*>& targets,
|
||||
TF_Buffer* run_metadata, TF_Status* out_status,
|
||||
TF_Buffer* run_metadata, TF_Status* status,
|
||||
std::vector<PyObject*>* py_outputs);
|
||||
|
||||
// Set up the graph with the intended feeds (inputs) and fetches (output) for
|
||||
@ -165,8 +164,7 @@ void TF_SessionPRunSetup_wrapper(TF_Session* session,
|
||||
const std::vector<TF_Output>& inputs,
|
||||
const std::vector<TF_Output>& outputs,
|
||||
const std::vector<TF_Operation*>& targets,
|
||||
const char** out_handle,
|
||||
TF_Status* out_status);
|
||||
const char** out_handle, TF_Status* status);
|
||||
|
||||
// Continue to run the graph with additional feeds and fetches. The
|
||||
// execution state is uniquely identified by the handle.
|
||||
@ -182,7 +180,7 @@ void TF_SessionPRun_wrapper(TF_Session* session, const char* handle,
|
||||
const std::vector<TF_Output>& inputs,
|
||||
const std::vector<PyObject*>& input_ndarrays,
|
||||
const std::vector<TF_Output>& outputs,
|
||||
TF_Status* out_status,
|
||||
TF_Status* status,
|
||||
std::vector<PyObject*>* py_outputs);
|
||||
|
||||
// Retrieves the inputs of this operation.
|
||||
@ -204,7 +202,7 @@ TF_Function* TF_GraphToFunction_wrapper(
|
||||
const std::vector<TF_Operation*>* opers,
|
||||
const std::vector<TF_Output>& inputs, const std::vector<TF_Output>& outputs,
|
||||
const NameVector& output_names, const TF_FunctionOptions* opts,
|
||||
const char* description, TF_Status* out_status);
|
||||
const char* description, TF_Status* status);
|
||||
|
||||
// Set the shapes and types for the output's handle.
|
||||
//
|
||||
|
@ -244,13 +244,9 @@ class Context(object):
|
||||
try:
|
||||
self._num_gpus = 0
|
||||
for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)):
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
dev_name = pywrap_tensorflow.TF_DeviceListName(
|
||||
device_list, i, status)
|
||||
dev_name = pywrap_tensorflow.TF_DeviceListName(device_list, i)
|
||||
self._context_devices.append(pydev.canonical_name(dev_name))
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
dev_type = pywrap_tensorflow.TF_DeviceListType(
|
||||
device_list, i, status)
|
||||
dev_type = pywrap_tensorflow.TF_DeviceListType(device_list, i)
|
||||
if dev_type == "GPU":
|
||||
self._num_gpus += 1
|
||||
|
||||
|
@ -34,7 +34,6 @@ from tensorflow.python.eager.graph_only_ops import graph_placeholder
|
||||
from tensorflow.python.framework import c_api_util
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes as dtypes_module
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -79,14 +78,10 @@ def capture_value(tensor_map, value, dtype, name):
|
||||
ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
|
||||
shapes = [[d.size for d in s.dim]
|
||||
if not s.unknown_rank else None for s in shapes]
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
|
||||
captured_value._op._graph._c_graph, # pylint: disable=protected-access
|
||||
captured_value._as_tf_output(), # pylint: disable=protected-access
|
||||
shapes,
|
||||
ranks,
|
||||
types,
|
||||
status)
|
||||
pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
|
||||
captured_value._op._graph._c_graph, # pylint: disable=protected-access
|
||||
captured_value._as_tf_output(), # pylint: disable=protected-access
|
||||
shapes, ranks, types)
|
||||
|
||||
tensor_map[ops.tensor_id(value)] = (value, captured_value)
|
||||
else:
|
||||
@ -275,23 +270,20 @@ class _EagerDefinedFunction(object):
|
||||
inputs: the tensors in the graph to be used as inputs to the function
|
||||
outputs: the tensors in the graph which will be outputs to the function
|
||||
"""
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
|
||||
graph._c_graph, # pylint: disable=protected-access
|
||||
compat.as_str(name),
|
||||
False,
|
||||
[o._c_op for o in operations], # pylint: disable=protected-access
|
||||
[t._as_tf_output() for t in inputs], # pylint: disable=protected-access
|
||||
[t._as_tf_output() for t in outputs], # pylint: disable=protected-access
|
||||
[],
|
||||
None,
|
||||
compat.as_str(""),
|
||||
status)
|
||||
fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
|
||||
graph._c_graph, # pylint: disable=protected-access
|
||||
compat.as_str(name),
|
||||
False,
|
||||
[o._c_op for o in operations], # pylint: disable=protected-access
|
||||
[t._as_tf_output() for t in inputs], # pylint: disable=protected-access
|
||||
[t._as_tf_output() for t in outputs], # pylint: disable=protected-access
|
||||
[],
|
||||
None,
|
||||
compat.as_str(""))
|
||||
# TODO(apassos) avoid creating a FunctionDef (specially to grab the
|
||||
# signature, but also in general it's nice not to depend on it.
|
||||
with c_api_util.tf_buffer() as buffer_:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_, status)
|
||||
pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_)
|
||||
proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
|
||||
function_def = function_pb2.FunctionDef()
|
||||
function_def.ParseFromString(compat.as_bytes(proto_data))
|
||||
|
@ -473,6 +473,8 @@ _CODE_TO_EXCEPTION_CLASS = {
|
||||
DATA_LOSS: DataLossError,
|
||||
}
|
||||
|
||||
c_api.PyExceptionRegistry_Init(_CODE_TO_EXCEPTION_CLASS)
|
||||
|
||||
_EXCEPTION_CLASS_TO_CODE = dict((
|
||||
(class_, code) for (code, class_) in _CODE_TO_EXCEPTION_CLASS.items()))
|
||||
|
||||
@ -499,6 +501,7 @@ def _make_specific_exception(node_def, op, message, error_code):
|
||||
# Named like a function for backwards compatibility with the
|
||||
# @tf_contextlib.contextmanager version, which was switched to a class to avoid
|
||||
# some object creation overhead.
|
||||
# TODO(b/77295559): expand use of TF_Status* SWIG typemap and deprecate this.
|
||||
@tf_export("errors.raise_exception_on_not_ok_status") # pylint: disable=invalid-name
|
||||
class raise_exception_on_not_ok_status(object):
|
||||
"""Context manager to check for C API status."""
|
||||
|
@ -30,7 +30,6 @@ from tensorflow.python import pywrap_tensorflow as c_api
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import c_api_util
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import graph_to_function_def
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -275,8 +274,7 @@ class _DefinedFunction(object):
|
||||
self._create_definition_if_needed()
|
||||
if self._c_func:
|
||||
with c_api_util.tf_buffer() as buf:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
c_api.TF_FunctionToFunctionDef(self._c_func, buf, status)
|
||||
c_api.TF_FunctionToFunctionDef(self._c_func, buf)
|
||||
fdef = function_pb2.FunctionDef()
|
||||
proto_data = c_api.TF_GetBuffer(buf)
|
||||
fdef.ParseFromString(compat.as_bytes(proto_data))
|
||||
@ -399,18 +397,16 @@ class _DefinedFunction(object):
|
||||
if self._out_names else [])
|
||||
description = self._func.__doc__ or None
|
||||
# pylint: disable=protected-access
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
self._c_func = c_api.TF_GraphToFunction_wrapper(
|
||||
temp_graph._c_graph,
|
||||
base_func_name,
|
||||
self._func_name is None, # append_hash_to_fn_name
|
||||
None, # opers
|
||||
[t._as_tf_output() for t in inputs],
|
||||
[t._as_tf_output() for t in outputs],
|
||||
output_names,
|
||||
None, # opts
|
||||
description,
|
||||
status)
|
||||
self._c_func = c_api.TF_GraphToFunction_wrapper(
|
||||
temp_graph._c_graph,
|
||||
base_func_name,
|
||||
self._func_name is None, # append_hash_to_fn_name
|
||||
None, # opers
|
||||
[t._as_tf_output() for t in inputs],
|
||||
[t._as_tf_output() for t in outputs],
|
||||
output_names,
|
||||
None, # opts
|
||||
description)
|
||||
# pylint: enable=protected-access
|
||||
self._set_c_attrs(kwargs_attr)
|
||||
|
||||
@ -433,9 +429,8 @@ class _DefinedFunction(object):
|
||||
serialized = attr_value.SerializeToString()
|
||||
# TODO(skyewm): this creates and deletes a new TF_Status for every attr.
|
||||
# It might be worth creating a convenient way to re-use the same status.
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
c_api.TF_FunctionSetAttrValueProto(self._c_func, compat.as_str(name),
|
||||
serialized, status)
|
||||
c_api.TF_FunctionSetAttrValueProto(self._c_func, compat.as_str(name),
|
||||
serialized)
|
||||
|
||||
def _create_hash_str(self, input_arg, output_arg, node_def):
|
||||
"""Creates an 8-character string unique to this input.
|
||||
@ -830,8 +825,7 @@ def _from_definition(fdef, grad_func=None):
|
||||
# pylint: disable=protected-access
|
||||
if ops._USE_C_API:
|
||||
serialized = fdef.SerializeToString()
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
result._c_func = c_api.TF_FunctionImportFunctionDef(serialized, status)
|
||||
result._c_func = c_api.TF_FunctionImportFunctionDef(serialized)
|
||||
result._extra_inputs = []
|
||||
else:
|
||||
result._definition = fdef
|
||||
|
@ -485,9 +485,8 @@ def import_graph_def(graph_def,
|
||||
with graph._lock: # pylint: disable=protected-access
|
||||
with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
|
||||
try:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
results = c_api.TF_GraphImportGraphDefWithResults(
|
||||
graph._c_graph, serialized, options, status) # pylint: disable=protected-access
|
||||
results = c_api.TF_GraphImportGraphDefWithResults(
|
||||
graph._c_graph, serialized, options) # pylint: disable=protected-access
|
||||
except errors.InvalidArgumentError as e:
|
||||
# Convert to ValueError for backwards compatibility.
|
||||
raise ValueError(str(e))
|
||||
|
@ -26,7 +26,6 @@ import threading # pylint: disable=unused-import
|
||||
from tensorflow.core.framework import op_def_pb2
|
||||
from tensorflow.core.lib.core import error_codes_pb2 # pylint: disable=unused-import
|
||||
from tensorflow.python import pywrap_tensorflow as py_tf
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@ -54,8 +53,7 @@ def load_op_library(library_filename):
|
||||
Raises:
|
||||
RuntimeError: when unable to load the library or get the python wrappers.
|
||||
"""
|
||||
with errors_impl.raise_exception_on_not_ok_status() as status:
|
||||
lib_handle = py_tf.TF_LoadLibrary(library_filename, status)
|
||||
lib_handle = py_tf.TF_LoadLibrary(library_filename)
|
||||
|
||||
op_list_str = py_tf.TF_GetOpList(lib_handle)
|
||||
op_list = op_def_pb2.OpList()
|
||||
@ -99,5 +97,4 @@ def load_file_system_library(library_filename):
|
||||
Raises:
|
||||
RuntimeError: when unable to load the library.
|
||||
"""
|
||||
with errors_impl.raise_exception_on_not_ok_status() as status:
|
||||
lib_handle = py_tf.TF_LoadLibrary(library_filename, status)
|
||||
py_tf.TF_LoadLibrary(library_filename)
|
||||
|
@ -373,15 +373,12 @@ class Tensor(_TensorLike):
|
||||
"""
|
||||
graph = self._op._graph._c_graph # pylint: disable=protected-access
|
||||
if graph and _USE_C_SHAPES:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
num_dims = c_api.TF_GraphGetTensorNumDims(graph, self._as_tf_output(),
|
||||
status)
|
||||
num_dims = c_api.TF_GraphGetTensorNumDims(graph, self._as_tf_output())
|
||||
if num_dims == -1:
|
||||
dim_list = None
|
||||
else:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
dim_list = c_api.TF_GraphGetTensorShape_wrapper(
|
||||
graph, self._as_tf_output(), num_dims, status)
|
||||
dim_list = c_api.TF_GraphGetTensorShape_wrapper(
|
||||
graph, self._as_tf_output(), num_dims)
|
||||
dim_list = [None if i == -1 else i for i in dim_list]
|
||||
return tensor_shape.TensorShape(dim_list)
|
||||
return self._shape_val
|
||||
@ -489,13 +486,11 @@ class Tensor(_TensorLike):
|
||||
else:
|
||||
dim_list.append(dim.value)
|
||||
try:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
c_api.TF_GraphSetTensorShape_wrapper(
|
||||
self._op._graph._c_graph, # pylint: disable=protected-access
|
||||
self._as_tf_output(),
|
||||
dim_list,
|
||||
unknown_shape,
|
||||
status)
|
||||
c_api.TF_GraphSetTensorShape_wrapper(
|
||||
self._op._graph._c_graph, # pylint: disable=protected-access
|
||||
self._as_tf_output(),
|
||||
dim_list,
|
||||
unknown_shape)
|
||||
except errors.InvalidArgumentError as e:
|
||||
# Convert to ValueError for backwards compatibility.
|
||||
raise ValueError(str(e))
|
||||
@ -1514,13 +1509,10 @@ def _create_c_op(graph, node_def, inputs, control_inputs):
|
||||
serialized = attr_value.SerializeToString()
|
||||
# TODO(skyewm): this creates and deletes a new TF_Status for every attr.
|
||||
# It might be worth creating a convenient way to re-use the same status.
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
c_api.TF_SetAttrValueProto(op_desc,
|
||||
compat.as_str(name), serialized, status)
|
||||
c_api.TF_SetAttrValueProto(op_desc, compat.as_str(name), serialized)
|
||||
|
||||
try:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
c_op = c_api.TF_FinishOperation(op_desc, status)
|
||||
c_op = c_api.TF_FinishOperation(op_desc)
|
||||
except errors.InvalidArgumentError as e:
|
||||
# Convert to ValueError for backwards compatibility.
|
||||
raise ValueError(str(e))
|
||||
@ -1943,12 +1935,10 @@ class Operation(object):
|
||||
if self._c_op:
|
||||
# Reset cached inputs.
|
||||
self._inputs_val = None
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
c_api.UpdateEdge(
|
||||
self._graph._c_graph, # pylint: disable=protected-access
|
||||
tensor._as_tf_output(), # pylint: disable=protected-access
|
||||
self._tf_input(index),
|
||||
status)
|
||||
c_api.UpdateEdge(
|
||||
self._graph._c_graph, # pylint: disable=protected-access
|
||||
tensor._as_tf_output(), # pylint: disable=protected-access
|
||||
self._tf_input(index))
|
||||
else:
|
||||
self._inputs_val[index].consumers().remove(self)
|
||||
self._inputs_val[index] = tensor
|
||||
@ -2169,8 +2159,7 @@ class Operation(object):
|
||||
# pylint: enable=line-too-long
|
||||
if self._c_op:
|
||||
with c_api_util.tf_buffer() as buf:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
c_api.TF_OperationToNodeDef(self._c_op, buf, status)
|
||||
c_api.TF_OperationToNodeDef(self._c_op, buf)
|
||||
data = c_api.TF_GetBuffer(buf)
|
||||
node_def = node_def_pb2.NodeDef()
|
||||
node_def.ParseFromString(compat.as_bytes(data))
|
||||
@ -2228,11 +2217,9 @@ class Operation(object):
|
||||
buf = c_api.TF_NewBufferFromString(
|
||||
compat.as_bytes(attr_value.SerializeToString()))
|
||||
try:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
# pylint: disable=protected-access
|
||||
c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf,
|
||||
status)
|
||||
# pylint: enable=protected-access
|
||||
# pylint: disable=protected-access
|
||||
c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf)
|
||||
# pylint: enable=protected-access
|
||||
finally:
|
||||
c_api.TF_DeleteBuffer(buf)
|
||||
else:
|
||||
@ -2254,8 +2241,7 @@ class Operation(object):
|
||||
if self._c_op:
|
||||
try:
|
||||
with c_api_util.tf_buffer() as buf:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
c_api.TF_OperationGetAttrValueProto(self._c_op, name, buf, status)
|
||||
c_api.TF_OperationGetAttrValueProto(self._c_op, name, buf)
|
||||
data = c_api.TF_GetBuffer(buf)
|
||||
except errors.InvalidArgumentError as e:
|
||||
# Convert to ValueError for backwards compatibility.
|
||||
@ -2469,11 +2455,10 @@ def _set_shapes_for_outputs_c_api(op):
|
||||
# The C API computes the shapes when the TF_Operation is created. Fetch the
|
||||
# output shapes from the C object.
|
||||
for output in op.outputs:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
# pylint: disable=protected-access
|
||||
shape_vector, unknown_shape = c_api.TF_GraphGetTensorShapeHelper(
|
||||
op._graph._c_graph, output._as_tf_output(), status)
|
||||
# pylint: enable=protected-access
|
||||
# pylint: disable=protected-access
|
||||
shape_vector, unknown_shape = c_api.TF_GraphGetTensorShapeHelper(
|
||||
op._graph._c_graph, output._as_tf_output())
|
||||
# pylint: enable=protected-access
|
||||
if unknown_shape:
|
||||
output.set_shape(tensor_shape.unknown_shape())
|
||||
elif not shape_vector:
|
||||
@ -2994,8 +2979,7 @@ class Graph(object):
|
||||
# pylint: enable=line-too-long
|
||||
if self._c_graph:
|
||||
with c_api_util.tf_buffer() as buf:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
c_api.TF_GraphVersions(self._c_graph, buf, status)
|
||||
c_api.TF_GraphVersions(self._c_graph, buf)
|
||||
data = c_api.TF_GetBuffer(buf)
|
||||
version_def = versions_pb2.VersionDef()
|
||||
version_def.ParseFromString(compat.as_bytes(data))
|
||||
@ -3098,8 +3082,7 @@ class Graph(object):
|
||||
if self._c_graph:
|
||||
with self._lock:
|
||||
with c_api_util.tf_buffer() as buf:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
c_api.TF_GraphToGraphDef(self._c_graph, buf, status)
|
||||
c_api.TF_GraphToGraphDef(self._c_graph, buf)
|
||||
data = c_api.TF_GetBuffer(buf)
|
||||
graph = graph_pb2.GraphDef()
|
||||
graph.ParseFromString(compat.as_bytes(data))
|
||||
@ -3208,14 +3191,10 @@ class Graph(object):
|
||||
# remove this when all functions are generated using the C API by default
|
||||
# as this will be unnecessary.
|
||||
if not function._c_func:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
serialized = function.definition.SerializeToString()
|
||||
function._c_func = c_api.TF_FunctionImportFunctionDef(
|
||||
serialized, status)
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
gradient = function._grad_func._c_func if function._grad_func else None
|
||||
c_api.TF_GraphCopyFunction(self._c_graph, function._c_func, gradient,
|
||||
status)
|
||||
serialized = function.definition.SerializeToString()
|
||||
function._c_func = c_api.TF_FunctionImportFunctionDef(serialized)
|
||||
gradient = function._grad_func._c_func if function._grad_func else None
|
||||
c_api.TF_GraphCopyFunction(self._c_graph, function._c_func, gradient)
|
||||
else:
|
||||
# If there is already a function with the same name, raise an error
|
||||
# if bodies are different. Else, do nothing. The C API version above
|
||||
@ -3732,11 +3711,9 @@ class Graph(object):
|
||||
"""Returns the `OpDef` proto for `type`. `type` is a string."""
|
||||
if self._c_graph:
|
||||
with c_api_util.tf_buffer() as buf:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
# pylint: disable=protected-access
|
||||
c_api.TF_GraphGetOpDef(self._c_graph,
|
||||
compat.as_bytes(type), buf, status)
|
||||
# pylint: enable=protected-access
|
||||
# pylint: disable=protected-access
|
||||
c_api.TF_GraphGetOpDef(self._c_graph, compat.as_bytes(type), buf)
|
||||
# pylint: enable=protected-access
|
||||
data = c_api.TF_GetBuffer(buf)
|
||||
op_def = op_def_pb2.OpDef()
|
||||
op_def.ParseFromString(compat.as_bytes(data))
|
||||
|
@ -19,7 +19,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow as c_api
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -83,9 +82,8 @@ def smart_constant_value(pred):
|
||||
# wanted to limit the change hidden behind _USE_C_API).
|
||||
# pylint: disable=protected-access
|
||||
if pred_value is None and ops._USE_C_API:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
pred_value = c_api.TF_TryEvaluateConstant_wrapper(
|
||||
pred.graph._c_graph, pred._as_tf_output(), status)
|
||||
pred_value = c_api.TF_TryEvaluateConstant_wrapper(pred.graph._c_graph,
|
||||
pred._as_tf_output())
|
||||
# pylint: enable=protected-access
|
||||
|
||||
else:
|
||||
|
50
tensorflow/python/lib/core/py_exception_registry.cc
Normal file
50
tensorflow/python/lib/core/py_exception_registry.cc
Normal file
@ -0,0 +1,50 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/python/lib/core/py_exception_registry.h"
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
PyExceptionRegistry* PyExceptionRegistry::singleton_ = nullptr;
|
||||
|
||||
void PyExceptionRegistry::Init(PyObject* code_to_exc_type_map) {
|
||||
DCHECK(singleton_ == nullptr) << "PyExceptionRegistry::Init() already called";
|
||||
singleton_ = new PyExceptionRegistry;
|
||||
|
||||
DCHECK(PyDict_Check(code_to_exc_type_map));
|
||||
PyObject* key;
|
||||
PyObject* value;
|
||||
Py_ssize_t pos = 0;
|
||||
while (PyDict_Next(code_to_exc_type_map, &pos, &key, &value)) {
|
||||
TF_Code code = static_cast<TF_Code>(PyLong_AsLong(key));
|
||||
singleton_->exc_types_[code] = value;
|
||||
// The exception classes should also have the lifetime of the process, but
|
||||
// incref just in case.
|
||||
Py_INCREF(value);
|
||||
}
|
||||
}
|
||||
|
||||
PyObject* PyExceptionRegistry::Lookup(TF_Code code) {
|
||||
DCHECK(singleton_ != nullptr) << "Must call PyExceptionRegistry::Init() "
|
||||
"before PyExceptionRegistry::Lookup()";
|
||||
DCHECK_NE(code, TF_OK);
|
||||
DCHECK(singleton_->exc_types_.find(code) != singleton_->exc_types_.end())
|
||||
<< "Unknown error code passed to PyExceptionRegistry::Lookup: " << code;
|
||||
return singleton_->exc_types_[code];
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
73
tensorflow/python/lib/core/py_exception_registry.h
Normal file
73
tensorflow/python/lib/core/py_exception_registry.h
Normal file
@ -0,0 +1,73 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_PYTHON_LIB_CORE_PY_EXCEPTION_REGISTRY_H_
|
||||
#define TENSORFLOW_PYTHON_LIB_CORE_PY_EXCEPTION_REGISTRY_H_
|
||||
|
||||
#include <map>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
#ifndef PyObject_HEAD
|
||||
struct _object;
|
||||
typedef _object PyObject;
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Global registry mapping C API error codes to the corresponding custom Python
|
||||
// exception type. This is used to expose the exception types to C extension
|
||||
// code (i.e. so we can raise custom exceptions via SWIG).
|
||||
//
|
||||
// Init() must be called exactly once at the beginning of the process before
|
||||
// Lookup() can be used.
|
||||
//
|
||||
// Example usage:
|
||||
// TF_Status* status = TF_NewStatus();
|
||||
// TF_Foo(..., status);
|
||||
//
|
||||
// if (TF_GetCode(status) != TF_OK) {
|
||||
// PyObject* exc_type = PyExceptionRegistry::Lookup(TF_GetCode(status));
|
||||
// // Arguments to OpError base class. Set `node_def` and `op` to None.
|
||||
// PyObject* args =
|
||||
// Py_BuildValue("sss", nullptr, nullptr, TF_Message(status));
|
||||
// PyErr_SetObject(exc_type, args);
|
||||
// Py_DECREF(args);
|
||||
// TF_DeleteStatus(status);
|
||||
// return NULL;
|
||||
// }
|
||||
class PyExceptionRegistry {
|
||||
public:
|
||||
// Initializes the process-wide registry. Should be called exactly once near
|
||||
// the beginning of the process. The arguments are the various Python
|
||||
// exception types (e.g. `cancelled_exc` corresponds to
|
||||
// errors.CancelledError).
|
||||
static void Init(PyObject* code_to_exc_type_map);
|
||||
|
||||
// Returns the Python exception type corresponding to `code`. Init() must be
|
||||
// called before using this function. `code` should not be TF_OK.
|
||||
static PyObject* Lookup(TF_Code code);
|
||||
|
||||
private:
|
||||
static PyExceptionRegistry* singleton_;
|
||||
PyExceptionRegistry() = default;
|
||||
|
||||
// Maps error codes to the corresponding Python exception type.
|
||||
std::map<TF_Code, PyObject*> exc_types_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_PYTHON_LIB_CORE_PY_EXCEPTION_REGISTRY_H_
|
28
tensorflow/python/lib/core/py_exception_registry.i
Normal file
28
tensorflow/python/lib/core/py_exception_registry.i
Normal file
@ -0,0 +1,28 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
%include "tensorflow/python/platform/base.i"
|
||||
|
||||
%{
|
||||
#include "tensorflow/python/lib/core/py_exception_registry.h"
|
||||
%}
|
||||
|
||||
%ignoreall
|
||||
|
||||
%unignore tensorflow::PyExceptionRegistry;
|
||||
%unignore tensorflow::PyExceptionRegistry::Init;
|
||||
|
||||
%include "tensorflow/python/lib/core/py_exception_registry.h"
|
||||
%unignoreall
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
%include "tensorflow/python/util/tfprof.i"
|
||||
|
||||
%include "tensorflow/python/lib/core/py_func.i"
|
||||
%include "tensorflow/python/lib/core/py_exception_registry.i"
|
||||
|
||||
%include "tensorflow/python/lib/io/py_record_reader.i"
|
||||
%include "tensorflow/python/lib/io/py_record_writer.i"
|
||||
@ -54,4 +55,3 @@ limitations under the License.
|
||||
%include "tensorflow/python/grappler/tf_optimizer.i"
|
||||
%include "tensorflow/python/grappler/cost_analyzer.i"
|
||||
%include "tensorflow/python/grappler/model_analyzer.i"
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user