Inline status in EagerTensor

PiperOrigin-RevId: 267284080
This commit is contained in:
Gaurav Jain 2019-09-04 19:49:45 -07:00 committed by TensorFlower Gardener
parent b3cde1ae7e
commit 2c39806e95
2 changed files with 37 additions and 38 deletions

View File

@ -443,7 +443,7 @@ typedef struct EagerTensor {
// Status objects on different functions that operate on EagerTensor and need // Status objects on different functions that operate on EagerTensor and need
// to use a TF_Status object. However note that accesses to `status` are not // to use a TF_Status object. However note that accesses to `status` are not
// thread-safe. // thread-safe.
TF_Status* status; TF_Status status;
// The eager Context (from eager/context.py) used by this Tensor. // The eager Context (from eager/context.py) used by this Tensor.
// This is currently used only to make sure context outlives TensorHandles. // This is currently used only to make sure context outlives TensorHandles.
@ -503,7 +503,7 @@ int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
self->handle_data = Py_None; self->handle_data = Py_None;
Py_INCREF(Py_None); Py_INCREF(Py_None);
self->tensor_shape = Py_None; self->tensor_shape = Py_None;
self->status = TF_NewStatus(); self->status.status = tensorflow::Status::OK();
self->dict = nullptr; self->dict = nullptr;
self->weakreflist = nullptr; self->weakreflist = nullptr;
self->context = nullptr; self->context = nullptr;
@ -543,7 +543,6 @@ void EagerTensor_dealloc(EagerTensor* self) {
// Needs to happen before any actual destruction. // Needs to happen before any actual destruction.
PyObject_ClearWeakRefs((PyObject*)self); PyObject_ClearWeakRefs((PyObject*)self);
TF_DeleteStatus(self->status);
Py_DECREF(self->handle_data); Py_DECREF(self->handle_data);
Py_DECREF(self->tensor_shape); Py_DECREF(self->tensor_shape);
// If an attribute dictionary has been created, release it. Note that this // If an attribute dictionary has been created, release it. Note that this
@ -579,21 +578,21 @@ static PyObject* EagerTensor_datatype_enum(EagerTensor* self) {
// Getter for `_shape_tuple`. // Getter for `_shape_tuple`.
static PyObject* EagerTensor_shape_tuple(EagerTensor* self) { static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
auto handle = self->handle; auto handle = self->handle;
int n = TFE_TensorHandleNumDims(handle, self->status); int n = TFE_TensorHandleNumDims(handle, &self->status);
if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr)) { if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
// Cleanup self->status before returning. // Cleanup self->status before returning.
TF_SetStatus(self->status, TF_OK, ""); self->status.status = tensorflow::Status::OK();
return nullptr; return nullptr;
} }
PyObject* shape = PyTuple_New(n); PyObject* shape = PyTuple_New(n);
if (PyErr_Occurred()) return nullptr; if (PyErr_Occurred()) return nullptr;
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
PyObject* dim = PyObject* dim =
PyLong_FromLongLong(TFE_TensorHandleDim(handle, i, self->status)); PyLong_FromLongLong(TFE_TensorHandleDim(handle, i, &self->status));
if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr) || if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr) ||
dim == nullptr || PyTuple_SetItem(shape, i, dim) != 0) { dim == nullptr || PyTuple_SetItem(shape, i, dim) != 0) {
// Cleanup self->status before returning. // Cleanup self->status before returning.
TF_SetStatus(self->status, TF_OK, ""); self->status.status = tensorflow::Status::OK();
Py_DECREF(shape); Py_DECREF(shape);
if (dim != nullptr) Py_DECREF(dim); if (dim != nullptr) Py_DECREF(dim);
PyErr_SetString(PyExc_RuntimeError, "Error while creating shape"); PyErr_SetString(PyExc_RuntimeError, "Error while creating shape");
@ -605,10 +604,10 @@ static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
// Getter for `_rank`. // Getter for `_rank`.
static PyObject* EagerTensor_rank(EagerTensor* self) { static PyObject* EagerTensor_rank(EagerTensor* self) {
int num_dims = TFE_TensorHandleNumDims(self->handle, self->status); int num_dims = TFE_TensorHandleNumDims(self->handle, &self->status);
if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr)) { if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
// Cleanup self->status before returning. // Cleanup self->status before returning.
TF_SetStatus(self->status, TF_OK, ""); self->status.status = tensorflow::Status::OK();
return nullptr; return nullptr;
} }
#if PY_MAJOR_VERSION < 3 #if PY_MAJOR_VERSION < 3
@ -621,10 +620,10 @@ static PyObject* EagerTensor_rank(EagerTensor* self) {
// Getter for `_num_elements`. // Getter for `_num_elements`.
static PyObject* EagerTensor_num_elements(EagerTensor* self) { static PyObject* EagerTensor_num_elements(EagerTensor* self) {
auto handle = self->handle; auto handle = self->handle;
int n = TFE_TensorHandleNumElements(handle, self->status); int n = TFE_TensorHandleNumElements(handle, &self->status);
if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr)) { if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
// Cleanup self->status before returning. // Cleanup self->status before returning.
TF_SetStatus(self->status, TF_OK, ""); self->status.status = tensorflow::Status::OK();
return nullptr; return nullptr;
} }
return PyLong_FromLongLong(n); return PyLong_FromLongLong(n);
@ -670,10 +669,11 @@ static PyObject* EagerTensor_copy_to_device(EagerTensor* self, PyObject* args,
// Note that this is a shallow copy and will share the underlying buffer // Note that this is a shallow copy and will share the underlying buffer
// if copying to the same device. // if copying to the same device.
TFE_TensorHandle* handle = TFE_TensorHandleCopyToDevice( TFE_TensorHandle* handle = TFE_TensorHandleCopyToDevice(
self->handle, GetContextHandle(self->context), device_name, self->status); self->handle, GetContextHandle(self->context), device_name,
if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_RuntimeError)) { &self->status);
if (MaybeRaiseExceptionFromTFStatus(&self->status, PyExc_RuntimeError)) {
// Cleanup self->status before returning. // Cleanup self->status before returning.
TF_SetStatus(self->status, TF_OK, ""); self->status.status = tensorflow::Status::OK();
return nullptr; return nullptr;
} }
@ -686,11 +686,11 @@ static PyObject* EagerTensor_copy_to_device(EagerTensor* self, PyObject* args,
// other. // other.
// Note that if `self` is not on CPU, we raise an Exception. // Note that if `self` is not on CPU, we raise an Exception.
static PyObject* EagerTensor_numpy_internal(EagerTensor* self) { static PyObject* EagerTensor_numpy_internal(EagerTensor* self) {
auto* py_array = TFE_TensorHandleToNumpy(self->handle, self->status); auto* py_array = TFE_TensorHandleToNumpy(self->handle, &self->status);
if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr)) { if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
Py_XDECREF(py_array); Py_XDECREF(py_array);
// Cleanup self->status before returning. // Cleanup self->status before returning.
TF_SetStatus(self->status, TF_OK, ""); self->status.status = tensorflow::Status::OK();
return nullptr; return nullptr;
} else { } else {
return PyArray_Return(reinterpret_cast<PyArrayObject*>(py_array)); return PyArray_Return(reinterpret_cast<PyArrayObject*>(py_array));
@ -699,10 +699,10 @@ static PyObject* EagerTensor_numpy_internal(EagerTensor* self) {
// Getter `device`. // Getter `device`.
static PyObject* EagerTensor_device(EagerTensor* self) { static PyObject* EagerTensor_device(EagerTensor* self) {
const char* device = TFE_TensorHandleDeviceName(self->handle, self->status); const char* device = TFE_TensorHandleDeviceName(self->handle, &self->status);
if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) { if (MaybeRaiseExceptionFromTFStatus(&self->status, PyExc_ValueError)) {
// Cleanup self->status before returning. // Cleanup self->status before returning.
TF_SetStatus(self->status, TF_OK, ""); self->status.status = tensorflow::Status::OK();
return nullptr; return nullptr;
} }
#if PY_MAJOR_VERSION >= 3 #if PY_MAJOR_VERSION >= 3
@ -715,10 +715,10 @@ static PyObject* EagerTensor_device(EagerTensor* self) {
// Getter `backing_device`. // Getter `backing_device`.
static PyObject* EagerTensor_backing_device(EagerTensor* self) { static PyObject* EagerTensor_backing_device(EagerTensor* self) {
const char* device = const char* device =
TFE_TensorHandleBackingDeviceName(self->handle, self->status); TFE_TensorHandleBackingDeviceName(self->handle, &self->status);
if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) { if (MaybeRaiseExceptionFromTFStatus(&self->status, PyExc_ValueError)) {
// Cleanup self->status before returning. // Cleanup self->status before returning.
TF_SetStatus(self->status, TF_OK, ""); self->status.status = tensorflow::Status::OK();
return nullptr; return nullptr;
} }
#if PY_MAJOR_VERSION >= 3 #if PY_MAJOR_VERSION >= 3
@ -779,10 +779,10 @@ static int EagerTensor_getbuffer(EagerTensor* self, Py_buffer* view,
// DT_STRING so the following is only slightly slower than a NumPy-free // DT_STRING so the following is only slightly slower than a NumPy-free
// implementation. // implementation.
auto py_array = tensorflow::make_safe( auto py_array = tensorflow::make_safe(
TFE_TensorHandleToNumpy(self->handle, self->status)); TFE_TensorHandleToNumpy(self->handle, &self->status));
if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_BufferError)) { if (MaybeRaiseExceptionFromTFStatus(&self->status, PyExc_BufferError)) {
// Cleanup self->status before returning. // Cleanup self->status before returning.
TF_SetStatus(self->status, TF_OK, ""); self->status.status = tensorflow::Status::OK();
return -1; return -1;
} }
if (PyObject_GetBuffer(py_array.get(), view, flags) < 0) { if (PyObject_GetBuffer(py_array.get(), view, flags) < 0) {
@ -899,7 +899,7 @@ PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle) {
Py_INCREF(Py_None); Py_INCREF(Py_None);
t->tensor_shape = Py_None; t->tensor_shape = Py_None;
t->handle = handle; t->handle = handle;
t->status = TF_NewStatus(); t->status.status = tensorflow::Status::OK();
t->weakreflist = nullptr; t->weakreflist = nullptr;
PyObject* py_context = GetPyEagerContext(); PyObject* py_context = GetPyEagerContext();
if (py_context == nullptr) { if (py_context == nullptr) {
@ -927,17 +927,16 @@ tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor) {
reinterpret_cast<const EagerTensor*>(tensor)->handle)); reinterpret_cast<const EagerTensor*>(tensor)->handle));
} }
tensorflow::int64 PyEagerTensor_NumElements(const PyObject* tensor) { tensorflow::int64 PyEagerTensor_NumElements(PyObject* tensor) {
DCHECK(EagerTensor_CheckExact(tensor)); DCHECK(EagerTensor_CheckExact(tensor));
const EagerTensor* as_c_eager_tensor = EagerTensor* as_c_eager_tensor = reinterpret_cast<EagerTensor*>(tensor);
reinterpret_cast<const EagerTensor*>(tensor);
tensorflow::int64 result = TFE_TensorHandleNumElements( tensorflow::int64 result = TFE_TensorHandleNumElements(
as_c_eager_tensor->handle, as_c_eager_tensor->status); as_c_eager_tensor->handle, &as_c_eager_tensor->status);
if (MaybeRaiseExceptionFromTFStatus(as_c_eager_tensor->status, if (MaybeRaiseExceptionFromTFStatus(&as_c_eager_tensor->status,
PyExc_ValueError)) { PyExc_ValueError)) {
// Cleanup status before returning. // Cleanup status before returning.
TF_SetStatus(as_c_eager_tensor->status, TF_OK, ""); as_c_eager_tensor->status.status = tensorflow::Status::OK();
return -1; return -1;
} }

View File

@ -23,7 +23,7 @@ limitations under the License.
bool EagerTensor_CheckExact(const PyObject* o); bool EagerTensor_CheckExact(const PyObject* o);
tensorflow::int64 PyEagerTensor_ID(const PyObject* tensor); tensorflow::int64 PyEagerTensor_ID(const PyObject* tensor);
tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor); tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor);
tensorflow::int64 PyEagerTensor_NumElements(const PyObject* tensor); tensorflow::int64 PyEagerTensor_NumElements(PyObject* tensor);
namespace tensorflow { namespace tensorflow {