Add a hint parameter to TransferLiteralToDeviceAsync that the implementation can use to accelerate transfers.

PiperOrigin-RevId: 215362667
This commit is contained in:
A. Unique TensorFlower 2018-10-02 03:36:14 -07:00 committed by TensorFlower Gardener
parent 44da41e490
commit f22037abf5
5 changed files with 33 additions and 10 deletions

View File

@ -75,8 +75,9 @@ XlaTransferManager::XlaTransferManager(
} }
} }
Status XlaTransferManager::TransferLiteralToDevice( Status XlaTransferManager::TransferLiteralToDevice(const Tensor& host_tensor,
const Tensor& host_tensor, Tensor* device_tensor) const { Tensor* device_tensor,
bool buffer_is_fresh) const {
xla::Shape xla_shape; xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(), TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(),
host_tensor.shape(), &xla_shape)); host_tensor.shape(), &xla_shape));
@ -97,8 +98,11 @@ Status XlaTransferManager::TransferLiteralToDevice(
// synchronized. // synchronized.
host_to_device_stream_->ThenWaitFor(stream_.get()); host_to_device_stream_->ThenWaitFor(stream_.get());
} }
xla::TransferManager::TransferToDeviceHint hint =
buffer_is_fresh ? xla::TransferManager::kBufferUndefined
: xla::TransferManager::kNoHint;
TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync( TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync(
host_to_device_stream_.get(), *literal, shaped_buffer)); host_to_device_stream_.get(), *literal, shaped_buffer, hint));
if (UseMultipleStreams()) { if (UseMultipleStreams()) {
auto event = std::make_shared<se::Event>(stream_->parent()); auto event = std::make_shared<se::Event>(stream_->parent());
TF_RET_CHECK(event->Init()) << "Event failed to initialize!"; TF_RET_CHECK(event->Init()) << "Event failed to initialize!";
@ -165,6 +169,7 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
return; return;
} }
TensorShape shape = shape_or_status.ValueOrDie(); TensorShape shape = shape_or_status.ValueOrDie();
bool buffer_is_fresh = false;
if (!xla_tensor->has_shaped_buffer()) { if (!xla_tensor->has_shaped_buffer()) {
Status s = Status s =
xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_, xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_,
@ -173,6 +178,7 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
done(s); done(s);
return; return;
} }
buffer_is_fresh = true;
} }
Status status; Status status;
@ -183,7 +189,8 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
"Tensor::CopyFrom failed when copying from CPU to XLA device")); "Tensor::CopyFrom failed when copying from CPU to XLA device"));
return; return;
} }
status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor); status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor,
buffer_is_fresh);
} else { } else {
se::DeviceMemoryBase dev_dst_ptr = se::DeviceMemoryBase dev_dst_ptr =
XlaTensor::DeviceMemoryFromTensor(*device_tensor); XlaTensor::DeviceMemoryFromTensor(*device_tensor);

View File

@ -67,7 +67,8 @@ class XlaTransferManager {
private: private:
Status TransferLiteralToDevice(const Tensor& host_tensor, Status TransferLiteralToDevice(const Tensor& host_tensor,
Tensor* device_tensor) const; Tensor* device_tensor,
bool buffer_is_fresh) const;
void TransferLiteralFromDevice(Tensor* host_tensor, void TransferLiteralFromDevice(Tensor* host_tensor,
const Tensor& device_tensor, const Tensor& device_tensor,
const StatusCallback& done) const; const StatusCallback& done) const;

View File

@ -98,7 +98,7 @@ Status GenericTransferManager::TransferLiteralFromDeviceInternal(
Status GenericTransferManager::TransferLiteralToDeviceAsync( Status GenericTransferManager::TransferLiteralToDeviceAsync(
se::Stream* stream, const LiteralSlice& literal, se::Stream* stream, const LiteralSlice& literal,
const ShapedBuffer& device_buffer) { const ShapedBuffer& device_buffer, TransferToDeviceHint /*hint*/) {
const Shape& shape = literal.shape(); const Shape& shape = literal.shape();
VLOG(2) << "transferring literal shape to device: " VLOG(2) << "transferring literal shape to device: "
<< ShapeUtil::HumanString(shape) << ShapeUtil::HumanString(shape)

View File

@ -45,9 +45,10 @@ class GenericTransferManager : public TransferManager {
MutableBorrowingLiteral literal, MutableBorrowingLiteral literal,
std::function<void(Status)> done) override; std::function<void(Status)> done) override;
Status TransferLiteralToDeviceAsync( Status TransferLiteralToDeviceAsync(se::Stream* stream,
se::Stream* stream, const LiteralSlice& literal, const LiteralSlice& literal,
const ShapedBuffer& device_buffer) override; const ShapedBuffer& device_buffer,
TransferToDeviceHint hint) override;
Status TransferLiteralToInfeed(se::StreamExecutor* executor, Status TransferLiteralToInfeed(se::StreamExecutor* executor,
const LiteralSlice& literal) override; const LiteralSlice& literal) override;

View File

@ -89,6 +89,16 @@ class TransferManager {
const LiteralSlice& literal, const LiteralSlice& literal,
const ShapedBuffer& device_buffer); const ShapedBuffer& device_buffer);
// Hint type given to TransferLiteralToDeviceAsync.
enum TransferToDeviceHint {
// No hint available.
kNoHint,
// The destination buffer is undefined on the device, meaning it can be
// transferred to eagerly rather than waiting for Stream ordering.
kBufferUndefined,
};
// Transfers the given literal into the previously allocated device memory // Transfers the given literal into the previously allocated device memory
// represented by the given ShapedBuffer using the given executor. The shape // represented by the given ShapedBuffer using the given executor. The shape
// of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible, // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible,
@ -96,9 +106,13 @@ class TransferManager {
// //
// This operation is performed asynchronously on the given stream. It returns // This operation is performed asynchronously on the given stream. It returns
// once the transfer is enqueued. // once the transfer is enqueued.
//
// The optional hint can allow implementations to optimize transfers. It is
// not mandatory for an implementation to obey the hint.
virtual Status TransferLiteralToDeviceAsync( virtual Status TransferLiteralToDeviceAsync(
se::Stream* stream, const LiteralSlice& literal, se::Stream* stream, const LiteralSlice& literal,
const ShapedBuffer& device_buffer) = 0; const ShapedBuffer& device_buffer,
TransferToDeviceHint hint = kNoHint) = 0;
// Convenience methods for transferring an array to or from the device at a // Convenience methods for transferring an array to or from the device at a
// known address. This avoids having to construct a ShapedBuffer just to // known address. This avoids having to construct a ShapedBuffer just to