[pod-driver] Handle cross-host device to device transfer correctly
PiperOrigin-RevId: 334627211 Change-Id: I13c69527e92dbab4014eb63c2dcab6eec401b4e1
This commit is contained in:
parent
801c87d900
commit
bdf4d6dbf2
@ -406,27 +406,44 @@ class PodTpuDriver : public TpuDriver {
|
||||
std::shared_ptr<Event> TransferFromDeviceToDevice(
|
||||
const BufferHandle* src, BufferHandle* dst,
|
||||
absl::Span<Event* const> wait_for) override {
|
||||
int64_t operation_id = GetOperationId();
|
||||
auto deps = GetDependencyOperationIds(wait_for);
|
||||
deps.insert(static_cast<const PodBufferHandle*>(src)->operation_id());
|
||||
deps.insert(static_cast<PodBufferHandle*>(dst)->operation_id());
|
||||
auto src_core_id = static_cast<const PodBufferHandle*>(src)->core_id();
|
||||
auto dst_core_id = static_cast<PodBufferHandle*>(dst)->core_id();
|
||||
|
||||
auto src_op_id = static_cast<const PodBufferHandle*>(src)->operation_id();
|
||||
auto dst_op_id = static_cast<PodBufferHandle*>(dst)->operation_id();
|
||||
auto core_id = static_cast<PodBufferHandle*>(dst)->core_id();
|
||||
auto src_driver_id = core_to_driver_id_[src_core_id];
|
||||
auto dst_driver_id = core_to_driver_id_[dst_core_id];
|
||||
|
||||
ScheduleRequest(
|
||||
operation_id,
|
||||
[this, src_op_id, dst_op_id, core_id]() {
|
||||
absl::MutexLock l(&mu_);
|
||||
auto src_iter = underlying_buffers_.find(src_op_id);
|
||||
auto dst_iter = underlying_buffers_.find(dst_op_id);
|
||||
return core_to_driver_[core_id]->TransferFromDeviceToDevice(
|
||||
src_iter->second.get(), dst_iter->second.get(), {});
|
||||
},
|
||||
deps);
|
||||
if (src_driver_id == dst_driver_id) {
|
||||
// They are in the same host, we can schedule it normally
|
||||
int64_t operation_id = GetOperationId();
|
||||
auto deps = GetDependencyOperationIds(wait_for);
|
||||
deps.insert(static_cast<const PodBufferHandle*>(src)->operation_id());
|
||||
deps.insert(static_cast<PodBufferHandle*>(dst)->operation_id());
|
||||
|
||||
return std::make_shared<PodEvent>(this, operation_id);
|
||||
auto src_op_id = static_cast<const PodBufferHandle*>(src)->operation_id();
|
||||
auto dst_op_id = static_cast<PodBufferHandle*>(dst)->operation_id();
|
||||
|
||||
ScheduleRequest(
|
||||
operation_id,
|
||||
[this, src_op_id, dst_op_id, dst_core_id]() {
|
||||
absl::MutexLock l(&mu_);
|
||||
auto src_iter = underlying_buffers_.find(src_op_id);
|
||||
auto dst_iter = underlying_buffers_.find(dst_op_id);
|
||||
return core_to_driver_[dst_core_id]->TransferFromDeviceToDevice(
|
||||
src_iter->second.get(), dst_iter->second.get(), {});
|
||||
},
|
||||
deps);
|
||||
return std::make_shared<PodEvent>(this, operation_id);
|
||||
} else {
|
||||
// src and dst are on different hosts, we have to bounce through us.
|
||||
auto dst_size = dst->size_in_bytes();
|
||||
char* host_buf = new char[dst_size];
|
||||
|
||||
auto src_event = TransferFromDevice(src, host_buf, wait_for);
|
||||
auto dst_event = TransferToDevice(host_buf, dst, {src_event.get()});
|
||||
dst_event->AddCallback(
|
||||
[src_event, host_buf](xla::Status status) { delete[] host_buf; });
|
||||
return dst_event;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<CompiledProgramHandle> CompileProgram(
|
||||
|
Loading…
Reference in New Issue
Block a user