Throw relevant exceptions based on status when copying Eager tensors.
Instead of blindly throwing a RuntimeError, throw a registered OpError exception based on the status when executing EagerTensor `.numpy()`. PiperOrigin-RevId: 313405387 Change-Id: I6ee8e804f96c9baf0c1af77a958bb1f4b26e614b
This commit is contained in:
parent
96ba1c3609
commit
b847ff9b30
@ -2,6 +2,8 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
|
||||
# buildifier: disable=same-origin-load
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_test")
|
||||
|
||||
# buildifier: disable=same-origin-load
|
||||
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
|
||||
load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_test")
|
||||
load(
|
||||
@ -28,6 +30,10 @@ cc_library(
|
||||
"pywrap_tensor_conversion.h",
|
||||
"pywrap_tfe.h",
|
||||
],
|
||||
copts = ["-fexceptions"],
|
||||
features = [
|
||||
"-use_header_modules", # Required for pybind11
|
||||
],
|
||||
visibility = [
|
||||
"//learning/deepmind/courier:__subpackages__",
|
||||
"//tensorflow:internal",
|
||||
@ -54,6 +60,7 @@ cc_library(
|
||||
"//tensorflow/python:ndarray_tensor",
|
||||
"//tensorflow/python:ndarray_tensor_bridge",
|
||||
"//tensorflow/python:numpy_lib",
|
||||
"//tensorflow/python:py_exception_registry",
|
||||
"//tensorflow/python:py_seq_tensor",
|
||||
"//tensorflow/python:safe_ptr",
|
||||
"//third_party/py/numpy:headers",
|
||||
@ -63,6 +70,7 @@ cc_library(
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:variant",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -21,9 +21,11 @@ limitations under the License.
|
||||
#include <cmath>
|
||||
|
||||
#include "structmember.h" // NOLINT // For PyMemberDef
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
@ -32,6 +34,7 @@ limitations under the License.
|
||||
#include "tensorflow/python/lib/core/ndarray_tensor.h"
|
||||
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
|
||||
#include "tensorflow/python/lib/core/numpy.h"
|
||||
#include "tensorflow/python/lib/core/py_exception_registry.h"
|
||||
#include "tensorflow/python/lib/core/py_seq_tensor.h"
|
||||
#include "tensorflow/python/lib/core/safe_ptr.h"
|
||||
|
||||
@ -300,7 +303,15 @@ TFE_TensorHandle* ConvertToEagerTensorUncached(TFE_Context* ctx,
|
||||
strstr(device_name, "/device:CPU:0") != nullptr) {
|
||||
handle = make_safe(TFE_TensorHandleCopyToDevice(handle.get(), ctx,
|
||||
device_name, status.get()));
|
||||
if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_RuntimeError)) {
|
||||
const TF_Code code = TF_GetCode(status.get());
|
||||
if (code != TF_OK) {
|
||||
// Instead of raising a generic RuntimeError, raise an exception type
|
||||
// based on the status error code.
|
||||
PyObject* exception = PyExceptionRegistry::Lookup(code);
|
||||
PyErr_SetObject(exception,
|
||||
pybind11::make_tuple(pybind11::none(), pybind11::none(),
|
||||
TF_Message(status.get()))
|
||||
.ptr());
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user