diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index 749a7c3054a..720b81a5097 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -27,6 +27,17 @@ Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, xla::Shape xla_shape; TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(), host_tensor.shape(), &xla_shape)); + return HostTensorToBorrowingLiteral(xla_shape, host_tensor, literal); +} + +Status HostTensorToBorrowingLiteral(const xla::Shape& xla_shape, + const Tensor& host_tensor, + xla::BorrowingLiteral* literal) { + const auto& tshape = host_tensor.shape(); + TF_RET_CHECK(tshape.IsFullyDefined() && + tshape.dims() == xla_shape.dimensions_size() && + tshape.dim_sizes() == xla_shape.dimensions()) + << "Provided xla::Shape must have the same dims as the Tensor shape."; *literal = xla::BorrowingLiteral( static_cast(DMAHelper::base(&host_tensor)), xla_shape); return Status::OK(); diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h index a153dddee61..b1fdf47f5b6 100644 --- a/tensorflow/compiler/tf2xla/literal_util.h +++ b/tensorflow/compiler/tf2xla/literal_util.h @@ -30,6 +30,12 @@ namespace tensorflow { // 'host_tensor'. Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, xla::BorrowingLiteral* literal); +// Similar as above, except the literal shape is explicitly provided and used +// instead of obtaining it from the 'host_tensor'. The provided literal shape +// 'xla_shape' must be compatible with the shape of 'host_tensor'. +Status HostTensorToBorrowingLiteral(const xla::Shape& xla_shape, + const Tensor& host_tensor, + xla::BorrowingLiteral* literal); // Returns a Literal with the contents of 'host_tensor', backed by its own // storage (i.e., not reusing 'host_tensor's buffers.)