This commit is contained in:
VoVAllen 2020-02-24 09:06:29 +00:00
parent 7aa5009b35
commit 61da5aaff3
2 changed files with 7 additions and 6 deletions

View File

@ -158,15 +158,13 @@ DLManagedTensor* TFEHandleToTfDlManagedTensorCtx(TFE_TensorHandle* h,
(*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1];
}
dlm_tensor->dl_tensor.shape =
reinterpret_cast<std::int64_t*>(shape_arr->data());
dlm_tensor->dl_tensor.shape = &(*shape_arr)[0];
// There are two ways to represent compact row-major data
// 1) nullptr indicates tensor is compact and row-majored.
// 2) fill in the strides array as the real case for compact row-major data
// Here we choose option 2, since some framework didn't handle the strides
// argument properly
dlm_tensor->dl_tensor.strides =
reinterpret_cast<std::int64_t*>(stride_arr->data());
dlm_tensor->dl_tensor.strides = &(*stride_arr)[0];
dlm_tensor->dl_tensor.byte_offset =
0; // TF doesn't handle the strides and byte_offsets here
return &tf_dlm_tensor_ctx->tensor;
@ -266,6 +264,9 @@ void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) {
bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr,
int ndim) {
if (ndim >= 1 && stride_arr[ndim - 1] != 1) {
return false;
}
for (int i = ndim - 2; i >= 0; --i) {
if (stride_arr[i] != shape_arr[i + 1] * stride_arr[i + 1]) {
return false;
@ -309,7 +310,7 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) {
total_bytes *= dims[i];
}
if ((dl_tensor->strides != nullptr) &&
if (dl_tensor->strides != nullptr &&
!IsValidStrideCompactRowMajorData(dl_tensor->shape, dl_tensor->strides,
num_dims)) {
status->status = tensorflow::errors::InvalidArgument(

View File

@ -22,7 +22,7 @@ limitations under the License.
namespace tensorflow {
TF_CAPI_EXPORT extern const char* const kDlTensorCapsuleName = "dltensor";
const char* const kDlTensorCapsuleName = "dltensor";
TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status);