[pjrt] Refresh stream error status in strategic places to flush out silent failures.

PiperOrigin-RevId: 317204018
Change-Id: If75a3ad9ec846ce1621cdba92a2dc738b65b7001
This commit is contained in:
A. Unique TensorFlower 2020-06-18 16:24:14 -07:00 committed by TensorFlower Gardener
parent 4aea552e06
commit aab151356d
3 changed files with 19 additions and 4 deletions

View File

@ -127,11 +127,15 @@ std::unique_ptr<se::Stream> LocalDeviceState::BorrowStreamFromPool() {
} else {
std::unique_ptr<se::Stream> stream = std::move(usage_stream_pool_.top());
usage_stream_pool_.pop();
stream->RefreshStatus().IgnoreError(); // Can return error::Unimplemented
QCHECK(stream->ok());
return stream;
}
}
void LocalDeviceState::ReturnStreamToPool(std::unique_ptr<se::Stream> stream) {
stream->RefreshStatus().IgnoreError(); // Can return error::Unimplemented
QCHECK(stream->ok());
absl::MutexLock lock(&mu_);
usage_stream_pool_.push(std::move(stream));
}

View File

@ -751,16 +751,22 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
// memory that has already been allocated, and a possible Event
// allocation.
se::Stream* h2d_stream = local_device->host_to_device_stream();
ShapedBuffer buffer = device_buffer->AsShapedBuffer(
compact_shape, on_device_shape, client->client()->platform());
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
local_device->host_to_device_stream(), literal, buffer));
h2d_stream, literal, buffer));
std::shared_ptr<BufferSequencingEvent> event =
device_buffer->definition_events()[0];
TF_CHECK_OK(AddDestinationBufferSynchronization(
local_device, std::move(device_buffer), event,
local_device->host_to_device_stream()));
local_device, std::move(device_buffer), event, h2d_stream));
// This can sometimes catch the case where the literal memory has been
// freed before the H2D transfer was issued.
h2d_stream->RefreshStatus()
.IgnoreError(); // Can return error::Unimplemented
QCHECK(h2d_stream->ok());
};
client->h2d_transfer_pool()->Schedule(transfer_h2d);
return py_buffer;

View File

@ -285,7 +285,12 @@ Stream::~Stream() {
port::Status Stream::RefreshStatus() {
port::Status status = parent_->GetStatus(this);
CheckStatus(status);
// We should not put the stream in an error state, just because the GetStatus
// method is unimplemented.
if (status != port::Status(port::error::UNIMPLEMENTED,
"GetStatus is not supported on this executor.")) {
CheckStatus(status);
}
return status;
}