Implement GetStream for all DeviceMemoryAllocator subclasses

Previous version had a bug where uses of MultiDeviceAdapter would always get a
`nullptr` for a requested stream, potentially resulting in race condition.

This version is more safe, as `GetStream` always has to be implemented and can
not return `nullptr`.

PiperOrigin-RevId: 274019013
This commit is contained in:
George Karpenkov 2019-10-10 12:51:22 -07:00 committed by TensorFlower Gardener
parent 349e97ed6f
commit f616f0662c
7 changed files with 56 additions and 31 deletions

View File

@ -242,15 +242,8 @@ static StatusOr<bool> RunOnInstruction(HloInstruction* instr,
if (allocator == nullptr) {
allocator = executor->GetAllocator();
}
absl::optional<se::Stream> stream_opt;
se::Stream* stream = [&]() {
if (allocator->GetStream()) {
return allocator->GetStream();
}
stream_opt.emplace(executor);
stream_opt->Init();
return &stream_opt.value();
}();
TF_ASSIGN_OR_RETURN(se::Stream* const stream,
allocator->GetStream(executor->device_ordinal()));
const HloModuleConfig& hlo_module_config = instr->GetModule()->config();
se::RedzoneAllocator input_output_allocator(

View File

@ -290,16 +290,8 @@ StatusOr<AutotuneResult> GpuConvAlgorithmPicker::PickBestAlgorithm(
allocator = &*se_allocator;
}
absl::optional<se::Stream> stream_opt;
se::Stream* stream = [&] {
if (allocator->GetStream()) {
return allocator->GetStream();
}
stream_opt.emplace(stream_exec_);
stream_opt->Init();
return &stream_opt.value();
}();
TF_ASSIGN_OR_RETURN(se::Stream* const stream,
allocator->GetStream(stream_exec_->device_ordinal()));
StatusOr<AutotuneResult> result_or(InternalError("Unknown platform."));
// Check StreamExecutor on which platform it is. ROCm and Cuda implementation
// have diverged. Secifically, we need to make sure redzone allocator related

View File

@ -88,6 +88,10 @@ class TestAllocator : public se::DeviceMemoryAllocator {
bool AllowsAsynchronousDeallocation() const override { return false; }
StatusOr<se::Stream*> GetStream(int device_ordinal) override {
LOG(FATAL) << "Not implemented";
}
private:
std::set<std::pair</*device_ordinal*/ int64, void*>> allocations_;
};

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <vector>
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
@ -211,15 +212,11 @@ class DeviceMemoryAllocator {
// a stream, or do we have to wait for the computation to complete first?
virtual bool AllowsAsynchronousDeallocation() const { return false; }
// Returns nullable stream pointer.
//
// If the pointer is non-null, then it is always safe to access the memory
// allocated by the allocator on the returned stream. This condition is not
// required though, as streams could be synchronized by other means.
//
// TODO(cheshire): clean up the interface, it might be cleaner to explicitly
// pass the stream to Compiler.
virtual Stream *GetStream() const { return nullptr; }
// Returns a stream pointer on which it is always safe to access memory
// allocated by this allocator. It is not necessary to use the returned stream
// though, as clients may have additional information letting them safely use
// a different stream.
virtual port::StatusOr<Stream *> GetStream(int device_ordinal) = 0;
protected:
const Platform* platform_;
@ -251,12 +248,21 @@ class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator {
bool AllowsAsynchronousDeallocation() const override;
// Gets-or-creates a stream for a given `device_ordinal` from an appropriate
// stream executor.
port::StatusOr<Stream *> GetStream(int device_ordinal) override;
private:
port::StatusOr<StreamExecutor*> GetStreamExecutor(int device_ordinal);
port::StatusOr<StreamExecutor *> GetStreamExecutor(int device_ordinal) const;
// Available stream executors. Each stream executor has a different device
// ordinal.
std::vector<StreamExecutor *> stream_executors_;
absl::Mutex mutex_;
// Cache of streams for GetStream.
std::map<int, Stream> streams_ GUARDED_BY(mutex_);
};
template <typename ElemT>

View File

@ -899,7 +899,7 @@ port::Status StreamExecutorMemoryAllocator::Deallocate(int device_ordinal,
}
port::StatusOr<StreamExecutor *>
StreamExecutorMemoryAllocator::GetStreamExecutor(int device_ordinal) {
StreamExecutorMemoryAllocator::GetStreamExecutor(int device_ordinal) const {
if (device_ordinal < 0) {
return tensorflow::errors::InvalidArgument(absl::StrFormat(
"device ordinal value (%d) must be non-negative", device_ordinal));
@ -918,4 +918,24 @@ bool StreamExecutorMemoryAllocator::AllowsAsynchronousDeallocation() const {
return false;
}
port::StatusOr<Stream *> StreamExecutorMemoryAllocator::GetStream(
int device_ordinal) {
CHECK(!AllowsAsynchronousDeallocation())
<< "The logic below only works for synchronous allocators";
TF_ASSIGN_OR_RETURN(StreamExecutor * executor,
GetStreamExecutor(device_ordinal));
Stream *out = [&] {
absl::MutexLock lock(&mutex_);
if (!streams_.count(device_ordinal)) {
auto p = streams_.emplace(std::piecewise_construct,
std::forward_as_tuple(device_ordinal),
std::forward_as_tuple(executor));
p.first->second.Init();
return &p.first->second;
}
return &streams_.at(device_ordinal);
}();
return out;
}
} // namespace stream_executor

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/stream_executor/tf_allocator_adapter.h"
#include "absl/synchronization/mutex.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/stream_executor/lib/error.h"
#include "tensorflow/stream_executor/stream.h"
@ -58,4 +59,9 @@ port::Status TfAllocatorAdapter::Deallocate(int device_ordinal,
return port::Status::OK();
}
port::StatusOr<Stream *> TfAllocatorAdapter::GetStream(int device_ordinal) {
CHECK_EQ(stream_->parent()->device_ordinal(), device_ordinal);
return stream_;
}
} // namespace stream_executor

View File

@ -54,7 +54,7 @@ class TfAllocatorAdapter : public DeviceMemoryAllocator {
// (This attribute has no effect on CPU.)
bool AllowsAsynchronousDeallocation() const override { return true; }
Stream *GetStream() const override { return stream_; }
port::StatusOr<Stream *> GetStream(int device_ordinal) override;
private:
tensorflow::Allocator *wrapped_;
@ -101,6 +101,10 @@ class MultiDeviceAdapter : public DeviceMemoryAllocator {
// (This attribute has no effect on CPU.)
bool AllowsAsynchronousDeallocation() const override { return true; }
port::StatusOr<Stream *> GetStream(int device_ordinal) override {
return per_device_allocators_[device_ordinal].GetStream(device_ordinal);
}
private:
std::vector<TfAllocatorAdapter> per_device_allocators_;
// The wrapped TF allocators backing per_device_allocators_