diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index af83c792e5e..e083652978f 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -75,8 +75,9 @@ XlaTransferManager::XlaTransferManager( } } -Status XlaTransferManager::TransferLiteralToDevice( - const Tensor& host_tensor, Tensor* device_tensor) const { +Status XlaTransferManager::TransferLiteralToDevice(const Tensor& host_tensor, + Tensor* device_tensor, + bool buffer_is_fresh) const { xla::Shape xla_shape; TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(), host_tensor.shape(), &xla_shape)); @@ -97,8 +98,11 @@ Status XlaTransferManager::TransferLiteralToDevice( // synchronized. host_to_device_stream_->ThenWaitFor(stream_.get()); } + xla::TransferManager::TransferToDeviceHint hint = + buffer_is_fresh ? xla::TransferManager::kBufferUndefined + : xla::TransferManager::kNoHint; TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync( - host_to_device_stream_.get(), *literal, shaped_buffer)); + host_to_device_stream_.get(), *literal, shaped_buffer, hint)); if (UseMultipleStreams()) { auto event = std::make_shared(stream_->parent()); TF_RET_CHECK(event->Init()) << "Event failed to initialize!"; @@ -165,6 +169,7 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, return; } TensorShape shape = shape_or_status.ValueOrDie(); + bool buffer_is_fresh = false; if (!xla_tensor->has_shaped_buffer()) { Status s = xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_, @@ -173,6 +178,7 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, done(s); return; } + buffer_is_fresh = true; } Status status; @@ -183,7 +189,8 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, "Tensor::CopyFrom failed when copying from CPU to XLA device")); return; } - status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor); + status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor, + buffer_is_fresh); } else { se::DeviceMemoryBase dev_dst_ptr = XlaTensor::DeviceMemoryFromTensor(*device_tensor); diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index df824212948..a4c0c296fcb 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -67,7 +67,8 @@ class XlaTransferManager { private: Status TransferLiteralToDevice(const Tensor& host_tensor, - Tensor* device_tensor) const; + Tensor* device_tensor, + bool buffer_is_fresh) const; void TransferLiteralFromDevice(Tensor* host_tensor, const Tensor& device_tensor, const StatusCallback& done) const; diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index bec02e14f95..f92fde7f46a 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -98,7 +98,7 @@ Status GenericTransferManager::TransferLiteralFromDeviceInternal( Status GenericTransferManager::TransferLiteralToDeviceAsync( se::Stream* stream, const LiteralSlice& literal, - const ShapedBuffer& device_buffer) { + const ShapedBuffer& device_buffer, TransferToDeviceHint /*hint*/) { const Shape& shape = literal.shape(); VLOG(2) << "transferring literal shape to device: " << ShapeUtil::HumanString(shape) diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index 86c8b1c145a..b1cba82b9fb 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -45,9 +45,10 @@ class GenericTransferManager : public TransferManager { MutableBorrowingLiteral literal, std::function done) override; - Status TransferLiteralToDeviceAsync( - se::Stream* stream, const LiteralSlice& literal, - const ShapedBuffer& device_buffer) override; + Status TransferLiteralToDeviceAsync(se::Stream* stream, + const LiteralSlice& literal, + const ShapedBuffer& device_buffer, + TransferToDeviceHint hint) override; Status TransferLiteralToInfeed(se::StreamExecutor* executor, const LiteralSlice& literal) override; diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index f952e64af2b..9199e32d0f5 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -89,6 +89,16 @@ class TransferManager { const LiteralSlice& literal, const ShapedBuffer& device_buffer); + // Hint type given to TransferLiteralToDeviceAsync. + enum TransferToDeviceHint { + // No hint available. + kNoHint, + + // The destination buffer is undefined on the device, meaning it can be + // transferred to eagerly rather than waiting for Stream ordering. + kBufferUndefined, + }; + // Transfers the given literal into the previously allocated device memory // represented by the given ShapedBuffer using the given executor. The shape // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible, @@ -96,9 +106,13 @@ class TransferManager { // // This operation is performed asynchronously on the given stream. It returns // once the transfer is enqueued. + // + // The optional hint can allow implementations to optimize transfers. It is + // not mandatory for an implementation to obey the hint. virtual Status TransferLiteralToDeviceAsync( se::Stream* stream, const LiteralSlice& literal, - const ShapedBuffer& device_buffer) = 0; + const ShapedBuffer& device_buffer, + TransferToDeviceHint hint = kNoHint) = 0; // Convenience methods for transferring an array to or from the device at a // known address. This avoids having to construct a ShapedBuffer just to