Drop failed sub-streams during both Get and Return.
The old code ensured that failed sub-streams would not be re-used, but had two flaws: 1) It only checked for failed sub-streams during Return. 2) It didn't actually remove the failed sub-streams from our state. The new code fixes these two flaws, and adds an extra test that explains why (1) is insufficient. PiperOrigin-RevId: 207333296
This commit is contained in:
parent
5de6d11b0b
commit
9cdcb0397c
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/stream_pool.h"
|
#include "tensorflow/compiler/xla/service/stream_pool.h"
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
@ -27,6 +28,8 @@ StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor) {
|
|||||||
// Re-use an existing stream from the pool.
|
// Re-use an existing stream from the pool.
|
||||||
stream = std::move(streams_.back());
|
stream = std::move(streams_.back());
|
||||||
streams_.pop_back();
|
streams_.pop_back();
|
||||||
|
VLOG(1) << stream->DebugStreamPointers()
|
||||||
|
<< " StreamPool reusing existing stream";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -34,6 +37,8 @@ StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor) {
|
|||||||
// Create a new stream.
|
// Create a new stream.
|
||||||
stream = MakeUnique<se::Stream>(executor);
|
stream = MakeUnique<se::Stream>(executor);
|
||||||
stream->Init();
|
stream->Init();
|
||||||
|
VLOG(1) << stream->DebugStreamPointers()
|
||||||
|
<< " StreamPool created new stream";
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return the stream wrapped in Ptr, which has our special deleter semantics.
|
// Return the stream wrapped in Ptr, which has our special deleter semantics.
|
||||||
@ -43,12 +48,16 @@ StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor) {
|
|||||||
|
|
||||||
void StreamPool::ReturnStream(se::Stream* stream) {
|
void StreamPool::ReturnStream(se::Stream* stream) {
|
||||||
if (stream->ok()) {
|
if (stream->ok()) {
|
||||||
|
VLOG(1) << stream->DebugStreamPointers()
|
||||||
|
<< " StreamPool returning ok stream";
|
||||||
tensorflow::mutex_lock lock(mu_);
|
tensorflow::mutex_lock lock(mu_);
|
||||||
streams_.emplace_back(stream);
|
streams_.emplace_back(stream);
|
||||||
} else {
|
} else {
|
||||||
// If the stream has encountered any errors, all subsequent
|
// If the stream has encountered any errors, all subsequent operations on it
|
||||||
// operations on it will fail. So just delete the stream, and rely
|
// will fail. So just delete the stream, and rely on new streams to be
|
||||||
// on new streams to be created in the future.
|
// created in the future.
|
||||||
|
VLOG(1) << stream->DebugStreamPointers()
|
||||||
|
<< " StreamPool deleting !ok stream";
|
||||||
delete stream;
|
delete stream;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -115,7 +115,7 @@ string ToVlogString(const DeviceMemoryBase &memory) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
string ToVlogString(const DeviceMemoryBase *memory) {
|
string ToVlogString(const DeviceMemoryBase *memory) {
|
||||||
return ToVlogString(*memory);
|
return memory == nullptr ? "null" : ToVlogString(*memory);
|
||||||
}
|
}
|
||||||
|
|
||||||
string ToVlogString(const Eigen::half &h) {
|
string ToVlogString(const Eigen::half &h) {
|
||||||
@ -211,13 +211,14 @@ string CallStr(const char *function_name, Stream *stream,
|
|||||||
// constructing all the strings in params is expensive.
|
// constructing all the strings in params is expensive.
|
||||||
CHECK(VLOG_IS_ON(1));
|
CHECK(VLOG_IS_ON(1));
|
||||||
|
|
||||||
string str = port::StrCat("Called Stream::", function_name, "(");
|
string str = port::StrCat(stream->DebugStreamPointers(),
|
||||||
|
" Called Stream::", function_name, "(");
|
||||||
const char *separator = "";
|
const char *separator = "";
|
||||||
for (const auto ¶m : params) {
|
for (const auto ¶m : params) {
|
||||||
port::StrAppend(&str, separator, param.first, "=", param.second);
|
port::StrAppend(&str, separator, param.first, "=", param.second);
|
||||||
separator = ", ";
|
separator = ", ";
|
||||||
}
|
}
|
||||||
port::StrAppend(&str, ") stream=", ToVlogString(stream));
|
port::StrAppend(&str, ")");
|
||||||
if (VLOG_IS_ON(10)) {
|
if (VLOG_IS_ON(10)) {
|
||||||
port::StrAppend(&str, " ", port::CurrentStackTrace(), "\n");
|
port::StrAppend(&str, " ", port::CurrentStackTrace(), "\n");
|
||||||
}
|
}
|
||||||
@ -1922,37 +1923,82 @@ Stream &Stream::ThenCopyDevice2HostBuffer(
|
|||||||
|
|
||||||
Stream *Stream::GetOrCreateSubStream() {
|
Stream *Stream::GetOrCreateSubStream() {
|
||||||
mutex_lock lock(mu_);
|
mutex_lock lock(mu_);
|
||||||
for (auto &stream : sub_streams_) {
|
|
||||||
if (stream.second) {
|
// Look for the first reusable sub_stream that is ok, dropping !ok sub_streams
|
||||||
stream.second = false;
|
// we encounter along the way.
|
||||||
return stream.first.get();
|
for (int64 index = 0; index < sub_streams_.size();) {
|
||||||
|
std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index];
|
||||||
|
if (pair.second) {
|
||||||
|
// The sub_stream is reusable.
|
||||||
|
Stream *sub_stream = pair.first.get();
|
||||||
|
if (sub_stream->ok()) {
|
||||||
|
VLOG(1) << DebugStreamPointers() << " reusing sub_stream "
|
||||||
|
<< sub_stream->DebugStreamPointers();
|
||||||
|
pair.second = false;
|
||||||
|
return sub_stream;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The stream is reusable and not ok. Streams have a monotonic state
|
||||||
|
// machine; the stream will remain in !ok forever. Swap it with the last
|
||||||
|
// stream and pop it off.
|
||||||
|
const int64 last = sub_streams_.size() - 1;
|
||||||
|
if (index != last) {
|
||||||
|
std::swap(pair, sub_streams_[last]);
|
||||||
|
}
|
||||||
|
sub_streams_.pop_back();
|
||||||
|
VLOG(1) << DebugStreamPointers() << " dropped !ok sub_stream "
|
||||||
|
<< sub_stream->DebugStreamPointers();
|
||||||
|
} else {
|
||||||
|
// The sub_stream is not reusable, move on to the next one.
|
||||||
|
++index;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// No streams are reusable; create a new stream.
|
||||||
sub_streams_.emplace_back(std::unique_ptr<Stream>{new Stream{parent_}},
|
sub_streams_.emplace_back(std::unique_ptr<Stream>{new Stream{parent_}},
|
||||||
false);
|
false);
|
||||||
Stream *sub_stream = sub_streams_.back().first.get();
|
Stream *sub_stream = sub_streams_.back().first.get();
|
||||||
sub_stream->Init();
|
sub_stream->Init();
|
||||||
CHECK(ok_) << "sub-stream failed to be initialized";
|
CHECK(ok_) << "sub-stream failed to be initialized";
|
||||||
|
VLOG(1) << DebugStreamPointers() << " created new sub_stream "
|
||||||
|
<< sub_stream->DebugStreamPointers();
|
||||||
|
|
||||||
return sub_stream;
|
return sub_stream;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Stream::ReturnSubStream(Stream *sub_stream) {
|
void Stream::ReturnSubStream(Stream *sub_stream) {
|
||||||
mutex_lock lock(mu_);
|
mutex_lock lock(mu_);
|
||||||
for (auto &stream : sub_streams_) {
|
|
||||||
if (stream.first.get() == sub_stream) {
|
// Look for the sub-stream.
|
||||||
// Streams have a monotonic state machine; if a stream
|
for (int64 index = 0; index < sub_streams_.size(); ++index) {
|
||||||
// encounters an error, it will remain in an error state
|
std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index];
|
||||||
// forever. Only allow re-use of ok streams.
|
if (pair.first.get() != sub_stream) {
|
||||||
//
|
continue;
|
||||||
// TODO(toddw): Improve this mechanism, if necessary, to drop
|
|
||||||
// failed streams completely.
|
|
||||||
const bool ready_to_reuse = sub_stream->ok();
|
|
||||||
stream.second = ready_to_reuse;
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Found the sub_stream.
|
||||||
|
if (sub_stream->ok()) {
|
||||||
|
VLOG(1) << DebugStreamPointers() << " returned ok sub_stream "
|
||||||
|
<< sub_stream->DebugStreamPointers();
|
||||||
|
pair.second = true;
|
||||||
|
} else {
|
||||||
|
// The returned stream is not ok. Streams have a monotonic state
|
||||||
|
// machine; the stream will remain in !ok forever. Swap it with the last
|
||||||
|
// stream and pop it off.
|
||||||
|
VLOG(1) << DebugStreamPointers() << " returned !ok sub_stream "
|
||||||
|
<< sub_stream->DebugStreamPointers();
|
||||||
|
const int64 last = sub_streams_.size() - 1;
|
||||||
|
if (index != last) {
|
||||||
|
std::swap(pair, sub_streams_[last]);
|
||||||
|
}
|
||||||
|
sub_streams_.pop_back();
|
||||||
|
}
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
LOG(FATAL) << "the sub-stream to be returned is not created by this stream";
|
|
||||||
|
LOG(FATAL) << DebugStreamPointers()
|
||||||
|
<< " did not create the returned sub-stream "
|
||||||
|
<< sub_stream->DebugStreamPointers();
|
||||||
}
|
}
|
||||||
|
|
||||||
Stream &Stream::ThenStartTimer(Timer *t) {
|
Stream &Stream::ThenStartTimer(Timer *t) {
|
||||||
@ -1961,7 +2007,8 @@ Stream &Stream::ThenStartTimer(Timer *t) {
|
|||||||
if (ok()) {
|
if (ok()) {
|
||||||
CheckError(parent_->StartTimer(this, t));
|
CheckError(parent_->StartTimer(this, t));
|
||||||
} else {
|
} else {
|
||||||
LOG(INFO) << "stream " << this << " did not enqueue 'start timer': " << t;
|
LOG(INFO) << DebugStreamPointers()
|
||||||
|
<< " did not enqueue 'start timer': " << t;
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
@ -1972,7 +2019,8 @@ Stream &Stream::ThenStopTimer(Timer *t) {
|
|||||||
if (ok()) {
|
if (ok()) {
|
||||||
CheckError(parent_->StopTimer(this, t));
|
CheckError(parent_->StopTimer(this, t));
|
||||||
} else {
|
} else {
|
||||||
LOG(INFO) << "stream " << this << " did not enqueue 'stop timer': " << t;
|
LOG(INFO) << DebugStreamPointers()
|
||||||
|
<< " did not enqueue 'stop timer': " << t;
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
@ -1985,7 +2033,8 @@ Stream &Stream::ThenWaitFor(Stream *other) {
|
|||||||
CheckError(parent_->CreateStreamDependency(this, other));
|
CheckError(parent_->CreateStreamDependency(this, other));
|
||||||
} else {
|
} else {
|
||||||
SetError();
|
SetError();
|
||||||
LOG(INFO) << "stream " << this << " did not wait for stream: " << other;
|
LOG(INFO) << DebugStreamPointers() << " did not wait for "
|
||||||
|
<< other->DebugStreamPointers();
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
@ -2002,7 +2051,7 @@ Stream &Stream::ThenWaitFor(Event *event) {
|
|||||||
<< "at fault. Monitor for further errors.";
|
<< "at fault. Monitor for further errors.";
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
LOG(INFO) << "stream " << this << " did not wait for an event.";
|
LOG(INFO) << DebugStreamPointers() << " did not wait for an event.";
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
@ -4802,10 +4851,10 @@ Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) {
|
|||||||
CheckError(rng->SetSeed(this, seed, seed_bytes));
|
CheckError(rng->SetSeed(this, seed, seed_bytes));
|
||||||
} else {
|
} else {
|
||||||
SetError();
|
SetError();
|
||||||
LOG(INFO) << "stream " << this << " unable to initialize RNG";
|
LOG(INFO) << DebugStreamPointers() << " unable to initialize RNG";
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
LOG(INFO) << "stream " << this
|
LOG(INFO) << DebugStreamPointers()
|
||||||
<< " did not set RNG seed: " << static_cast<const void *>(seed)
|
<< " did not set RNG seed: " << static_cast<const void *>(seed)
|
||||||
<< "; bytes: " << seed_bytes;
|
<< "; bytes: " << seed_bytes;
|
||||||
}
|
}
|
||||||
@ -4820,8 +4869,9 @@ Stream &Stream::ThenPopulateRandUniform(DeviceMemory<float> *values) {
|
|||||||
CheckError(rng->DoPopulateRandUniform(this, values));
|
CheckError(rng->DoPopulateRandUniform(this, values));
|
||||||
} else {
|
} else {
|
||||||
SetError();
|
SetError();
|
||||||
LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
|
LOG(INFO) << DebugStreamPointers()
|
||||||
"without RNG support.";
|
<< " attempting to perform RNG operation using StreamExecutor"
|
||||||
|
" without RNG support.";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
@ -4836,8 +4886,9 @@ Stream &Stream::ThenPopulateRandGaussian(float mean, float sd,
|
|||||||
CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
|
CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
|
||||||
} else {
|
} else {
|
||||||
SetError();
|
SetError();
|
||||||
LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
|
LOG(INFO) << DebugStreamPointers()
|
||||||
"without RNG support.";
|
<< " attempting to perform RNG operation using StreamExecutor"
|
||||||
|
" without RNG support.";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
@ -4852,8 +4903,9 @@ Stream &Stream::ThenPopulateRandGaussian(double mean, double sd,
|
|||||||
CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
|
CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
|
||||||
} else {
|
} else {
|
||||||
SetError();
|
SetError();
|
||||||
LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
|
LOG(INFO) << DebugStreamPointers()
|
||||||
"without RNG support.";
|
<< " attempting to perform RNG operation using StreamExecutor"
|
||||||
|
" without RNG support.";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
@ -4867,8 +4919,9 @@ Stream &Stream::ThenPopulateRandUniform(DeviceMemory<double> *values) {
|
|||||||
CheckError(rng->DoPopulateRandUniform(this, values));
|
CheckError(rng->DoPopulateRandUniform(this, values));
|
||||||
} else {
|
} else {
|
||||||
SetError();
|
SetError();
|
||||||
LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
|
LOG(INFO) << DebugStreamPointers()
|
||||||
"without RNG support.";
|
<< " attempting to perform RNG operation using StreamExecutor"
|
||||||
|
" without RNG support.";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
@ -4883,8 +4936,9 @@ Stream &Stream::ThenPopulateRandUniform(
|
|||||||
CheckError(rng->DoPopulateRandUniform(this, values));
|
CheckError(rng->DoPopulateRandUniform(this, values));
|
||||||
} else {
|
} else {
|
||||||
SetError();
|
SetError();
|
||||||
LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
|
LOG(INFO) << DebugStreamPointers()
|
||||||
"without RNG support.";
|
<< " attempting to perform RNG operation using StreamExecutor"
|
||||||
|
" without RNG support.";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
@ -4899,9 +4953,9 @@ Stream &Stream::ThenPopulateRandUniform(
|
|||||||
CheckError(rng->DoPopulateRandUniform(this, values));
|
CheckError(rng->DoPopulateRandUniform(this, values));
|
||||||
} else {
|
} else {
|
||||||
SetError();
|
SetError();
|
||||||
LOG(INFO) << "stream " << this
|
LOG(INFO) << DebugStreamPointers()
|
||||||
<< " attempting to perform RNG operation using StreamExecutor "
|
<< " attempting to perform RNG operation using StreamExecutor"
|
||||||
"without RNG support.";
|
" without RNG support.";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
@ -4914,7 +4968,7 @@ Stream &Stream::ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
|
|||||||
if (ok()) {
|
if (ok()) {
|
||||||
CheckError(parent_->Memcpy(this, host_dst, gpu_src, size));
|
CheckError(parent_->Memcpy(this, host_dst, gpu_src, size));
|
||||||
} else {
|
} else {
|
||||||
LOG(INFO) << "stream " << this
|
LOG(INFO) << DebugStreamPointers()
|
||||||
<< " did not memcpy device-to-host; source: " << gpu_src.opaque();
|
<< " did not memcpy device-to-host; source: " << gpu_src.opaque();
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
@ -4927,7 +4981,7 @@ Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
|
|||||||
if (ok()) {
|
if (ok()) {
|
||||||
CheckError(parent_->Memcpy(this, gpu_dst, host_src, size));
|
CheckError(parent_->Memcpy(this, gpu_dst, host_src, size));
|
||||||
} else {
|
} else {
|
||||||
LOG(INFO) << "stream " << this
|
LOG(INFO) << DebugStreamPointers()
|
||||||
<< " did not memcpy host-to-device; source: " << host_src;
|
<< " did not memcpy host-to-device; source: " << host_src;
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
@ -4940,7 +4994,7 @@ Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst,
|
|||||||
if (ok()) {
|
if (ok()) {
|
||||||
CheckError(parent_->MemcpyDeviceToDevice(this, gpu_dst, gpu_src, size));
|
CheckError(parent_->MemcpyDeviceToDevice(this, gpu_dst, gpu_src, size));
|
||||||
} else {
|
} else {
|
||||||
LOG(INFO) << "stream " << this
|
LOG(INFO) << DebugStreamPointers()
|
||||||
<< " did not memcpy gpu-to-gpu; source: " << &gpu_src;
|
<< " did not memcpy gpu-to-gpu; source: " << &gpu_src;
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
@ -4952,7 +5006,7 @@ Stream &Stream::ThenMemZero(DeviceMemoryBase *location, uint64 size) {
|
|||||||
if (ok()) {
|
if (ok()) {
|
||||||
CheckError(parent_->MemZero(this, location, size));
|
CheckError(parent_->MemZero(this, location, size));
|
||||||
} else {
|
} else {
|
||||||
LOG(INFO) << "stream " << this
|
LOG(INFO) << DebugStreamPointers()
|
||||||
<< " did not memzero GPU location; source: " << location;
|
<< " did not memzero GPU location; source: " << location;
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
@ -4965,7 +5019,7 @@ Stream &Stream::ThenMemset32(DeviceMemoryBase *location, uint32 pattern,
|
|||||||
if (ok()) {
|
if (ok()) {
|
||||||
CheckError(parent_->Memset32(this, location, pattern, size));
|
CheckError(parent_->Memset32(this, location, pattern, size));
|
||||||
} else {
|
} else {
|
||||||
LOG(INFO) << "stream " << this
|
LOG(INFO) << DebugStreamPointers()
|
||||||
<< " did not memset GPU location; source: " << location
|
<< " did not memset GPU location; source: " << location
|
||||||
<< "; size: " << size << "; pattern: " << std::hex << pattern;
|
<< "; size: " << size << "; pattern: " << std::hex << pattern;
|
||||||
}
|
}
|
||||||
@ -5234,7 +5288,7 @@ Stream &Stream::ThenDoHostCallback(std::function<void()> callback) {
|
|||||||
if (ok()) {
|
if (ok()) {
|
||||||
CheckError(parent_->HostCallback(this, callback));
|
CheckError(parent_->HostCallback(this, callback));
|
||||||
} else {
|
} else {
|
||||||
LOG(INFO) << "stream " << this
|
LOG(INFO) << DebugStreamPointers()
|
||||||
<< " was in error state before adding host callback";
|
<< " was in error state before adding host callback";
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
@ -5250,8 +5304,9 @@ Stream &Stream::ThenFft(fft::Plan *plan,
|
|||||||
CheckError(fft->DoFft(this, plan, input, output));
|
CheckError(fft->DoFft(this, plan, input, output));
|
||||||
} else {
|
} else {
|
||||||
SetError();
|
SetError();
|
||||||
LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
|
LOG(INFO) << DebugStreamPointers()
|
||||||
"without FFT support";
|
<< " attempting to perform FFT operation using StreamExecutor"
|
||||||
|
" without FFT support";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
@ -5267,8 +5322,9 @@ Stream &Stream::ThenFft(fft::Plan *plan,
|
|||||||
CheckError(fft->DoFft(this, plan, input, output));
|
CheckError(fft->DoFft(this, plan, input, output));
|
||||||
} else {
|
} else {
|
||||||
SetError();
|
SetError();
|
||||||
LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
|
LOG(INFO) << DebugStreamPointers()
|
||||||
"without FFT support";
|
<< " attempting to perform FFT operation using StreamExecutor"
|
||||||
|
" without FFT support";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
@ -5283,8 +5339,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<float> &input,
|
|||||||
CheckError(fft->DoFft(this, plan, input, output));
|
CheckError(fft->DoFft(this, plan, input, output));
|
||||||
} else {
|
} else {
|
||||||
SetError();
|
SetError();
|
||||||
LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
|
LOG(INFO) << DebugStreamPointers()
|
||||||
"without FFT support";
|
<< " attempting to perform FFT operation using StreamExecutor"
|
||||||
|
" without FFT support";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
@ -5299,8 +5356,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<double> &input,
|
|||||||
CheckError(fft->DoFft(this, plan, input, output));
|
CheckError(fft->DoFft(this, plan, input, output));
|
||||||
} else {
|
} else {
|
||||||
SetError();
|
SetError();
|
||||||
LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
|
LOG(INFO) << DebugStreamPointers()
|
||||||
"without FFT support";
|
<< " attempting to perform FFT operation using StreamExecutor"
|
||||||
|
" without FFT support";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
@ -5316,8 +5374,9 @@ Stream &Stream::ThenFft(fft::Plan *plan,
|
|||||||
CheckError(fft->DoFft(this, plan, input, output));
|
CheckError(fft->DoFft(this, plan, input, output));
|
||||||
} else {
|
} else {
|
||||||
SetError();
|
SetError();
|
||||||
LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
|
LOG(INFO) << DebugStreamPointers()
|
||||||
"without FFT support";
|
<< " attempting to perform FFT operation using StreamExecutor"
|
||||||
|
" without FFT support";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
@ -5333,8 +5392,9 @@ Stream &Stream::ThenFft(fft::Plan *plan,
|
|||||||
CheckError(fft->DoFft(this, plan, input, output));
|
CheckError(fft->DoFft(this, plan, input, output));
|
||||||
} else {
|
} else {
|
||||||
SetError();
|
SetError();
|
||||||
LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
|
LOG(INFO) << DebugStreamPointers()
|
||||||
"without FFT support";
|
<< " attempting to perform FFT operation using StreamExecutor"
|
||||||
|
" without FFT support";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
@ -5361,7 +5421,7 @@ port::Status Stream::BlockHostUntilDone() {
|
|||||||
port::Status status = port::Status(
|
port::Status status = port::Status(
|
||||||
port::error::INTERNAL,
|
port::error::INTERNAL,
|
||||||
"stream did not block host until done; was already in an error state");
|
"stream did not block host until done; was already in an error state");
|
||||||
LOG(INFO) << status << " " << this;
|
LOG(INFO) << DebugStreamPointers() << " " << status;
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -5372,4 +5432,10 @@ port::Status Stream::BlockHostUntilDone() {
|
|||||||
return error;
|
return error;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
string Stream::DebugStreamPointers() const {
|
||||||
|
// Relies on the ToVlogString(const void*) overload above.
|
||||||
|
return port::StrCat("[stream=", ToVlogString(this),
|
||||||
|
",impl=", ToVlogString(implementation_.get()), "]");
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace stream_executor
|
} // namespace stream_executor
|
||||||
|
@ -122,10 +122,14 @@ class Stream {
|
|||||||
// Get or create a sub-stream from this stream. If there is any sub-stream in
|
// Get or create a sub-stream from this stream. If there is any sub-stream in
|
||||||
// the pool that can be reused then just return this sub-stream. Otherwise
|
// the pool that can be reused then just return this sub-stream. Otherwise
|
||||||
// create a new sub-stream.
|
// create a new sub-stream.
|
||||||
|
//
|
||||||
|
// TODO(b/112196569): The semantics of failed sub-streams is error-prone.
|
||||||
Stream *GetOrCreateSubStream() LOCKS_EXCLUDED(mu_);
|
Stream *GetOrCreateSubStream() LOCKS_EXCLUDED(mu_);
|
||||||
|
|
||||||
// Return the sub-stream back to the host stream so that it can be reused
|
// Return the sub-stream back to the host stream so that it can be reused
|
||||||
// later. Sub-streams that are !ok() will not be reused.
|
// later. Sub-streams that are !ok() will not be reused.
|
||||||
|
//
|
||||||
|
// TODO(b/112196569): The semantics of failed sub-streams is error-prone.
|
||||||
void ReturnSubStream(Stream *sub_stream) LOCKS_EXCLUDED(mu_);
|
void ReturnSubStream(Stream *sub_stream) LOCKS_EXCLUDED(mu_);
|
||||||
|
|
||||||
// Allocate temporary memories. The stream will deallocate them when blocked
|
// Allocate temporary memories. The stream will deallocate them when blocked
|
||||||
@ -2051,6 +2055,9 @@ class Stream {
|
|||||||
// with this stream.
|
// with this stream.
|
||||||
internal::TemporaryMemoryManager *temporary_memory_manager();
|
internal::TemporaryMemoryManager *temporary_memory_manager();
|
||||||
|
|
||||||
|
// Returns a debugging string "[stream=0x...,impl=0x...]".
|
||||||
|
string DebugStreamPointers() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend class host::HostBlas; // for parent_.
|
friend class host::HostBlas; // for parent_.
|
||||||
friend class host::HostFft; // for parent_.
|
friend class host::HostFft; // for parent_.
|
||||||
|
@ -95,18 +95,18 @@ TEST_F(StreamTest, TwoSubStreams) {
|
|||||||
EXPECT_NE(sub_stream3, sub_stream4);
|
EXPECT_NE(sub_stream3, sub_stream4);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(StreamTest, FailedSubStreamNotReused) {
|
TEST_F(StreamTest, FailedSubStreamBeforeReturnNotReused) {
|
||||||
std::unique_ptr<StreamExecutor> executor = NewStreamExecutor();
|
std::unique_ptr<StreamExecutor> executor = NewStreamExecutor();
|
||||||
Stream stream(executor.get());
|
Stream stream(executor.get());
|
||||||
stream.Init();
|
stream.Init();
|
||||||
EXPECT_TRUE(stream.ok());
|
EXPECT_TRUE(stream.ok());
|
||||||
|
|
||||||
// Get a sub-stream.
|
// Get sub_stream1.
|
||||||
Stream* sub_stream1 = stream.GetOrCreateSubStream();
|
Stream* sub_stream1 = stream.GetOrCreateSubStream();
|
||||||
EXPECT_TRUE(sub_stream1->ok());
|
EXPECT_TRUE(sub_stream1->ok());
|
||||||
|
|
||||||
// Force an error on the stream; here we call a method that requires
|
// Force an error on sub_stream1; here we call a method that requires DNN
|
||||||
// DNN support, which we know the Host platform doesn't support.
|
// support, which we know the Host platform doesn't support.
|
||||||
sub_stream1->ThenDepthConcatenate({}, {}, nullptr);
|
sub_stream1->ThenDepthConcatenate({}, {}, nullptr);
|
||||||
EXPECT_FALSE(sub_stream1->ok());
|
EXPECT_FALSE(sub_stream1->ok());
|
||||||
|
|
||||||
@ -115,20 +115,84 @@ TEST_F(StreamTest, FailedSubStreamNotReused) {
|
|||||||
Stream* sub_stream2 = stream.GetOrCreateSubStream();
|
Stream* sub_stream2 = stream.GetOrCreateSubStream();
|
||||||
EXPECT_TRUE(sub_stream2->ok());
|
EXPECT_TRUE(sub_stream2->ok());
|
||||||
|
|
||||||
// The underlying streams should be different. They would have been
|
// The underlying sub_streams should be different. They would have been the
|
||||||
// the same, but since we forced an error on sub_stream1, it will
|
// same, but since we forced an error on sub_stream1, it will not be
|
||||||
// not be re-used. Sadly we can't just check:
|
// re-used. Sadly we can't just check:
|
||||||
// EXPECT_NE(sub_stream1, sub_stream2);
|
// EXPECT_NE(sub_stream1, sub_stream2);
|
||||||
//
|
//
|
||||||
// The above should hold logically, but it may fail if the new
|
// The above should hold logically, but it may fail if the new Stream instance
|
||||||
// stream instance allocated for sub_stream2 happens to reside in
|
// allocated for sub_stream2 happens to reside in the same memory address as
|
||||||
// the same memory address as sub_stream1.
|
// sub_stream1.
|
||||||
//
|
//
|
||||||
// The check that sub_stream2->ok() serves as a good-enough check.
|
// The check that sub_stream2->ok() serves as a good-enough check.
|
||||||
|
|
||||||
// Return sub_stream2 and get sub_stream3. The previous error on
|
// Return sub_stream2 and get sub_stream3. The previous error on sub_stream1
|
||||||
// sub_stream1 has no effect on these streams, and they are the
|
// has no effect on these streams, and they are the same.
|
||||||
// same.
|
stream.ReturnSubStream(sub_stream2);
|
||||||
|
Stream* sub_stream3 = stream.GetOrCreateSubStream();
|
||||||
|
EXPECT_TRUE(sub_stream3->ok());
|
||||||
|
EXPECT_EQ(sub_stream2, sub_stream3);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(StreamTest, FailedSubStreamAfterReturnNotReused) {
|
||||||
|
std::unique_ptr<StreamExecutor> executor = NewStreamExecutor();
|
||||||
|
Stream stream(executor.get());
|
||||||
|
stream.Init();
|
||||||
|
EXPECT_TRUE(stream.ok());
|
||||||
|
|
||||||
|
// Get and return sub_stream1.
|
||||||
|
Stream* sub_stream1 = stream.GetOrCreateSubStream();
|
||||||
|
EXPECT_TRUE(sub_stream1->ok());
|
||||||
|
stream.ReturnSubStream(sub_stream1);
|
||||||
|
|
||||||
|
// Force an error on sub_stream1; here we call a method that requires DNN
|
||||||
|
// support, which we know the Host platform doesn't support.
|
||||||
|
//
|
||||||
|
// It is a bit weird to use sub_stream1 after it has already been returned. By
|
||||||
|
// doing this, we're simulating an asynchronous error that occurs during
|
||||||
|
// execution of the sub_stream, that occurs after the sub_stream is returned.
|
||||||
|
//
|
||||||
|
// E.g. the following is a common pattern of usage, where the execution of the
|
||||||
|
// operations enqueued onto the sub streams may occur after the streams have
|
||||||
|
// already been returned.
|
||||||
|
//
|
||||||
|
// void EnqueueOnSubStreams(Stream* stream) {
|
||||||
|
// Stream* sub_stream1 = stream.GetOrCreateSubStream();
|
||||||
|
// Stream* sub_stream2 = stream.GetOrCreateSubStream();
|
||||||
|
// // ... enqueue some operations on the sub streams ...
|
||||||
|
// stream.ThenWaitFor(sub_stream1).ThenWaitFor(sub_stream2);
|
||||||
|
// stream.ReturnSubStream(sub_stream1);
|
||||||
|
// stream.ReturnSubStream(sub_stream2);
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// Stream* main_stream = ...;
|
||||||
|
// EnqueueOnSubStreams(main_stream);
|
||||||
|
// main_stream.BlockHostUntilDone();
|
||||||
|
//
|
||||||
|
// TODO(b/112196569): The semantics of failed sub-streams is error-prone;
|
||||||
|
// GetOrCreateSubStream can still return a sub-stream that has not encountered
|
||||||
|
// an error yet, but will encounter one in the future, based on previously
|
||||||
|
// enqueued operations.
|
||||||
|
sub_stream1->ThenDepthConcatenate({}, {}, nullptr);
|
||||||
|
EXPECT_FALSE(sub_stream1->ok());
|
||||||
|
|
||||||
|
// Get and return sub_stream2.
|
||||||
|
Stream* sub_stream2 = stream.GetOrCreateSubStream();
|
||||||
|
EXPECT_TRUE(sub_stream2->ok());
|
||||||
|
|
||||||
|
// The underlying streams should be different. They would have been the same,
|
||||||
|
// but since we forced an error on sub_stream1, it will not be re-used. Sadly
|
||||||
|
// we can't just check:
|
||||||
|
// EXPECT_NE(sub_stream1, sub_stream2);
|
||||||
|
//
|
||||||
|
// The above should hold logically, but it may fail if the new stream instance
|
||||||
|
// allocated for sub_stream2 happens to reside in the same memory address as
|
||||||
|
// sub_stream1.
|
||||||
|
//
|
||||||
|
// The check that sub_stream2->ok() serves as a good-enough check.
|
||||||
|
|
||||||
|
// Return sub_stream2 and get sub_stream3. The previous error on sub_stream1
|
||||||
|
// has no effect on these streams, and they are the same.
|
||||||
stream.ReturnSubStream(sub_stream2);
|
stream.ReturnSubStream(sub_stream2);
|
||||||
Stream* sub_stream3 = stream.GetOrCreateSubStream();
|
Stream* sub_stream3 = stream.GetOrCreateSubStream();
|
||||||
EXPECT_TRUE(sub_stream3->ok());
|
EXPECT_TRUE(sub_stream3->ok());
|
||||||
|
Loading…
Reference in New Issue
Block a user