From bdf4d6dbf24e6bf9bae7eca74f537a33826180e5 Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Wed, 30 Sep 2020 10:23:57 -0700 Subject: [PATCH] [pod-driver] Handle cross-host device to device transfer correctly PiperOrigin-RevId: 334627211 Change-Id: I13c69527e92dbab4014eb63c2dcab6eec401b4e1 --- .../xla/python/tpu_driver/pod_tpu_driver.cc | 53 ++++++++++++------- 1 file changed, 35 insertions(+), 18 deletions(-) diff --git a/tensorflow/compiler/xla/python/tpu_driver/pod_tpu_driver.cc b/tensorflow/compiler/xla/python/tpu_driver/pod_tpu_driver.cc index ac54df39895..cb1647832f7 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/pod_tpu_driver.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/pod_tpu_driver.cc @@ -406,27 +406,44 @@ class PodTpuDriver : public TpuDriver { std::shared_ptr TransferFromDeviceToDevice( const BufferHandle* src, BufferHandle* dst, absl::Span wait_for) override { - int64_t operation_id = GetOperationId(); - auto deps = GetDependencyOperationIds(wait_for); - deps.insert(static_cast(src)->operation_id()); - deps.insert(static_cast(dst)->operation_id()); + auto src_core_id = static_cast(src)->core_id(); + auto dst_core_id = static_cast(dst)->core_id(); - auto src_op_id = static_cast(src)->operation_id(); - auto dst_op_id = static_cast(dst)->operation_id(); - auto core_id = static_cast(dst)->core_id(); + auto src_driver_id = core_to_driver_id_[src_core_id]; + auto dst_driver_id = core_to_driver_id_[dst_core_id]; - ScheduleRequest( - operation_id, - [this, src_op_id, dst_op_id, core_id]() { - absl::MutexLock l(&mu_); - auto src_iter = underlying_buffers_.find(src_op_id); - auto dst_iter = underlying_buffers_.find(dst_op_id); - return core_to_driver_[core_id]->TransferFromDeviceToDevice( - src_iter->second.get(), dst_iter->second.get(), {}); - }, - deps); + if (src_driver_id == dst_driver_id) { + // They are in the same host, we can schedule it normally + int64_t operation_id = GetOperationId(); + auto deps = GetDependencyOperationIds(wait_for); + deps.insert(static_cast(src)->operation_id()); + deps.insert(static_cast(dst)->operation_id()); - return std::make_shared(this, operation_id); + auto src_op_id = static_cast(src)->operation_id(); + auto dst_op_id = static_cast(dst)->operation_id(); + + ScheduleRequest( + operation_id, + [this, src_op_id, dst_op_id, dst_core_id]() { + absl::MutexLock l(&mu_); + auto src_iter = underlying_buffers_.find(src_op_id); + auto dst_iter = underlying_buffers_.find(dst_op_id); + return core_to_driver_[dst_core_id]->TransferFromDeviceToDevice( + src_iter->second.get(), dst_iter->second.get(), {}); + }, + deps); + return std::make_shared(this, operation_id); + } else { + // src and dst are on different hosts, we have to bounce through us. + auto dst_size = dst->size_in_bytes(); + char* host_buf = new char[dst_size]; + + auto src_event = TransferFromDevice(src, host_buf, wait_for); + auto dst_event = TransferToDevice(host_buf, dst, {src_event.get()}); + dst_event->AddCallback( + [src_event, host_buf](xla::Status status) { delete[] host_buf; }); + return dst_event; + } } std::unique_ptr CompileProgram(