Add a HostToBorrowingLiteral function that takes an xla::Shape
PiperOrigin-RevId: 296514358 Change-Id: I5a69538960c064300baabe7ef2cd3c021f156bb1
This commit is contained in:
parent
e406991c37
commit
917ebfe5fc
|
@ -27,6 +27,17 @@ Status HostTensorToBorrowingLiteral(const Tensor& host_tensor,
|
|||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(),
|
||||
host_tensor.shape(), &xla_shape));
|
||||
return HostTensorToBorrowingLiteral(xla_shape, host_tensor, literal);
|
||||
}
|
||||
|
||||
Status HostTensorToBorrowingLiteral(const xla::Shape& xla_shape,
|
||||
const Tensor& host_tensor,
|
||||
xla::BorrowingLiteral* literal) {
|
||||
const auto& tshape = host_tensor.shape();
|
||||
TF_RET_CHECK(tshape.IsFullyDefined() &&
|
||||
tshape.dims() == xla_shape.dimensions_size() &&
|
||||
tshape.dim_sizes() == xla_shape.dimensions())
|
||||
<< "Provided xla::Shape must have the same dims as the Tensor shape.";
|
||||
*literal = xla::BorrowingLiteral(
|
||||
static_cast<const char*>(DMAHelper::base(&host_tensor)), xla_shape);
|
||||
return Status::OK();
|
||||
|
|
|
@ -30,6 +30,12 @@ namespace tensorflow {
|
|||
// 'host_tensor'.
|
||||
Status HostTensorToBorrowingLiteral(const Tensor& host_tensor,
|
||||
xla::BorrowingLiteral* literal);
|
||||
// Similar as above, except the literal shape is explicitly provided and used
|
||||
// instead of obtaining it from the 'host_tensor'. The provided literal shape
|
||||
// 'xla_shape' must be compatible with the shape of 'host_tensor'.
|
||||
Status HostTensorToBorrowingLiteral(const xla::Shape& xla_shape,
|
||||
const Tensor& host_tensor,
|
||||
xla::BorrowingLiteral* literal);
|
||||
|
||||
// Returns a Literal with the contents of 'host_tensor', backed by its own
|
||||
// storage (i.e., not reusing 'host_tensor's buffers.)
|
||||
|
|
Loading…
Reference in New Issue