Add a hint parameter to TransferLiteralToDeviceAsync that the implementation can use to accelerate transfers.

PiperOrigin-RevId: 215362667
This commit is contained in:
A. Unique TensorFlower 2018-10-02 03:36:14 -07:00 committed by TensorFlower Gardener
parent 44da41e490
commit f22037abf5
5 changed files with 33 additions and 10 deletions

View File

@ -75,8 +75,9 @@ XlaTransferManager::XlaTransferManager(
}
}
Status XlaTransferManager::TransferLiteralToDevice(
const Tensor& host_tensor, Tensor* device_tensor) const {
Status XlaTransferManager::TransferLiteralToDevice(const Tensor& host_tensor,
Tensor* device_tensor,
bool buffer_is_fresh) const {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(),
host_tensor.shape(), &xla_shape));
@ -97,8 +98,11 @@ Status XlaTransferManager::TransferLiteralToDevice(
// synchronized.
host_to_device_stream_->ThenWaitFor(stream_.get());
}
xla::TransferManager::TransferToDeviceHint hint =
buffer_is_fresh ? xla::TransferManager::kBufferUndefined
: xla::TransferManager::kNoHint;
TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync(
host_to_device_stream_.get(), *literal, shaped_buffer));
host_to_device_stream_.get(), *literal, shaped_buffer, hint));
if (UseMultipleStreams()) {
auto event = std::make_shared<se::Event>(stream_->parent());
TF_RET_CHECK(event->Init()) << "Event failed to initialize!";
@ -165,6 +169,7 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
return;
}
TensorShape shape = shape_or_status.ValueOrDie();
bool buffer_is_fresh = false;
if (!xla_tensor->has_shaped_buffer()) {
Status s =
xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_,
@ -173,6 +178,7 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
done(s);
return;
}
buffer_is_fresh = true;
}
Status status;
@ -183,7 +189,8 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
"Tensor::CopyFrom failed when copying from CPU to XLA device"));
return;
}
status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor);
status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor,
buffer_is_fresh);
} else {
se::DeviceMemoryBase dev_dst_ptr =
XlaTensor::DeviceMemoryFromTensor(*device_tensor);

View File

@ -67,7 +67,8 @@ class XlaTransferManager {
private:
Status TransferLiteralToDevice(const Tensor& host_tensor,
Tensor* device_tensor) const;
Tensor* device_tensor,
bool buffer_is_fresh) const;
void TransferLiteralFromDevice(Tensor* host_tensor,
const Tensor& device_tensor,
const StatusCallback& done) const;

View File

@ -98,7 +98,7 @@ Status GenericTransferManager::TransferLiteralFromDeviceInternal(
Status GenericTransferManager::TransferLiteralToDeviceAsync(
se::Stream* stream, const LiteralSlice& literal,
const ShapedBuffer& device_buffer) {
const ShapedBuffer& device_buffer, TransferToDeviceHint /*hint*/) {
const Shape& shape = literal.shape();
VLOG(2) << "transferring literal shape to device: "
<< ShapeUtil::HumanString(shape)

View File

@ -45,9 +45,10 @@ class GenericTransferManager : public TransferManager {
MutableBorrowingLiteral literal,
std::function<void(Status)> done) override;
Status TransferLiteralToDeviceAsync(
se::Stream* stream, const LiteralSlice& literal,
const ShapedBuffer& device_buffer) override;
Status TransferLiteralToDeviceAsync(se::Stream* stream,
const LiteralSlice& literal,
const ShapedBuffer& device_buffer,
TransferToDeviceHint hint) override;
Status TransferLiteralToInfeed(se::StreamExecutor* executor,
const LiteralSlice& literal) override;

View File

@ -89,6 +89,16 @@ class TransferManager {
const LiteralSlice& literal,
const ShapedBuffer& device_buffer);
// Hint type given to TransferLiteralToDeviceAsync.
enum TransferToDeviceHint {
// No hint available.
kNoHint,
// The destination buffer is undefined on the device, meaning it can be
// transferred to eagerly rather than waiting for Stream ordering.
kBufferUndefined,
};
// Transfers the given literal into the previously allocated device memory
// represented by the given ShapedBuffer using the given executor. The shape
// of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible,
@ -96,9 +106,13 @@ class TransferManager {
//
// This operation is performed asynchronously on the given stream. It returns
// once the transfer is enqueued.
//
// The optional hint can allow implementations to optimize transfers. It is
// not mandatory for an implementation to obey the hint.
virtual Status TransferLiteralToDeviceAsync(
se::Stream* stream, const LiteralSlice& literal,
const ShapedBuffer& device_buffer) = 0;
const ShapedBuffer& device_buffer,
TransferToDeviceHint hint = kNoHint) = 0;
// Convenience methods for transferring an array to or from the device at a
// known address. This avoids having to construct a ShapedBuffer just to