Fix a TransferFromDeviceToDevice race in the client.
PiperOrigin-RevId: 281640160 Change-Id: Ic5043ef5e2bfe66d554987bce5f931d47ce39cc7
This commit is contained in:
parent
162cd54855
commit
472804bd1f
@ -380,14 +380,17 @@ StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::CopyToDevice(
|
||||
}
|
||||
|
||||
tpu_driver::TpuDriver* driver = client_->driver();
|
||||
tpu_driver::BufferHandle* src_handle = src_device_buffer->handle.get();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<PyTpuBuffer> dst_buffer,
|
||||
CreateBuffer(
|
||||
on_host_shape_,
|
||||
[driver, src_handle](tpu_driver::BufferHandle* dst_handle) {
|
||||
return driver->TransferFromDeviceToDevice(src_handle, dst_handle,
|
||||
{});
|
||||
[driver, src_device_buffer](tpu_driver::BufferHandle* dst_handle) {
|
||||
std::vector<tpu_driver::Event*> src_wait_for_use;
|
||||
for (auto& event : src_device_buffer->wait_for_use) {
|
||||
src_wait_for_use.push_back(event.get());
|
||||
}
|
||||
return driver->TransferFromDeviceToDevice(
|
||||
src_device_buffer->handle.get(), dst_handle, src_wait_for_use);
|
||||
},
|
||||
client_, dst_device_ordinal));
|
||||
// TODO(jiawenhao): This may be too pessimistic: it prevents future readers
|
||||
|
Loading…
Reference in New Issue
Block a user