Inline status in EagerTensor
PiperOrigin-RevId: 267284080
This commit is contained in:
parent
b3cde1ae7e
commit
2c39806e95
@ -443,7 +443,7 @@ typedef struct EagerTensor {
|
||||
// Status objects on different functions that operate on EagerTensor and need
|
||||
// to use a TF_Status object. However note that accesses to `status` are not
|
||||
// thread-safe.
|
||||
TF_Status* status;
|
||||
TF_Status status;
|
||||
|
||||
// The eager Context (from eager/context.py) used by this Tensor.
|
||||
// This is currently used only to make sure context outlives TensorHandles.
|
||||
@ -503,7 +503,7 @@ int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
|
||||
self->handle_data = Py_None;
|
||||
Py_INCREF(Py_None);
|
||||
self->tensor_shape = Py_None;
|
||||
self->status = TF_NewStatus();
|
||||
self->status.status = tensorflow::Status::OK();
|
||||
self->dict = nullptr;
|
||||
self->weakreflist = nullptr;
|
||||
self->context = nullptr;
|
||||
@ -543,7 +543,6 @@ void EagerTensor_dealloc(EagerTensor* self) {
|
||||
// Needs to happen before any actual destruction.
|
||||
PyObject_ClearWeakRefs((PyObject*)self);
|
||||
|
||||
TF_DeleteStatus(self->status);
|
||||
Py_DECREF(self->handle_data);
|
||||
Py_DECREF(self->tensor_shape);
|
||||
// If an attribute dictionary has been created, release it. Note that this
|
||||
@ -579,21 +578,21 @@ static PyObject* EagerTensor_datatype_enum(EagerTensor* self) {
|
||||
// Getter for `_shape_tuple`.
|
||||
static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
|
||||
auto handle = self->handle;
|
||||
int n = TFE_TensorHandleNumDims(handle, self->status);
|
||||
if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr)) {
|
||||
int n = TFE_TensorHandleNumDims(handle, &self->status);
|
||||
if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
|
||||
// Cleanup self->status before returning.
|
||||
TF_SetStatus(self->status, TF_OK, "");
|
||||
self->status.status = tensorflow::Status::OK();
|
||||
return nullptr;
|
||||
}
|
||||
PyObject* shape = PyTuple_New(n);
|
||||
if (PyErr_Occurred()) return nullptr;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
PyObject* dim =
|
||||
PyLong_FromLongLong(TFE_TensorHandleDim(handle, i, self->status));
|
||||
if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr) ||
|
||||
PyLong_FromLongLong(TFE_TensorHandleDim(handle, i, &self->status));
|
||||
if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr) ||
|
||||
dim == nullptr || PyTuple_SetItem(shape, i, dim) != 0) {
|
||||
// Cleanup self->status before returning.
|
||||
TF_SetStatus(self->status, TF_OK, "");
|
||||
self->status.status = tensorflow::Status::OK();
|
||||
Py_DECREF(shape);
|
||||
if (dim != nullptr) Py_DECREF(dim);
|
||||
PyErr_SetString(PyExc_RuntimeError, "Error while creating shape");
|
||||
@ -605,10 +604,10 @@ static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
|
||||
|
||||
// Getter for `_rank`.
|
||||
static PyObject* EagerTensor_rank(EagerTensor* self) {
|
||||
int num_dims = TFE_TensorHandleNumDims(self->handle, self->status);
|
||||
if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr)) {
|
||||
int num_dims = TFE_TensorHandleNumDims(self->handle, &self->status);
|
||||
if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
|
||||
// Cleanup self->status before returning.
|
||||
TF_SetStatus(self->status, TF_OK, "");
|
||||
self->status.status = tensorflow::Status::OK();
|
||||
return nullptr;
|
||||
}
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
@ -621,10 +620,10 @@ static PyObject* EagerTensor_rank(EagerTensor* self) {
|
||||
// Getter for `_num_elements`.
|
||||
static PyObject* EagerTensor_num_elements(EagerTensor* self) {
|
||||
auto handle = self->handle;
|
||||
int n = TFE_TensorHandleNumElements(handle, self->status);
|
||||
if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr)) {
|
||||
int n = TFE_TensorHandleNumElements(handle, &self->status);
|
||||
if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
|
||||
// Cleanup self->status before returning.
|
||||
TF_SetStatus(self->status, TF_OK, "");
|
||||
self->status.status = tensorflow::Status::OK();
|
||||
return nullptr;
|
||||
}
|
||||
return PyLong_FromLongLong(n);
|
||||
@ -670,10 +669,11 @@ static PyObject* EagerTensor_copy_to_device(EagerTensor* self, PyObject* args,
|
||||
// Note that this is a shallow copy and will share the underlying buffer
|
||||
// if copying to the same device.
|
||||
TFE_TensorHandle* handle = TFE_TensorHandleCopyToDevice(
|
||||
self->handle, GetContextHandle(self->context), device_name, self->status);
|
||||
if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_RuntimeError)) {
|
||||
self->handle, GetContextHandle(self->context), device_name,
|
||||
&self->status);
|
||||
if (MaybeRaiseExceptionFromTFStatus(&self->status, PyExc_RuntimeError)) {
|
||||
// Cleanup self->status before returning.
|
||||
TF_SetStatus(self->status, TF_OK, "");
|
||||
self->status.status = tensorflow::Status::OK();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -686,11 +686,11 @@ static PyObject* EagerTensor_copy_to_device(EagerTensor* self, PyObject* args,
|
||||
// other.
|
||||
// Note that if `self` is not on CPU, we raise an Exception.
|
||||
static PyObject* EagerTensor_numpy_internal(EagerTensor* self) {
|
||||
auto* py_array = TFE_TensorHandleToNumpy(self->handle, self->status);
|
||||
if (MaybeRaiseExceptionFromTFStatus(self->status, nullptr)) {
|
||||
auto* py_array = TFE_TensorHandleToNumpy(self->handle, &self->status);
|
||||
if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
|
||||
Py_XDECREF(py_array);
|
||||
// Cleanup self->status before returning.
|
||||
TF_SetStatus(self->status, TF_OK, "");
|
||||
self->status.status = tensorflow::Status::OK();
|
||||
return nullptr;
|
||||
} else {
|
||||
return PyArray_Return(reinterpret_cast<PyArrayObject*>(py_array));
|
||||
@ -699,10 +699,10 @@ static PyObject* EagerTensor_numpy_internal(EagerTensor* self) {
|
||||
|
||||
// Getter `device`.
|
||||
static PyObject* EagerTensor_device(EagerTensor* self) {
|
||||
const char* device = TFE_TensorHandleDeviceName(self->handle, self->status);
|
||||
if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
|
||||
const char* device = TFE_TensorHandleDeviceName(self->handle, &self->status);
|
||||
if (MaybeRaiseExceptionFromTFStatus(&self->status, PyExc_ValueError)) {
|
||||
// Cleanup self->status before returning.
|
||||
TF_SetStatus(self->status, TF_OK, "");
|
||||
self->status.status = tensorflow::Status::OK();
|
||||
return nullptr;
|
||||
}
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
@ -715,10 +715,10 @@ static PyObject* EagerTensor_device(EagerTensor* self) {
|
||||
// Getter `backing_device`.
|
||||
static PyObject* EagerTensor_backing_device(EagerTensor* self) {
|
||||
const char* device =
|
||||
TFE_TensorHandleBackingDeviceName(self->handle, self->status);
|
||||
if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
|
||||
TFE_TensorHandleBackingDeviceName(self->handle, &self->status);
|
||||
if (MaybeRaiseExceptionFromTFStatus(&self->status, PyExc_ValueError)) {
|
||||
// Cleanup self->status before returning.
|
||||
TF_SetStatus(self->status, TF_OK, "");
|
||||
self->status.status = tensorflow::Status::OK();
|
||||
return nullptr;
|
||||
}
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
@ -779,10 +779,10 @@ static int EagerTensor_getbuffer(EagerTensor* self, Py_buffer* view,
|
||||
// DT_STRING so the following is only slightly slower than a NumPy-free
|
||||
// implementation.
|
||||
auto py_array = tensorflow::make_safe(
|
||||
TFE_TensorHandleToNumpy(self->handle, self->status));
|
||||
if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_BufferError)) {
|
||||
TFE_TensorHandleToNumpy(self->handle, &self->status));
|
||||
if (MaybeRaiseExceptionFromTFStatus(&self->status, PyExc_BufferError)) {
|
||||
// Cleanup self->status before returning.
|
||||
TF_SetStatus(self->status, TF_OK, "");
|
||||
self->status.status = tensorflow::Status::OK();
|
||||
return -1;
|
||||
}
|
||||
if (PyObject_GetBuffer(py_array.get(), view, flags) < 0) {
|
||||
@ -899,7 +899,7 @@ PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle) {
|
||||
Py_INCREF(Py_None);
|
||||
t->tensor_shape = Py_None;
|
||||
t->handle = handle;
|
||||
t->status = TF_NewStatus();
|
||||
t->status.status = tensorflow::Status::OK();
|
||||
t->weakreflist = nullptr;
|
||||
PyObject* py_context = GetPyEagerContext();
|
||||
if (py_context == nullptr) {
|
||||
@ -927,17 +927,16 @@ tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor) {
|
||||
reinterpret_cast<const EagerTensor*>(tensor)->handle));
|
||||
}
|
||||
|
||||
tensorflow::int64 PyEagerTensor_NumElements(const PyObject* tensor) {
|
||||
tensorflow::int64 PyEagerTensor_NumElements(PyObject* tensor) {
|
||||
DCHECK(EagerTensor_CheckExact(tensor));
|
||||
const EagerTensor* as_c_eager_tensor =
|
||||
reinterpret_cast<const EagerTensor*>(tensor);
|
||||
EagerTensor* as_c_eager_tensor = reinterpret_cast<EagerTensor*>(tensor);
|
||||
tensorflow::int64 result = TFE_TensorHandleNumElements(
|
||||
as_c_eager_tensor->handle, as_c_eager_tensor->status);
|
||||
as_c_eager_tensor->handle, &as_c_eager_tensor->status);
|
||||
|
||||
if (MaybeRaiseExceptionFromTFStatus(as_c_eager_tensor->status,
|
||||
if (MaybeRaiseExceptionFromTFStatus(&as_c_eager_tensor->status,
|
||||
PyExc_ValueError)) {
|
||||
// Cleanup status before returning.
|
||||
TF_SetStatus(as_c_eager_tensor->status, TF_OK, "");
|
||||
as_c_eager_tensor->status.status = tensorflow::Status::OK();
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||
bool EagerTensor_CheckExact(const PyObject* o);
|
||||
tensorflow::int64 PyEagerTensor_ID(const PyObject* tensor);
|
||||
tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor);
|
||||
tensorflow::int64 PyEagerTensor_NumElements(const PyObject* tensor);
|
||||
tensorflow::int64 PyEagerTensor_NumElements(PyObject* tensor);
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user