Drop failed sub-streams during both Get and Return.
The old code ensured that failed sub-streams would not be re-used, but had two flaws: 1) It only checked for failed sub-streams during Return. 2) It didn't actually remove the failed sub-streams from our state. The new code fixes these two flaws, and adds an extra test that explains why (1) is insufficient. PiperOrigin-RevId: 207333296
This commit is contained in:
parent
5de6d11b0b
commit
9cdcb0397c
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/stream_pool.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -27,6 +28,8 @@ StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor) {
|
||||
// Re-use an existing stream from the pool.
|
||||
stream = std::move(streams_.back());
|
||||
streams_.pop_back();
|
||||
VLOG(1) << stream->DebugStreamPointers()
|
||||
<< " StreamPool reusing existing stream";
|
||||
}
|
||||
}
|
||||
|
||||
@ -34,6 +37,8 @@ StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor) {
|
||||
// Create a new stream.
|
||||
stream = MakeUnique<se::Stream>(executor);
|
||||
stream->Init();
|
||||
VLOG(1) << stream->DebugStreamPointers()
|
||||
<< " StreamPool created new stream";
|
||||
}
|
||||
|
||||
// Return the stream wrapped in Ptr, which has our special deleter semantics.
|
||||
@ -43,12 +48,16 @@ StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor) {
|
||||
|
||||
void StreamPool::ReturnStream(se::Stream* stream) {
|
||||
if (stream->ok()) {
|
||||
VLOG(1) << stream->DebugStreamPointers()
|
||||
<< " StreamPool returning ok stream";
|
||||
tensorflow::mutex_lock lock(mu_);
|
||||
streams_.emplace_back(stream);
|
||||
} else {
|
||||
// If the stream has encountered any errors, all subsequent
|
||||
// operations on it will fail. So just delete the stream, and rely
|
||||
// on new streams to be created in the future.
|
||||
// If the stream has encountered any errors, all subsequent operations on it
|
||||
// will fail. So just delete the stream, and rely on new streams to be
|
||||
// created in the future.
|
||||
VLOG(1) << stream->DebugStreamPointers()
|
||||
<< " StreamPool deleting !ok stream";
|
||||
delete stream;
|
||||
}
|
||||
}
|
||||
|
@ -115,7 +115,7 @@ string ToVlogString(const DeviceMemoryBase &memory) {
|
||||
}
|
||||
|
||||
string ToVlogString(const DeviceMemoryBase *memory) {
|
||||
return ToVlogString(*memory);
|
||||
return memory == nullptr ? "null" : ToVlogString(*memory);
|
||||
}
|
||||
|
||||
string ToVlogString(const Eigen::half &h) {
|
||||
@ -211,13 +211,14 @@ string CallStr(const char *function_name, Stream *stream,
|
||||
// constructing all the strings in params is expensive.
|
||||
CHECK(VLOG_IS_ON(1));
|
||||
|
||||
string str = port::StrCat("Called Stream::", function_name, "(");
|
||||
string str = port::StrCat(stream->DebugStreamPointers(),
|
||||
" Called Stream::", function_name, "(");
|
||||
const char *separator = "";
|
||||
for (const auto ¶m : params) {
|
||||
port::StrAppend(&str, separator, param.first, "=", param.second);
|
||||
separator = ", ";
|
||||
}
|
||||
port::StrAppend(&str, ") stream=", ToVlogString(stream));
|
||||
port::StrAppend(&str, ")");
|
||||
if (VLOG_IS_ON(10)) {
|
||||
port::StrAppend(&str, " ", port::CurrentStackTrace(), "\n");
|
||||
}
|
||||
@ -1922,37 +1923,82 @@ Stream &Stream::ThenCopyDevice2HostBuffer(
|
||||
|
||||
Stream *Stream::GetOrCreateSubStream() {
|
||||
mutex_lock lock(mu_);
|
||||
for (auto &stream : sub_streams_) {
|
||||
if (stream.second) {
|
||||
stream.second = false;
|
||||
return stream.first.get();
|
||||
|
||||
// Look for the first reusable sub_stream that is ok, dropping !ok sub_streams
|
||||
// we encounter along the way.
|
||||
for (int64 index = 0; index < sub_streams_.size();) {
|
||||
std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index];
|
||||
if (pair.second) {
|
||||
// The sub_stream is reusable.
|
||||
Stream *sub_stream = pair.first.get();
|
||||
if (sub_stream->ok()) {
|
||||
VLOG(1) << DebugStreamPointers() << " reusing sub_stream "
|
||||
<< sub_stream->DebugStreamPointers();
|
||||
pair.second = false;
|
||||
return sub_stream;
|
||||
}
|
||||
|
||||
// The stream is reusable and not ok. Streams have a monotonic state
|
||||
// machine; the stream will remain in !ok forever. Swap it with the last
|
||||
// stream and pop it off.
|
||||
const int64 last = sub_streams_.size() - 1;
|
||||
if (index != last) {
|
||||
std::swap(pair, sub_streams_[last]);
|
||||
}
|
||||
sub_streams_.pop_back();
|
||||
VLOG(1) << DebugStreamPointers() << " dropped !ok sub_stream "
|
||||
<< sub_stream->DebugStreamPointers();
|
||||
} else {
|
||||
// The sub_stream is not reusable, move on to the next one.
|
||||
++index;
|
||||
}
|
||||
}
|
||||
|
||||
// No streams are reusable; create a new stream.
|
||||
sub_streams_.emplace_back(std::unique_ptr<Stream>{new Stream{parent_}},
|
||||
false);
|
||||
Stream *sub_stream = sub_streams_.back().first.get();
|
||||
sub_stream->Init();
|
||||
CHECK(ok_) << "sub-stream failed to be initialized";
|
||||
VLOG(1) << DebugStreamPointers() << " created new sub_stream "
|
||||
<< sub_stream->DebugStreamPointers();
|
||||
|
||||
return sub_stream;
|
||||
}
|
||||
|
||||
void Stream::ReturnSubStream(Stream *sub_stream) {
|
||||
mutex_lock lock(mu_);
|
||||
for (auto &stream : sub_streams_) {
|
||||
if (stream.first.get() == sub_stream) {
|
||||
// Streams have a monotonic state machine; if a stream
|
||||
// encounters an error, it will remain in an error state
|
||||
// forever. Only allow re-use of ok streams.
|
||||
//
|
||||
// TODO(toddw): Improve this mechanism, if necessary, to drop
|
||||
// failed streams completely.
|
||||
const bool ready_to_reuse = sub_stream->ok();
|
||||
stream.second = ready_to_reuse;
|
||||
return;
|
||||
|
||||
// Look for the sub-stream.
|
||||
for (int64 index = 0; index < sub_streams_.size(); ++index) {
|
||||
std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index];
|
||||
if (pair.first.get() != sub_stream) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Found the sub_stream.
|
||||
if (sub_stream->ok()) {
|
||||
VLOG(1) << DebugStreamPointers() << " returned ok sub_stream "
|
||||
<< sub_stream->DebugStreamPointers();
|
||||
pair.second = true;
|
||||
} else {
|
||||
// The returned stream is not ok. Streams have a monotonic state
|
||||
// machine; the stream will remain in !ok forever. Swap it with the last
|
||||
// stream and pop it off.
|
||||
VLOG(1) << DebugStreamPointers() << " returned !ok sub_stream "
|
||||
<< sub_stream->DebugStreamPointers();
|
||||
const int64 last = sub_streams_.size() - 1;
|
||||
if (index != last) {
|
||||
std::swap(pair, sub_streams_[last]);
|
||||
}
|
||||
sub_streams_.pop_back();
|
||||
}
|
||||
return;
|
||||
}
|
||||
LOG(FATAL) << "the sub-stream to be returned is not created by this stream";
|
||||
|
||||
LOG(FATAL) << DebugStreamPointers()
|
||||
<< " did not create the returned sub-stream "
|
||||
<< sub_stream->DebugStreamPointers();
|
||||
}
|
||||
|
||||
Stream &Stream::ThenStartTimer(Timer *t) {
|
||||
@ -1961,7 +2007,8 @@ Stream &Stream::ThenStartTimer(Timer *t) {
|
||||
if (ok()) {
|
||||
CheckError(parent_->StartTimer(this, t));
|
||||
} else {
|
||||
LOG(INFO) << "stream " << this << " did not enqueue 'start timer': " << t;
|
||||
LOG(INFO) << DebugStreamPointers()
|
||||
<< " did not enqueue 'start timer': " << t;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
@ -1972,7 +2019,8 @@ Stream &Stream::ThenStopTimer(Timer *t) {
|
||||
if (ok()) {
|
||||
CheckError(parent_->StopTimer(this, t));
|
||||
} else {
|
||||
LOG(INFO) << "stream " << this << " did not enqueue 'stop timer': " << t;
|
||||
LOG(INFO) << DebugStreamPointers()
|
||||
<< " did not enqueue 'stop timer': " << t;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
@ -1985,7 +2033,8 @@ Stream &Stream::ThenWaitFor(Stream *other) {
|
||||
CheckError(parent_->CreateStreamDependency(this, other));
|
||||
} else {
|
||||
SetError();
|
||||
LOG(INFO) << "stream " << this << " did not wait for stream: " << other;
|
||||
LOG(INFO) << DebugStreamPointers() << " did not wait for "
|
||||
<< other->DebugStreamPointers();
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
@ -2002,7 +2051,7 @@ Stream &Stream::ThenWaitFor(Event *event) {
|
||||
<< "at fault. Monitor for further errors.";
|
||||
}
|
||||
} else {
|
||||
LOG(INFO) << "stream " << this << " did not wait for an event.";
|
||||
LOG(INFO) << DebugStreamPointers() << " did not wait for an event.";
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
@ -4802,10 +4851,10 @@ Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) {
|
||||
CheckError(rng->SetSeed(this, seed, seed_bytes));
|
||||
} else {
|
||||
SetError();
|
||||
LOG(INFO) << "stream " << this << " unable to initialize RNG";
|
||||
LOG(INFO) << DebugStreamPointers() << " unable to initialize RNG";
|
||||
}
|
||||
} else {
|
||||
LOG(INFO) << "stream " << this
|
||||
LOG(INFO) << DebugStreamPointers()
|
||||
<< " did not set RNG seed: " << static_cast<const void *>(seed)
|
||||
<< "; bytes: " << seed_bytes;
|
||||
}
|
||||
@ -4820,8 +4869,9 @@ Stream &Stream::ThenPopulateRandUniform(DeviceMemory<float> *values) {
|
||||
CheckError(rng->DoPopulateRandUniform(this, values));
|
||||
} else {
|
||||
SetError();
|
||||
LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
|
||||
"without RNG support.";
|
||||
LOG(INFO) << DebugStreamPointers()
|
||||
<< " attempting to perform RNG operation using StreamExecutor"
|
||||
" without RNG support.";
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
@ -4836,8 +4886,9 @@ Stream &Stream::ThenPopulateRandGaussian(float mean, float sd,
|
||||
CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
|
||||
} else {
|
||||
SetError();
|
||||
LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
|
||||
"without RNG support.";
|
||||
LOG(INFO) << DebugStreamPointers()
|
||||
<< " attempting to perform RNG operation using StreamExecutor"
|
||||
" without RNG support.";
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
@ -4852,8 +4903,9 @@ Stream &Stream::ThenPopulateRandGaussian(double mean, double sd,
|
||||
CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
|
||||
} else {
|
||||
SetError();
|
||||
LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
|
||||
"without RNG support.";
|
||||
LOG(INFO) << DebugStreamPointers()
|
||||
<< " attempting to perform RNG operation using StreamExecutor"
|
||||
" without RNG support.";
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
@ -4867,8 +4919,9 @@ Stream &Stream::ThenPopulateRandUniform(DeviceMemory<double> *values) {
|
||||
CheckError(rng->DoPopulateRandUniform(this, values));
|
||||
} else {
|
||||
SetError();
|
||||
LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
|
||||
"without RNG support.";
|
||||
LOG(INFO) << DebugStreamPointers()
|
||||
<< " attempting to perform RNG operation using StreamExecutor"
|
||||
" without RNG support.";
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
@ -4883,8 +4936,9 @@ Stream &Stream::ThenPopulateRandUniform(
|
||||
CheckError(rng->DoPopulateRandUniform(this, values));
|
||||
} else {
|
||||
SetError();
|
||||
LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
|
||||
"without RNG support.";
|
||||
LOG(INFO) << DebugStreamPointers()
|
||||
<< " attempting to perform RNG operation using StreamExecutor"
|
||||
" without RNG support.";
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
@ -4899,9 +4953,9 @@ Stream &Stream::ThenPopulateRandUniform(
|
||||
CheckError(rng->DoPopulateRandUniform(this, values));
|
||||
} else {
|
||||
SetError();
|
||||
LOG(INFO) << "stream " << this
|
||||
<< " attempting to perform RNG operation using StreamExecutor "
|
||||
"without RNG support.";
|
||||
LOG(INFO) << DebugStreamPointers()
|
||||
<< " attempting to perform RNG operation using StreamExecutor"
|
||||
" without RNG support.";
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
@ -4914,7 +4968,7 @@ Stream &Stream::ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
|
||||
if (ok()) {
|
||||
CheckError(parent_->Memcpy(this, host_dst, gpu_src, size));
|
||||
} else {
|
||||
LOG(INFO) << "stream " << this
|
||||
LOG(INFO) << DebugStreamPointers()
|
||||
<< " did not memcpy device-to-host; source: " << gpu_src.opaque();
|
||||
}
|
||||
return *this;
|
||||
@ -4927,7 +4981,7 @@ Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
|
||||
if (ok()) {
|
||||
CheckError(parent_->Memcpy(this, gpu_dst, host_src, size));
|
||||
} else {
|
||||
LOG(INFO) << "stream " << this
|
||||
LOG(INFO) << DebugStreamPointers()
|
||||
<< " did not memcpy host-to-device; source: " << host_src;
|
||||
}
|
||||
return *this;
|
||||
@ -4940,7 +4994,7 @@ Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst,
|
||||
if (ok()) {
|
||||
CheckError(parent_->MemcpyDeviceToDevice(this, gpu_dst, gpu_src, size));
|
||||
} else {
|
||||
LOG(INFO) << "stream " << this
|
||||
LOG(INFO) << DebugStreamPointers()
|
||||
<< " did not memcpy gpu-to-gpu; source: " << &gpu_src;
|
||||
}
|
||||
return *this;
|
||||
@ -4952,7 +5006,7 @@ Stream &Stream::ThenMemZero(DeviceMemoryBase *location, uint64 size) {
|
||||
if (ok()) {
|
||||
CheckError(parent_->MemZero(this, location, size));
|
||||
} else {
|
||||
LOG(INFO) << "stream " << this
|
||||
LOG(INFO) << DebugStreamPointers()
|
||||
<< " did not memzero GPU location; source: " << location;
|
||||
}
|
||||
return *this;
|
||||
@ -4965,7 +5019,7 @@ Stream &Stream::ThenMemset32(DeviceMemoryBase *location, uint32 pattern,
|
||||
if (ok()) {
|
||||
CheckError(parent_->Memset32(this, location, pattern, size));
|
||||
} else {
|
||||
LOG(INFO) << "stream " << this
|
||||
LOG(INFO) << DebugStreamPointers()
|
||||
<< " did not memset GPU location; source: " << location
|
||||
<< "; size: " << size << "; pattern: " << std::hex << pattern;
|
||||
}
|
||||
@ -5234,7 +5288,7 @@ Stream &Stream::ThenDoHostCallback(std::function<void()> callback) {
|
||||
if (ok()) {
|
||||
CheckError(parent_->HostCallback(this, callback));
|
||||
} else {
|
||||
LOG(INFO) << "stream " << this
|
||||
LOG(INFO) << DebugStreamPointers()
|
||||
<< " was in error state before adding host callback";
|
||||
}
|
||||
return *this;
|
||||
@ -5250,8 +5304,9 @@ Stream &Stream::ThenFft(fft::Plan *plan,
|
||||
CheckError(fft->DoFft(this, plan, input, output));
|
||||
} else {
|
||||
SetError();
|
||||
LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
|
||||
"without FFT support";
|
||||
LOG(INFO) << DebugStreamPointers()
|
||||
<< " attempting to perform FFT operation using StreamExecutor"
|
||||
" without FFT support";
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
@ -5267,8 +5322,9 @@ Stream &Stream::ThenFft(fft::Plan *plan,
|
||||
CheckError(fft->DoFft(this, plan, input, output));
|
||||
} else {
|
||||
SetError();
|
||||
LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
|
||||
"without FFT support";
|
||||
LOG(INFO) << DebugStreamPointers()
|
||||
<< " attempting to perform FFT operation using StreamExecutor"
|
||||
" without FFT support";
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
@ -5283,8 +5339,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<float> &input,
|
||||
CheckError(fft->DoFft(this, plan, input, output));
|
||||
} else {
|
||||
SetError();
|
||||
LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
|
||||
"without FFT support";
|
||||
LOG(INFO) << DebugStreamPointers()
|
||||
<< " attempting to perform FFT operation using StreamExecutor"
|
||||
" without FFT support";
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
@ -5299,8 +5356,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<double> &input,
|
||||
CheckError(fft->DoFft(this, plan, input, output));
|
||||
} else {
|
||||
SetError();
|
||||
LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
|
||||
"without FFT support";
|
||||
LOG(INFO) << DebugStreamPointers()
|
||||
<< " attempting to perform FFT operation using StreamExecutor"
|
||||
" without FFT support";
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
@ -5316,8 +5374,9 @@ Stream &Stream::ThenFft(fft::Plan *plan,
|
||||
CheckError(fft->DoFft(this, plan, input, output));
|
||||
} else {
|
||||
SetError();
|
||||
LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
|
||||
"without FFT support";
|
||||
LOG(INFO) << DebugStreamPointers()
|
||||
<< " attempting to perform FFT operation using StreamExecutor"
|
||||
" without FFT support";
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
@ -5333,8 +5392,9 @@ Stream &Stream::ThenFft(fft::Plan *plan,
|
||||
CheckError(fft->DoFft(this, plan, input, output));
|
||||
} else {
|
||||
SetError();
|
||||
LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
|
||||
"without FFT support";
|
||||
LOG(INFO) << DebugStreamPointers()
|
||||
<< " attempting to perform FFT operation using StreamExecutor"
|
||||
" without FFT support";
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
@ -5361,7 +5421,7 @@ port::Status Stream::BlockHostUntilDone() {
|
||||
port::Status status = port::Status(
|
||||
port::error::INTERNAL,
|
||||
"stream did not block host until done; was already in an error state");
|
||||
LOG(INFO) << status << " " << this;
|
||||
LOG(INFO) << DebugStreamPointers() << " " << status;
|
||||
return status;
|
||||
}
|
||||
|
||||
@ -5372,4 +5432,10 @@ port::Status Stream::BlockHostUntilDone() {
|
||||
return error;
|
||||
}
|
||||
|
||||
string Stream::DebugStreamPointers() const {
|
||||
// Relies on the ToVlogString(const void*) overload above.
|
||||
return port::StrCat("[stream=", ToVlogString(this),
|
||||
",impl=", ToVlogString(implementation_.get()), "]");
|
||||
}
|
||||
|
||||
} // namespace stream_executor
|
||||
|
@ -122,10 +122,14 @@ class Stream {
|
||||
// Get or create a sub-stream from this stream. If there is any sub-stream in
|
||||
// the pool that can be reused then just return this sub-stream. Otherwise
|
||||
// create a new sub-stream.
|
||||
//
|
||||
// TODO(b/112196569): The semantics of failed sub-streams is error-prone.
|
||||
Stream *GetOrCreateSubStream() LOCKS_EXCLUDED(mu_);
|
||||
|
||||
// Return the sub-stream back to the host stream so that it can be reused
|
||||
// later. Sub-streams that are !ok() will not be reused.
|
||||
//
|
||||
// TODO(b/112196569): The semantics of failed sub-streams is error-prone.
|
||||
void ReturnSubStream(Stream *sub_stream) LOCKS_EXCLUDED(mu_);
|
||||
|
||||
// Allocate temporary memories. The stream will deallocate them when blocked
|
||||
@ -2051,6 +2055,9 @@ class Stream {
|
||||
// with this stream.
|
||||
internal::TemporaryMemoryManager *temporary_memory_manager();
|
||||
|
||||
// Returns a debugging string "[stream=0x...,impl=0x...]".
|
||||
string DebugStreamPointers() const;
|
||||
|
||||
private:
|
||||
friend class host::HostBlas; // for parent_.
|
||||
friend class host::HostFft; // for parent_.
|
||||
|
@ -95,18 +95,18 @@ TEST_F(StreamTest, TwoSubStreams) {
|
||||
EXPECT_NE(sub_stream3, sub_stream4);
|
||||
}
|
||||
|
||||
TEST_F(StreamTest, FailedSubStreamNotReused) {
|
||||
TEST_F(StreamTest, FailedSubStreamBeforeReturnNotReused) {
|
||||
std::unique_ptr<StreamExecutor> executor = NewStreamExecutor();
|
||||
Stream stream(executor.get());
|
||||
stream.Init();
|
||||
EXPECT_TRUE(stream.ok());
|
||||
|
||||
// Get a sub-stream.
|
||||
// Get sub_stream1.
|
||||
Stream* sub_stream1 = stream.GetOrCreateSubStream();
|
||||
EXPECT_TRUE(sub_stream1->ok());
|
||||
|
||||
// Force an error on the stream; here we call a method that requires
|
||||
// DNN support, which we know the Host platform doesn't support.
|
||||
// Force an error on sub_stream1; here we call a method that requires DNN
|
||||
// support, which we know the Host platform doesn't support.
|
||||
sub_stream1->ThenDepthConcatenate({}, {}, nullptr);
|
||||
EXPECT_FALSE(sub_stream1->ok());
|
||||
|
||||
@ -115,20 +115,84 @@ TEST_F(StreamTest, FailedSubStreamNotReused) {
|
||||
Stream* sub_stream2 = stream.GetOrCreateSubStream();
|
||||
EXPECT_TRUE(sub_stream2->ok());
|
||||
|
||||
// The underlying streams should be different. They would have been
|
||||
// the same, but since we forced an error on sub_stream1, it will
|
||||
// not be re-used. Sadly we can't just check:
|
||||
// The underlying sub_streams should be different. They would have been the
|
||||
// same, but since we forced an error on sub_stream1, it will not be
|
||||
// re-used. Sadly we can't just check:
|
||||
// EXPECT_NE(sub_stream1, sub_stream2);
|
||||
//
|
||||
// The above should hold logically, but it may fail if the new
|
||||
// stream instance allocated for sub_stream2 happens to reside in
|
||||
// the same memory address as sub_stream1.
|
||||
// The above should hold logically, but it may fail if the new Stream instance
|
||||
// allocated for sub_stream2 happens to reside in the same memory address as
|
||||
// sub_stream1.
|
||||
//
|
||||
// The check that sub_stream2->ok() serves as a good-enough check.
|
||||
|
||||
// Return sub_stream2 and get sub_stream3. The previous error on
|
||||
// sub_stream1 has no effect on these streams, and they are the
|
||||
// same.
|
||||
// Return sub_stream2 and get sub_stream3. The previous error on sub_stream1
|
||||
// has no effect on these streams, and they are the same.
|
||||
stream.ReturnSubStream(sub_stream2);
|
||||
Stream* sub_stream3 = stream.GetOrCreateSubStream();
|
||||
EXPECT_TRUE(sub_stream3->ok());
|
||||
EXPECT_EQ(sub_stream2, sub_stream3);
|
||||
}
|
||||
|
||||
TEST_F(StreamTest, FailedSubStreamAfterReturnNotReused) {
|
||||
std::unique_ptr<StreamExecutor> executor = NewStreamExecutor();
|
||||
Stream stream(executor.get());
|
||||
stream.Init();
|
||||
EXPECT_TRUE(stream.ok());
|
||||
|
||||
// Get and return sub_stream1.
|
||||
Stream* sub_stream1 = stream.GetOrCreateSubStream();
|
||||
EXPECT_TRUE(sub_stream1->ok());
|
||||
stream.ReturnSubStream(sub_stream1);
|
||||
|
||||
// Force an error on sub_stream1; here we call a method that requires DNN
|
||||
// support, which we know the Host platform doesn't support.
|
||||
//
|
||||
// It is a bit weird to use sub_stream1 after it has already been returned. By
|
||||
// doing this, we're simulating an asynchronous error that occurs during
|
||||
// execution of the sub_stream, that occurs after the sub_stream is returned.
|
||||
//
|
||||
// E.g. the following is a common pattern of usage, where the execution of the
|
||||
// operations enqueued onto the sub streams may occur after the streams have
|
||||
// already been returned.
|
||||
//
|
||||
// void EnqueueOnSubStreams(Stream* stream) {
|
||||
// Stream* sub_stream1 = stream.GetOrCreateSubStream();
|
||||
// Stream* sub_stream2 = stream.GetOrCreateSubStream();
|
||||
// // ... enqueue some operations on the sub streams ...
|
||||
// stream.ThenWaitFor(sub_stream1).ThenWaitFor(sub_stream2);
|
||||
// stream.ReturnSubStream(sub_stream1);
|
||||
// stream.ReturnSubStream(sub_stream2);
|
||||
// }
|
||||
//
|
||||
// Stream* main_stream = ...;
|
||||
// EnqueueOnSubStreams(main_stream);
|
||||
// main_stream.BlockHostUntilDone();
|
||||
//
|
||||
// TODO(b/112196569): The semantics of failed sub-streams is error-prone;
|
||||
// GetOrCreateSubStream can still return a sub-stream that has not encountered
|
||||
// an error yet, but will encounter one in the future, based on previously
|
||||
// enqueued operations.
|
||||
sub_stream1->ThenDepthConcatenate({}, {}, nullptr);
|
||||
EXPECT_FALSE(sub_stream1->ok());
|
||||
|
||||
// Get and return sub_stream2.
|
||||
Stream* sub_stream2 = stream.GetOrCreateSubStream();
|
||||
EXPECT_TRUE(sub_stream2->ok());
|
||||
|
||||
// The underlying streams should be different. They would have been the same,
|
||||
// but since we forced an error on sub_stream1, it will not be re-used. Sadly
|
||||
// we can't just check:
|
||||
// EXPECT_NE(sub_stream1, sub_stream2);
|
||||
//
|
||||
// The above should hold logically, but it may fail if the new stream instance
|
||||
// allocated for sub_stream2 happens to reside in the same memory address as
|
||||
// sub_stream1.
|
||||
//
|
||||
// The check that sub_stream2->ok() serves as a good-enough check.
|
||||
|
||||
// Return sub_stream2 and get sub_stream3. The previous error on sub_stream1
|
||||
// has no effect on these streams, and they are the same.
|
||||
stream.ReturnSubStream(sub_stream2);
|
||||
Stream* sub_stream3 = stream.GetOrCreateSubStream();
|
||||
EXPECT_TRUE(sub_stream3->ok());
|
||||
|
Loading…
Reference in New Issue
Block a user