fix
This commit is contained in:
parent
7aa5009b35
commit
61da5aaff3
@ -158,15 +158,13 @@ DLManagedTensor* TFEHandleToTfDlManagedTensorCtx(TFE_TensorHandle* h,
|
|||||||
(*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1];
|
(*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
dlm_tensor->dl_tensor.shape =
|
dlm_tensor->dl_tensor.shape = &(*shape_arr)[0];
|
||||||
reinterpret_cast<std::int64_t*>(shape_arr->data());
|
|
||||||
// There are two ways to represent compact row-major data
|
// There are two ways to represent compact row-major data
|
||||||
// 1) nullptr indicates tensor is compact and row-majored.
|
// 1) nullptr indicates tensor is compact and row-majored.
|
||||||
// 2) fill in the strides array as the real case for compact row-major data
|
// 2) fill in the strides array as the real case for compact row-major data
|
||||||
// Here we choose option 2, since some framework didn't handle the strides
|
// Here we choose option 2, since some framework didn't handle the strides
|
||||||
// argument properly
|
// argument properly
|
||||||
dlm_tensor->dl_tensor.strides =
|
dlm_tensor->dl_tensor.strides = &(*stride_arr)[0];
|
||||||
reinterpret_cast<std::int64_t*>(stride_arr->data());
|
|
||||||
dlm_tensor->dl_tensor.byte_offset =
|
dlm_tensor->dl_tensor.byte_offset =
|
||||||
0; // TF doesn't handle the strides and byte_offsets here
|
0; // TF doesn't handle the strides and byte_offsets here
|
||||||
return &tf_dlm_tensor_ctx->tensor;
|
return &tf_dlm_tensor_ctx->tensor;
|
||||||
@ -266,6 +264,9 @@ void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) {
|
|||||||
|
|
||||||
bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr,
|
bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr,
|
||||||
int ndim) {
|
int ndim) {
|
||||||
|
if (ndim >= 1 && stride_arr[ndim - 1] != 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
for (int i = ndim - 2; i >= 0; --i) {
|
for (int i = ndim - 2; i >= 0; --i) {
|
||||||
if (stride_arr[i] != shape_arr[i + 1] * stride_arr[i + 1]) {
|
if (stride_arr[i] != shape_arr[i + 1] * stride_arr[i + 1]) {
|
||||||
return false;
|
return false;
|
||||||
@ -309,7 +310,7 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) {
|
|||||||
total_bytes *= dims[i];
|
total_bytes *= dims[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
if ((dl_tensor->strides != nullptr) &&
|
if (dl_tensor->strides != nullptr &&
|
||||||
!IsValidStrideCompactRowMajorData(dl_tensor->shape, dl_tensor->strides,
|
!IsValidStrideCompactRowMajorData(dl_tensor->shape, dl_tensor->strides,
|
||||||
num_dims)) {
|
num_dims)) {
|
||||||
status->status = tensorflow::errors::InvalidArgument(
|
status->status = tensorflow::errors::InvalidArgument(
|
||||||
|
@ -22,7 +22,7 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
TF_CAPI_EXPORT extern const char* const kDlTensorCapsuleName = "dltensor";
|
const char* const kDlTensorCapsuleName = "dltensor";
|
||||||
|
|
||||||
TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status);
|
TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user