Inline status in EagerTensor

PiperOrigin-RevId: 267284080
This commit is contained in:
Gaurav Jain 2019-09-04 19:49:45 -07:00 committed by TensorFlower Gardener
parent b3cde1ae7e
commit 2c39806e95
2 changed files with 37 additions and 38 deletions

View File

@ -443,7 +443,7 @@ typedef struct EagerTensor {
// Status objects on different functions that operate on EagerTensor and need
// to use a TF_Status object. However note that accesses to `status` are not
// thread-safe.
TF_Status* status;
TF_Status status;
// The eager Context (from eager/context.py) used by this Tensor.
// This is currently used only to make sure context outlives TensorHandles.
@ -503,7 +503,7 @@ int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
self->handle_data = Py_None;
Py_INCREF(Py_None);
self->tensor_shape = Py_None;
self->status = TF_NewStatus();
self->status.status = tensorflow::Status::OK();
self->dict = nullptr;
self->weakreflist = nullptr;
self->context = nullptr;
@ -543,7 +543,6 @@ void EagerTensor_dealloc(EagerTensor* self) {
// Needs to happen before any actual destruction.
PyObject_ClearWeakRefs((PyObject*)self);
TF_DeleteStatus(self->status);
Py_DECREF(self->handle_data);
Py_DECREF(self->tensor_shape);
// If an attribute dictionary has been created, release it. Note that this
@ -579,21 +578,21 @@ static PyObject* EagerTensor_datatype_enum(EagerTensor* self) {
// Getter for `_shape_tuple`.
static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
auto handle = self->handle;
int n = TFE_TensorHandleNumDims(handle, self->status);
if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr)) {
int n = TFE_TensorHandleNumDims(handle, &self->status);
if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
// Cleanup self->status before returning.
TF_SetStatus(self->status, TF_OK, "");
self->status.status = tensorflow::Status::OK();
return nullptr;
}
PyObject* shape = PyTuple_New(n);
if (PyErr_Occurred()) return nullptr;
for (int i = 0; i < n; ++i) {
PyObject* dim =
PyLong_FromLongLong(TFE_TensorHandleDim(handle, i, self->status));
if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr) ||
PyLong_FromLongLong(TFE_TensorHandleDim(handle, i, &self->status));
if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr) ||
dim == nullptr || PyTuple_SetItem(shape, i, dim) != 0) {
// Cleanup self->status before returning.
TF_SetStatus(self->status, TF_OK, "");
self->status.status = tensorflow::Status::OK();
Py_DECREF(shape);
if (dim != nullptr) Py_DECREF(dim);
PyErr_SetString(PyExc_RuntimeError, "Error while creating shape");
@ -605,10 +604,10 @@ static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
// Getter for `_rank`.
static PyObject* EagerTensor_rank(EagerTensor* self) {
int num_dims = TFE_TensorHandleNumDims(self->handle, self->status);
if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr)) {
int num_dims = TFE_TensorHandleNumDims(self->handle, &self->status);
if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
// Cleanup self->status before returning.
TF_SetStatus(self->status, TF_OK, "");
self->status.status = tensorflow::Status::OK();
return nullptr;
}
#if PY_MAJOR_VERSION < 3
@ -621,10 +620,10 @@ static PyObject* EagerTensor_rank(EagerTensor* self) {
// Getter for `_num_elements`.
static PyObject* EagerTensor_num_elements(EagerTensor* self) {
auto handle = self->handle;
int n = TFE_TensorHandleNumElements(handle, self->status);
if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr)) {
int n = TFE_TensorHandleNumElements(handle, &self->status);
if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
// Cleanup self->status before returning.
TF_SetStatus(self->status, TF_OK, "");
self->status.status = tensorflow::Status::OK();
return nullptr;
}
return PyLong_FromLongLong(n);
@ -670,10 +669,11 @@ static PyObject* EagerTensor_copy_to_device(EagerTensor* self, PyObject* args,
// Note that this is a shallow copy and will share the underlying buffer
// if copying to the same device.
TFE_TensorHandle* handle = TFE_TensorHandleCopyToDevice(
self->handle, GetContextHandle(self->context), device_name, self->status);
if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_RuntimeError)) {
self->handle, GetContextHandle(self->context), device_name,
&self->status);
if (MaybeRaiseExceptionFromTFStatus(&self->status, PyExc_RuntimeError)) {
// Cleanup self->status before returning.
TF_SetStatus(self->status, TF_OK, "");
self->status.status = tensorflow::Status::OK();
return nullptr;
}
@ -686,11 +686,11 @@ static PyObject* EagerTensor_copy_to_device(EagerTensor* self, PyObject* args,
// other.
// Note that if `self` is not on CPU, we raise an Exception.
static PyObject* EagerTensor_numpy_internal(EagerTensor* self) {
auto* py_array = TFE_TensorHandleToNumpy(self->handle, self->status);
if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr)) {
auto* py_array = TFE_TensorHandleToNumpy(self->handle, &self->status);
if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
Py_XDECREF(py_array);
// Cleanup self->status before returning.
TF_SetStatus(self->status, TF_OK, "");
self->status.status = tensorflow::Status::OK();
return nullptr;
} else {
return PyArray_Return(reinterpret_cast<PyArrayObject*>(py_array));
@ -699,10 +699,10 @@ static PyObject* EagerTensor_numpy_internal(EagerTensor* self) {
// Getter `device`.
static PyObject* EagerTensor_device(EagerTensor* self) {
const char* device = TFE_TensorHandleDeviceName(self->handle, self->status);
if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
const char* device = TFE_TensorHandleDeviceName(self->handle, &self->status);
if (MaybeRaiseExceptionFromTFStatus(&self->status, PyExc_ValueError)) {
// Cleanup self->status before returning.
TF_SetStatus(self->status, TF_OK, "");
self->status.status = tensorflow::Status::OK();
return nullptr;
}
#if PY_MAJOR_VERSION >= 3
@ -715,10 +715,10 @@ static PyObject* EagerTensor_device(EagerTensor* self) {
// Getter `backing_device`.
static PyObject* EagerTensor_backing_device(EagerTensor* self) {
const char* device =
TFE_TensorHandleBackingDeviceName(self->handle, self->status);
if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
TFE_TensorHandleBackingDeviceName(self->handle, &self->status);
if (MaybeRaiseExceptionFromTFStatus(&self->status, PyExc_ValueError)) {
// Cleanup self->status before returning.
TF_SetStatus(self->status, TF_OK, "");
self->status.status = tensorflow::Status::OK();
return nullptr;
}
#if PY_MAJOR_VERSION >= 3
@ -779,10 +779,10 @@ static int EagerTensor_getbuffer(EagerTensor* self, Py_buffer* view,
// DT_STRING so the following is only slightly slower than a NumPy-free
// implementation.
auto py_array = tensorflow::make_safe(
TFE_TensorHandleToNumpy(self->handle, self->status));
if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_BufferError)) {
TFE_TensorHandleToNumpy(self->handle, &self->status));
if (MaybeRaiseExceptionFromTFStatus(&self->status, PyExc_BufferError)) {
// Cleanup self->status before returning.
TF_SetStatus(self->status, TF_OK, "");
self->status.status = tensorflow::Status::OK();
return -1;
}
if (PyObject_GetBuffer(py_array.get(), view, flags) < 0) {
@ -899,7 +899,7 @@ PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle) {
Py_INCREF(Py_None);
t->tensor_shape = Py_None;
t->handle = handle;
t->status = TF_NewStatus();
t->status.status = tensorflow::Status::OK();
t->weakreflist = nullptr;
PyObject* py_context = GetPyEagerContext();
if (py_context == nullptr) {
@ -927,17 +927,16 @@ tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor) {
reinterpret_cast<const EagerTensor*>(tensor)->handle));
}
tensorflow::int64 PyEagerTensor_NumElements(const PyObject* tensor) {
tensorflow::int64 PyEagerTensor_NumElements(PyObject* tensor) {
DCHECK(EagerTensor_CheckExact(tensor));
const EagerTensor* as_c_eager_tensor =
reinterpret_cast<const EagerTensor*>(tensor);
EagerTensor* as_c_eager_tensor = reinterpret_cast<EagerTensor*>(tensor);
tensorflow::int64 result = TFE_TensorHandleNumElements(
as_c_eager_tensor->handle, as_c_eager_tensor->status);
as_c_eager_tensor->handle, &as_c_eager_tensor->status);
if (MaybeRaiseExceptionFromTFStatus(as_c_eager_tensor->status,
if (MaybeRaiseExceptionFromTFStatus(&as_c_eager_tensor->status,
PyExc_ValueError)) {
// Cleanup status before returning.
TF_SetStatus(as_c_eager_tensor->status, TF_OK, "");
as_c_eager_tensor->status.status = tensorflow::Status::OK();
return -1;
}

View File

@ -23,7 +23,7 @@ limitations under the License.
bool EagerTensor_CheckExact(const PyObject* o);
tensorflow::int64 PyEagerTensor_ID(const PyObject* tensor);
tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor);
tensorflow::int64 PyEagerTensor_NumElements(const PyObject* tensor);
tensorflow::int64 PyEagerTensor_NumElements(PyObject* tensor);
namespace tensorflow {