[XLA:Python] Add support for direct device-to-device copies.

PiperOrigin-RevId: 254191250
This commit is contained in:
Peter Hawkins 2019-06-20 07:22:45 -07:00 committed by TensorFlower Gardener
parent af4eb9c864
commit 4b59a16a2b
3 changed files with 138 additions and 12 deletions

View File

@ -123,9 +123,16 @@ Device::Device(se::StreamExecutor* executor, bool use_multiple_streams,
host_to_device_stream_->Init();
device_to_host_stream_->Init();
callback_stream_->Init();
device_to_device_streams_.reserve(kNumDeviceToDeviceStreams);
for (int i = 0; i < kNumDeviceToDeviceStreams; ++i) {
auto stream = std::make_shared<se::Stream>(executor);
stream->Init();
device_to_device_streams_.push_back(std::move(stream));
}
} else {
callback_stream_ = host_to_device_stream_ = device_to_host_stream_ =
compute_stream_;
device_to_device_streams_.push_back(compute_stream_);
}
worker_thread_ = absl::make_unique<WorkerThread>(tensorflow::Env::Default(),
"py_xla_execute");
@ -153,12 +160,31 @@ Status Device::SynchronizeAllActivity() {
return status;
}
Status Device::ThenMemcpyDeviceToDevice(se::Stream* src_stream,
se::Stream* dst_stream,
se::DeviceMemoryBase src_buffer,
se::DeviceMemoryBase dst_buffer) {
// The default implementation simply calls ThenMemcpyD2D, and assumes that
// the buffer addresses identify the devices. This does not work
// on all platforms; this method is virtual so it can be overridden.
src_stream->ThenMemcpyD2D(&dst_buffer, src_buffer, dst_buffer.size());
return Status::OK();
}
void Device::ThenExecuteOnWorkerThread(se::Stream* stream,
std::function<void()> callback) const {
stream->ThenDoHostCallback(
[this, callback]() { worker_thread_->Schedule(std::move(callback)); });
}
se::Stream* Device::GetDeviceToDeviceStream() {
absl::MutexLock lock(&mu_);
int i = next_device_to_device_stream_;
next_device_to_device_stream_ =
(next_device_to_device_stream_ + 1) % device_to_device_streams_.size();
return device_to_device_streams_.at(i).get();
}
static StatusOr<std::unique_ptr<se::MultiDeviceAdapter>> CreateBFCAllocator(
se::Platform* platform, LocalClient* client, double memory_fraction,
bool preallocate) {
@ -224,15 +250,30 @@ StatusOr<std::shared_ptr<PyLocalClient>> PyLocalClient::Get(
allocator_config.preallocate));
allocator = std::move(bfc_allocator);
}
std::vector<std::unique_ptr<Device>> devices;
devices.reserve(client->device_count());
bool use_multiple_streams = (platform_name != "cpu");
bool synchronous_deallocation = !use_multiple_streams;
for (int i = 0; i < client->device_count(); ++i) {
se::StreamExecutor* executor =
client->backend().stream_executor(i).ValueOrDie();
devices.push_back(absl::make_unique<Device>(executor, use_multiple_streams,
synchronous_deallocation,
asynchronous));
}
return std::make_shared<PyLocalClient>(platform_name, client,
std::move(devices),
std::move(allocator), asynchronous);
}
PyLocalClient::PyLocalClient(
std::string platform_name, LocalClient* client,
std::vector<std::unique_ptr<Device>> devices,
std::unique_ptr<se::DeviceMemoryAllocator> allocator, bool asynchronous)
: platform_name_(std::move(platform_name)),
client_(client),
devices_(std::move(devices)),
owned_allocator_(std::move(allocator)),
h2d_transfer_pool_(tensorflow::Env::Default(), "py_xla_h2d_transfer",
client->device_count()) {
@ -241,16 +282,6 @@ PyLocalClient::PyLocalClient(
} else {
allocator_ = client_->backend().memory_allocator();
}
devices_.reserve(client->device_count());
bool use_multiple_streams = (platform_name_ != "cpu");
bool synchronous_deallocation = !use_multiple_streams;
for (int i = 0; i < client->device_count(); ++i) {
se::StreamExecutor* executor =
client_->backend().stream_executor(i).ValueOrDie();
devices_.push_back(absl::make_unique<Device>(executor, use_multiple_streams,
synchronous_deallocation,
asynchronous));
}
}
Status PyLocalClient::TransferToInfeed(const LiteralSlice& literal,
@ -577,6 +608,79 @@ PyLocalBuffer::DestructureTuple() {
return results;
}
StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::CopyToDevice(
int dst_device_ordinal) {
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::CopyToDevice");
client_->py_ref_manager().CollectGarbage();
py::gil_scoped_release gil_release;
std::shared_ptr<PySharedDeviceBuffer> src_device_buffer = DeviceBuffer();
if (dst_device_ordinal == device_ordinal_) {
return absl::make_unique<PyLocalBuffer>(on_host_shape_, src_device_buffer,
client_);
}
Device& src_device = client_->device(device_ordinal_);
const Device& dst_device = client_->device(dst_device_ordinal);
se::Stream* src_device_to_device_stream =
src_device.GetDeviceToDeviceStream();
TransferManager* transfer_manager =
client_->client()->backend().transfer_manager();
TF_ASSIGN_OR_RETURN(
ScopedShapedBuffer dst_buffer,
transfer_manager->AllocateScopedShapedBuffer(
on_host_shape_, client_->allocator(), dst_device_ordinal));
if (dst_device.use_multiple_streams() &&
!transfer_manager->CanShapedBufferBeAccessedNow(
dst_device.compute_stream()->parent(), dst_buffer)) {
src_device_to_device_stream->ThenWaitFor(dst_device.compute_stream());
}
TF_ASSIGN_OR_RETURN(ShapedBuffer src_buffer, AsShapedBuffer());
WaitForBufferDefinitionEventsOnStream(*src_device_buffer,
src_device_to_device_stream);
// Copy the leaf buffers.
for (const auto& leaf : src_buffer.buffers().leaves()) {
const xla::ShapeIndex& index = leaf.first;
const se::DeviceMemoryBase& input_buffer = leaf.second;
const se::DeviceMemoryBase& output_buffer = dst_buffer.buffer(index);
TF_RET_CHECK(input_buffer.size() == output_buffer.size())
<< "input: " << input_buffer.size()
<< " output: " << output_buffer.size();
TF_RETURN_IF_ERROR(src_device.ThenMemcpyDeviceToDevice(
src_device_to_device_stream, dst_device.compute_stream(), input_buffer,
output_buffer));
}
// Write new tuple buffers. The destination buffers have different addresses,
// so we must construct tuple buffers from scratch instead of copying them.
if (dst_buffer.on_device_shape().IsTuple()) {
TF_RETURN_IF_ERROR(transfer_manager->WriteTupleIndexTablesAsync(
dst_device.host_to_device_stream(), dst_buffer));
// We need a single definition event, so make the device to device stream
// wait for the stream that wrote the tuple index tables on the destination
// device.
src_device_to_device_stream->ThenWaitFor(
dst_device.host_to_device_stream());
}
std::shared_ptr<BufferDefinitionEvent> definition_event;
if (dst_device.use_multiple_streams()) {
TF_ASSIGN_OR_RETURN(
definition_event,
BufferDefinitionEvent::Create(src_device_to_device_stream->parent()));
definition_event->RecordOnStream(src_device_to_device_stream);
}
std::shared_ptr<PySharedDeviceBuffer> dst_device_buffer =
PySharedDeviceBuffer::FromScopedShapedBuffer(std::move(dst_buffer),
definition_event);
return absl::make_unique<PyLocalBuffer>(
on_host_shape_, std::move(dst_device_buffer), client_);
}
Status PyLocalBuffer::BlockHostUntilReady() {
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::BlockHostUntilReady");
std::shared_ptr<PySharedDeviceBuffer> device_buffer = DeviceBuffer();

View File

@ -62,7 +62,7 @@ class Device {
// each execution or transfer. This is intended for debugging only.
Device(se::StreamExecutor* executor, bool use_multiple_streams,
bool synchronous_deallocation, bool asynchronous);
~Device();
virtual ~Device();
bool use_multiple_streams() const { return use_multiple_streams_; }
bool synchronous_deallocation() const { return synchronous_deallocation_; }
@ -75,6 +75,16 @@ class Device {
return device_to_host_stream_.get();
}
// Returns a device to device stream. Allocates streams in a round-robin
// fashion amongst the available streams.
se::Stream* GetDeviceToDeviceStream();
// Enqueues a copy of `src_buffer` to `dst_buffer` onto `src_stream`.
virtual Status ThenMemcpyDeviceToDevice(se::Stream* src_stream,
se::Stream* dst_stream,
se::DeviceMemoryBase src_buffer,
se::DeviceMemoryBase dst_buffer);
// A worker thread, used for replicated computation launches and callbacks.
WorkerThread* worker_thread() const { return worker_thread_.get(); }
@ -132,6 +142,13 @@ class Device {
std::shared_ptr<se::Stream> compute_stream_;
std::shared_ptr<se::Stream> host_to_device_stream_;
std::shared_ptr<se::Stream> device_to_host_stream_;
std::vector<std::shared_ptr<se::Stream>> device_to_device_streams_;
// Number of device-to-device streams to create in the multistream case.
static constexpr int kNumDeviceToDeviceStreams = 4;
absl::Mutex mu_;
int next_device_to_device_stream_ GUARDED_BY(mu_) = 0;
// Callback stream is used for running short host-side callbacks after device
// side events, without preventing the device-side stream from doing useful
@ -172,6 +189,7 @@ class PyLocalClient {
// `allocator` may null, in which case the platform default allocator is used.
explicit PyLocalClient(std::string platform_name, LocalClient* client,
std::vector<std::unique_ptr<Device>> devices,
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
bool asynchronous);
virtual ~PyLocalClient() = default;
@ -181,7 +199,7 @@ class PyLocalClient {
int device_ordinal);
int device_count() const { return client_->device_count(); }
const Device& device(int device_ordinal) const {
Device& device(int device_ordinal) const {
return *devices_.at(device_ordinal);
}
LocalClient* client() const { return client_; }
@ -267,6 +285,9 @@ class PyLocalBuffer {
// Destructures a tuple-valued PyLocalBuffer into its constituent elements.
StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>> DestructureTuple();
// Copies the buffer to device `dst_device_ordinal`.
StatusOr<std::unique_ptr<PyLocalBuffer>> CopyToDevice(int dst_device_ordinal);
// Blocks the host until the buffer's value has been computed and is ready for
// immediate use on the device. Useful in particular for timing benchmarks.
Status BlockHostUntilReady();

View File

@ -299,6 +299,7 @@ PYBIND11_MODULE(xla_extension, m) {
.def_static("from_python", &PyLocalBuffer::FromPython)
.def_static("from_python_values", &PyLocalBuffer::FromPythonValues)
.def_static("make_tuple", &PyLocalBuffer::MakeTuple)
.def("copy_to_device", &PyLocalBuffer::CopyToDevice)
.def("delete", &PyLocalBuffer::Delete)
.def("destructure", &PyLocalBuffer::DestructureTuple)
.def("block_host_until_ready", &PyLocalBuffer::BlockHostUntilReady)