Add a hint parameter to TransferLiteralToDeviceAsync that the implementation can use to accelerate transfers.
PiperOrigin-RevId: 215362667
This commit is contained in:
parent
44da41e490
commit
f22037abf5
@ -75,8 +75,9 @@ XlaTransferManager::XlaTransferManager(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status XlaTransferManager::TransferLiteralToDevice(
|
Status XlaTransferManager::TransferLiteralToDevice(const Tensor& host_tensor,
|
||||||
const Tensor& host_tensor, Tensor* device_tensor) const {
|
Tensor* device_tensor,
|
||||||
|
bool buffer_is_fresh) const {
|
||||||
xla::Shape xla_shape;
|
xla::Shape xla_shape;
|
||||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(),
|
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(),
|
||||||
host_tensor.shape(), &xla_shape));
|
host_tensor.shape(), &xla_shape));
|
||||||
@ -97,8 +98,11 @@ Status XlaTransferManager::TransferLiteralToDevice(
|
|||||||
// synchronized.
|
// synchronized.
|
||||||
host_to_device_stream_->ThenWaitFor(stream_.get());
|
host_to_device_stream_->ThenWaitFor(stream_.get());
|
||||||
}
|
}
|
||||||
|
xla::TransferManager::TransferToDeviceHint hint =
|
||||||
|
buffer_is_fresh ? xla::TransferManager::kBufferUndefined
|
||||||
|
: xla::TransferManager::kNoHint;
|
||||||
TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync(
|
TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync(
|
||||||
host_to_device_stream_.get(), *literal, shaped_buffer));
|
host_to_device_stream_.get(), *literal, shaped_buffer, hint));
|
||||||
if (UseMultipleStreams()) {
|
if (UseMultipleStreams()) {
|
||||||
auto event = std::make_shared<se::Event>(stream_->parent());
|
auto event = std::make_shared<se::Event>(stream_->parent());
|
||||||
TF_RET_CHECK(event->Init()) << "Event failed to initialize!";
|
TF_RET_CHECK(event->Init()) << "Event failed to initialize!";
|
||||||
@ -165,6 +169,7 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
TensorShape shape = shape_or_status.ValueOrDie();
|
TensorShape shape = shape_or_status.ValueOrDie();
|
||||||
|
bool buffer_is_fresh = false;
|
||||||
if (!xla_tensor->has_shaped_buffer()) {
|
if (!xla_tensor->has_shaped_buffer()) {
|
||||||
Status s =
|
Status s =
|
||||||
xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_,
|
xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_,
|
||||||
@ -173,6 +178,7 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
|||||||
done(s);
|
done(s);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
buffer_is_fresh = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status status;
|
Status status;
|
||||||
@ -183,7 +189,8 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
|||||||
"Tensor::CopyFrom failed when copying from CPU to XLA device"));
|
"Tensor::CopyFrom failed when copying from CPU to XLA device"));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor);
|
status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor,
|
||||||
|
buffer_is_fresh);
|
||||||
} else {
|
} else {
|
||||||
se::DeviceMemoryBase dev_dst_ptr =
|
se::DeviceMemoryBase dev_dst_ptr =
|
||||||
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
|
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
|
||||||
|
@ -67,7 +67,8 @@ class XlaTransferManager {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
Status TransferLiteralToDevice(const Tensor& host_tensor,
|
Status TransferLiteralToDevice(const Tensor& host_tensor,
|
||||||
Tensor* device_tensor) const;
|
Tensor* device_tensor,
|
||||||
|
bool buffer_is_fresh) const;
|
||||||
void TransferLiteralFromDevice(Tensor* host_tensor,
|
void TransferLiteralFromDevice(Tensor* host_tensor,
|
||||||
const Tensor& device_tensor,
|
const Tensor& device_tensor,
|
||||||
const StatusCallback& done) const;
|
const StatusCallback& done) const;
|
||||||
|
@ -98,7 +98,7 @@ Status GenericTransferManager::TransferLiteralFromDeviceInternal(
|
|||||||
|
|
||||||
Status GenericTransferManager::TransferLiteralToDeviceAsync(
|
Status GenericTransferManager::TransferLiteralToDeviceAsync(
|
||||||
se::Stream* stream, const LiteralSlice& literal,
|
se::Stream* stream, const LiteralSlice& literal,
|
||||||
const ShapedBuffer& device_buffer) {
|
const ShapedBuffer& device_buffer, TransferToDeviceHint /*hint*/) {
|
||||||
const Shape& shape = literal.shape();
|
const Shape& shape = literal.shape();
|
||||||
VLOG(2) << "transferring literal shape to device: "
|
VLOG(2) << "transferring literal shape to device: "
|
||||||
<< ShapeUtil::HumanString(shape)
|
<< ShapeUtil::HumanString(shape)
|
||||||
|
@ -45,9 +45,10 @@ class GenericTransferManager : public TransferManager {
|
|||||||
MutableBorrowingLiteral literal,
|
MutableBorrowingLiteral literal,
|
||||||
std::function<void(Status)> done) override;
|
std::function<void(Status)> done) override;
|
||||||
|
|
||||||
Status TransferLiteralToDeviceAsync(
|
Status TransferLiteralToDeviceAsync(se::Stream* stream,
|
||||||
se::Stream* stream, const LiteralSlice& literal,
|
const LiteralSlice& literal,
|
||||||
const ShapedBuffer& device_buffer) override;
|
const ShapedBuffer& device_buffer,
|
||||||
|
TransferToDeviceHint hint) override;
|
||||||
|
|
||||||
Status TransferLiteralToInfeed(se::StreamExecutor* executor,
|
Status TransferLiteralToInfeed(se::StreamExecutor* executor,
|
||||||
const LiteralSlice& literal) override;
|
const LiteralSlice& literal) override;
|
||||||
|
@ -89,6 +89,16 @@ class TransferManager {
|
|||||||
const LiteralSlice& literal,
|
const LiteralSlice& literal,
|
||||||
const ShapedBuffer& device_buffer);
|
const ShapedBuffer& device_buffer);
|
||||||
|
|
||||||
|
// Hint type given to TransferLiteralToDeviceAsync.
|
||||||
|
enum TransferToDeviceHint {
|
||||||
|
// No hint available.
|
||||||
|
kNoHint,
|
||||||
|
|
||||||
|
// The destination buffer is undefined on the device, meaning it can be
|
||||||
|
// transferred to eagerly rather than waiting for Stream ordering.
|
||||||
|
kBufferUndefined,
|
||||||
|
};
|
||||||
|
|
||||||
// Transfers the given literal into the previously allocated device memory
|
// Transfers the given literal into the previously allocated device memory
|
||||||
// represented by the given ShapedBuffer using the given executor. The shape
|
// represented by the given ShapedBuffer using the given executor. The shape
|
||||||
// of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible,
|
// of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible,
|
||||||
@ -96,9 +106,13 @@ class TransferManager {
|
|||||||
//
|
//
|
||||||
// This operation is performed asynchronously on the given stream. It returns
|
// This operation is performed asynchronously on the given stream. It returns
|
||||||
// once the transfer is enqueued.
|
// once the transfer is enqueued.
|
||||||
|
//
|
||||||
|
// The optional hint can allow implementations to optimize transfers. It is
|
||||||
|
// not mandatory for an implementation to obey the hint.
|
||||||
virtual Status TransferLiteralToDeviceAsync(
|
virtual Status TransferLiteralToDeviceAsync(
|
||||||
se::Stream* stream, const LiteralSlice& literal,
|
se::Stream* stream, const LiteralSlice& literal,
|
||||||
const ShapedBuffer& device_buffer) = 0;
|
const ShapedBuffer& device_buffer,
|
||||||
|
TransferToDeviceHint hint = kNoHint) = 0;
|
||||||
|
|
||||||
// Convenience methods for transferring an array to or from the device at a
|
// Convenience methods for transferring an array to or from the device at a
|
||||||
// known address. This avoids having to construct a ShapedBuffer just to
|
// known address. This avoids having to construct a ShapedBuffer just to
|
||||||
|
Loading…
Reference in New Issue
Block a user