Inline status in EagerTensor
PiperOrigin-RevId: 267284080
This commit is contained in:
parent
b3cde1ae7e
commit
2c39806e95
@ -443,7 +443,7 @@ typedef struct EagerTensor {
|
|||||||
// Status objects on different functions that operate on EagerTensor and need
|
// Status objects on different functions that operate on EagerTensor and need
|
||||||
// to use a TF_Status object. However note that accesses to `status` are not
|
// to use a TF_Status object. However note that accesses to `status` are not
|
||||||
// thread-safe.
|
// thread-safe.
|
||||||
TF_Status* status;
|
TF_Status status;
|
||||||
|
|
||||||
// The eager Context (from eager/context.py) used by this Tensor.
|
// The eager Context (from eager/context.py) used by this Tensor.
|
||||||
// This is currently used only to make sure context outlives TensorHandles.
|
// This is currently used only to make sure context outlives TensorHandles.
|
||||||
@ -503,7 +503,7 @@ int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
|
|||||||
self->handle_data = Py_None;
|
self->handle_data = Py_None;
|
||||||
Py_INCREF(Py_None);
|
Py_INCREF(Py_None);
|
||||||
self->tensor_shape = Py_None;
|
self->tensor_shape = Py_None;
|
||||||
self->status = TF_NewStatus();
|
self->status.status = tensorflow::Status::OK();
|
||||||
self->dict = nullptr;
|
self->dict = nullptr;
|
||||||
self->weakreflist = nullptr;
|
self->weakreflist = nullptr;
|
||||||
self->context = nullptr;
|
self->context = nullptr;
|
||||||
@ -543,7 +543,6 @@ void EagerTensor_dealloc(EagerTensor* self) {
|
|||||||
// Needs to happen before any actual destruction.
|
// Needs to happen before any actual destruction.
|
||||||
PyObject_ClearWeakRefs((PyObject*)self);
|
PyObject_ClearWeakRefs((PyObject*)self);
|
||||||
|
|
||||||
TF_DeleteStatus(self->status);
|
|
||||||
Py_DECREF(self->handle_data);
|
Py_DECREF(self->handle_data);
|
||||||
Py_DECREF(self->tensor_shape);
|
Py_DECREF(self->tensor_shape);
|
||||||
// If an attribute dictionary has been created, release it. Note that this
|
// If an attribute dictionary has been created, release it. Note that this
|
||||||
@ -579,21 +578,21 @@ static PyObject* EagerTensor_datatype_enum(EagerTensor* self) {
|
|||||||
// Getter for `_shape_tuple`.
|
// Getter for `_shape_tuple`.
|
||||||
static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
|
static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
|
||||||
auto handle = self->handle;
|
auto handle = self->handle;
|
||||||
int n = TFE_TensorHandleNumDims(handle, self->status);
|
int n = TFE_TensorHandleNumDims(handle, &self->status);
|
||||||
if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr)) {
|
if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
|
||||||
// Cleanup self->status before returning.
|
// Cleanup self->status before returning.
|
||||||
TF_SetStatus(self->status, TF_OK, "");
|
self->status.status = tensorflow::Status::OK();
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
PyObject* shape = PyTuple_New(n);
|
PyObject* shape = PyTuple_New(n);
|
||||||
if (PyErr_Occurred()) return nullptr;
|
if (PyErr_Occurred()) return nullptr;
|
||||||
for (int i = 0; i < n; ++i) {
|
for (int i = 0; i < n; ++i) {
|
||||||
PyObject* dim =
|
PyObject* dim =
|
||||||
PyLong_FromLongLong(TFE_TensorHandleDim(handle, i, self->status));
|
PyLong_FromLongLong(TFE_TensorHandleDim(handle, i, &self->status));
|
||||||
if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr) ||
|
if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr) ||
|
||||||
dim == nullptr || PyTuple_SetItem(shape, i, dim) != 0) {
|
dim == nullptr || PyTuple_SetItem(shape, i, dim) != 0) {
|
||||||
// Cleanup self->status before returning.
|
// Cleanup self->status before returning.
|
||||||
TF_SetStatus(self->status, TF_OK, "");
|
self->status.status = tensorflow::Status::OK();
|
||||||
Py_DECREF(shape);
|
Py_DECREF(shape);
|
||||||
if (dim != nullptr) Py_DECREF(dim);
|
if (dim != nullptr) Py_DECREF(dim);
|
||||||
PyErr_SetString(PyExc_RuntimeError, "Error while creating shape");
|
PyErr_SetString(PyExc_RuntimeError, "Error while creating shape");
|
||||||
@ -605,10 +604,10 @@ static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
|
|||||||
|
|
||||||
// Getter for `_rank`.
|
// Getter for `_rank`.
|
||||||
static PyObject* EagerTensor_rank(EagerTensor* self) {
|
static PyObject* EagerTensor_rank(EagerTensor* self) {
|
||||||
int num_dims = TFE_TensorHandleNumDims(self->handle, self->status);
|
int num_dims = TFE_TensorHandleNumDims(self->handle, &self->status);
|
||||||
if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr)) {
|
if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
|
||||||
// Cleanup self->status before returning.
|
// Cleanup self->status before returning.
|
||||||
TF_SetStatus(self->status, TF_OK, "");
|
self->status.status = tensorflow::Status::OK();
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
#if PY_MAJOR_VERSION < 3
|
#if PY_MAJOR_VERSION < 3
|
||||||
@ -621,10 +620,10 @@ static PyObject* EagerTensor_rank(EagerTensor* self) {
|
|||||||
// Getter for `_num_elements`.
|
// Getter for `_num_elements`.
|
||||||
static PyObject* EagerTensor_num_elements(EagerTensor* self) {
|
static PyObject* EagerTensor_num_elements(EagerTensor* self) {
|
||||||
auto handle = self->handle;
|
auto handle = self->handle;
|
||||||
int n = TFE_TensorHandleNumElements(handle, self->status);
|
int n = TFE_TensorHandleNumElements(handle, &self->status);
|
||||||
if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr)) {
|
if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
|
||||||
// Cleanup self->status before returning.
|
// Cleanup self->status before returning.
|
||||||
TF_SetStatus(self->status, TF_OK, "");
|
self->status.status = tensorflow::Status::OK();
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return PyLong_FromLongLong(n);
|
return PyLong_FromLongLong(n);
|
||||||
@ -670,10 +669,11 @@ static PyObject* EagerTensor_copy_to_device(EagerTensor* self, PyObject* args,
|
|||||||
// Note that this is a shallow copy and will share the underlying buffer
|
// Note that this is a shallow copy and will share the underlying buffer
|
||||||
// if copying to the same device.
|
// if copying to the same device.
|
||||||
TFE_TensorHandle* handle = TFE_TensorHandleCopyToDevice(
|
TFE_TensorHandle* handle = TFE_TensorHandleCopyToDevice(
|
||||||
self->handle, GetContextHandle(self->context), device_name, self->status);
|
self->handle, GetContextHandle(self->context), device_name,
|
||||||
if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_RuntimeError)) {
|
&self->status);
|
||||||
|
if (MaybeRaiseExceptionFromTFStatus(&self->status, PyExc_RuntimeError)) {
|
||||||
// Cleanup self->status before returning.
|
// Cleanup self->status before returning.
|
||||||
TF_SetStatus(self->status, TF_OK, "");
|
self->status.status = tensorflow::Status::OK();
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -686,11 +686,11 @@ static PyObject* EagerTensor_copy_to_device(EagerTensor* self, PyObject* args,
|
|||||||
// other.
|
// other.
|
||||||
// Note that if `self` is not on CPU, we raise an Exception.
|
// Note that if `self` is not on CPU, we raise an Exception.
|
||||||
static PyObject* EagerTensor_numpy_internal(EagerTensor* self) {
|
static PyObject* EagerTensor_numpy_internal(EagerTensor* self) {
|
||||||
auto* py_array = TFE_TensorHandleToNumpy(self->handle, self->status);
|
auto* py_array = TFE_TensorHandleToNumpy(self->handle, &self->status);
|
||||||
if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr)) {
|
if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
|
||||||
Py_XDECREF(py_array);
|
Py_XDECREF(py_array);
|
||||||
// Cleanup self->status before returning.
|
// Cleanup self->status before returning.
|
||||||
TF_SetStatus(self->status, TF_OK, "");
|
self->status.status = tensorflow::Status::OK();
|
||||||
return nullptr;
|
return nullptr;
|
||||||
} else {
|
} else {
|
||||||
return PyArray_Return(reinterpret_cast<PyArrayObject*>(py_array));
|
return PyArray_Return(reinterpret_cast<PyArrayObject*>(py_array));
|
||||||
@ -699,10 +699,10 @@ static PyObject* EagerTensor_numpy_internal(EagerTensor* self) {
|
|||||||
|
|
||||||
// Getter `device`.
|
// Getter `device`.
|
||||||
static PyObject* EagerTensor_device(EagerTensor* self) {
|
static PyObject* EagerTensor_device(EagerTensor* self) {
|
||||||
const char* device = TFE_TensorHandleDeviceName(self->handle, self->status);
|
const char* device = TFE_TensorHandleDeviceName(self->handle, &self->status);
|
||||||
if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
|
if (MaybeRaiseExceptionFromTFStatus(&self->status, PyExc_ValueError)) {
|
||||||
// Cleanup self->status before returning.
|
// Cleanup self->status before returning.
|
||||||
TF_SetStatus(self->status, TF_OK, "");
|
self->status.status = tensorflow::Status::OK();
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
#if PY_MAJOR_VERSION >= 3
|
#if PY_MAJOR_VERSION >= 3
|
||||||
@ -715,10 +715,10 @@ static PyObject* EagerTensor_device(EagerTensor* self) {
|
|||||||
// Getter `backing_device`.
|
// Getter `backing_device`.
|
||||||
static PyObject* EagerTensor_backing_device(EagerTensor* self) {
|
static PyObject* EagerTensor_backing_device(EagerTensor* self) {
|
||||||
const char* device =
|
const char* device =
|
||||||
TFE_TensorHandleBackingDeviceName(self->handle, self->status);
|
TFE_TensorHandleBackingDeviceName(self->handle, &self->status);
|
||||||
if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
|
if (MaybeRaiseExceptionFromTFStatus(&self->status, PyExc_ValueError)) {
|
||||||
// Cleanup self->status before returning.
|
// Cleanup self->status before returning.
|
||||||
TF_SetStatus(self->status, TF_OK, "");
|
self->status.status = tensorflow::Status::OK();
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
#if PY_MAJOR_VERSION >= 3
|
#if PY_MAJOR_VERSION >= 3
|
||||||
@ -779,10 +779,10 @@ static int EagerTensor_getbuffer(EagerTensor* self, Py_buffer* view,
|
|||||||
// DT_STRING so the following is only slightly slower than a NumPy-free
|
// DT_STRING so the following is only slightly slower than a NumPy-free
|
||||||
// implementation.
|
// implementation.
|
||||||
auto py_array = tensorflow::make_safe(
|
auto py_array = tensorflow::make_safe(
|
||||||
TFE_TensorHandleToNumpy(self->handle, self->status));
|
TFE_TensorHandleToNumpy(self->handle, &self->status));
|
||||||
if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_BufferError)) {
|
if (MaybeRaiseExceptionFromTFStatus(&self->status, PyExc_BufferError)) {
|
||||||
// Cleanup self->status before returning.
|
// Cleanup self->status before returning.
|
||||||
TF_SetStatus(self->status, TF_OK, "");
|
self->status.status = tensorflow::Status::OK();
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
if (PyObject_GetBuffer(py_array.get(), view, flags) < 0) {
|
if (PyObject_GetBuffer(py_array.get(), view, flags) < 0) {
|
||||||
@ -899,7 +899,7 @@ PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle) {
|
|||||||
Py_INCREF(Py_None);
|
Py_INCREF(Py_None);
|
||||||
t->tensor_shape = Py_None;
|
t->tensor_shape = Py_None;
|
||||||
t->handle = handle;
|
t->handle = handle;
|
||||||
t->status = TF_NewStatus();
|
t->status.status = tensorflow::Status::OK();
|
||||||
t->weakreflist = nullptr;
|
t->weakreflist = nullptr;
|
||||||
PyObject* py_context = GetPyEagerContext();
|
PyObject* py_context = GetPyEagerContext();
|
||||||
if (py_context == nullptr) {
|
if (py_context == nullptr) {
|
||||||
@ -927,17 +927,16 @@ tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor) {
|
|||||||
reinterpret_cast<const EagerTensor*>(tensor)->handle));
|
reinterpret_cast<const EagerTensor*>(tensor)->handle));
|
||||||
}
|
}
|
||||||
|
|
||||||
tensorflow::int64 PyEagerTensor_NumElements(const PyObject* tensor) {
|
tensorflow::int64 PyEagerTensor_NumElements(PyObject* tensor) {
|
||||||
DCHECK(EagerTensor_CheckExact(tensor));
|
DCHECK(EagerTensor_CheckExact(tensor));
|
||||||
const EagerTensor* as_c_eager_tensor =
|
EagerTensor* as_c_eager_tensor = reinterpret_cast<EagerTensor*>(tensor);
|
||||||
reinterpret_cast<const EagerTensor*>(tensor);
|
|
||||||
tensorflow::int64 result = TFE_TensorHandleNumElements(
|
tensorflow::int64 result = TFE_TensorHandleNumElements(
|
||||||
as_c_eager_tensor->handle, as_c_eager_tensor->status);
|
as_c_eager_tensor->handle, &as_c_eager_tensor->status);
|
||||||
|
|
||||||
if (MaybeRaiseExceptionFromTFStatus(as_c_eager_tensor->status,
|
if (MaybeRaiseExceptionFromTFStatus(&as_c_eager_tensor->status,
|
||||||
PyExc_ValueError)) {
|
PyExc_ValueError)) {
|
||||||
// Cleanup status before returning.
|
// Cleanup status before returning.
|
||||||
TF_SetStatus(as_c_eager_tensor->status, TF_OK, "");
|
as_c_eager_tensor->status.status = tensorflow::Status::OK();
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
|||||||
bool EagerTensor_CheckExact(const PyObject* o);
|
bool EagerTensor_CheckExact(const PyObject* o);
|
||||||
tensorflow::int64 PyEagerTensor_ID(const PyObject* tensor);
|
tensorflow::int64 PyEagerTensor_ID(const PyObject* tensor);
|
||||||
tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor);
|
tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor);
|
||||||
tensorflow::int64 PyEagerTensor_NumElements(const PyObject* tensor);
|
tensorflow::int64 PyEagerTensor_NumElements(PyObject* tensor);
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user