Catch the device-side execution stream error earlier when device to host transfer finishes instead of when all subgraph ops have been scheduled(ExecutorState::Finish).
PiperOrigin-RevId: 245492101
This commit is contained in:
parent
38310b582d
commit
85c4d2348a
@ -243,16 +243,25 @@ void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
||||
cpu_tensor, &literal));
|
||||
|
||||
TensorReference ref(*device_tensor);
|
||||
const bool device_allows_sync_on_completion =
|
||||
device->AllowsSyncOnCompletion();
|
||||
// Explicitly capture device_to_host_stream to make sure the stream is alive
|
||||
// before the transfer finishes.
|
||||
transfer_manager_->TransferLiteralFromDevice(
|
||||
device_to_host_stream.get(), xla_tensor->shaped_buffer(), literal,
|
||||
[ref, xla_tensor, done, device_to_host_stream](xla::Status status) {
|
||||
done([&]() -> Status {
|
||||
VLOG(2) << "Transfer from device as literal: "
|
||||
<< xla_tensor->shaped_buffer().ToString();
|
||||
return status;
|
||||
}());
|
||||
[ref, xla_tensor, done, device_to_host_stream,
|
||||
device_allows_sync_on_completion](xla::Status status) {
|
||||
Status done_status = status;
|
||||
VLOG(2) << "Transfer from device as literal: "
|
||||
<< xla_tensor->shaped_buffer().ToString();
|
||||
// For devices don't allow sync on completion, the device execution is
|
||||
// deferred. We check the execution stream status here to avoid wrong
|
||||
// results from a failed stream being propogated to following
|
||||
// host-side ops.
|
||||
if (!device_allows_sync_on_completion) {
|
||||
done_status.Update(xla_tensor->RefreshStatusOfStreams());
|
||||
}
|
||||
done(done_status);
|
||||
ref.Unref();
|
||||
});
|
||||
}
|
||||
|
@ -97,6 +97,15 @@ void XlaTensor::ResetDefinitionEvent(std::shared_ptr<se::Event> event,
|
||||
streams_defined_on_ = {stream};
|
||||
}
|
||||
|
||||
Status XlaTensor::RefreshStatusOfStreams() {
|
||||
mutex_lock lock(mu_);
|
||||
Status status;
|
||||
for (se::Stream* stream : streams_defined_on_) {
|
||||
status.Update(stream->RefreshStatus());
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
// The pointer tag, OR-ed into the XlaTensor's address to distinguish it from
|
||||
// device-side tensors, which are either CPU or GPU memory pointers. This works
|
||||
// because we're guaranteed that CPU and GPU pointers are aligned to > 1 bits.
|
||||
|
@ -102,6 +102,10 @@ class XlaTensor {
|
||||
void ResetDefinitionEvent(std::shared_ptr<se::Event> event,
|
||||
se::Stream* stream);
|
||||
|
||||
// Refresh the status of streams_defined_on_. Return the first not-OK stream's
|
||||
// status or OK.
|
||||
Status RefreshStatusOfStreams();
|
||||
|
||||
// Convert from a raw pointer to an XlaTensor, removing the pointer tag.
|
||||
static XlaTensor* FromOpaquePointer(void* ptr);
|
||||
// Convert to a raw pointer from an XlaTensor, adding the pointer tag.
|
||||
|
Loading…
Reference in New Issue
Block a user