diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc index fdc439da4b5..794550e840a 100644 --- a/tensorflow/c/eager/dlpack.cc +++ b/tensorflow/c/eager/dlpack.cc @@ -158,15 +158,13 @@ DLManagedTensor* TFEHandleToTfDlManagedTensorCtx(TFE_TensorHandle* h, (*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1]; } - dlm_tensor->dl_tensor.shape = - reinterpret_cast(shape_arr->data()); + dlm_tensor->dl_tensor.shape = &(*shape_arr)[0]; // There are two ways to represent compact row-major data // 1) nullptr indicates tensor is compact and row-majored. // 2) fill in the strides array as the real case for compact row-major data // Here we choose option 2, since some framework didn't handle the strides // argument properly - dlm_tensor->dl_tensor.strides = - reinterpret_cast(stride_arr->data()); + dlm_tensor->dl_tensor.strides = &(*stride_arr)[0]; dlm_tensor->dl_tensor.byte_offset = 0; // TF doesn't handle the strides and byte_offsets here return &tf_dlm_tensor_ctx->tensor; @@ -266,6 +264,9 @@ void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) { bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr, int ndim) { + if (ndim >= 1 && stride_arr[ndim - 1] != 1) { + return false; + } for (int i = ndim - 2; i >= 0; --i) { if (stride_arr[i] != shape_arr[i + 1] * stride_arr[i + 1]) { return false; @@ -309,7 +310,7 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) { total_bytes *= dims[i]; } - if ((dl_tensor->strides != nullptr) && + if (dl_tensor->strides != nullptr && !IsValidStrideCompactRowMajorData(dl_tensor->shape, dl_tensor->strides, num_dims)) { status->status = tensorflow::errors::InvalidArgument( diff --git a/tensorflow/c/eager/dlpack.h b/tensorflow/c/eager/dlpack.h index b563bc24495..f656f4393f6 100644 --- a/tensorflow/c/eager/dlpack.h +++ b/tensorflow/c/eager/dlpack.h @@ -22,7 +22,7 @@ limitations under the License. namespace tensorflow { -TF_CAPI_EXPORT extern const char* const kDlTensorCapsuleName = "dltensor"; +const char* const kDlTensorCapsuleName = "dltensor"; TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status);