fix
This commit is contained in:
parent
48a353cdaf
commit
59d8c5b6c0
@ -30,7 +30,6 @@ namespace {
|
|||||||
struct TFDLManagedTensorCtx {
|
struct TFDLManagedTensorCtx {
|
||||||
TensorReference* handle;
|
TensorReference* handle;
|
||||||
std::vector<int64_t> shape;
|
std::vector<int64_t> shape;
|
||||||
std::vector<int64_t> strides;
|
|
||||||
DLManagedTensor tensor;
|
DLManagedTensor tensor;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -145,14 +144,14 @@ DLManagedTensor* TFEHandleToTFDLManagedTensorCtx(TFE_TensorHandle* h,
|
|||||||
std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
|
std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
|
||||||
std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
|
std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
|
||||||
shape_arr->resize(ndim);
|
shape_arr->resize(ndim);
|
||||||
stride_arr->resize(ndim);
|
|
||||||
for (int i = 0; i < ndim; i++) {
|
for (int i = 0; i < ndim; i++) {
|
||||||
(*shape_arr)[i] = tensor->dim_size(i);
|
(*shape_arr)[i] = tensor->dim_size(i);
|
||||||
(*stride_arr)[i] = 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tf_dlm_tensor_ctx->tensor.dl_tensor.shape = reinterpret_cast<std::int64_t*>(shape_arr->data());
|
tf_dlm_tensor_ctx->tensor.dl_tensor.shape =
|
||||||
tf_dlm_tensor_ctx->tensor.dl_tensor.strides = reinterpret_cast<std::int64_t*>(stride_arr->data());
|
reinterpret_cast<std::int64_t*>(shape_arr->data());
|
||||||
|
tf_dlm_tensor_ctx->tensor.dl_tensor.strides =
|
||||||
|
nullptr; // NULL indicates tensor is compact and row-majored.
|
||||||
tf_dlm_tensor_ctx->tensor.dl_tensor.byte_offset =
|
tf_dlm_tensor_ctx->tensor.dl_tensor.byte_offset =
|
||||||
0; // TF doesn't handle the strides and byte_offsets here
|
0; // TF doesn't handle the strides and byte_offsets here
|
||||||
return &tf_dlm_tensor_ctx->tensor;
|
return &tf_dlm_tensor_ctx->tensor;
|
||||||
|
Loading…
Reference in New Issue
Block a user