Avoid crashing in LocalDeviceState::ReturnStreamToPool and BorrowStreamFromPool when stream fails with "ABORTED: Bad connection".

PiperOrigin-RevId: 355324168
Change-Id: Id5c99a418733ad394af4f5a99aaca398615e986b
This commit is contained in:
A. Unique TensorFlower 2021-02-02 22:06:41 -08:00 committed by TensorFlower Gardener
parent ddb089b72f
commit bd14bb7cf4
2 changed files with 12 additions and 4 deletions
tensorflow/compiler/xla/pjrt

View File

@ -121,6 +121,7 @@ cc_library(
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/core:lib",
"//tensorflow/core/platform:stream_executor",
"//tensorflow/core/protobuf:error_codes_proto_impl_cc",
"//tensorflow/stream_executor:event",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization",

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "absl/synchronization/mutex.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/stream_executor/stream.h"
namespace xla {
@ -127,15 +128,21 @@ std::unique_ptr<se::Stream> LocalDeviceState::BorrowStreamFromPool() {
} else {
std::unique_ptr<se::Stream> stream = std::move(usage_stream_pool_.top());
usage_stream_pool_.pop();
stream->RefreshStatus().IgnoreError(); // Can return error::Unimplemented
QCHECK(stream->ok());
auto status = stream->RefreshStatus(); // Can return error::Unimplemented
// Stream may fail with "ABORTED: Bad connection".
if (status.code() != tensorflow::error::ABORTED) {
CHECK(stream->ok()) << status;
}
return stream;
}
}
void LocalDeviceState::ReturnStreamToPool(std::unique_ptr<se::Stream> stream) {
stream->RefreshStatus().IgnoreError(); // Can return error::Unimplemented
QCHECK(stream->ok());
auto status = stream->RefreshStatus(); // Can return error::Unimplemented
// Stream may fail with "ABORTED: Bad connection".
if (status.code() != tensorflow::error::ABORTED) {
CHECK(stream->ok()) << status;
}
absl::MutexLock lock(&mu_);
usage_stream_pool_.push(std::move(stream));
}