Drop failed sub-streams during both Get and Return.

The old code ensured that failed sub-streams would not be re-used, but
had two flaws:
1) It only checked for failed sub-streams during Return.
2) It didn't actually remove the failed sub-streams from our state.

The new code fixes these two flaws, and adds an extra test that
explains why (1) is insufficient.

PiperOrigin-RevId: 207333296
This commit is contained in:
Todd Wang 2018-08-03 15:21:58 -07:00 committed by TensorFlower Gardener
parent 5de6d11b0b
commit 9cdcb0397c
4 changed files with 219 additions and 73 deletions

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/compiler/xla/service/stream_pool.h"
#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/core/platform/logging.h"
namespace xla { namespace xla {
@ -27,6 +28,8 @@ StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor) {
// Re-use an existing stream from the pool. // Re-use an existing stream from the pool.
stream = std::move(streams_.back()); stream = std::move(streams_.back());
streams_.pop_back(); streams_.pop_back();
VLOG(1) << stream->DebugStreamPointers()
<< " StreamPool reusing existing stream";
} }
} }
@ -34,6 +37,8 @@ StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor) {
// Create a new stream. // Create a new stream.
stream = MakeUnique<se::Stream>(executor); stream = MakeUnique<se::Stream>(executor);
stream->Init(); stream->Init();
VLOG(1) << stream->DebugStreamPointers()
<< " StreamPool created new stream";
} }
// Return the stream wrapped in Ptr, which has our special deleter semantics. // Return the stream wrapped in Ptr, which has our special deleter semantics.
@ -43,12 +48,16 @@ StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor) {
void StreamPool::ReturnStream(se::Stream* stream) { void StreamPool::ReturnStream(se::Stream* stream) {
if (stream->ok()) { if (stream->ok()) {
VLOG(1) << stream->DebugStreamPointers()
<< " StreamPool returning ok stream";
tensorflow::mutex_lock lock(mu_); tensorflow::mutex_lock lock(mu_);
streams_.emplace_back(stream); streams_.emplace_back(stream);
} else { } else {
// If the stream has encountered any errors, all subsequent // If the stream has encountered any errors, all subsequent operations on it
// operations on it will fail. So just delete the stream, and rely // will fail. So just delete the stream, and rely on new streams to be
// on new streams to be created in the future. // created in the future.
VLOG(1) << stream->DebugStreamPointers()
<< " StreamPool deleting !ok stream";
delete stream; delete stream;
} }
} }

View File

@ -115,7 +115,7 @@ string ToVlogString(const DeviceMemoryBase &memory) {
} }
string ToVlogString(const DeviceMemoryBase *memory) { string ToVlogString(const DeviceMemoryBase *memory) {
return ToVlogString(*memory); return memory == nullptr ? "null" : ToVlogString(*memory);
} }
string ToVlogString(const Eigen::half &h) { string ToVlogString(const Eigen::half &h) {
@ -211,13 +211,14 @@ string CallStr(const char *function_name, Stream *stream,
// constructing all the strings in params is expensive. // constructing all the strings in params is expensive.
CHECK(VLOG_IS_ON(1)); CHECK(VLOG_IS_ON(1));
string str = port::StrCat("Called Stream::", function_name, "("); string str = port::StrCat(stream->DebugStreamPointers(),
" Called Stream::", function_name, "(");
const char *separator = ""; const char *separator = "";
for (const auto &param : params) { for (const auto &param : params) {
port::StrAppend(&str, separator, param.first, "=", param.second); port::StrAppend(&str, separator, param.first, "=", param.second);
separator = ", "; separator = ", ";
} }
port::StrAppend(&str, ") stream=", ToVlogString(stream)); port::StrAppend(&str, ")");
if (VLOG_IS_ON(10)) { if (VLOG_IS_ON(10)) {
port::StrAppend(&str, " ", port::CurrentStackTrace(), "\n"); port::StrAppend(&str, " ", port::CurrentStackTrace(), "\n");
} }
@ -1922,37 +1923,82 @@ Stream &Stream::ThenCopyDevice2HostBuffer(
Stream *Stream::GetOrCreateSubStream() { Stream *Stream::GetOrCreateSubStream() {
mutex_lock lock(mu_); mutex_lock lock(mu_);
for (auto &stream : sub_streams_) {
if (stream.second) { // Look for the first reusable sub_stream that is ok, dropping !ok sub_streams
stream.second = false; // we encounter along the way.
return stream.first.get(); for (int64 index = 0; index < sub_streams_.size();) {
std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index];
if (pair.second) {
// The sub_stream is reusable.
Stream *sub_stream = pair.first.get();
if (sub_stream->ok()) {
VLOG(1) << DebugStreamPointers() << " reusing sub_stream "
<< sub_stream->DebugStreamPointers();
pair.second = false;
return sub_stream;
}
// The stream is reusable and not ok. Streams have a monotonic state
// machine; the stream will remain in !ok forever. Swap it with the last
// stream and pop it off.
const int64 last = sub_streams_.size() - 1;
if (index != last) {
std::swap(pair, sub_streams_[last]);
}
sub_streams_.pop_back();
VLOG(1) << DebugStreamPointers() << " dropped !ok sub_stream "
<< sub_stream->DebugStreamPointers();
} else {
// The sub_stream is not reusable, move on to the next one.
++index;
} }
} }
// No streams are reusable; create a new stream.
sub_streams_.emplace_back(std::unique_ptr<Stream>{new Stream{parent_}}, sub_streams_.emplace_back(std::unique_ptr<Stream>{new Stream{parent_}},
false); false);
Stream *sub_stream = sub_streams_.back().first.get(); Stream *sub_stream = sub_streams_.back().first.get();
sub_stream->Init(); sub_stream->Init();
CHECK(ok_) << "sub-stream failed to be initialized"; CHECK(ok_) << "sub-stream failed to be initialized";
VLOG(1) << DebugStreamPointers() << " created new sub_stream "
<< sub_stream->DebugStreamPointers();
return sub_stream; return sub_stream;
} }
void Stream::ReturnSubStream(Stream *sub_stream) { void Stream::ReturnSubStream(Stream *sub_stream) {
mutex_lock lock(mu_); mutex_lock lock(mu_);
for (auto &stream : sub_streams_) {
if (stream.first.get() == sub_stream) { // Look for the sub-stream.
// Streams have a monotonic state machine; if a stream for (int64 index = 0; index < sub_streams_.size(); ++index) {
// encounters an error, it will remain in an error state std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index];
// forever. Only allow re-use of ok streams. if (pair.first.get() != sub_stream) {
// continue;
// TODO(toddw): Improve this mechanism, if necessary, to drop }
// failed streams completely.
const bool ready_to_reuse = sub_stream->ok(); // Found the sub_stream.
stream.second = ready_to_reuse; if (sub_stream->ok()) {
VLOG(1) << DebugStreamPointers() << " returned ok sub_stream "
<< sub_stream->DebugStreamPointers();
pair.second = true;
} else {
// The returned stream is not ok. Streams have a monotonic state
// machine; the stream will remain in !ok forever. Swap it with the last
// stream and pop it off.
VLOG(1) << DebugStreamPointers() << " returned !ok sub_stream "
<< sub_stream->DebugStreamPointers();
const int64 last = sub_streams_.size() - 1;
if (index != last) {
std::swap(pair, sub_streams_[last]);
}
sub_streams_.pop_back();
}
return; return;
} }
}
LOG(FATAL) << "the sub-stream to be returned is not created by this stream"; LOG(FATAL) << DebugStreamPointers()
<< " did not create the returned sub-stream "
<< sub_stream->DebugStreamPointers();
} }
Stream &Stream::ThenStartTimer(Timer *t) { Stream &Stream::ThenStartTimer(Timer *t) {
@ -1961,7 +2007,8 @@ Stream &Stream::ThenStartTimer(Timer *t) {
if (ok()) { if (ok()) {
CheckError(parent_->StartTimer(this, t)); CheckError(parent_->StartTimer(this, t));
} else { } else {
LOG(INFO) << "stream " << this << " did not enqueue 'start timer': " << t; LOG(INFO) << DebugStreamPointers()
<< " did not enqueue 'start timer': " << t;
} }
return *this; return *this;
} }
@ -1972,7 +2019,8 @@ Stream &Stream::ThenStopTimer(Timer *t) {
if (ok()) { if (ok()) {
CheckError(parent_->StopTimer(this, t)); CheckError(parent_->StopTimer(this, t));
} else { } else {
LOG(INFO) << "stream " << this << " did not enqueue 'stop timer': " << t; LOG(INFO) << DebugStreamPointers()
<< " did not enqueue 'stop timer': " << t;
} }
return *this; return *this;
} }
@ -1985,7 +2033,8 @@ Stream &Stream::ThenWaitFor(Stream *other) {
CheckError(parent_->CreateStreamDependency(this, other)); CheckError(parent_->CreateStreamDependency(this, other));
} else { } else {
SetError(); SetError();
LOG(INFO) << "stream " << this << " did not wait for stream: " << other; LOG(INFO) << DebugStreamPointers() << " did not wait for "
<< other->DebugStreamPointers();
} }
return *this; return *this;
} }
@ -2002,7 +2051,7 @@ Stream &Stream::ThenWaitFor(Event *event) {
<< "at fault. Monitor for further errors."; << "at fault. Monitor for further errors.";
} }
} else { } else {
LOG(INFO) << "stream " << this << " did not wait for an event."; LOG(INFO) << DebugStreamPointers() << " did not wait for an event.";
} }
return *this; return *this;
} }
@ -4802,10 +4851,10 @@ Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) {
CheckError(rng->SetSeed(this, seed, seed_bytes)); CheckError(rng->SetSeed(this, seed, seed_bytes));
} else { } else {
SetError(); SetError();
LOG(INFO) << "stream " << this << " unable to initialize RNG"; LOG(INFO) << DebugStreamPointers() << " unable to initialize RNG";
} }
} else { } else {
LOG(INFO) << "stream " << this LOG(INFO) << DebugStreamPointers()
<< " did not set RNG seed: " << static_cast<const void *>(seed) << " did not set RNG seed: " << static_cast<const void *>(seed)
<< "; bytes: " << seed_bytes; << "; bytes: " << seed_bytes;
} }
@ -4820,7 +4869,8 @@ Stream &Stream::ThenPopulateRandUniform(DeviceMemory<float> *values) {
CheckError(rng->DoPopulateRandUniform(this, values)); CheckError(rng->DoPopulateRandUniform(this, values));
} else { } else {
SetError(); SetError();
LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " LOG(INFO) << DebugStreamPointers()
<< " attempting to perform RNG operation using StreamExecutor"
" without RNG support."; " without RNG support.";
} }
} }
@ -4836,7 +4886,8 @@ Stream &Stream::ThenPopulateRandGaussian(float mean, float sd,
CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values)); CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
} else { } else {
SetError(); SetError();
LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " LOG(INFO) << DebugStreamPointers()
<< " attempting to perform RNG operation using StreamExecutor"
" without RNG support."; " without RNG support.";
} }
} }
@ -4852,7 +4903,8 @@ Stream &Stream::ThenPopulateRandGaussian(double mean, double sd,
CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values)); CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
} else { } else {
SetError(); SetError();
LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " LOG(INFO) << DebugStreamPointers()
<< " attempting to perform RNG operation using StreamExecutor"
" without RNG support."; " without RNG support.";
} }
} }
@ -4867,7 +4919,8 @@ Stream &Stream::ThenPopulateRandUniform(DeviceMemory<double> *values) {
CheckError(rng->DoPopulateRandUniform(this, values)); CheckError(rng->DoPopulateRandUniform(this, values));
} else { } else {
SetError(); SetError();
LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " LOG(INFO) << DebugStreamPointers()
<< " attempting to perform RNG operation using StreamExecutor"
" without RNG support."; " without RNG support.";
} }
} }
@ -4883,7 +4936,8 @@ Stream &Stream::ThenPopulateRandUniform(
CheckError(rng->DoPopulateRandUniform(this, values)); CheckError(rng->DoPopulateRandUniform(this, values));
} else { } else {
SetError(); SetError();
LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " LOG(INFO) << DebugStreamPointers()
<< " attempting to perform RNG operation using StreamExecutor"
" without RNG support."; " without RNG support.";
} }
} }
@ -4899,7 +4953,7 @@ Stream &Stream::ThenPopulateRandUniform(
CheckError(rng->DoPopulateRandUniform(this, values)); CheckError(rng->DoPopulateRandUniform(this, values));
} else { } else {
SetError(); SetError();
LOG(INFO) << "stream " << this LOG(INFO) << DebugStreamPointers()
<< " attempting to perform RNG operation using StreamExecutor" << " attempting to perform RNG operation using StreamExecutor"
" without RNG support."; " without RNG support.";
} }
@ -4914,7 +4968,7 @@ Stream &Stream::ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
if (ok()) { if (ok()) {
CheckError(parent_->Memcpy(this, host_dst, gpu_src, size)); CheckError(parent_->Memcpy(this, host_dst, gpu_src, size));
} else { } else {
LOG(INFO) << "stream " << this LOG(INFO) << DebugStreamPointers()
<< " did not memcpy device-to-host; source: " << gpu_src.opaque(); << " did not memcpy device-to-host; source: " << gpu_src.opaque();
} }
return *this; return *this;
@ -4927,7 +4981,7 @@ Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
if (ok()) { if (ok()) {
CheckError(parent_->Memcpy(this, gpu_dst, host_src, size)); CheckError(parent_->Memcpy(this, gpu_dst, host_src, size));
} else { } else {
LOG(INFO) << "stream " << this LOG(INFO) << DebugStreamPointers()
<< " did not memcpy host-to-device; source: " << host_src; << " did not memcpy host-to-device; source: " << host_src;
} }
return *this; return *this;
@ -4940,7 +4994,7 @@ Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst,
if (ok()) { if (ok()) {
CheckError(parent_->MemcpyDeviceToDevice(this, gpu_dst, gpu_src, size)); CheckError(parent_->MemcpyDeviceToDevice(this, gpu_dst, gpu_src, size));
} else { } else {
LOG(INFO) << "stream " << this LOG(INFO) << DebugStreamPointers()
<< " did not memcpy gpu-to-gpu; source: " << &gpu_src; << " did not memcpy gpu-to-gpu; source: " << &gpu_src;
} }
return *this; return *this;
@ -4952,7 +5006,7 @@ Stream &Stream::ThenMemZero(DeviceMemoryBase *location, uint64 size) {
if (ok()) { if (ok()) {
CheckError(parent_->MemZero(this, location, size)); CheckError(parent_->MemZero(this, location, size));
} else { } else {
LOG(INFO) << "stream " << this LOG(INFO) << DebugStreamPointers()
<< " did not memzero GPU location; source: " << location; << " did not memzero GPU location; source: " << location;
} }
return *this; return *this;
@ -4965,7 +5019,7 @@ Stream &Stream::ThenMemset32(DeviceMemoryBase *location, uint32 pattern,
if (ok()) { if (ok()) {
CheckError(parent_->Memset32(this, location, pattern, size)); CheckError(parent_->Memset32(this, location, pattern, size));
} else { } else {
LOG(INFO) << "stream " << this LOG(INFO) << DebugStreamPointers()
<< " did not memset GPU location; source: " << location << " did not memset GPU location; source: " << location
<< "; size: " << size << "; pattern: " << std::hex << pattern; << "; size: " << size << "; pattern: " << std::hex << pattern;
} }
@ -5234,7 +5288,7 @@ Stream &Stream::ThenDoHostCallback(std::function<void()> callback) {
if (ok()) { if (ok()) {
CheckError(parent_->HostCallback(this, callback)); CheckError(parent_->HostCallback(this, callback));
} else { } else {
LOG(INFO) << "stream " << this LOG(INFO) << DebugStreamPointers()
<< " was in error state before adding host callback"; << " was in error state before adding host callback";
} }
return *this; return *this;
@ -5250,7 +5304,8 @@ Stream &Stream::ThenFft(fft::Plan *plan,
CheckError(fft->DoFft(this, plan, input, output)); CheckError(fft->DoFft(this, plan, input, output));
} else { } else {
SetError(); SetError();
LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " LOG(INFO) << DebugStreamPointers()
<< " attempting to perform FFT operation using StreamExecutor"
" without FFT support"; " without FFT support";
} }
} }
@ -5267,7 +5322,8 @@ Stream &Stream::ThenFft(fft::Plan *plan,
CheckError(fft->DoFft(this, plan, input, output)); CheckError(fft->DoFft(this, plan, input, output));
} else { } else {
SetError(); SetError();
LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " LOG(INFO) << DebugStreamPointers()
<< " attempting to perform FFT operation using StreamExecutor"
" without FFT support"; " without FFT support";
} }
} }
@ -5283,7 +5339,8 @@ Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<float> &input,
CheckError(fft->DoFft(this, plan, input, output)); CheckError(fft->DoFft(this, plan, input, output));
} else { } else {
SetError(); SetError();
LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " LOG(INFO) << DebugStreamPointers()
<< " attempting to perform FFT operation using StreamExecutor"
" without FFT support"; " without FFT support";
} }
} }
@ -5299,7 +5356,8 @@ Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<double> &input,
CheckError(fft->DoFft(this, plan, input, output)); CheckError(fft->DoFft(this, plan, input, output));
} else { } else {
SetError(); SetError();
LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " LOG(INFO) << DebugStreamPointers()
<< " attempting to perform FFT operation using StreamExecutor"
" without FFT support"; " without FFT support";
} }
} }
@ -5316,7 +5374,8 @@ Stream &Stream::ThenFft(fft::Plan *plan,
CheckError(fft->DoFft(this, plan, input, output)); CheckError(fft->DoFft(this, plan, input, output));
} else { } else {
SetError(); SetError();
LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " LOG(INFO) << DebugStreamPointers()
<< " attempting to perform FFT operation using StreamExecutor"
" without FFT support"; " without FFT support";
} }
} }
@ -5333,7 +5392,8 @@ Stream &Stream::ThenFft(fft::Plan *plan,
CheckError(fft->DoFft(this, plan, input, output)); CheckError(fft->DoFft(this, plan, input, output));
} else { } else {
SetError(); SetError();
LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " LOG(INFO) << DebugStreamPointers()
<< " attempting to perform FFT operation using StreamExecutor"
" without FFT support"; " without FFT support";
} }
} }
@ -5361,7 +5421,7 @@ port::Status Stream::BlockHostUntilDone() {
port::Status status = port::Status( port::Status status = port::Status(
port::error::INTERNAL, port::error::INTERNAL,
"stream did not block host until done; was already in an error state"); "stream did not block host until done; was already in an error state");
LOG(INFO) << status << " " << this; LOG(INFO) << DebugStreamPointers() << " " << status;
return status; return status;
} }
@ -5372,4 +5432,10 @@ port::Status Stream::BlockHostUntilDone() {
return error; return error;
} }
string Stream::DebugStreamPointers() const {
// Relies on the ToVlogString(const void*) overload above.
return port::StrCat("[stream=", ToVlogString(this),
",impl=", ToVlogString(implementation_.get()), "]");
}
} // namespace stream_executor } // namespace stream_executor

