From 26c4a2c1ec5497227440c19177ae854861b17a31 Mon Sep 17 00:00:00 2001 From: Anthony Platanios Date: Mon, 28 Aug 2017 14:31:02 -0400 Subject: [PATCH] Added handling of the status flag when creating TFE tensors. --- tensorflow/c/eager/c_api.cc | 6 +++--- tensorflow/c/eager/c_api.h | 4 ++-- tensorflow/c/eager/c_api_test.cc | 7 +++++-- tensorflow/python/eager/pywrap_tfe_src.cc | 4 +++- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 13a1825aaed..e70539ceefa 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -151,10 +151,10 @@ TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { return TF_SessionListDevices(ctx->session, status); } -TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t) { +TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { tensorflow::Tensor tensor; - // TODO: Add status argument and check on it. - tensorflow::TF_TensorToTensor(t, &tensor); + status->status = tensorflow::TF_TensorToTensor(t, &tensor); + if (!status->status.ok()) return nullptr; return new TFE_TensorHandle(tensor, nullptr); } diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 476c9288f89..24a80a8f5b5 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -43,7 +43,7 @@ extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, // placed in memory of different devices or remote address spaces. typedef struct TFE_TensorHandle TFE_TensorHandle; -extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t); +extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status); extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h); extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h); extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h); @@ -153,7 +153,7 @@ class Tensor; const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory( TFE_TensorHandle* h, TF_Status* status); -TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t); +TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t, TF_Status* status); #endif #endif // TENSORFLOW_C_EAGER_C_API_H_ diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 6f5c21c9472..d19583a3abe 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -34,7 +34,9 @@ TFE_TensorHandle* TestMatrixTensorHandle() { TF_Tensor* t = TF_AllocateTensor( TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); - TFE_TensorHandle* th = TFE_NewTensorHandle(t); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteTensor(t); return th; } @@ -383,7 +385,8 @@ TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value, memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get())); std::unique_ptr - value_handle(TFE_NewTensorHandle(t.get()), TFE_DeleteTensorHandle); + value_handle(TFE_NewTensorHandle(t.get(), status), TFE_DeleteTensorHandle); + if (TF_GetCode(status) != TF_OK) return nullptr; TFE_OpAddInput(op, value_handle.get(), status); if (TF_GetCode(status) != TF_OK) return nullptr; diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 655e3ec8491..8146d9de141 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -347,7 +347,9 @@ TFE_TensorHandle* TFE_Py_NumpyToTensorHandle(PyObject* obj) { tensorflow::Tensor t; auto cppstatus = tensorflow::NdarrayToTensor(obj, &t); if (cppstatus.ok()) { - return TFE_NewTensorHandle(t); + TFE_TensorHandle* tensor = TFE_NewTensorHandle(t, cppstatus); + if (!cppstatus.ok()) return nullptr; + return tensor; } else { tensorflow::mutex_lock l(exception_class_mutex); auto msg = tensorflow::strings::StrCat(