Catch the device-side execution stream error earlier when device to host transfer finishes instead of when all subgraph ops have been scheduled(ExecutorState::Finish).

PiperOrigin-RevId: 245492101
This commit is contained in:
A. Unique TensorFlower 2019-04-26 14:52:16 -07:00 committed by TensorFlower Gardener
parent 38310b582d
commit 85c4d2348a
3 changed files with 28 additions and 6 deletions

View File

@ -243,16 +243,25 @@ void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
cpu_tensor, &literal));
TensorReference ref(*device_tensor);
const bool device_allows_sync_on_completion =
device->AllowsSyncOnCompletion();
// Explicitly capture device_to_host_stream to make sure the stream is alive
// before the transfer finishes.
transfer_manager_->TransferLiteralFromDevice(
device_to_host_stream.get(), xla_tensor->shaped_buffer(), literal,
[ref, xla_tensor, done, device_to_host_stream](xla::Status status) {
done([&]() -> Status {
VLOG(2) << "Transfer from device as literal: "
<< xla_tensor->shaped_buffer().ToString();
return status;
}());
[ref, xla_tensor, done, device_to_host_stream,
device_allows_sync_on_completion](xla::Status status) {
Status done_status = status;
VLOG(2) << "Transfer from device as literal: "
<< xla_tensor->shaped_buffer().ToString();
// For devices don't allow sync on completion, the device execution is
// deferred. We check the execution stream status here to avoid wrong
// results from a failed stream being propogated to following
// host-side ops.
if (!device_allows_sync_on_completion) {
done_status.Update(xla_tensor->RefreshStatusOfStreams());
}
done(done_status);
ref.Unref();
});
}

View File

@ -97,6 +97,15 @@ void XlaTensor::ResetDefinitionEvent(std::shared_ptr<se::Event> event,
streams_defined_on_ = {stream};
}
Status XlaTensor::RefreshStatusOfStreams() {
mutex_lock lock(mu_);
Status status;
for (se::Stream* stream : streams_defined_on_) {
status.Update(stream->RefreshStatus());
}
return status;
}
// The pointer tag, OR-ed into the XlaTensor's address to distinguish it from
// device-side tensors, which are either CPU or GPU memory pointers. This works
// because we're guaranteed that CPU and GPU pointers are aligned to > 1 bits.

View File

@ -102,6 +102,10 @@ class XlaTensor {
void ResetDefinitionEvent(std::shared_ptr<se::Event> event,
se::Stream* stream);
// Refresh the status of streams_defined_on_. Return the first not-OK stream's
// status or OK.
Status RefreshStatusOfStreams();
// Convert from a raw pointer to an XlaTensor, removing the pointer tag.
static XlaTensor* FromOpaquePointer(void* ptr);
// Convert to a raw pointer from an XlaTensor, adding the pointer tag.