Added handling of the status flag when creating TFE tensors.
This commit is contained in:
parent
5598a3dec7
commit
26c4a2c1ec
@ -151,10 +151,10 @@ TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
|
||||
return TF_SessionListDevices(ctx->session, status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t) {
|
||||
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
|
||||
tensorflow::Tensor tensor;
|
||||
// TODO: Add status argument and check on it.
|
||||
tensorflow::TF_TensorToTensor(t, &tensor);
|
||||
status->status = tensorflow::TF_TensorToTensor(t, &tensor);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
return new TFE_TensorHandle(tensor, nullptr);
|
||||
}
|
||||
|
||||
|
@ -43,7 +43,7 @@ extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx,
|
||||
// placed in memory of different devices or remote address spaces.
|
||||
typedef struct TFE_TensorHandle TFE_TensorHandle;
|
||||
|
||||
extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t);
|
||||
extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status);
|
||||
extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h);
|
||||
extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h);
|
||||
extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h);
|
||||
@ -153,7 +153,7 @@ class Tensor;
|
||||
|
||||
const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
|
||||
TFE_TensorHandle* h, TF_Status* status);
|
||||
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t);
|
||||
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t, TF_Status* status);
|
||||
#endif
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_C_API_H_
|
||||
|
@ -34,7 +34,9 @@ TFE_TensorHandle* TestMatrixTensorHandle() {
|
||||
TF_Tensor* t = TF_AllocateTensor(
|
||||
TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandle(t);
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
return th;
|
||||
}
|
||||
@ -383,7 +385,8 @@ TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value,
|
||||
memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get()));
|
||||
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
value_handle(TFE_NewTensorHandle(t.get()), TFE_DeleteTensorHandle);
|
||||
value_handle(TFE_NewTensorHandle(t.get(), status), TFE_DeleteTensorHandle);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
|
||||
TFE_OpAddInput(op, value_handle.get(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
|
@ -347,7 +347,9 @@ TFE_TensorHandle* TFE_Py_NumpyToTensorHandle(PyObject* obj) {
|
||||
tensorflow::Tensor t;
|
||||
auto cppstatus = tensorflow::NdarrayToTensor(obj, &t);
|
||||
if (cppstatus.ok()) {
|
||||
return TFE_NewTensorHandle(t);
|
||||
TFE_TensorHandle* tensor = TFE_NewTensorHandle(t, cppstatus);
|
||||
if (!cppstatus.ok()) return nullptr;
|
||||
return tensor;
|
||||
} else {
|
||||
tensorflow::mutex_lock l(exception_class_mutex);
|
||||
auto msg = tensorflow::strings::StrCat(
|
||||
|
Loading…
Reference in New Issue
Block a user