Avoid deleting a bad Stream in a host callback enqueued on itself, which can cause bad behaviors on some platforms.

PiperOrigin-RevId: 302709536
Change-Id: I4d2ba96e2132a92f90af530ea97a82389788af28
This commit is contained in:
A. Unique TensorFlower 2020-03-24 11:29:52 -07:00 committed by TensorFlower Gardener
parent 2f3651df27
commit 7ae1992664

View File

@ -255,7 +255,7 @@ void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
// before the transfer finishes.
transfer_manager_->TransferLiteralFromDevice(
device_to_host_stream.get(), xla_tensor->shaped_buffer(), literal,
[ref, xla_tensor, done, device_to_host_stream,
[this, ref, xla_tensor, done, device_to_host_stream,
device_allows_sync_on_completion](xla::Status status) {
Status done_status = status;
VLOG(2) << "Transfer from device as literal: "
@ -269,6 +269,19 @@ void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
}
done(done_status);
ref.Unref();
// If a stream is in a bad state, it gets deleted when it's returned to
// the stream pool, i.e. when it leaves this scope. However, a stream
// deleting itself in a host callback on itself can cause bad behaviors
// on some platforms. Releasing it in another stream to avoid that.
if (!device_allows_sync_on_completion &&
!device_to_host_stream->RefreshStatus().ok()) {
auto status_or_new_stream = client_->mutable_backend()->BorrowStream(
stream_->parent()->device_ordinal());
if (status_or_new_stream.ok()) {
status_or_new_stream.ValueOrDie()->ThenDoHostCallback(
[device_to_host_stream] {});
}
}
});
}