This commit is contained in:
VoVAllen 2020-02-21 15:22:48 +00:00
parent 48a353cdaf
commit 59d8c5b6c0

View File

@ -30,7 +30,6 @@ namespace {
struct TFDLManagedTensorCtx { struct TFDLManagedTensorCtx {
TensorReference* handle; TensorReference* handle;
std::vector<int64_t> shape; std::vector<int64_t> shape;
std::vector<int64_t> strides;
DLManagedTensor tensor; DLManagedTensor tensor;
}; };
@ -145,14 +144,14 @@ DLManagedTensor* TFEHandleToTFDLManagedTensorCtx(TFE_TensorHandle* h,
std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape; std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides; std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
shape_arr->resize(ndim); shape_arr->resize(ndim);
stride_arr->resize(ndim);
for (int i = 0; i < ndim; i++) { for (int i = 0; i < ndim; i++) {
(*shape_arr)[i] = tensor->dim_size(i); (*shape_arr)[i] = tensor->dim_size(i);
(*stride_arr)[i] = 1;
} }
tf_dlm_tensor_ctx->tensor.dl_tensor.shape = reinterpret_cast<std::int64_t*>(shape_arr->data()); tf_dlm_tensor_ctx->tensor.dl_tensor.shape =
tf_dlm_tensor_ctx->tensor.dl_tensor.strides = reinterpret_cast<std::int64_t*>(stride_arr->data()); reinterpret_cast<std::int64_t*>(shape_arr->data());
tf_dlm_tensor_ctx->tensor.dl_tensor.strides =
nullptr; // NULL indicates tensor is compact and row-majored.
tf_dlm_tensor_ctx->tensor.dl_tensor.byte_offset = tf_dlm_tensor_ctx->tensor.dl_tensor.byte_offset =
0; // TF doesn't handle the strides and byte_offsets here 0; // TF doesn't handle the strides and byte_offsets here
return &tf_dlm_tensor_ctx->tensor; return &tf_dlm_tensor_ctx->tensor;