fix
This commit is contained in:
parent
7aa5009b35
commit
61da5aaff3
@ -158,15 +158,13 @@ DLManagedTensor* TFEHandleToTfDlManagedTensorCtx(TFE_TensorHandle* h,
|
||||
(*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1];
|
||||
}
|
||||
|
||||
dlm_tensor->dl_tensor.shape =
|
||||
reinterpret_cast<std::int64_t*>(shape_arr->data());
|
||||
dlm_tensor->dl_tensor.shape = &(*shape_arr)[0];
|
||||
// There are two ways to represent compact row-major data
|
||||
// 1) nullptr indicates tensor is compact and row-majored.
|
||||
// 2) fill in the strides array as the real case for compact row-major data
|
||||
// Here we choose option 2, since some framework didn't handle the strides
|
||||
// argument properly
|
||||
dlm_tensor->dl_tensor.strides =
|
||||
reinterpret_cast<std::int64_t*>(stride_arr->data());
|
||||
dlm_tensor->dl_tensor.strides = &(*stride_arr)[0];
|
||||
dlm_tensor->dl_tensor.byte_offset =
|
||||
0; // TF doesn't handle the strides and byte_offsets here
|
||||
return &tf_dlm_tensor_ctx->tensor;
|
||||
@ -266,6 +264,9 @@ void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) {
|
||||
|
||||
bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr,
|
||||
int ndim) {
|
||||
if (ndim >= 1 && stride_arr[ndim - 1] != 1) {
|
||||
return false;
|
||||
}
|
||||
for (int i = ndim - 2; i >= 0; --i) {
|
||||
if (stride_arr[i] != shape_arr[i + 1] * stride_arr[i + 1]) {
|
||||
return false;
|
||||
@ -309,7 +310,7 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) {
|
||||
total_bytes *= dims[i];
|
||||
}
|
||||
|
||||
if ((dl_tensor->strides != nullptr) &&
|
||||
if (dl_tensor->strides != nullptr &&
|
||||
!IsValidStrideCompactRowMajorData(dl_tensor->shape, dl_tensor->strides,
|
||||
num_dims)) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
|
@ -22,7 +22,7 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
TF_CAPI_EXPORT extern const char* const kDlTensorCapsuleName = "dltensor";
|
||||
const char* const kDlTensorCapsuleName = "dltensor";
|
||||
|
||||
TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user