This commit is contained in:
VoVAllen 2020-02-24 09:06:29 +00:00
parent 7aa5009b35
commit 61da5aaff3
2 changed files with 7 additions and 6 deletions

View File

@ -158,15 +158,13 @@ DLManagedTensor* TFEHandleToTfDlManagedTensorCtx(TFE_TensorHandle* h,
(*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1]; (*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1];
} }
dlm_tensor->dl_tensor.shape = dlm_tensor->dl_tensor.shape = &(*shape_arr)[0];
reinterpret_cast<std::int64_t*>(shape_arr->data());
// There are two ways to represent compact row-major data // There are two ways to represent compact row-major data
// 1) nullptr indicates tensor is compact and row-majored. // 1) nullptr indicates tensor is compact and row-majored.
// 2) fill in the strides array as the real case for compact row-major data // 2) fill in the strides array as the real case for compact row-major data
// Here we choose option 2, since some framework didn't handle the strides // Here we choose option 2, since some framework didn't handle the strides
// argument properly // argument properly
dlm_tensor->dl_tensor.strides = dlm_tensor->dl_tensor.strides = &(*stride_arr)[0];
reinterpret_cast<std::int64_t*>(stride_arr->data());
dlm_tensor->dl_tensor.byte_offset = dlm_tensor->dl_tensor.byte_offset =
0; // TF doesn't handle the strides and byte_offsets here 0; // TF doesn't handle the strides and byte_offsets here
return &tf_dlm_tensor_ctx->tensor; return &tf_dlm_tensor_ctx->tensor;
@ -266,6 +264,9 @@ void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) {
bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr, bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr,
int ndim) { int ndim) {
if (ndim >= 1 && stride_arr[ndim - 1] != 1) {
return false;
}
for (int i = ndim - 2; i >= 0; --i) { for (int i = ndim - 2; i >= 0; --i) {
if (stride_arr[i] != shape_arr[i + 1] * stride_arr[i + 1]) { if (stride_arr[i] != shape_arr[i + 1] * stride_arr[i + 1]) {
return false; return false;
@ -309,7 +310,7 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) {
total_bytes *= dims[i]; total_bytes *= dims[i];
} }
if ((dl_tensor->strides != nullptr) && if (dl_tensor->strides != nullptr &&
!IsValidStrideCompactRowMajorData(dl_tensor->shape, dl_tensor->strides, !IsValidStrideCompactRowMajorData(dl_tensor->shape, dl_tensor->strides,
num_dims)) { num_dims)) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(

View File

@ -22,7 +22,7 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
TF_CAPI_EXPORT extern const char* const kDlTensorCapsuleName = "dltensor"; const char* const kDlTensorCapsuleName = "dltensor";
TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status); TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status);