diff --git a/tensorflow/compiler/xla/python/local_client.cc b/tensorflow/compiler/xla/python/local_client.cc index d7f27d00f76..2057788c12d 100644 --- a/tensorflow/compiler/xla/python/local_client.cc +++ b/tensorflow/compiler/xla/python/local_client.cc @@ -652,6 +652,10 @@ StatusOr> PyLocalBuffer::CopyToDevice( output_buffer)); } + // We hold on to the `src_device_buffer` until the transfer is finished. + src_device.ThenRelease(src_device_to_device_stream, + std::move(src_device_buffer)); + // Write new tuple buffers. The destination buffers have different addresses, // so we must construct tuple buffers from scratch instead of copying them. if (dst_buffer.on_device_shape().IsTuple()) {