View File

@ -122,10 +122,14 @@ class Stream {
// Get or create a sub-stream from this stream. If there is any sub-stream in // Get or create a sub-stream from this stream. If there is any sub-stream in
// the pool that can be reused then just return this sub-stream. Otherwise // the pool that can be reused then just return this sub-stream. Otherwise
// create a new sub-stream. // create a new sub-stream.
//
// TODO(b/112196569): The semantics of failed sub-streams is error-prone.
Stream *GetOrCreateSubStream() LOCKS_EXCLUDED(mu_); Stream *GetOrCreateSubStream() LOCKS_EXCLUDED(mu_);
// Return the sub-stream back to the host stream so that it can be reused // Return the sub-stream back to the host stream so that it can be reused
// later. Sub-streams that are !ok() will not be reused. // later. Sub-streams that are !ok() will not be reused.
//
// TODO(b/112196569): The semantics of failed sub-streams is error-prone.
void ReturnSubStream(Stream *sub_stream) LOCKS_EXCLUDED(mu_); void ReturnSubStream(Stream *sub_stream) LOCKS_EXCLUDED(mu_);
// Allocate temporary memories. The stream will deallocate them when blocked // Allocate temporary memories. The stream will deallocate them when blocked
@ -2051,6 +2055,9 @@ class Stream {
// with this stream. // with this stream.
internal::TemporaryMemoryManager *temporary_memory_manager(); internal::TemporaryMemoryManager *temporary_memory_manager();
// Returns a debugging string "[stream=0x...,impl=0x...]".
string DebugStreamPointers() const;
private: private:
friend class host::HostBlas; // for parent_. friend class host::HostBlas; // for parent_.
friend class host::HostFft; // for parent_. friend class host::HostFft; // for parent_.

View File

@ -95,18 +95,18 @@ TEST_F(StreamTest, TwoSubStreams) {
EXPECT_NE(sub_stream3, sub_stream4); EXPECT_NE(sub_stream3, sub_stream4);
} }
TEST_F(StreamTest, FailedSubStreamNotReused) { TEST_F(StreamTest, FailedSubStreamBeforeReturnNotReused) {
std::unique_ptr<StreamExecutor> executor = NewStreamExecutor(); std::unique_ptr<StreamExecutor> executor = NewStreamExecutor();
Stream stream(executor.get()); Stream stream(executor.get());
stream.Init(); stream.Init();
EXPECT_TRUE(stream.ok()); EXPECT_TRUE(stream.ok());
// Get a sub-stream. // Get sub_stream1.
Stream* sub_stream1 = stream.GetOrCreateSubStream(); Stream* sub_stream1 = stream.GetOrCreateSubStream();
EXPECT_TRUE(sub_stream1->ok()); EXPECT_TRUE(sub_stream1->ok());
// Force an error on the stream; here we call a method that requires // Force an error on sub_stream1; here we call a method that requires DNN
// DNN support, which we know the Host platform doesn't support. // support, which we know the Host platform doesn't support.
sub_stream1->ThenDepthConcatenate({}, {}, nullptr); sub_stream1->ThenDepthConcatenate({}, {}, nullptr);
EXPECT_FALSE(sub_stream1->ok()); EXPECT_FALSE(sub_stream1->ok());
@ -115,20 +115,84 @@ TEST_F(StreamTest, FailedSubStreamNotReused) {
Stream* sub_stream2 = stream.GetOrCreateSubStream(); Stream* sub_stream2 = stream.GetOrCreateSubStream();
EXPECT_TRUE(sub_stream2->ok()); EXPECT_TRUE(sub_stream2->ok());
// The underlying streams should be different. They would have been // The underlying sub_streams should be different. They would have been the
// the same, but since we forced an error on sub_stream1, it will // same, but since we forced an error on sub_stream1, it will not be
// not be re-used. Sadly we can't just check: // re-used. Sadly we can't just check:
// EXPECT_NE(sub_stream1, sub_stream2); // EXPECT_NE(sub_stream1, sub_stream2);
// //
// The above should hold logically, but it may fail if the new // The above should hold logically, but it may fail if the new Stream instance
// stream instance allocated for sub_stream2 happens to reside in // allocated for sub_stream2 happens to reside in the same memory address as
// the same memory address as sub_stream1. // sub_stream1.
// //
// The check that sub_stream2->ok() serves as a good-enough check. // The check that sub_stream2->ok() serves as a good-enough check.
// Return sub_stream2 and get sub_stream3. The previous error on // Return sub_stream2 and get sub_stream3. The previous error on sub_stream1
// sub_stream1 has no effect on these streams, and they are the // has no effect on these streams, and they are the same.
// same. stream.ReturnSubStream(sub_stream2);
Stream* sub_stream3 = stream.GetOrCreateSubStream();
EXPECT_TRUE(sub_stream3->ok());
EXPECT_EQ(sub_stream2, sub_stream3);
}
TEST_F(StreamTest, FailedSubStreamAfterReturnNotReused) {
std::unique_ptr<StreamExecutor> executor = NewStreamExecutor();
Stream stream(executor.get());
stream.Init();
EXPECT_TRUE(stream.ok());
// Get and return sub_stream1.
Stream* sub_stream1 = stream.GetOrCreateSubStream();
EXPECT_TRUE(sub_stream1->ok());
stream.ReturnSubStream(sub_stream1);
// Force an error on sub_stream1; here we call a method that requires DNN
// support, which we know the Host platform doesn't support.
//
// It is a bit weird to use sub_stream1 after it has already been returned. By
// doing this, we're simulating an asynchronous error that occurs during
// execution of the sub_stream, that occurs after the sub_stream is returned.
//
// E.g. the following is a common pattern of usage, where the execution of the
// operations enqueued onto the sub streams may occur after the streams have
// already been returned.
//
// void EnqueueOnSubStreams(Stream* stream) {
// Stream* sub_stream1 = stream.GetOrCreateSubStream();
// Stream* sub_stream2 = stream.GetOrCreateSubStream();
// // ... enqueue some operations on the sub streams ...
// stream.ThenWaitFor(sub_stream1).ThenWaitFor(sub_stream2);
// stream.ReturnSubStream(sub_stream1);
// stream.ReturnSubStream(sub_stream2);
// }
//
// Stream* main_stream = ...;
// EnqueueOnSubStreams(main_stream);
// main_stream.BlockHostUntilDone();
//
// TODO(b/112196569): The semantics of failed sub-streams is error-prone;
// GetOrCreateSubStream can still return a sub-stream that has not encountered
// an error yet, but will encounter one in the future, based on previously
// enqueued operations.
sub_stream1->ThenDepthConcatenate({}, {}, nullptr);
EXPECT_FALSE(sub_stream1->ok());
// Get and return sub_stream2.
Stream* sub_stream2 = stream.GetOrCreateSubStream();
EXPECT_TRUE(sub_stream2->ok());
// The underlying streams should be different. They would have been the same,
// but since we forced an error on sub_stream1, it will not be re-used. Sadly
// we can't just check:
// EXPECT_NE(sub_stream1, sub_stream2);
//
// The above should hold logically, but it may fail if the new stream instance
// allocated for sub_stream2 happens to reside in the same memory address as
// sub_stream1.
//
// The check that sub_stream2->ok() serves as a good-enough check.
// Return sub_stream2 and get sub_stream3. The previous error on sub_stream1
// has no effect on these streams, and they are the same.
stream.ReturnSubStream(sub_stream2); stream.ReturnSubStream(sub_stream2);
Stream* sub_stream3 = stream.GetOrCreateSubStream(); Stream* sub_stream3 = stream.GetOrCreateSubStream();
EXPECT_TRUE(sub_stream3->ok()); EXPECT_TRUE(sub_stream3->ok());