Raise exception in SWIG on bad TF_Status from C API.

This change provides an alternative mechanism to
tf.raise_exception_on_not_ok_status(), which is inefficient and
error-prone (people often use the status multiple times in the with
block, but it's only checked when the context manager exits). Instead,
it uses SWIG to automatically raise an exception when a C API method
fails. Note that this removes the status argument from affected
methods.

For now, I've only applied this typemap to C API methods. It would be
good to expand this to all uses of raise_exception_on_not_ok_status.

PiperOrigin-RevId: 191121016
This commit is contained in:
Skye Wanderman-Milne 2018-03-30 14:56:08 -07:00 committed by TensorFlower Gardener
parent 15c10899c9
commit 97731cb122
19 changed files with 324 additions and 195 deletions

View File

@ -1496,7 +1496,8 @@ TF_CAPI_EXPORT extern int TF_DeviceListCount(const TF_DeviceList* list);
// If index is out of bounds, an error code will be set in the status object,
// and a null pointer will be returned.
TF_CAPI_EXPORT extern const char* TF_DeviceListName(const TF_DeviceList* list,
int index, TF_Status*);
int index,
TF_Status* status);
// Retrieves the type of the device at the given index.
//
@ -1506,14 +1507,15 @@ TF_CAPI_EXPORT extern const char* TF_DeviceListName(const TF_DeviceList* list,
// If index is out of bounds, an error code will be set in the status object,
// and a null pointer will be returned.
TF_CAPI_EXPORT extern const char* TF_DeviceListType(const TF_DeviceList* list,
int index, TF_Status*);
int index,
TF_Status* status);
// Retrieve the amount of memory associated with a given device.
//
// If index is out of bounds, an error code will be set in the status object,
// and -1 will be returned.
TF_CAPI_EXPORT extern int64_t TF_DeviceListMemoryBytes(
const TF_DeviceList* list, int index, TF_Status*);
const TF_DeviceList* list, int index, TF_Status* status);
// --------------------------------------------------------------------------
// Load plugins containing custom ops and kernels

View File

