From b847ff9b3067a101296d1d857358b5bdeefd2342 Mon Sep 17 00:00:00 2001 From: Haoyu Zhang Date: Wed, 27 May 2020 09:37:50 -0700 Subject: [PATCH] Throw relevant exceptions based on status when copying Eager tensors. Instead of blindly throwing a RuntimeError, throw a registered OpError exception based on the status when executing EagerTensor `.numpy()`. PiperOrigin-RevId: 313405387 Change-Id: I6ee8e804f96c9baf0c1af77a958bb1f4b26e614b --- tensorflow/python/eager/BUILD | 8 ++++++++ tensorflow/python/eager/pywrap_tensor.cc | 13 ++++++++++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index adc30eab5e1..a44d8a493c1 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -2,6 +2,8 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test") # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_py_test") + +# buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension") load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_test") load( @@ -28,6 +30,10 @@ cc_library( "pywrap_tensor_conversion.h", "pywrap_tfe.h", ], + copts = ["-fexceptions"], + features = [ + "-use_header_modules", # Required for pybind11 + ], visibility = [ "//learning/deepmind/courier:__subpackages__", "//tensorflow:internal", @@ -54,6 +60,7 @@ cc_library( "//tensorflow/python:ndarray_tensor", "//tensorflow/python:ndarray_tensor_bridge", "//tensorflow/python:numpy_lib", + "//tensorflow/python:py_exception_registry", "//tensorflow/python:py_seq_tensor", "//tensorflow/python:safe_ptr", "//third_party/py/numpy:headers", @@ -63,6 +70,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:variant", + "@pybind11", ], ) diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index b209ddb6162..031545531f1 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -21,9 +21,11 @@ limitations under the License. #include #include "structmember.h" // NOLINT // For PyMemberDef +#include "pybind11/pybind11.h" #include "tensorflow/c/c_api.h" #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/c/tf_status.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -32,6 +34,7 @@ limitations under the License. #include "tensorflow/python/lib/core/ndarray_tensor.h" #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h" #include "tensorflow/python/lib/core/numpy.h" +#include "tensorflow/python/lib/core/py_exception_registry.h" #include "tensorflow/python/lib/core/py_seq_tensor.h" #include "tensorflow/python/lib/core/safe_ptr.h" @@ -300,7 +303,15 @@ TFE_TensorHandle* ConvertToEagerTensorUncached(TFE_Context* ctx, strstr(device_name, "/device:CPU:0") != nullptr) { handle = make_safe(TFE_TensorHandleCopyToDevice(handle.get(), ctx, device_name, status.get())); - if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_RuntimeError)) { + const TF_Code code = TF_GetCode(status.get()); + if (code != TF_OK) { + // Instead of raising a generic RuntimeError, raise an exception type + // based on the status error code. + PyObject* exception = PyExceptionRegistry::Lookup(code); + PyErr_SetObject(exception, + pybind11::make_tuple(pybind11::none(), pybind11::none(), + TF_Message(status.get())) + .ptr()); return nullptr; } }