Implement GetStream for all DeviceMemoryAllocator subclasses
Previous version had a bug where uses of MultiDeviceAdapter would always get a `nullptr` for a requested stream, potentially resulting in race condition. This version is more safe, as `GetStream` always has to be implemented and can not return `nullptr`. PiperOrigin-RevId: 274019013
This commit is contained in:
parent
349e97ed6f
commit
f616f0662c
tensorflow
@ -242,15 +242,8 @@ static StatusOr<bool> RunOnInstruction(HloInstruction* instr,
|
||||
if (allocator == nullptr) {
|
||||
allocator = executor->GetAllocator();
|
||||
}
|
||||
absl::optional<se::Stream> stream_opt;
|
||||
se::Stream* stream = [&]() {
|
||||
if (allocator->GetStream()) {
|
||||
return allocator->GetStream();
|
||||
}
|
||||
stream_opt.emplace(executor);
|
||||
stream_opt->Init();
|
||||
return &stream_opt.value();
|
||||
}();
|
||||
TF_ASSIGN_OR_RETURN(se::Stream* const stream,
|
||||
allocator->GetStream(executor->device_ordinal()));
|
||||
|
||||
const HloModuleConfig& hlo_module_config = instr->GetModule()->config();
|
||||
se::RedzoneAllocator input_output_allocator(
|
||||
|
@ -290,16 +290,8 @@ StatusOr<AutotuneResult> GpuConvAlgorithmPicker::PickBestAlgorithm(
|
||||
allocator = &*se_allocator;
|
||||
}
|
||||
|
||||
absl::optional<se::Stream> stream_opt;
|
||||
se::Stream* stream = [&] {
|
||||
if (allocator->GetStream()) {
|
||||
return allocator->GetStream();
|
||||
}
|
||||
stream_opt.emplace(stream_exec_);
|
||||
stream_opt->Init();
|
||||
return &stream_opt.value();
|
||||
}();
|
||||
|
||||
TF_ASSIGN_OR_RETURN(se::Stream* const stream,
|
||||
allocator->GetStream(stream_exec_->device_ordinal()));
|
||||
StatusOr<AutotuneResult> result_or(InternalError("Unknown platform."));
|
||||
// Check StreamExecutor on which platform it is. ROCm and Cuda implementation
|
||||
// have diverged. Secifically, we need to make sure redzone allocator related
|
||||
|
@ -88,6 +88,10 @@ class TestAllocator : public se::DeviceMemoryAllocator {
|
||||
|
||||
bool AllowsAsynchronousDeallocation() const override { return false; }
|
||||
|
||||
StatusOr<se::Stream*> GetStream(int device_ordinal) override {
|
||||
LOG(FATAL) << "Not implemented";
|
||||
}
|
||||
|
||||
private:
|
||||
std::set<std::pair</*device_ordinal*/ int64, void*>> allocations_;
|
||||
};
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
@ -211,15 +212,11 @@ class DeviceMemoryAllocator {
|
||||
// a stream, or do we have to wait for the computation to complete first?
|
||||
virtual bool AllowsAsynchronousDeallocation() const { return false; }
|
||||
|
||||
// Returns nullable stream pointer.
|
||||
//
|
||||
// If the pointer is non-null, then it is always safe to access the memory
|
||||
// allocated by the allocator on the returned stream. This condition is not
|
||||
// required though, as streams could be synchronized by other means.
|
||||
//
|
||||
// TODO(cheshire): clean up the interface, it might be cleaner to explicitly
|
||||
// pass the stream to Compiler.
|
||||
virtual Stream *GetStream() const { return nullptr; }
|
||||
// Returns a stream pointer on which it is always safe to access memory
|
||||
// allocated by this allocator. It is not necessary to use the returned stream
|
||||
// though, as clients may have additional information letting them safely use
|
||||
// a different stream.
|
||||
virtual port::StatusOr<Stream *> GetStream(int device_ordinal) = 0;
|
||||
|
||||
protected:
|
||||
const Platform* platform_;
|
||||
@ -251,12 +248,21 @@ class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator {
|
||||
|
||||
bool AllowsAsynchronousDeallocation() const override;
|
||||
|
||||
// Gets-or-creates a stream for a given `device_ordinal` from an appropriate
|
||||
// stream executor.
|
||||
port::StatusOr<Stream *> GetStream(int device_ordinal) override;
|
||||
|
||||
private:
|
||||
port::StatusOr<StreamExecutor*> GetStreamExecutor(int device_ordinal);
|
||||
port::StatusOr<StreamExecutor *> GetStreamExecutor(int device_ordinal) const;
|
||||
|
||||
// Available stream executors. Each stream executor has a different device
|
||||
// ordinal.
|
||||
std::vector<StreamExecutor *> stream_executors_;
|
||||
|
||||
absl::Mutex mutex_;
|
||||
|
||||
// Cache of streams for GetStream.
|
||||
std::map<int, Stream> streams_ GUARDED_BY(mutex_);
|
||||
};
|
||||
|
||||
template <typename ElemT>
|
||||
|
@ -899,7 +899,7 @@ port::Status StreamExecutorMemoryAllocator::Deallocate(int device_ordinal,
|
||||
}
|
||||
|
||||
port::StatusOr<StreamExecutor *>
|
||||
StreamExecutorMemoryAllocator::GetStreamExecutor(int device_ordinal) {
|
||||
StreamExecutorMemoryAllocator::GetStreamExecutor(int device_ordinal) const {
|
||||
if (device_ordinal < 0) {
|
||||
return tensorflow::errors::InvalidArgument(absl::StrFormat(
|
||||
"device ordinal value (%d) must be non-negative", device_ordinal));
|
||||
@ -918,4 +918,24 @@ bool StreamExecutorMemoryAllocator::AllowsAsynchronousDeallocation() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
port::StatusOr<Stream *> StreamExecutorMemoryAllocator::GetStream(
|
||||
int device_ordinal) {
|
||||
CHECK(!AllowsAsynchronousDeallocation())
|
||||
<< "The logic below only works for synchronous allocators";
|
||||
TF_ASSIGN_OR_RETURN(StreamExecutor * executor,
|
||||
GetStreamExecutor(device_ordinal));
|
||||
Stream *out = [&] {
|
||||
absl::MutexLock lock(&mutex_);
|
||||
if (!streams_.count(device_ordinal)) {
|
||||
auto p = streams_.emplace(std::piecewise_construct,
|
||||
std::forward_as_tuple(device_ordinal),
|
||||
std::forward_as_tuple(executor));
|
||||
p.first->second.Init();
|
||||
return &p.first->second;
|
||||
}
|
||||
return &streams_.at(device_ordinal);
|
||||
}();
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace stream_executor
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/stream_executor/tf_allocator_adapter.h"
|
||||
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/stream_executor/lib/error.h"
|
||||
#include "tensorflow/stream_executor/stream.h"
|
||||
@ -58,4 +59,9 @@ port::Status TfAllocatorAdapter::Deallocate(int device_ordinal,
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
port::StatusOr<Stream *> TfAllocatorAdapter::GetStream(int device_ordinal) {
|
||||
CHECK_EQ(stream_->parent()->device_ordinal(), device_ordinal);
|
||||
return stream_;
|
||||
}
|
||||
|
||||
} // namespace stream_executor
|
||||
|
@ -54,7 +54,7 @@ class TfAllocatorAdapter : public DeviceMemoryAllocator {
|
||||
// (This attribute has no effect on CPU.)
|
||||
bool AllowsAsynchronousDeallocation() const override { return true; }
|
||||
|
||||
Stream *GetStream() const override { return stream_; }
|
||||
port::StatusOr<Stream *> GetStream(int device_ordinal) override;
|
||||
|
||||
private:
|
||||
tensorflow::Allocator *wrapped_;
|
||||
@ -101,6 +101,10 @@ class MultiDeviceAdapter : public DeviceMemoryAllocator {
|
||||
// (This attribute has no effect on CPU.)
|
||||
bool AllowsAsynchronousDeallocation() const override { return true; }
|
||||
|
||||
port::StatusOr<Stream *> GetStream(int device_ordinal) override {
|
||||
return per_device_allocators_[device_ordinal].GetStream(device_ordinal);
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<TfAllocatorAdapter> per_device_allocators_;
|
||||
// The wrapped TF allocators backing per_device_allocators_
|
||||
|
Loading…
Reference in New Issue
Block a user