@ -474,6 +474,8 @@ set (pywrap_tensorflow_internal_src
"${tensorflow_source_dir}/tensorflow/python/lib/core/ndarray_tensor_bridge.cc"
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_func.h"
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_func.cc"
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_exception_registry.h"
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_exception_registry.cc"
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.h"
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.cc"
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_util.h"

View File

@ -283,6 +283,17 @@ cc_library(
],
)
cc_library(
name = "py_exception_registry",
srcs = ["lib/core/py_exception_registry.cc"],
hdrs = ["lib/core/py_exception_registry.h"],
deps = [
"//tensorflow/c:c_api",
"//tensorflow/core:lib",
"//util/python:python_headers",
],
)
cc_library(
name = "kernel_registry",
srcs = ["util/kernel_registry.cc"],
@ -3313,6 +3324,7 @@ tf_py_wrap_cc(
"grappler/model_analyzer.i",
"grappler/tf_optimizer.i",
"lib/core/bfloat16.i",
"lib/core/py_exception_registry.i",
"lib/core/py_func.i",
"lib/core/strings.i",
"lib/io/file_io.i",
@ -3344,6 +3356,7 @@ tf_py_wrap_cc(
":kernel_registry",
":numpy_lib",
":safe_ptr",
":py_exception_registry",
":py_func_lib",
":py_record_reader_lib",
":py_record_writer_lib",

View File

@ -27,7 +27,6 @@ import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import pywrap_tensorflow as tf_session
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import device
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@ -629,14 +628,12 @@ class BaseSession(SessionInterface):
self._session = None
opts = tf_session.TF_NewSessionOptions(target=self._target, config=config)
try:
with errors.raise_exception_on_not_ok_status() as status:
if self._created_with_new_api:
# pylint: disable=protected-access
self._session = tf_session.TF_NewSession(self._graph._c_graph, opts,
status)
# pylint: enable=protected-access
else:
self._session = tf_session.TF_NewDeprecatedSession(opts, status)
if self._created_with_new_api:
# pylint: disable=protected-access
self._session = tf_session.TF_NewSession(self._graph._c_graph, opts)
# pylint: enable=protected-access
else:
self._session = tf_session.TF_NewDeprecatedSession(opts)
finally:
tf_session.TF_DeleteSessionOptions(opts)
@ -663,22 +660,20 @@ class BaseSession(SessionInterface):
Returns:
A list of devices in the session.
"""
with errors.raise_exception_on_not_ok_status() as status:
if self._created_with_new_api:
raw_device_list = tf_session.TF_SessionListDevices(
self._session, status)
else:
raw_device_list = tf_session.TF_DeprecatedSessionListDevices(
self._session, status)
device_list = []
size = tf_session.TF_DeviceListCount(raw_device_list)
for i in range(size):
name = tf_session.TF_DeviceListName(raw_device_list, i, status)
device_type = tf_session.TF_DeviceListType(raw_device_list, i, status)
memory = tf_session.TF_DeviceListMemoryBytes(raw_device_list, i, status)
device_list.append(_DeviceAttributes(name, device_type, memory))
tf_session.TF_DeleteDeviceList(raw_device_list)
return device_list
if self._created_with_new_api:
raw_device_list = tf_session.TF_SessionListDevices(self._session)
else:
raw_device_list = tf_session.TF_DeprecatedSessionListDevices(
self._session)
device_list = []
size = tf_session.TF_DeviceListCount(raw_device_list)
for i in range(size):
name = tf_session.TF_DeviceListName(raw_device_list, i)
device_type = tf_session.TF_DeviceListType(raw_device_list, i)
memory = tf_session.TF_DeviceListMemoryBytes(raw_device_list, i)
device_list.append(_DeviceAttributes(name, device_type, memory))
tf_session.TF_DeleteDeviceList(raw_device_list)
return device_list
def close(self):
"""Closes this session.
@ -692,15 +687,13 @@ class BaseSession(SessionInterface):
if self._created_with_new_api:
if self._session and not self._closed:
self._closed = True
with errors.raise_exception_on_not_ok_status() as status:
tf_session.TF_CloseSession(self._session, status)
tf_session.TF_CloseSession(self._session)
else:
with self._extend_lock:
if self._opened and not self._closed:
self._closed = True
with errors.raise_exception_on_not_ok_status() as status:
tf_session.TF_CloseDeprecatedSession(self._session, status)
tf_session.TF_CloseDeprecatedSession(self._session)
def __del__(self):
# cleanly ignore all exceptions
@ -710,11 +703,10 @@ class BaseSession(SessionInterface):
pass
if self._session is not None:
try:
status = c_api_util.ScopedTFStatus()
if self._created_with_new_api:
tf_session.TF_DeleteSession(self._session, status)
tf_session.TF_DeleteSession(self._session)
else:
tf_session.TF_DeleteDeprecatedSession(self._session, status)
tf_session.TF_DeleteDeprecatedSession(self._session)
except AttributeError:
# At shutdown, `c_api_util` or `tf_session` may have been garbage
# collected, causing the above method calls to fail. In this case,
@ -1031,11 +1023,11 @@ class BaseSession(SessionInterface):
# Set up a graph with feeds and fetches for partial run.
def _setup_fn(session, feed_list, fetch_list, target_list):
self._extend_graph()
with errors.raise_exception_on_not_ok_status() as status:
if self._created_with_new_api:
return tf_session.TF_SessionPRunSetup_wrapper(
session, feed_list, fetch_list, target_list, status)
else:
if self._created_with_new_api:
return tf_session.TF_SessionPRunSetup_wrapper(
session, feed_list, fetch_list, target_list)
else:
with errors.raise_exception_on_not_ok_status() as status:
return tf_session.TF_PRunSetup(session, feed_list, fetch_list,
target_list, status)
@ -1345,8 +1337,7 @@ class BaseSession(SessionInterface):
def _extend_graph(self):
if self._created_with_new_api:
with self._graph._lock: # pylint: disable=protected-access
with errors.raise_exception_on_not_ok_status() as status:
tf_session.ExtendSession(self._session, status)
tf_session.ExtendSession(self._session)
else:
# Ensure any changes to the graph are reflected in the runtime.
with self._extend_lock:
@ -1412,22 +1403,22 @@ class BaseSession(SessionInterface):
def _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list,
run_metadata):
with errors.raise_exception_on_not_ok_status() as status:
if self._created_with_new_api:
return tf_session.TF_SessionRun_wrapper(
self._session, options, feed_dict, fetch_list, target_list,
run_metadata, status)
else:
if self._created_with_new_api:
return tf_session.TF_SessionRun_wrapper(
self._session, options, feed_dict, fetch_list, target_list,
run_metadata)
else:
with errors.raise_exception_on_not_ok_status() as status:
return tf_session.TF_Run(
self._session, options, feed_dict, fetch_list, target_list,
status, run_metadata)
def _call_tf_sessionprun(self, handle, feed_dict, fetch_list):
with errors.raise_exception_on_not_ok_status() as status:
if self._created_with_new_api:
return tf_session.TF_SessionPRun_wrapper(
self._session, handle, feed_dict, fetch_list, status)
else:
if self._created_with_new_api:
return tf_session.TF_SessionPRun_wrapper(
self._session, handle, feed_dict, fetch_list)
else:
with errors.raise_exception_on_not_ok_status() as status:
return tf_session.TF_PRun(
self._session, handle, feed_dict, fetch_list, status)

View File

@ -23,7 +23,6 @@ from tensorflow.core.protobuf import cluster_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import pywrap_tensorflow as tf_session
from tensorflow.python.client import session
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
@ -42,21 +41,13 @@ class SessionListDevicesTestMethods(object):
def testInvalidDeviceNumber(self):
opts = tf_session.TF_NewSessionOptions()
with errors.raise_exception_on_not_ok_status() as status:
c_session = tf_session.TF_NewSession(
ops.get_default_graph()._c_graph, opts, status)
raw_device_list = tf_session.TF_SessionListDevices(
c_session, status)
c_session = tf_session.TF_NewSession(ops.get_default_graph()._c_graph, opts)
raw_device_list = tf_session.TF_SessionListDevices(c_session)
size = tf_session.TF_DeviceListCount(raw_device_list)
# Test that invalid device numbers return -1 rather than a Swig-wrapped
# pointer.
status_no_exception = c_api_util.ScopedTFStatus()
memory = tf_session.TF_DeviceListMemoryBytes(
raw_device_list, size, status_no_exception)
self.assertEqual(memory, -1)
with self.assertRaises(errors.InvalidArgumentError):
tf_session.TF_DeviceListMemoryBytes(raw_device_list, size)
tf_session.TF_DeleteDeviceList(raw_device_list)
with errors.raise_exception_on_not_ok_status() as status:
tf_session.TF_CloseSession(c_session, status)
tf_session.TF_CloseSession(c_session)
def testListDevicesGrpcSession(self):
server = server_lib.Server.create_local_server()

View File

@ -18,11 +18,12 @@ limitations under the License.
%{
#include "tensorflow/c/python_api.h"
#include "tensorflow/python/client/tf_session_helper.h"
#include "tensorflow/core/framework/session_state.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/python/client/tf_session_helper.h"
#include "tensorflow/python/lib/core/py_exception_registry.h"
// Helper function to convert a Python list of Tensors to a C++ vector of
// TF_Outputs.
@ -352,6 +353,27 @@ TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper{
reinterpret_cast<const char*>($1.data), $1.length);
}
// Typemaps to automatically raise a Python exception from bad output TF_Status.
// TODO(b/77295559): expand this to all TF_Status* output params and deprecate
// raise_exception_on_not_ok_status (currently it only affects the C API).
%typemap(in, numinputs=0) TF_Status* status (TF_Status* status) {
status = TF_NewStatus();
$1 = status;
}
%typemap(argout) TF_Status* status {
TF_Code code = TF_GetCode($1);
if (code != TF_OK) {
PyObject* exc = tensorflow::PyExceptionRegistry::Lookup(code);
// Arguments to OpError.
PyObject* exc_args = Py_BuildValue("sss", nullptr, nullptr, TF_Message($1));
TF_DeleteStatus($1);
SWIG_SetErrorObj(exc, exc_args);
SWIG_fail;
}
TF_DeleteStatus($1);
}
// Converts input Python list of wrapped TF_Outputs into a single array
%typemap(in) (const TF_Output* inputs, int num_inputs)
(std::vector<TF_Output> inputs) {
@ -499,9 +521,8 @@ TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper{
_TF_SetTarget(opts, target)
if config is not None:
from tensorflow.python.framework import errors
with errors.raise_exception_on_not_ok_status() as status:
config_str = config.SerializeToString()
_TF_SetConfig(opts, config_str, status)
config_str = config.SerializeToString()
_TF_SetConfig(opts, config_str)
return opts
%}
@ -758,3 +779,7 @@ def TF_Reset(target, containers=None, config=None):
%include "tensorflow/python/client/tf_session_helper.h"
%unignoreall
// Clear "TF_Status* status" typemap so it doesn't affect other modules and
// unexpectedly remove the TF_Status* argument from wrappers.
%clear TF_Status* status;

View File

@ -136,8 +136,7 @@ string EqualAttrValueWrapper(const string& actual, const string& expected);
//
// If shape is unknown, sets unknown_shape to true.
tensorflow::gtl::InlinedVector<int64_t, 6> TF_GraphGetTensorShapeHelper(
TF_Graph* graph, TF_Output output, TF_Status* out_status,
bool* unknown_shape);
TF_Graph* graph, TF_Output output, TF_Status* status, bool* unknown_shape);
// Runs the graph associated with the session starting with the supplied inputs.
// On success, `py_outputs` is populated with a numpy ndarray for each output
@ -149,7 +148,7 @@ void TF_SessionRun_wrapper(TF_Session* session, const TF_Buffer* run_options,
const std::vector<PyObject*>& input_ndarrays,
const std::vector<TF_Output>& outputs,
const std::vector<TF_Operation*>& targets,
TF_Buffer* run_metadata, TF_Status* out_status,
TF_Buffer* run_metadata, TF_Status* status,
std::vector<PyObject*>* py_outputs);
// Set up the graph with the intended feeds (inputs) and fetches (output) for
@ -165,8 +164,7 @@ void TF_SessionPRunSetup_wrapper(TF_Session* session,
const std::vector<TF_Output>& inputs,
const std::vector<TF_Output>& outputs,
const std::vector<TF_Operation*>& targets,
const char** out_handle,
TF_Status* out_status);
const char** out_handle, TF_Status* status);
// Continue to run the graph with additional feeds and fetches. The
// execution state is uniquely identified by the handle.
@ -182,7 +180,7 @@ void TF_SessionPRun_wrapper(TF_Session* session, const char* handle,
const std::vector<TF_Output>& inputs,
const std::vector<PyObject*>& input_ndarrays,
const std::vector<TF_Output>& outputs,
TF_Status* out_status,
TF_Status* status,
std::vector<PyObject*>* py_outputs);
// Retrieves the inputs of this operation.
@ -204,7 +202,7 @@ TF_Function* TF_GraphToFunction_wrapper(
const std::vector<TF_Operation*>* opers,
const std::vector<TF_Output>& inputs, const std::vector<TF_Output>& outputs,
const NameVector& output_names, const TF_FunctionOptions* opts,
const char* description, TF_Status* out_status);
const char* description, TF_Status* status);
// Set the shapes and types for the output's handle.
//

