[pjrt] Refresh stream error status in strategic places to flush out silent failures.
PiperOrigin-RevId: 317204018 Change-Id: If75a3ad9ec846ce1621cdba92a2dc738b65b7001
This commit is contained in:
parent
4aea552e06
commit
aab151356d
@ -127,11 +127,15 @@ std::unique_ptr<se::Stream> LocalDeviceState::BorrowStreamFromPool() {
|
||||
} else {
|
||||
std::unique_ptr<se::Stream> stream = std::move(usage_stream_pool_.top());
|
||||
usage_stream_pool_.pop();
|
||||
stream->RefreshStatus().IgnoreError(); // Can return error::Unimplemented
|
||||
QCHECK(stream->ok());
|
||||
return stream;
|
||||
}
|
||||
}
|
||||
|
||||
void LocalDeviceState::ReturnStreamToPool(std::unique_ptr<se::Stream> stream) {
|
||||
stream->RefreshStatus().IgnoreError(); // Can return error::Unimplemented
|
||||
QCHECK(stream->ok());
|
||||
absl::MutexLock lock(&mu_);
|
||||
usage_stream_pool_.push(std::move(stream));
|
||||
}
|
||||
|
@ -751,16 +751,22 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
|
||||
// memory that has already been allocated, and a possible Event
|
||||
// allocation.
|
||||
|
||||
se::Stream* h2d_stream = local_device->host_to_device_stream();
|
||||
ShapedBuffer buffer = device_buffer->AsShapedBuffer(
|
||||
compact_shape, on_device_shape, client->client()->platform());
|
||||
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
|
||||
local_device->host_to_device_stream(), literal, buffer));
|
||||
h2d_stream, literal, buffer));
|
||||
|
||||
std::shared_ptr<BufferSequencingEvent> event =
|
||||
device_buffer->definition_events()[0];
|
||||
TF_CHECK_OK(AddDestinationBufferSynchronization(
|
||||
local_device, std::move(device_buffer), event,
|
||||
local_device->host_to_device_stream()));
|
||||
local_device, std::move(device_buffer), event, h2d_stream));
|
||||
|
||||
// This can sometimes catch the case where the literal memory has been
|
||||
// freed before the H2D transfer was issued.
|
||||
h2d_stream->RefreshStatus()
|
||||
.IgnoreError(); // Can return error::Unimplemented
|
||||
QCHECK(h2d_stream->ok());
|
||||
};
|
||||
client->h2d_transfer_pool()->Schedule(transfer_h2d);
|
||||
return py_buffer;
|
||||
|
@ -285,7 +285,12 @@ Stream::~Stream() {
|
||||
|
||||
port::Status Stream::RefreshStatus() {
|
||||
port::Status status = parent_->GetStatus(this);
|
||||
CheckStatus(status);
|
||||
// We should not put the stream in an error state, just because the GetStatus
|
||||
// method is unimplemented.
|
||||
if (status != port::Status(port::error::UNIMPLEMENTED,
|
||||
"GetStatus is not supported on this executor.")) {
|
||||
CheckStatus(status);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user