diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index 031545531f1..0789eab6270 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -180,6 +180,15 @@ int ConvertDeviceName(PyObject* obj, const char** dst) { return 1; } +void RaiseExceptionTypeFromTFStatus(TF_Status* status) { + TF_Code code = TF_GetCode(status); + PyObject* exception = tensorflow::PyExceptionRegistry::Lookup(code); + PyErr_SetObject(exception, + pybind11::make_tuple(pybind11::none(), pybind11::none(), + TF_Message(status)) + .ptr()); +} + } // namespace namespace tensorflow { @@ -305,13 +314,7 @@ TFE_TensorHandle* ConvertToEagerTensorUncached(TFE_Context* ctx, device_name, status.get())); const TF_Code code = TF_GetCode(status.get()); if (code != TF_OK) { - // Instead of raising a generic RuntimeError, raise an exception type - // based on the status error code. - PyObject* exception = PyExceptionRegistry::Lookup(code); - PyErr_SetObject(exception, - pybind11::make_tuple(pybind11::none(), pybind11::none(), - TF_Message(status.get())) - .ptr()); + RaiseExceptionTypeFromTFStatus(status.get()); return nullptr; } } @@ -512,7 +515,9 @@ static PyObject* EagerTensor_datatype_enum(EagerTensor* self) { static PyObject* EagerTensor_shape_tuple(EagerTensor* self) { auto handle = self->handle; int n = TFE_TensorHandleNumDims(handle, &self->status); - if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) { + TF_Code code = TF_GetCode(&self->status); + if (code != TF_OK) { + RaiseExceptionTypeFromTFStatus(&self->status); // Cleanup self->status before returning. self->status.status = tensorflow::Status::OK(); return nullptr; @@ -522,13 +527,18 @@ static PyObject* EagerTensor_shape_tuple(EagerTensor* self) { for (int i = 0; i < n; ++i) { PyObject* dim = PyLong_FromLongLong(TFE_TensorHandleDim(handle, i, &self->status)); - if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr) || - dim == nullptr || PyTuple_SetItem(shape, i, dim) != 0) { + code = TF_GetCode(&self->status); + if (code != TF_OK || dim == nullptr || + PyTuple_SetItem(shape, i, dim) != 0) { + if (code != TF_OK) { + RaiseExceptionTypeFromTFStatus(&self->status); + } else { + PyErr_SetString(PyExc_RuntimeError, "Error while creating shape"); + } // Cleanup self->status before returning. self->status.status = tensorflow::Status::OK(); Py_DECREF(shape); if (dim != nullptr) Py_DECREF(dim); - PyErr_SetString(PyExc_RuntimeError, "Error while creating shape"); return nullptr; } }