Add a HostToBorrowingLiteral function that takes an xla::Shape
PiperOrigin-RevId: 296514358 Change-Id: I5a69538960c064300baabe7ef2cd3c021f156bb1
This commit is contained in:
parent
e406991c37
commit
917ebfe5fc
@ -27,6 +27,17 @@ Status HostTensorToBorrowingLiteral(const Tensor& host_tensor,
|
|||||||
xla::Shape xla_shape;
|
xla::Shape xla_shape;
|
||||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(),
|
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(),
|
||||||
host_tensor.shape(), &xla_shape));
|
host_tensor.shape(), &xla_shape));
|
||||||
|
return HostTensorToBorrowingLiteral(xla_shape, host_tensor, literal);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status HostTensorToBorrowingLiteral(const xla::Shape& xla_shape,
|
||||||
|
const Tensor& host_tensor,
|
||||||
|
xla::BorrowingLiteral* literal) {
|
||||||
|
const auto& tshape = host_tensor.shape();
|
||||||
|
TF_RET_CHECK(tshape.IsFullyDefined() &&
|
||||||
|
tshape.dims() == xla_shape.dimensions_size() &&
|
||||||
|
tshape.dim_sizes() == xla_shape.dimensions())
|
||||||
|
<< "Provided xla::Shape must have the same dims as the Tensor shape.";
|
||||||
*literal = xla::BorrowingLiteral(
|
*literal = xla::BorrowingLiteral(
|
||||||
static_cast<const char*>(DMAHelper::base(&host_tensor)), xla_shape);
|
static_cast<const char*>(DMAHelper::base(&host_tensor)), xla_shape);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -30,6 +30,12 @@ namespace tensorflow {
|
|||||||
// 'host_tensor'.
|
// 'host_tensor'.
|
||||||
Status HostTensorToBorrowingLiteral(const Tensor& host_tensor,
|
Status HostTensorToBorrowingLiteral(const Tensor& host_tensor,
|
||||||
xla::BorrowingLiteral* literal);
|
xla::BorrowingLiteral* literal);
|
||||||
|
// Similar as above, except the literal shape is explicitly provided and used
|
||||||
|
// instead of obtaining it from the 'host_tensor'. The provided literal shape
|
||||||
|
// 'xla_shape' must be compatible with the shape of 'host_tensor'.
|
||||||
|
Status HostTensorToBorrowingLiteral(const xla::Shape& xla_shape,
|
||||||
|
const Tensor& host_tensor,
|
||||||
|
xla::BorrowingLiteral* literal);
|
||||||
|
|
||||||
// Returns a Literal with the contents of 'host_tensor', backed by its own
|
// Returns a Literal with the contents of 'host_tensor', backed by its own
|
||||||
// storage (i.e., not reusing 'host_tensor's buffers.)
|
// storage (i.e., not reusing 'host_tensor's buffers.)
|
||||||
|
Loading…
Reference in New Issue
Block a user