[XLA:Python] Add support for direct device-to-device copies.
PiperOrigin-RevId: 254191250
This commit is contained in:
parent
af4eb9c864
commit
4b59a16a2b
|
@ -123,9 +123,16 @@ Device::Device(se::StreamExecutor* executor, bool use_multiple_streams,
|
|||
host_to_device_stream_->Init();
|
||||
device_to_host_stream_->Init();
|
||||
callback_stream_->Init();
|
||||
device_to_device_streams_.reserve(kNumDeviceToDeviceStreams);
|
||||
for (int i = 0; i < kNumDeviceToDeviceStreams; ++i) {
|
||||
auto stream = std::make_shared<se::Stream>(executor);
|
||||
stream->Init();
|
||||
device_to_device_streams_.push_back(std::move(stream));
|
||||
}
|
||||
} else {
|
||||
callback_stream_ = host_to_device_stream_ = device_to_host_stream_ =
|
||||
compute_stream_;
|
||||
device_to_device_streams_.push_back(compute_stream_);
|
||||
}
|
||||
worker_thread_ = absl::make_unique<WorkerThread>(tensorflow::Env::Default(),
|
||||
"py_xla_execute");
|
||||
|
@ -153,12 +160,31 @@ Status Device::SynchronizeAllActivity() {
|
|||
return status;
|
||||
}
|
||||
|
||||
Status Device::ThenMemcpyDeviceToDevice(se::Stream* src_stream,
|
||||
se::Stream* dst_stream,
|
||||
se::DeviceMemoryBase src_buffer,
|
||||
se::DeviceMemoryBase dst_buffer) {
|
||||
// The default implementation simply calls ThenMemcpyD2D, and assumes that
|
||||
// the buffer addresses identify the devices. This does not work
|
||||
// on all platforms; this method is virtual so it can be overridden.
|
||||
src_stream->ThenMemcpyD2D(&dst_buffer, src_buffer, dst_buffer.size());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void Device::ThenExecuteOnWorkerThread(se::Stream* stream,
|
||||
std::function<void()> callback) const {
|
||||
stream->ThenDoHostCallback(
|
||||
[this, callback]() { worker_thread_->Schedule(std::move(callback)); });
|
||||
}
|
||||
|
||||
se::Stream* Device::GetDeviceToDeviceStream() {
|
||||
absl::MutexLock lock(&mu_);
|
||||
int i = next_device_to_device_stream_;
|
||||
next_device_to_device_stream_ =
|
||||
(next_device_to_device_stream_ + 1) % device_to_device_streams_.size();
|
||||
return device_to_device_streams_.at(i).get();
|
||||
}
|
||||
|
||||
static StatusOr<std::unique_ptr<se::MultiDeviceAdapter>> CreateBFCAllocator(
|
||||
se::Platform* platform, LocalClient* client, double memory_fraction,
|
||||
bool preallocate) {
|
||||
|
@ -224,15 +250,30 @@ StatusOr<std::shared_ptr<PyLocalClient>> PyLocalClient::Get(
|
|||
allocator_config.preallocate));
|
||||
allocator = std::move(bfc_allocator);
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
devices.reserve(client->device_count());
|
||||
bool use_multiple_streams = (platform_name != "cpu");
|
||||
bool synchronous_deallocation = !use_multiple_streams;
|
||||
for (int i = 0; i < client->device_count(); ++i) {
|
||||
se::StreamExecutor* executor =
|
||||
client->backend().stream_executor(i).ValueOrDie();
|
||||
devices.push_back(absl::make_unique<Device>(executor, use_multiple_streams,
|
||||
synchronous_deallocation,
|
||||
asynchronous));
|
||||
}
|
||||
return std::make_shared<PyLocalClient>(platform_name, client,
|
||||
std::move(devices),
|
||||
std::move(allocator), asynchronous);
|
||||
}
|
||||
|
||||
PyLocalClient::PyLocalClient(
|
||||
std::string platform_name, LocalClient* client,
|
||||
std::vector<std::unique_ptr<Device>> devices,
|
||||
std::unique_ptr<se::DeviceMemoryAllocator> allocator, bool asynchronous)
|
||||
: platform_name_(std::move(platform_name)),
|
||||
client_(client),
|
||||
devices_(std::move(devices)),
|
||||
owned_allocator_(std::move(allocator)),
|
||||
h2d_transfer_pool_(tensorflow::Env::Default(), "py_xla_h2d_transfer",
|
||||
client->device_count()) {
|
||||
|
@ -241,16 +282,6 @@ PyLocalClient::PyLocalClient(
|
|||
} else {
|
||||
allocator_ = client_->backend().memory_allocator();
|
||||
}
|
||||
devices_.reserve(client->device_count());
|
||||
bool use_multiple_streams = (platform_name_ != "cpu");
|
||||
bool synchronous_deallocation = !use_multiple_streams;
|
||||
for (int i = 0; i < client->device_count(); ++i) {
|
||||
se::StreamExecutor* executor =
|
||||
client_->backend().stream_executor(i).ValueOrDie();
|
||||
devices_.push_back(absl::make_unique<Device>(executor, use_multiple_streams,
|
||||
synchronous_deallocation,
|
||||
asynchronous));
|
||||
}
|
||||
}
|
||||
|
||||
Status PyLocalClient::TransferToInfeed(const LiteralSlice& literal,
|
||||
|
@ -577,6 +608,79 @@ PyLocalBuffer::DestructureTuple() {
|
|||
return results;
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::CopyToDevice(
|
||||
int dst_device_ordinal) {
|
||||
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::CopyToDevice");
|
||||
client_->py_ref_manager().CollectGarbage();
|
||||
py::gil_scoped_release gil_release;
|
||||
std::shared_ptr<PySharedDeviceBuffer> src_device_buffer = DeviceBuffer();
|
||||
if (dst_device_ordinal == device_ordinal_) {
|
||||
return absl::make_unique<PyLocalBuffer>(on_host_shape_, src_device_buffer,
|
||||
client_);
|
||||
}
|
||||
Device& src_device = client_->device(device_ordinal_);
|
||||
const Device& dst_device = client_->device(dst_device_ordinal);
|
||||
|
||||
se::Stream* src_device_to_device_stream =
|
||||
src_device.GetDeviceToDeviceStream();
|
||||
|
||||
TransferManager* transfer_manager =
|
||||
client_->client()->backend().transfer_manager();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
ScopedShapedBuffer dst_buffer,
|
||||
transfer_manager->AllocateScopedShapedBuffer(
|
||||
on_host_shape_, client_->allocator(), dst_device_ordinal));
|
||||
if (dst_device.use_multiple_streams() &&
|
||||
!transfer_manager->CanShapedBufferBeAccessedNow(
|
||||
dst_device.compute_stream()->parent(), dst_buffer)) {
|
||||
src_device_to_device_stream->ThenWaitFor(dst_device.compute_stream());
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(ShapedBuffer src_buffer, AsShapedBuffer());
|
||||
|
||||
WaitForBufferDefinitionEventsOnStream(*src_device_buffer,
|
||||
src_device_to_device_stream);
|
||||
|
||||
// Copy the leaf buffers.
|
||||
for (const auto& leaf : src_buffer.buffers().leaves()) {
|
||||
const xla::ShapeIndex& index = leaf.first;
|
||||
const se::DeviceMemoryBase& input_buffer = leaf.second;
|
||||
const se::DeviceMemoryBase& output_buffer = dst_buffer.buffer(index);
|
||||
TF_RET_CHECK(input_buffer.size() == output_buffer.size())
|
||||
<< "input: " << input_buffer.size()
|
||||
<< " output: " << output_buffer.size();
|
||||
TF_RETURN_IF_ERROR(src_device.ThenMemcpyDeviceToDevice(
|
||||
src_device_to_device_stream, dst_device.compute_stream(), input_buffer,
|
||||
output_buffer));
|
||||
}
|
||||
|
||||
// Write new tuple buffers. The destination buffers have different addresses,
|
||||
// so we must construct tuple buffers from scratch instead of copying them.
|
||||
if (dst_buffer.on_device_shape().IsTuple()) {
|
||||
TF_RETURN_IF_ERROR(transfer_manager->WriteTupleIndexTablesAsync(
|
||||
dst_device.host_to_device_stream(), dst_buffer));
|
||||
|
||||
// We need a single definition event, so make the device to device stream
|
||||
// wait for the stream that wrote the tuple index tables on the destination
|
||||
// device.
|
||||
src_device_to_device_stream->ThenWaitFor(
|
||||
dst_device.host_to_device_stream());
|
||||
}
|
||||
|
||||
std::shared_ptr<BufferDefinitionEvent> definition_event;
|
||||
if (dst_device.use_multiple_streams()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
definition_event,
|
||||
BufferDefinitionEvent::Create(src_device_to_device_stream->parent()));
|
||||
definition_event->RecordOnStream(src_device_to_device_stream);
|
||||
}
|
||||
|
||||
std::shared_ptr<PySharedDeviceBuffer> dst_device_buffer =
|
||||
PySharedDeviceBuffer::FromScopedShapedBuffer(std::move(dst_buffer),
|
||||
definition_event);
|
||||
return absl::make_unique<PyLocalBuffer>(
|
||||
on_host_shape_, std::move(dst_device_buffer), client_);
|
||||
}
|
||||
|
||||
Status PyLocalBuffer::BlockHostUntilReady() {
|
||||
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::BlockHostUntilReady");
|
||||
std::shared_ptr<PySharedDeviceBuffer> device_buffer = DeviceBuffer();
|
||||
|
|
|
@ -62,7 +62,7 @@ class Device {
|
|||
// each execution or transfer. This is intended for debugging only.
|
||||
Device(se::StreamExecutor* executor, bool use_multiple_streams,
|
||||
bool synchronous_deallocation, bool asynchronous);
|
||||
~Device();
|
||||
virtual ~Device();
|
||||
|
||||
bool use_multiple_streams() const { return use_multiple_streams_; }
|
||||
bool synchronous_deallocation() const { return synchronous_deallocation_; }
|
||||
|
@ -75,6 +75,16 @@ class Device {
|
|||
return device_to_host_stream_.get();
|
||||
}
|
||||
|
||||
// Returns a device to device stream. Allocates streams in a round-robin
|
||||
// fashion amongst the available streams.
|
||||
se::Stream* GetDeviceToDeviceStream();
|
||||
|
||||
// Enqueues a copy of `src_buffer` to `dst_buffer` onto `src_stream`.
|
||||
virtual Status ThenMemcpyDeviceToDevice(se::Stream* src_stream,
|
||||
se::Stream* dst_stream,
|
||||
se::DeviceMemoryBase src_buffer,
|
||||
se::DeviceMemoryBase dst_buffer);
|
||||
|
||||
// A worker thread, used for replicated computation launches and callbacks.
|
||||
WorkerThread* worker_thread() const { return worker_thread_.get(); }
|
||||
|
||||
|
@ -132,6 +142,13 @@ class Device {
|
|||
std::shared_ptr<se::Stream> compute_stream_;
|
||||
std::shared_ptr<se::Stream> host_to_device_stream_;
|
||||
std::shared_ptr<se::Stream> device_to_host_stream_;
|
||||
std::vector<std::shared_ptr<se::Stream>> device_to_device_streams_;
|
||||
|
||||
// Number of device-to-device streams to create in the multistream case.
|
||||
static constexpr int kNumDeviceToDeviceStreams = 4;
|
||||
|
||||
absl::Mutex mu_;
|
||||
int next_device_to_device_stream_ GUARDED_BY(mu_) = 0;
|
||||
|
||||
// Callback stream is used for running short host-side callbacks after device
|
||||
// side events, without preventing the device-side stream from doing useful
|
||||
|
@ -172,6 +189,7 @@ class PyLocalClient {
|
|||
|
||||
// `allocator` may null, in which case the platform default allocator is used.
|
||||
explicit PyLocalClient(std::string platform_name, LocalClient* client,
|
||||
std::vector<std::unique_ptr<Device>> devices,
|
||||
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
||||
bool asynchronous);
|
||||
virtual ~PyLocalClient() = default;
|
||||
|
@ -181,7 +199,7 @@ class PyLocalClient {
|
|||
int device_ordinal);
|
||||
|
||||
int device_count() const { return client_->device_count(); }
|
||||
const Device& device(int device_ordinal) const {
|
||||
Device& device(int device_ordinal) const {
|
||||
return *devices_.at(device_ordinal);
|
||||
}
|
||||
LocalClient* client() const { return client_; }
|
||||
|
@ -267,6 +285,9 @@ class PyLocalBuffer {
|
|||
// Destructures a tuple-valued PyLocalBuffer into its constituent elements.
|
||||
StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>> DestructureTuple();
|
||||
|
||||
// Copies the buffer to device `dst_device_ordinal`.
|
||||
StatusOr<std::unique_ptr<PyLocalBuffer>> CopyToDevice(int dst_device_ordinal);
|
||||
|
||||
// Blocks the host until the buffer's value has been computed and is ready for
|
||||
// immediate use on the device. Useful in particular for timing benchmarks.
|
||||
Status BlockHostUntilReady();
|
||||
|
|
|
@ -299,6 +299,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||
.def_static("from_python", &PyLocalBuffer::FromPython)
|
||||
.def_static("from_python_values", &PyLocalBuffer::FromPythonValues)
|
||||
.def_static("make_tuple", &PyLocalBuffer::MakeTuple)
|
||||
.def("copy_to_device", &PyLocalBuffer::CopyToDevice)
|
||||
.def("delete", &PyLocalBuffer::Delete)
|
||||
.def("destructure", &PyLocalBuffer::DestructureTuple)
|
||||
.def("block_host_until_ready", &PyLocalBuffer::BlockHostUntilReady)
|
||||
|
|
Loading…
Reference in New Issue