From 4b59a16a2baf7297053a4639ab3fe04a3b4cfe3e Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 20 Jun 2019 07:22:45 -0700 Subject: [PATCH] [XLA:Python] Add support for direct device-to-device copies. PiperOrigin-RevId: 254191250 --- .../compiler/xla/python/local_client.cc | 124 ++++++++++++++++-- tensorflow/compiler/xla/python/local_client.h | 25 +++- tensorflow/compiler/xla/python/xla.cc | 1 + 3 files changed, 138 insertions(+), 12 deletions(-) diff --git a/tensorflow/compiler/xla/python/local_client.cc b/tensorflow/compiler/xla/python/local_client.cc index 125d97e38d4..aa4933baed4 100644 --- a/tensorflow/compiler/xla/python/local_client.cc +++ b/tensorflow/compiler/xla/python/local_client.cc @@ -123,9 +123,16 @@ Device::Device(se::StreamExecutor* executor, bool use_multiple_streams, host_to_device_stream_->Init(); device_to_host_stream_->Init(); callback_stream_->Init(); + device_to_device_streams_.reserve(kNumDeviceToDeviceStreams); + for (int i = 0; i < kNumDeviceToDeviceStreams; ++i) { + auto stream = std::make_shared(executor); + stream->Init(); + device_to_device_streams_.push_back(std::move(stream)); + } } else { callback_stream_ = host_to_device_stream_ = device_to_host_stream_ = compute_stream_; + device_to_device_streams_.push_back(compute_stream_); } worker_thread_ = absl::make_unique(tensorflow::Env::Default(), "py_xla_execute"); @@ -153,12 +160,31 @@ Status Device::SynchronizeAllActivity() { return status; } +Status Device::ThenMemcpyDeviceToDevice(se::Stream* src_stream, + se::Stream* dst_stream, + se::DeviceMemoryBase src_buffer, + se::DeviceMemoryBase dst_buffer) { + // The default implementation simply calls ThenMemcpyD2D, and assumes that + // the buffer addresses identify the devices. This does not work + // on all platforms; this method is virtual so it can be overridden. + src_stream->ThenMemcpyD2D(&dst_buffer, src_buffer, dst_buffer.size()); + return Status::OK(); +} + void Device::ThenExecuteOnWorkerThread(se::Stream* stream, std::function callback) const { stream->ThenDoHostCallback( [this, callback]() { worker_thread_->Schedule(std::move(callback)); }); } +se::Stream* Device::GetDeviceToDeviceStream() { + absl::MutexLock lock(&mu_); + int i = next_device_to_device_stream_; + next_device_to_device_stream_ = + (next_device_to_device_stream_ + 1) % device_to_device_streams_.size(); + return device_to_device_streams_.at(i).get(); +} + static StatusOr> CreateBFCAllocator( se::Platform* platform, LocalClient* client, double memory_fraction, bool preallocate) { @@ -224,15 +250,30 @@ StatusOr> PyLocalClient::Get( allocator_config.preallocate)); allocator = std::move(bfc_allocator); } + + std::vector> devices; + devices.reserve(client->device_count()); + bool use_multiple_streams = (platform_name != "cpu"); + bool synchronous_deallocation = !use_multiple_streams; + for (int i = 0; i < client->device_count(); ++i) { + se::StreamExecutor* executor = + client->backend().stream_executor(i).ValueOrDie(); + devices.push_back(absl::make_unique(executor, use_multiple_streams, + synchronous_deallocation, + asynchronous)); + } return std::make_shared(platform_name, client, + std::move(devices), std::move(allocator), asynchronous); } PyLocalClient::PyLocalClient( std::string platform_name, LocalClient* client, + std::vector> devices, std::unique_ptr allocator, bool asynchronous) : platform_name_(std::move(platform_name)), client_(client), + devices_(std::move(devices)), owned_allocator_(std::move(allocator)), h2d_transfer_pool_(tensorflow::Env::Default(), "py_xla_h2d_transfer", client->device_count()) { @@ -241,16 +282,6 @@ PyLocalClient::PyLocalClient( } else { allocator_ = client_->backend().memory_allocator(); } - devices_.reserve(client->device_count()); - bool use_multiple_streams = (platform_name_ != "cpu"); - bool synchronous_deallocation = !use_multiple_streams; - for (int i = 0; i < client->device_count(); ++i) { - se::StreamExecutor* executor = - client_->backend().stream_executor(i).ValueOrDie(); - devices_.push_back(absl::make_unique(executor, use_multiple_streams, - synchronous_deallocation, - asynchronous)); - } } Status PyLocalClient::TransferToInfeed(const LiteralSlice& literal, @@ -577,6 +608,79 @@ PyLocalBuffer::DestructureTuple() { return results; } +StatusOr> PyLocalBuffer::CopyToDevice( + int dst_device_ordinal) { + tensorflow::profiler::TraceMe traceme("PyLocalBuffer::CopyToDevice"); + client_->py_ref_manager().CollectGarbage(); + py::gil_scoped_release gil_release; + std::shared_ptr src_device_buffer = DeviceBuffer(); + if (dst_device_ordinal == device_ordinal_) { + return absl::make_unique(on_host_shape_, src_device_buffer, + client_); + } + Device& src_device = client_->device(device_ordinal_); + const Device& dst_device = client_->device(dst_device_ordinal); + + se::Stream* src_device_to_device_stream = + src_device.GetDeviceToDeviceStream(); + + TransferManager* transfer_manager = + client_->client()->backend().transfer_manager(); + TF_ASSIGN_OR_RETURN( + ScopedShapedBuffer dst_buffer, + transfer_manager->AllocateScopedShapedBuffer( + on_host_shape_, client_->allocator(), dst_device_ordinal)); + if (dst_device.use_multiple_streams() && + !transfer_manager->CanShapedBufferBeAccessedNow( + dst_device.compute_stream()->parent(), dst_buffer)) { + src_device_to_device_stream->ThenWaitFor(dst_device.compute_stream()); + } + TF_ASSIGN_OR_RETURN(ShapedBuffer src_buffer, AsShapedBuffer()); + + WaitForBufferDefinitionEventsOnStream(*src_device_buffer, + src_device_to_device_stream); + + // Copy the leaf buffers. + for (const auto& leaf : src_buffer.buffers().leaves()) { + const xla::ShapeIndex& index = leaf.first; + const se::DeviceMemoryBase& input_buffer = leaf.second; + const se::DeviceMemoryBase& output_buffer = dst_buffer.buffer(index); + TF_RET_CHECK(input_buffer.size() == output_buffer.size()) + << "input: " << input_buffer.size() + << " output: " << output_buffer.size(); + TF_RETURN_IF_ERROR(src_device.ThenMemcpyDeviceToDevice( + src_device_to_device_stream, dst_device.compute_stream(), input_buffer, + output_buffer)); + } + + // Write new tuple buffers. The destination buffers have different addresses, + // so we must construct tuple buffers from scratch instead of copying them. + if (dst_buffer.on_device_shape().IsTuple()) { + TF_RETURN_IF_ERROR(transfer_manager->WriteTupleIndexTablesAsync( + dst_device.host_to_device_stream(), dst_buffer)); + + // We need a single definition event, so make the device to device stream + // wait for the stream that wrote the tuple index tables on the destination + // device. + src_device_to_device_stream->ThenWaitFor( + dst_device.host_to_device_stream()); + } + + std::shared_ptr definition_event; + if (dst_device.use_multiple_streams()) { + TF_ASSIGN_OR_RETURN( + definition_event, + BufferDefinitionEvent::Create(src_device_to_device_stream->parent())); + definition_event->RecordOnStream(src_device_to_device_stream); + } + + std::shared_ptr dst_device_buffer = + PySharedDeviceBuffer::FromScopedShapedBuffer(std::move(dst_buffer), + definition_event); + return absl::make_unique( + on_host_shape_, std::move(dst_device_buffer), client_); +} + Status PyLocalBuffer::BlockHostUntilReady() { tensorflow::profiler::TraceMe traceme("PyLocalBuffer::BlockHostUntilReady"); std::shared_ptr device_buffer = DeviceBuffer(); diff --git a/tensorflow/compiler/xla/python/local_client.h b/tensorflow/compiler/xla/python/local_client.h index 8070d360074..efed57c1f28 100644 --- a/tensorflow/compiler/xla/python/local_client.h +++ b/tensorflow/compiler/xla/python/local_client.h @@ -62,7 +62,7 @@ class Device { // each execution or transfer. This is intended for debugging only. Device(se::StreamExecutor* executor, bool use_multiple_streams, bool synchronous_deallocation, bool asynchronous); - ~Device(); + virtual ~Device(); bool use_multiple_streams() const { return use_multiple_streams_; } bool synchronous_deallocation() const { return synchronous_deallocation_; } @@ -75,6 +75,16 @@ class Device { return device_to_host_stream_.get(); } + // Returns a device to device stream. Allocates streams in a round-robin + // fashion amongst the available streams. + se::Stream* GetDeviceToDeviceStream(); + + // Enqueues a copy of `src_buffer` to `dst_buffer` onto `src_stream`. + virtual Status ThenMemcpyDeviceToDevice(se::Stream* src_stream, + se::Stream* dst_stream, + se::DeviceMemoryBase src_buffer, + se::DeviceMemoryBase dst_buffer); + // A worker thread, used for replicated computation launches and callbacks. WorkerThread* worker_thread() const { return worker_thread_.get(); } @@ -132,6 +142,13 @@ class Device { std::shared_ptr compute_stream_; std::shared_ptr host_to_device_stream_; std::shared_ptr device_to_host_stream_; + std::vector> device_to_device_streams_; + + // Number of device-to-device streams to create in the multistream case. + static constexpr int kNumDeviceToDeviceStreams = 4; + + absl::Mutex mu_; + int next_device_to_device_stream_ GUARDED_BY(mu_) = 0; // Callback stream is used for running short host-side callbacks after device // side events, without preventing the device-side stream from doing useful @@ -172,6 +189,7 @@ class PyLocalClient { // `allocator` may null, in which case the platform default allocator is used. explicit PyLocalClient(std::string platform_name, LocalClient* client, + std::vector> devices, std::unique_ptr allocator, bool asynchronous); virtual ~PyLocalClient() = default; @@ -181,7 +199,7 @@ class PyLocalClient { int device_ordinal); int device_count() const { return client_->device_count(); } - const Device& device(int device_ordinal) const { + Device& device(int device_ordinal) const { return *devices_.at(device_ordinal); } LocalClient* client() const { return client_; } @@ -267,6 +285,9 @@ class PyLocalBuffer { // Destructures a tuple-valued PyLocalBuffer into its constituent elements. StatusOr>> DestructureTuple(); + // Copies the buffer to device `dst_device_ordinal`. + StatusOr> CopyToDevice(int dst_device_ordinal); + // Blocks the host until the buffer's value has been computed and is ready for // immediate use on the device. Useful in particular for timing benchmarks. Status BlockHostUntilReady(); diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index 4f3de836bf4..40bf429d91f 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -299,6 +299,7 @@ PYBIND11_MODULE(xla_extension, m) { .def_static("from_python", &PyLocalBuffer::FromPython) .def_static("from_python_values", &PyLocalBuffer::FromPythonValues) .def_static("make_tuple", &PyLocalBuffer::MakeTuple) + .def("copy_to_device", &PyLocalBuffer::CopyToDevice) .def("delete", &PyLocalBuffer::Delete) .def("destructure", &PyLocalBuffer::DestructureTuple) .def("block_host_until_ready", &PyLocalBuffer::BlockHostUntilReady)