View File

@ -244,13 +244,9 @@ class Context(object):
try:
self._num_gpus = 0
for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)):
with errors.raise_exception_on_not_ok_status() as status:
dev_name = pywrap_tensorflow.TF_DeviceListName(
device_list, i, status)
dev_name = pywrap_tensorflow.TF_DeviceListName(device_list, i)
self._context_devices.append(pydev.canonical_name(dev_name))
with errors.raise_exception_on_not_ok_status() as status:
dev_type = pywrap_tensorflow.TF_DeviceListType(
device_list, i, status)
dev_type = pywrap_tensorflow.TF_DeviceListType(device_list, i)
if dev_type == "GPU":
self._num_gpus += 1

View File

@ -34,7 +34,6 @@ from tensorflow.python.eager.graph_only_ops import graph_placeholder
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes as dtypes_module
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@ -79,14 +78,10 @@ def capture_value(tensor_map, value, dtype, name):
ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
shapes = [[d.size for d in s.dim]
if not s.unknown_rank else None for s in shapes]
with errors.raise_exception_on_not_ok_status() as status:
pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
captured_value._op._graph._c_graph, # pylint: disable=protected-access
captured_value._as_tf_output(), # pylint: disable=protected-access
shapes,
ranks,
types,
status)
pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
captured_value._op._graph._c_graph, # pylint: disable=protected-access
captured_value._as_tf_output(), # pylint: disable=protected-access
shapes, ranks, types)
tensor_map[ops.tensor_id(value)] = (value, captured_value)
else:
@ -275,23 +270,20 @@ class _EagerDefinedFunction(object):
inputs: the tensors in the graph to be used as inputs to the function
outputs: the tensors in the graph which will be outputs to the function
"""
with errors.raise_exception_on_not_ok_status() as status:
fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
graph._c_graph, # pylint: disable=protected-access
compat.as_str(name),
False,
[o._c_op for o in operations], # pylint: disable=protected-access
[t._as_tf_output() for t in inputs], # pylint: disable=protected-access
[t._as_tf_output() for t in outputs], # pylint: disable=protected-access
[],
None,
compat.as_str(""),
status)
fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
graph._c_graph, # pylint: disable=protected-access
compat.as_str(name),
False,
[o._c_op for o in operations], # pylint: disable=protected-access
[t._as_tf_output() for t in inputs], # pylint: disable=protected-access
[t._as_tf_output() for t in outputs], # pylint: disable=protected-access
[],
None,
compat.as_str(""))
# TODO(apassos) avoid creating a FunctionDef (specially to grab the
# signature, but also in general it's nice not to depend on it.
with c_api_util.tf_buffer() as buffer_:
with errors.raise_exception_on_not_ok_status() as status:
pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_, status)
pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_)
proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
function_def = function_pb2.FunctionDef()
function_def.ParseFromString(compat.as_bytes(proto_data))

View File

@ -473,6 +473,8 @@ _CODE_TO_EXCEPTION_CLASS = {
DATA_LOSS: DataLossError,
}
c_api.PyExceptionRegistry_Init(_CODE_TO_EXCEPTION_CLASS)
_EXCEPTION_CLASS_TO_CODE = dict((
(class_, code) for (code, class_) in _CODE_TO_EXCEPTION_CLASS.items()))
@ -499,6 +501,7 @@ def _make_specific_exception(node_def, op, message, error_code):
# Named like a function for backwards compatibility with the
# @tf_contextlib.contextmanager version, which was switched to a class to avoid
# some object creation overhead.
# TODO(b/77295559): expand use of TF_Status* SWIG typemap and deprecate this.
@tf_export("errors.raise_exception_on_not_ok_status") # pylint: disable=invalid-name
class raise_exception_on_not_ok_status(object):
"""Context manager to check for C API status."""

View File

@ -30,7 +30,6 @@ from tensorflow.python import pywrap_tensorflow as c_api
from tensorflow.python.eager import context
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import graph_to_function_def
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@ -275,8 +274,7 @@ class _DefinedFunction(object):
self._create_definition_if_needed()
if self._c_func:
with c_api_util.tf_buffer() as buf:
with errors.raise_exception_on_not_ok_status() as status:
c_api.TF_FunctionToFunctionDef(self._c_func, buf, status)
c_api.TF_FunctionToFunctionDef(self._c_func, buf)
fdef = function_pb2.FunctionDef()
proto_data = c_api.TF_GetBuffer(buf)
fdef.ParseFromString(compat.as_bytes(proto_data))
@ -399,18 +397,16 @@ class _DefinedFunction(object):
if self._out_names else [])
description = self._func.__doc__ or None
# pylint: disable=protected-access
with errors.raise_exception_on_not_ok_status() as status:
self._c_func = c_api.TF_GraphToFunction_wrapper(
temp_graph._c_graph,
base_func_name,
self._func_name is None, # append_hash_to_fn_name
None, # opers
[t._as_tf_output() for t in inputs],
[t._as_tf_output() for t in outputs],
output_names,
None, # opts
description,
status)
self._c_func = c_api.TF_GraphToFunction_wrapper(
temp_graph._c_graph,
base_func_name,
self._func_name is None, # append_hash_to_fn_name
None, # opers
[t._as_tf_output() for t in inputs],
[t._as_tf_output() for t in outputs],
output_names,
None, # opts
description)
# pylint: enable=protected-access
self._set_c_attrs(kwargs_attr)
@ -433,9 +429,8 @@ class _DefinedFunction(object):
serialized = attr_value.SerializeToString()
# TODO(skyewm): this creates and deletes a new TF_Status for every attr.
# It might be worth creating a convenient way to re-use the same status.
with errors.raise_exception_on_not_ok_status() as status:
c_api.TF_FunctionSetAttrValueProto(self._c_func, compat.as_str(name),
serialized, status)
c_api.TF_FunctionSetAttrValueProto(self._c_func, compat.as_str(name),
serialized)
def _create_hash_str(self, input_arg, output_arg, node_def):
"""Creates an 8-character string unique to this input.
@ -830,8 +825,7 @@ def _from_definition(fdef, grad_func=None):
# pylint: disable=protected-access
if ops._USE_C_API:
serialized = fdef.SerializeToString()
with errors.raise_exception_on_not_ok_status() as status:
result._c_func = c_api.TF_FunctionImportFunctionDef(serialized, status)
result._c_func = c_api.TF_FunctionImportFunctionDef(serialized)
result._extra_inputs = []
else:
result._definition = fdef

