From 85c4d2348af5ce87ceb17ff8f4f54d4eb6d3630a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 26 Apr 2019 14:52:16 -0700 Subject: [PATCH] Catch the device-side execution stream error earlier when device to host transfer finishes instead of when all subgraph ops have been scheduled(ExecutorState::Finish). PiperOrigin-RevId: 245492101 --- tensorflow/compiler/jit/xla_device_context.cc | 21 +++++++++++++------ tensorflow/compiler/jit/xla_tensor.cc | 9 ++++++++ tensorflow/compiler/jit/xla_tensor.h | 4 ++++ 3 files changed, 28 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 80bbdea4dfe..4adf55260ea 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -243,16 +243,25 @@ void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, cpu_tensor, &literal)); TensorReference ref(*device_tensor); + const bool device_allows_sync_on_completion = + device->AllowsSyncOnCompletion(); // Explicitly capture device_to_host_stream to make sure the stream is alive // before the transfer finishes. transfer_manager_->TransferLiteralFromDevice( device_to_host_stream.get(), xla_tensor->shaped_buffer(), literal, - [ref, xla_tensor, done, device_to_host_stream](xla::Status status) { - done([&]() -> Status { - VLOG(2) << "Transfer from device as literal: " - << xla_tensor->shaped_buffer().ToString(); - return status; - }()); + [ref, xla_tensor, done, device_to_host_stream, + device_allows_sync_on_completion](xla::Status status) { + Status done_status = status; + VLOG(2) << "Transfer from device as literal: " + << xla_tensor->shaped_buffer().ToString(); + // For devices don't allow sync on completion, the device execution is + // deferred. We check the execution stream status here to avoid wrong + // results from a failed stream being propogated to following + // host-side ops. + if (!device_allows_sync_on_completion) { + done_status.Update(xla_tensor->RefreshStatusOfStreams()); + } + done(done_status); ref.Unref(); }); } diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc index d1f7f754c83..b92bd675378 100644 --- a/tensorflow/compiler/jit/xla_tensor.cc +++ b/tensorflow/compiler/jit/xla_tensor.cc @@ -97,6 +97,15 @@ void XlaTensor::ResetDefinitionEvent(std::shared_ptr event, streams_defined_on_ = {stream}; } +Status XlaTensor::RefreshStatusOfStreams() { + mutex_lock lock(mu_); + Status status; + for (se::Stream* stream : streams_defined_on_) { + status.Update(stream->RefreshStatus()); + } + return status; +} + // The pointer tag, OR-ed into the XlaTensor's address to distinguish it from // device-side tensors, which are either CPU or GPU memory pointers. This works // because we're guaranteed that CPU and GPU pointers are aligned to > 1 bits. diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h index 77e80aa2527..8a4eb7493be 100644 --- a/tensorflow/compiler/jit/xla_tensor.h +++ b/tensorflow/compiler/jit/xla_tensor.h @@ -102,6 +102,10 @@ class XlaTensor { void ResetDefinitionEvent(std::shared_ptr event, se::Stream* stream); + // Refresh the status of streams_defined_on_. Return the first not-OK stream's + // status or OK. + Status RefreshStatusOfStreams(); + // Convert from a raw pointer to an XlaTensor, removing the pointer tag. static XlaTensor* FromOpaquePointer(void* ptr); // Convert to a raw pointer from an XlaTensor, adding the pointer tag.