Raise error type corresponding to status code instead of generic RuntimeError.
PiperOrigin-RevId: 317367352 Change-Id: I35378b88a33269ac225632ae848398b819c694a1
This commit is contained in:
parent
c575e2ba93
commit
72d30dfb8b
@ -180,6 +180,15 @@ int ConvertDeviceName(PyObject* obj, const char** dst) {
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void RaiseExceptionTypeFromTFStatus(TF_Status* status) {
|
||||||
|
TF_Code code = TF_GetCode(status);
|
||||||
|
PyObject* exception = tensorflow::PyExceptionRegistry::Lookup(code);
|
||||||
|
PyErr_SetObject(exception,
|
||||||
|
pybind11::make_tuple(pybind11::none(), pybind11::none(),
|
||||||
|
TF_Message(status))
|
||||||
|
.ptr());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -305,13 +314,7 @@ TFE_TensorHandle* ConvertToEagerTensorUncached(TFE_Context* ctx,
|
|||||||
device_name, status.get()));
|
device_name, status.get()));
|
||||||
const TF_Code code = TF_GetCode(status.get());
|
const TF_Code code = TF_GetCode(status.get());
|
||||||
if (code != TF_OK) {
|
if (code != TF_OK) {
|
||||||
// Instead of raising a generic RuntimeError, raise an exception type
|
RaiseExceptionTypeFromTFStatus(status.get());
|
||||||
// based on the status error code.
|
|
||||||
PyObject* exception = PyExceptionRegistry::Lookup(code);
|
|
||||||
PyErr_SetObject(exception,
|
|
||||||
pybind11::make_tuple(pybind11::none(), pybind11::none(),
|
|
||||||
TF_Message(status.get()))
|
|
||||||
.ptr());
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -512,7 +515,9 @@ static PyObject* EagerTensor_datatype_enum(EagerTensor* self) {
|
|||||||
static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
|
static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
|
||||||
auto handle = self->handle;
|
auto handle = self->handle;
|
||||||
int n = TFE_TensorHandleNumDims(handle, &self->status);
|
int n = TFE_TensorHandleNumDims(handle, &self->status);
|
||||||
if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
|
TF_Code code = TF_GetCode(&self->status);
|
||||||
|
if (code != TF_OK) {
|
||||||
|
RaiseExceptionTypeFromTFStatus(&self->status);
|
||||||
// Cleanup self->status before returning.
|
// Cleanup self->status before returning.
|
||||||
self->status.status = tensorflow::Status::OK();
|
self->status.status = tensorflow::Status::OK();
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@ -522,13 +527,18 @@ static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
|
|||||||
for (int i = 0; i < n; ++i) {
|
for (int i = 0; i < n; ++i) {
|
||||||
PyObject* dim =
|
PyObject* dim =
|
||||||
PyLong_FromLongLong(TFE_TensorHandleDim(handle, i, &self->status));
|
PyLong_FromLongLong(TFE_TensorHandleDim(handle, i, &self->status));
|
||||||
if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr) ||
|
code = TF_GetCode(&self->status);
|
||||||
dim == nullptr || PyTuple_SetItem(shape, i, dim) != 0) {
|
if (code != TF_OK || dim == nullptr ||
|
||||||
|
PyTuple_SetItem(shape, i, dim) != 0) {
|
||||||
|
if (code != TF_OK) {
|
||||||
|
RaiseExceptionTypeFromTFStatus(&self->status);
|
||||||
|
} else {
|
||||||
|
PyErr_SetString(PyExc_RuntimeError, "Error while creating shape");
|
||||||
|
}
|
||||||
// Cleanup self->status before returning.
|
// Cleanup self->status before returning.
|
||||||
self->status.status = tensorflow::Status::OK();
|
self->status.status = tensorflow::Status::OK();
|
||||||
Py_DECREF(shape);
|
Py_DECREF(shape);
|
||||||
if (dim != nullptr) Py_DECREF(dim);
|
if (dim != nullptr) Py_DECREF(dim);
|
||||||
PyErr_SetString(PyExc_RuntimeError, "Error while creating shape");
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user