Added handling of the status flag when creating TFE tensors.

This commit is contained in:
Anthony Platanios 2017-08-28 14:31:02 -04:00 committed by Martin Wicke
parent 5598a3dec7
commit 26c4a2c1ec
4 changed files with 13 additions and 8 deletions

View File

@ -151,10 +151,10 @@ TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
return TF_SessionListDevices(ctx->session, status); return TF_SessionListDevices(ctx->session, status);
} }
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t) { TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
tensorflow::Tensor tensor; tensorflow::Tensor tensor;
// TODO: Add status argument and check on it. status->status = tensorflow::TF_TensorToTensor(t, &tensor);
tensorflow::TF_TensorToTensor(t, &tensor); if (!status->status.ok()) return nullptr;
return new TFE_TensorHandle(tensor, nullptr); return new TFE_TensorHandle(tensor, nullptr);
} }

View File

@ -43,7 +43,7 @@ extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx,
// placed in memory of different devices or remote address spaces. // placed in memory of different devices or remote address spaces.
typedef struct TFE_TensorHandle TFE_TensorHandle; typedef struct TFE_TensorHandle TFE_TensorHandle;
extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t); extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status);
extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h); extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h);
extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h); extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h);
extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h); extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h);
@ -153,7 +153,7 @@ class Tensor;
const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory( const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
TFE_TensorHandle* h, TF_Status* status); TFE_TensorHandle* h, TF_Status* status);
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t); TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t, TF_Status* status);
#endif #endif
#endif // TENSORFLOW_C_EAGER_C_API_H_ #endif // TENSORFLOW_C_EAGER_C_API_H_

View File

@ -34,7 +34,9 @@ TFE_TensorHandle* TestMatrixTensorHandle() {
TF_Tensor* t = TF_AllocateTensor( TF_Tensor* t = TF_AllocateTensor(
TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandle(t); TF_Status* status = TF_NewStatus();
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t); TF_DeleteTensor(t);
return th; return th;
} }
@ -383,7 +385,8 @@ TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value,
memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get())); memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get()));
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
value_handle(TFE_NewTensorHandle(t.get()), TFE_DeleteTensorHandle); value_handle(TFE_NewTensorHandle(t.get(), status), TFE_DeleteTensorHandle);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpAddInput(op, value_handle.get(), status); TFE_OpAddInput(op, value_handle.get(), status);
if (TF_GetCode(status) != TF_OK) return nullptr; if (TF_GetCode(status) != TF_OK) return nullptr;

View File

@ -347,7 +347,9 @@ TFE_TensorHandle* TFE_Py_NumpyToTensorHandle(PyObject* obj) {
tensorflow::Tensor t; tensorflow::Tensor t;
auto cppstatus = tensorflow::NdarrayToTensor(obj, &t); auto cppstatus = tensorflow::NdarrayToTensor(obj, &t);
if (cppstatus.ok()) { if (cppstatus.ok()) {
return TFE_NewTensorHandle(t); TFE_TensorHandle* tensor = TFE_NewTensorHandle(t, cppstatus);
if (!cppstatus.ok()) return nullptr;
return tensor;
} else { } else {
tensorflow::mutex_lock l(exception_class_mutex); tensorflow::mutex_lock l(exception_class_mutex);
auto msg = tensorflow::strings::StrCat( auto msg = tensorflow::strings::StrCat(