Catch the device-side execution stream error earlier when device to host transfer finishes instead of when all subgraph ops have been scheduled(ExecutorState::Finish).
PiperOrigin-RevId: 245492101
This commit is contained in:
parent
38310b582d
commit
85c4d2348a
@ -243,16 +243,25 @@ void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
|||||||
cpu_tensor, &literal));
|
cpu_tensor, &literal));
|
||||||
|
|
||||||
TensorReference ref(*device_tensor);
|
TensorReference ref(*device_tensor);
|
||||||
|
const bool device_allows_sync_on_completion =
|
||||||
|
device->AllowsSyncOnCompletion();
|
||||||
// Explicitly capture device_to_host_stream to make sure the stream is alive
|
// Explicitly capture device_to_host_stream to make sure the stream is alive
|
||||||
// before the transfer finishes.
|
// before the transfer finishes.
|
||||||
transfer_manager_->TransferLiteralFromDevice(
|
transfer_manager_->TransferLiteralFromDevice(
|
||||||
device_to_host_stream.get(), xla_tensor->shaped_buffer(), literal,
|
device_to_host_stream.get(), xla_tensor->shaped_buffer(), literal,
|
||||||
[ref, xla_tensor, done, device_to_host_stream](xla::Status status) {
|
[ref, xla_tensor, done, device_to_host_stream,
|
||||||
done([&]() -> Status {
|
device_allows_sync_on_completion](xla::Status status) {
|
||||||
VLOG(2) << "Transfer from device as literal: "
|
Status done_status = status;
|
||||||
<< xla_tensor->shaped_buffer().ToString();
|
VLOG(2) << "Transfer from device as literal: "
|
||||||
return status;
|
<< xla_tensor->shaped_buffer().ToString();
|
||||||
}());
|
// For devices don't allow sync on completion, the device execution is
|
||||||
|
// deferred. We check the execution stream status here to avoid wrong
|
||||||
|
// results from a failed stream being propogated to following
|
||||||
|
// host-side ops.
|
||||||
|
if (!device_allows_sync_on_completion) {
|
||||||
|
done_status.Update(xla_tensor->RefreshStatusOfStreams());
|
||||||
|
}
|
||||||
|
done(done_status);
|
||||||
ref.Unref();
|
ref.Unref();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -97,6 +97,15 @@ void XlaTensor::ResetDefinitionEvent(std::shared_ptr<se::Event> event,
|
|||||||
streams_defined_on_ = {stream};
|
streams_defined_on_ = {stream};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status XlaTensor::RefreshStatusOfStreams() {
|
||||||
|
mutex_lock lock(mu_);
|
||||||
|
Status status;
|
||||||
|
for (se::Stream* stream : streams_defined_on_) {
|
||||||
|
status.Update(stream->RefreshStatus());
|
||||||
|
}
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
// The pointer tag, OR-ed into the XlaTensor's address to distinguish it from
|
// The pointer tag, OR-ed into the XlaTensor's address to distinguish it from
|
||||||
// device-side tensors, which are either CPU or GPU memory pointers. This works
|
// device-side tensors, which are either CPU or GPU memory pointers. This works
|
||||||
// because we're guaranteed that CPU and GPU pointers are aligned to > 1 bits.
|
// because we're guaranteed that CPU and GPU pointers are aligned to > 1 bits.
|
||||||
|
@ -102,6 +102,10 @@ class XlaTensor {
|
|||||||
void ResetDefinitionEvent(std::shared_ptr<se::Event> event,
|
void ResetDefinitionEvent(std::shared_ptr<se::Event> event,
|
||||||
se::Stream* stream);
|
se::Stream* stream);
|
||||||
|
|
||||||
|
// Refresh the status of streams_defined_on_. Return the first not-OK stream's
|
||||||
|
// status or OK.
|
||||||
|
Status RefreshStatusOfStreams();
|
||||||
|
|
||||||
// Convert from a raw pointer to an XlaTensor, removing the pointer tag.
|
// Convert from a raw pointer to an XlaTensor, removing the pointer tag.
|
||||||
static XlaTensor* FromOpaquePointer(void* ptr);
|
static XlaTensor* FromOpaquePointer(void* ptr);
|
||||||
// Convert to a raw pointer from an XlaTensor, adding the pointer tag.
|
// Convert to a raw pointer from an XlaTensor, adding the pointer tag.
|
||||||
|
Loading…
Reference in New Issue
Block a user