[pod-driver] Handle cross-host device to device transfer correctly

PiperOrigin-RevId: 334627211
Change-Id: I13c69527e92dbab4014eb63c2dcab6eec401b4e1
This commit is contained in:
Frank Chen 2020-09-30 10:23:57 -07:00 committed by TensorFlower Gardener
parent 801c87d900
commit bdf4d6dbf2

View File

@ -406,27 +406,44 @@ class PodTpuDriver : public TpuDriver {
std::shared_ptr<Event> TransferFromDeviceToDevice( std::shared_ptr<Event> TransferFromDeviceToDevice(
const BufferHandle* src, BufferHandle* dst, const BufferHandle* src, BufferHandle* dst,
absl::Span<Event* const> wait_for) override { absl::Span<Event* const> wait_for) override {
int64_t operation_id = GetOperationId(); auto src_core_id = static_cast<const PodBufferHandle*>(src)->core_id();
auto deps = GetDependencyOperationIds(wait_for); auto dst_core_id = static_cast<PodBufferHandle*>(dst)->core_id();
deps.insert(static_cast<const PodBufferHandle*>(src)->operation_id());
deps.insert(static_cast<PodBufferHandle*>(dst)->operation_id());
auto src_op_id = static_cast<const PodBufferHandle*>(src)->operation_id(); auto src_driver_id = core_to_driver_id_[src_core_id];
auto dst_op_id = static_cast<PodBufferHandle*>(dst)->operation_id(); auto dst_driver_id = core_to_driver_id_[dst_core_id];
auto core_id = static_cast<PodBufferHandle*>(dst)->core_id();
ScheduleRequest( if (src_driver_id == dst_driver_id) {
operation_id, // They are in the same host, we can schedule it normally
[this, src_op_id, dst_op_id, core_id]() { int64_t operation_id = GetOperationId();
absl::MutexLock l(&mu_); auto deps = GetDependencyOperationIds(wait_for);
auto src_iter = underlying_buffers_.find(src_op_id); deps.insert(static_cast<const PodBufferHandle*>(src)->operation_id());
auto dst_iter = underlying_buffers_.find(dst_op_id); deps.insert(static_cast<PodBufferHandle*>(dst)->operation_id());
return core_to_driver_[core_id]->TransferFromDeviceToDevice(
src_iter->second.get(), dst_iter->second.get(), {});
},
deps);
return std::make_shared<PodEvent>(this, operation_id); auto src_op_id = static_cast<const PodBufferHandle*>(src)->operation_id();
auto dst_op_id = static_cast<PodBufferHandle*>(dst)->operation_id();
ScheduleRequest(
operation_id,
[this, src_op_id, dst_op_id, dst_core_id]() {
absl::MutexLock l(&mu_);
auto src_iter = underlying_buffers_.find(src_op_id);
auto dst_iter = underlying_buffers_.find(dst_op_id);
return core_to_driver_[dst_core_id]->TransferFromDeviceToDevice(
src_iter->second.get(), dst_iter->second.get(), {});
},
deps);
return std::make_shared<PodEvent>(this, operation_id);
} else {
// src and dst are on different hosts, we have to bounce through us.
auto dst_size = dst->size_in_bytes();
char* host_buf = new char[dst_size];
auto src_event = TransferFromDevice(src, host_buf, wait_for);
auto dst_event = TransferToDevice(host_buf, dst, {src_event.get()});
dst_event->AddCallback(
[src_event, host_buf](xla::Status status) { delete[] host_buf; });
return dst_event;
}
} }
std::unique_ptr<CompiledProgramHandle> CompileProgram( std::unique_ptr<CompiledProgramHandle> CompileProgram(