Add a HostToBorrowingLiteral function that takes an xla::Shape

PiperOrigin-RevId: 296514358
Change-Id: I5a69538960c064300baabe7ef2cd3c021f156bb1
This commit is contained in:
A. Unique TensorFlower 2020-02-21 14:38:48 -08:00 committed by TensorFlower Gardener
parent e406991c37
commit 917ebfe5fc
2 changed files with 17 additions and 0 deletions

View File

@ -27,6 +27,17 @@ Status HostTensorToBorrowingLiteral(const Tensor& host_tensor,
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(),
host_tensor.shape(), &xla_shape));
return HostTensorToBorrowingLiteral(xla_shape, host_tensor, literal);
}
Status HostTensorToBorrowingLiteral(const xla::Shape& xla_shape,
const Tensor& host_tensor,
xla::BorrowingLiteral* literal) {
const auto& tshape = host_tensor.shape();
TF_RET_CHECK(tshape.IsFullyDefined() &&
tshape.dims() == xla_shape.dimensions_size() &&
tshape.dim_sizes() == xla_shape.dimensions())
<< "Provided xla::Shape must have the same dims as the Tensor shape.";
*literal = xla::BorrowingLiteral(
static_cast<const char*>(DMAHelper::base(&host_tensor)), xla_shape);
return Status::OK();

View File

@ -30,6 +30,12 @@ namespace tensorflow {
// 'host_tensor'.
Status HostTensorToBorrowingLiteral(const Tensor& host_tensor,
xla::BorrowingLiteral* literal);
// Similar as above, except the literal shape is explicitly provided and used
// instead of obtaining it from the 'host_tensor'. The provided literal shape
// 'xla_shape' must be compatible with the shape of 'host_tensor'.
Status HostTensorToBorrowingLiteral(const xla::Shape& xla_shape,
const Tensor& host_tensor,
xla::BorrowingLiteral* literal);
// Returns a Literal with the contents of 'host_tensor', backed by its own
// storage (i.e., not reusing 'host_tensor's buffers.)