This commit is contained in:
VoVAllen 2020-02-21 15:22:48 +00:00
parent 48a353cdaf
commit 59d8c5b6c0

View File

@ -30,7 +30,6 @@ namespace {
struct TFDLManagedTensorCtx {
TensorReference* handle;
std::vector<int64_t> shape;
std::vector<int64_t> strides;
DLManagedTensor tensor;
};
@ -145,14 +144,14 @@ DLManagedTensor* TFEHandleToTFDLManagedTensorCtx(TFE_TensorHandle* h,
std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
shape_arr->resize(ndim);
stride_arr->resize(ndim);
for (int i = 0; i < ndim; i++) {
(*shape_arr)[i] = tensor->dim_size(i);
(*stride_arr)[i] = 1;
(*shape_arr)[i] = tensor->dim_size(i);
}
tf_dlm_tensor_ctx->tensor.dl_tensor.shape = reinterpret_cast<std::int64_t*>(shape_arr->data());
tf_dlm_tensor_ctx->tensor.dl_tensor.strides = reinterpret_cast<std::int64_t*>(stride_arr->data());
tf_dlm_tensor_ctx->tensor.dl_tensor.shape =
reinterpret_cast<std::int64_t*>(shape_arr->data());
tf_dlm_tensor_ctx->tensor.dl_tensor.strides =
nullptr; // NULL indicates tensor is compact and row-majored.
tf_dlm_tensor_ctx->tensor.dl_tensor.byte_offset =
0; // TF doesn't handle the strides and byte_offsets here
return &tf_dlm_tensor_ctx->tensor;