Throw relevant exceptions based on status when copying Eager tensors.

Instead of blindly throwing a RuntimeError, throw a registered OpError exception based on the status when executing EagerTensor `.numpy()`.

PiperOrigin-RevId: 313405387
Change-Id: I6ee8e804f96c9baf0c1af77a958bb1f4b26e614b
This commit is contained in:
Haoyu Zhang 2020-05-27 09:37:50 -07:00 committed by TensorFlower Gardener
parent 96ba1c3609
commit b847ff9b30
2 changed files with 20 additions and 1 deletions

View File

@ -2,6 +2,8 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test")
# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "tf_py_test")
# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_test")
load(
@ -28,6 +30,10 @@ cc_library(
"pywrap_tensor_conversion.h",
"pywrap_tfe.h",
],
copts = ["-fexceptions"],
features = [
"-use_header_modules", # Required for pybind11
],
visibility = [
"//learning/deepmind/courier:__subpackages__",
"//tensorflow:internal",
@ -54,6 +60,7 @@ cc_library(
"//tensorflow/python:ndarray_tensor",
"//tensorflow/python:ndarray_tensor_bridge",
"//tensorflow/python:numpy_lib",
"//tensorflow/python:py_exception_registry",
"//tensorflow/python:py_seq_tensor",
"//tensorflow/python:safe_ptr",
"//third_party/py/numpy:headers",
@ -63,6 +70,7 @@ cc_library(
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:variant",
"@pybind11",
],
)

View File

@ -21,9 +21,11 @@ limitations under the License.
#include <cmath>
#include "structmember.h" // NOLINT // For PyMemberDef
#include "pybind11/pybind11.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/strings/strcat.h"
@ -32,6 +34,7 @@ limitations under the License.
#include "tensorflow/python/lib/core/ndarray_tensor.h"
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
#include "tensorflow/python/lib/core/numpy.h"
#include "tensorflow/python/lib/core/py_exception_registry.h"
#include "tensorflow/python/lib/core/py_seq_tensor.h"
#include "tensorflow/python/lib/core/safe_ptr.h"
@ -300,7 +303,15 @@ TFE_TensorHandle* ConvertToEagerTensorUncached(TFE_Context* ctx,
strstr(device_name, "/device:CPU:0") != nullptr) {
handle = make_safe(TFE_TensorHandleCopyToDevice(handle.get(), ctx,
device_name, status.get()));
if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_RuntimeError)) {
const TF_Code code = TF_GetCode(status.get());
if (code != TF_OK) {
// Instead of raising a generic RuntimeError, raise an exception type
// based on the status error code.
PyObject* exception = PyExceptionRegistry::Lookup(code);
PyErr_SetObject(exception,
pybind11::make_tuple(pybind11::none(), pybind11::none(),
TF_Message(status.get()))
.ptr());
return nullptr;
}
}