View File

@ -485,9 +485,8 @@ def import_graph_def(graph_def,
with graph._lock: # pylint: disable=protected-access
with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
try:
with errors.raise_exception_on_not_ok_status() as status:
results = c_api.TF_GraphImportGraphDefWithResults(
graph._c_graph, serialized, options, status) # pylint: disable=protected-access
results = c_api.TF_GraphImportGraphDefWithResults(
graph._c_graph, serialized, options) # pylint: disable=protected-access
except errors.InvalidArgumentError as e:
# Convert to ValueError for backwards compatibility.
raise ValueError(str(e))

View File

@ -26,7 +26,6 @@ import threading # pylint: disable=unused-import
from tensorflow.core.framework import op_def_pb2
from tensorflow.core.lib.core import error_codes_pb2 # pylint: disable=unused-import
from tensorflow.python import pywrap_tensorflow as py_tf
from tensorflow.python.framework import errors_impl
from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export
@ -54,8 +53,7 @@ def load_op_library(library_filename):
Raises:
RuntimeError: when unable to load the library or get the python wrappers.
"""
with errors_impl.raise_exception_on_not_ok_status() as status:
lib_handle = py_tf.TF_LoadLibrary(library_filename, status)
lib_handle = py_tf.TF_LoadLibrary(library_filename)
op_list_str = py_tf.TF_GetOpList(lib_handle)
op_list = op_def_pb2.OpList()
@ -99,5 +97,4 @@ def load_file_system_library(library_filename):
Raises:
RuntimeError: when unable to load the library.
"""
with errors_impl.raise_exception_on_not_ok_status() as status:
lib_handle = py_tf.TF_LoadLibrary(library_filename, status)
py_tf.TF_LoadLibrary(library_filename)

View File

@ -373,15 +373,12 @@ class Tensor(_TensorLike):
"""
graph = self._op._graph._c_graph # pylint: disable=protected-access
if graph and _USE_C_SHAPES:
with errors.raise_exception_on_not_ok_status() as status:
num_dims = c_api.TF_GraphGetTensorNumDims(graph, self._as_tf_output(),
status)
num_dims = c_api.TF_GraphGetTensorNumDims(graph, self._as_tf_output())
if num_dims == -1:
dim_list = None
else:
with errors.raise_exception_on_not_ok_status() as status:
dim_list = c_api.TF_GraphGetTensorShape_wrapper(
graph, self._as_tf_output(), num_dims, status)
dim_list = c_api.TF_GraphGetTensorShape_wrapper(
graph, self._as_tf_output(), num_dims)
dim_list = [None if i == -1 else i for i in dim_list]
return tensor_shape.TensorShape(dim_list)
return self._shape_val
@ -489,13 +486,11 @@ class Tensor(_TensorLike):
else:
dim_list.append(dim.value)
try:
with errors.raise_exception_on_not_ok_status() as status:
c_api.TF_GraphSetTensorShape_wrapper(
self._op._graph._c_graph, # pylint: disable=protected-access
self._as_tf_output(),
dim_list,
unknown_shape,
status)
c_api.TF_GraphSetTensorShape_wrapper(
self._op._graph._c_graph, # pylint: disable=protected-access
self._as_tf_output(),
dim_list,
unknown_shape)
except errors.InvalidArgumentError as e:
# Convert to ValueError for backwards compatibility.
raise ValueError(str(e))
@ -1514,13 +1509,10 @@ def _create_c_op(graph, node_def, inputs, control_inputs):
serialized = attr_value.SerializeToString()
# TODO(skyewm): this creates and deletes a new TF_Status for every attr.
# It might be worth creating a convenient way to re-use the same status.
with errors.raise_exception_on_not_ok_status() as status:
c_api.TF_SetAttrValueProto(op_desc,
compat.as_str(name), serialized, status)
c_api.TF_SetAttrValueProto(op_desc, compat.as_str(name), serialized)
try:
with errors.raise_exception_on_not_ok_status() as status:
c_op = c_api.TF_FinishOperation(op_desc, status)
c_op = c_api.TF_FinishOperation(op_desc)
except errors.InvalidArgumentError as e:
# Convert to ValueError for backwards compatibility.
raise ValueError(str(e))
@ -1943,12 +1935,10 @@ class Operation(object):
if self._c_op:
# Reset cached inputs.
self._inputs_val = None
with errors.raise_exception_on_not_ok_status() as status:
c_api.UpdateEdge(
self._graph._c_graph, # pylint: disable=protected-access
tensor._as_tf_output(), # pylint: disable=protected-access
self._tf_input(index),
status)
c_api.UpdateEdge(
self._graph._c_graph, # pylint: disable=protected-access
tensor._as_tf_output(), # pylint: disable=protected-access
self._tf_input(index))
else:
self._inputs_val[index].consumers().remove(self)
self._inputs_val[index] = tensor
@ -2169,8 +2159,7 @@ class Operation(object):
# pylint: enable=line-too-long
if self._c_op:
with c_api_util.tf_buffer() as buf:
with errors.raise_exception_on_not_ok_status() as status:
c_api.TF_OperationToNodeDef(self._c_op, buf, status)
c_api.TF_OperationToNodeDef(self._c_op, buf)
data = c_api.TF_GetBuffer(buf)
node_def = node_def_pb2.NodeDef()
node_def.ParseFromString(compat.as_bytes(data))
@ -2228,11 +2217,9 @@ class Operation(object):
buf = c_api.TF_NewBufferFromString(
compat.as_bytes(attr_value.SerializeToString()))
try:
with errors.raise_exception_on_not_ok_status() as status:
# pylint: disable=protected-access
c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf,
status)
# pylint: enable=protected-access
# pylint: disable=protected-access
c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf)
# pylint: enable=protected-access
finally:
c_api.TF_DeleteBuffer(buf)
else:
@ -2254,8 +2241,7 @@ class Operation(object):
if self._c_op:
try:
with c_api_util.tf_buffer() as buf:
with errors.raise_exception_on_not_ok_status() as status:
c_api.TF_OperationGetAttrValueProto(self._c_op, name, buf, status)
c_api.TF_OperationGetAttrValueProto(self._c_op, name, buf)
data = c_api.TF_GetBuffer(buf)
except errors.InvalidArgumentError as e:
# Convert to ValueError for backwards compatibility.
@ -2469,11 +2455,10 @@ def _set_shapes_for_outputs_c_api(op):
# The C API computes the shapes when the TF_Operation is created. Fetch the
# output shapes from the C object.
for output in op.outputs:
with errors.raise_exception_on_not_ok_status() as status:
# pylint: disable=protected-access
shape_vector, unknown_shape = c_api.TF_GraphGetTensorShapeHelper(
op._graph._c_graph, output._as_tf_output(), status)
# pylint: enable=protected-access
# pylint: disable=protected-access
shape_vector, unknown_shape = c_api.TF_GraphGetTensorShapeHelper(
op._graph._c_graph, output._as_tf_output())
# pylint: enable=protected-access
if unknown_shape:
output.set_shape(tensor_shape.unknown_shape())
elif not shape_vector:
@ -2994,8 +2979,7 @@ class Graph(object):
# pylint: enable=line-too-long
if self._c_graph:
with c_api_util.tf_buffer() as buf:
with errors.raise_exception_on_not_ok_status() as status:
c_api.TF_GraphVersions(self._c_graph, buf, status)
c_api.TF_GraphVersions(self._c_graph, buf)
data = c_api.TF_GetBuffer(buf)
version_def = versions_pb2.VersionDef()
version_def.ParseFromString(compat.as_bytes(data))
@ -3098,8 +3082,7 @@ class Graph(object):
if self._c_graph:
with self._lock:
with c_api_util.tf_buffer() as buf:
with errors.raise_exception_on_not_ok_status() as status:
c_api.TF_GraphToGraphDef(self._c_graph, buf, status)
c_api.TF_GraphToGraphDef(self._c_graph, buf)
data = c_api.TF_GetBuffer(buf)
graph = graph_pb2.GraphDef()
graph.ParseFromString(compat.as_bytes(data))
@ -3208,14 +3191,10 @@ class Graph(object):
# remove this when all functions are generated using the C API by default
# as this will be unnecessary.
if not function._c_func:
with errors.raise_exception_on_not_ok_status() as status:
serialized = function.definition.SerializeToString()
function._c_func = c_api.TF_FunctionImportFunctionDef(
serialized, status)
with errors.raise_exception_on_not_ok_status() as status:
gradient = function._grad_func._c_func if function._grad_func else None
c_api.TF_GraphCopyFunction(self._c_graph, function._c_func, gradient,
status)
serialized = function.definition.SerializeToString()
function._c_func = c_api.TF_FunctionImportFunctionDef(serialized)
gradient = function._grad_func._c_func if function._grad_func else None
c_api.TF_GraphCopyFunction(self._c_graph, function._c_func, gradient)
else:
# If there is already a function with the same name, raise an error
# if bodies are different. Else, do nothing. The C API version above
@ -3732,11 +3711,9 @@ class Graph(object):
"""Returns the `OpDef` proto for `type`. `type` is a string."""
if self._c_graph:
with c_api_util.tf_buffer() as buf:
with errors.raise_exception_on_not_ok_status() as status:
# pylint: disable=protected-access
c_api.TF_GraphGetOpDef(self._c_graph,
compat.as_bytes(type), buf, status)
# pylint: enable=protected-access
# pylint: disable=protected-access
c_api.TF_GraphGetOpDef(self._c_graph, compat.as_bytes(type), buf)
# pylint: enable=protected-access
data = c_api.TF_GetBuffer(buf)
op_def = op_def_pb2.OpDef()
op_def.ParseFromString(compat.as_bytes(data))

View File

@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python import pywrap_tensorflow as c_api
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import control_flow_ops
@ -83,9 +82,8 @@ def smart_constant_value(pred):
# wanted to limit the change hidden behind _USE_C_API).
# pylint: disable=protected-access
if pred_value is None and ops._USE_C_API:
with errors.raise_exception_on_not_ok_status() as status:
pred_value = c_api.TF_TryEvaluateConstant_wrapper(
pred.graph._c_graph, pred._as_tf_output(), status)
pred_value = c_api.TF_TryEvaluateConstant_wrapper(pred.graph._c_graph,
pred._as_tf_output())
# pylint: enable=protected-access
else:

View File

@ -0,0 +1,50 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/python/lib/core/py_exception_registry.h"
#include <Python.h>
namespace tensorflow {
PyExceptionRegistry* PyExceptionRegistry::singleton_ = nullptr;
void PyExceptionRegistry::Init(PyObject* code_to_exc_type_map) {
DCHECK(singleton_ == nullptr) << "PyExceptionRegistry::Init() already called";
singleton_ = new PyExceptionRegistry;
DCHECK(PyDict_Check(code_to_exc_type_map));
PyObject* key;
PyObject* value;
Py_ssize_t pos = 0;
while (PyDict_Next(code_to_exc_type_map, &pos, &key, &value)) {
TF_Code code = static_cast<TF_Code>(PyLong_AsLong(key));
singleton_->exc_types_[code] = value;
// The exception classes should also have the lifetime of the process, but
// incref just in case.
Py_INCREF(value);
}
}
PyObject* PyExceptionRegistry::Lookup(TF_Code code) {
DCHECK(singleton_ != nullptr) << "Must call PyExceptionRegistry::Init() "
"before PyExceptionRegistry::Lookup()";
DCHECK_NE(code, TF_OK);
DCHECK(singleton_->exc_types_.find(code) != singleton_->exc_types_.end())
<< "Unknown error code passed to PyExceptionRegistry::Lookup: " << code;
return singleton_->exc_types_[code];
}
} // namespace tensorflow

View File

@ -0,0 +1,73 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_PYTHON_LIB_CORE_PY_EXCEPTION_REGISTRY_H_
#define TENSORFLOW_PYTHON_LIB_CORE_PY_EXCEPTION_REGISTRY_H_
#include <map>
#include "tensorflow/c/c_api.h"
#include "tensorflow/core/platform/logging.h"
#ifndef PyObject_HEAD
struct _object;
typedef _object PyObject;
#endif
namespace tensorflow {
// Global registry mapping C API error codes to the corresponding custom Python
// exception type. This is used to expose the exception types to C extension
// code (i.e. so we can raise custom exceptions via SWIG).
//
// Init() must be called exactly once at the beginning of the process before
// Lookup() can be used.
//
// Example usage:
// TF_Status* status = TF_NewStatus();
// TF_Foo(..., status);
//
// if (TF_GetCode(status) != TF_OK) {
// PyObject* exc_type = PyExceptionRegistry::Lookup(TF_GetCode(status));
// // Arguments to OpError base class. Set `node_def` and `op` to None.
// PyObject* args =
// Py_BuildValue("sss", nullptr, nullptr, TF_Message(status));
// PyErr_SetObject(exc_type, args);
// Py_DECREF(args);
// TF_DeleteStatus(status);
// return NULL;
// }
class PyExceptionRegistry {
public:
// Initializes the process-wide registry. Should be called exactly once near
// the beginning of the process. The arguments are the various Python
// exception types (e.g. `cancelled_exc` corresponds to
// errors.CancelledError).
static void Init(PyObject* code_to_exc_type_map);
// Returns the Python exception type corresponding to `code`. Init() must be
// called before using this function. `code` should not be TF_OK.
static PyObject* Lookup(TF_Code code);
private:
static PyExceptionRegistry* singleton_;
PyExceptionRegistry() = default;
// Maps error codes to the corresponding Python exception type.
std::map<TF_Code, PyObject*> exc_types_;
};
} // namespace tensorflow
#endif // TENSORFLOW_PYTHON_LIB_CORE_PY_EXCEPTION_REGISTRY_H_

View File

@ -0,0 +1,28 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
%include "tensorflow/python/platform/base.i"
%{
#include "tensorflow/python/lib/core/py_exception_registry.h"
%}
%ignoreall
%unignore tensorflow::PyExceptionRegistry;
%unignore tensorflow::PyExceptionRegistry::Init;
%include "tensorflow/python/lib/core/py_exception_registry.h"
%unignoreall

View File

@ -25,6 +25,7 @@ limitations under the License.
%include "tensorflow/python/util/tfprof.i"
%include "tensorflow/python/lib/core/py_func.i"
%include "tensorflow/python/lib/core/py_exception_registry.i"
%include "tensorflow/python/lib/io/py_record_reader.i"
%include "tensorflow/python/lib/io/py_record_writer.i"
@ -54,4 +55,3 @@ limitations under the License.
%include "tensorflow/python/grappler/tf_optimizer.i"
%include "tensorflow/python/grappler/cost_analyzer.i"
%include "tensorflow/python/grappler/model_analyzer.i"