diff --git a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc index 21a0ffab5e0..8dcbfbbc96a 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc @@ -242,15 +242,8 @@ static StatusOr RunOnInstruction(HloInstruction* instr, if (allocator == nullptr) { allocator = executor->GetAllocator(); } - absl::optional stream_opt; - se::Stream* stream = [&]() { - if (allocator->GetStream()) { - return allocator->GetStream(); - } - stream_opt.emplace(executor); - stream_opt->Init(); - return &stream_opt.value(); - }(); + TF_ASSIGN_OR_RETURN(se::Stream* const stream, + allocator->GetStream(executor->device_ordinal())); const HloModuleConfig& hlo_module_config = instr->GetModule()->config(); se::RedzoneAllocator input_output_allocator( diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc index 676982bfdbc..eab1003af01 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc @@ -290,16 +290,8 @@ StatusOr GpuConvAlgorithmPicker::PickBestAlgorithm( allocator = &*se_allocator; } - absl::optional stream_opt; - se::Stream* stream = [&] { - if (allocator->GetStream()) { - return allocator->GetStream(); - } - stream_opt.emplace(stream_exec_); - stream_opt->Init(); - return &stream_opt.value(); - }(); - + TF_ASSIGN_OR_RETURN(se::Stream* const stream, + allocator->GetStream(stream_exec_->device_ordinal())); StatusOr result_or(InternalError("Unknown platform.")); // Check StreamExecutor on which platform it is. ROCm and Cuda implementation // have diverged. Secifically, we need to make sure redzone allocator related diff --git a/tensorflow/compiler/xla/service/shaped_buffer_test.cc b/tensorflow/compiler/xla/service/shaped_buffer_test.cc index 301b1112598..07ef7fa1c31 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer_test.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer_test.cc @@ -88,6 +88,10 @@ class TestAllocator : public se::DeviceMemoryAllocator { bool AllowsAsynchronousDeallocation() const override { return false; } + StatusOr GetStream(int device_ordinal) override { + LOG(FATAL) << "Not implemented"; + } + private: std::set> allocations_; }; diff --git a/tensorflow/stream_executor/device_memory_allocator.h b/tensorflow/stream_executor/device_memory_allocator.h index 68728ef3543..fa5c4d3abfc 100644 --- a/tensorflow/stream_executor/device_memory_allocator.h +++ b/tensorflow/stream_executor/device_memory_allocator.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" @@ -211,15 +212,11 @@ class DeviceMemoryAllocator { // a stream, or do we have to wait for the computation to complete first? virtual bool AllowsAsynchronousDeallocation() const { return false; } - // Returns nullable stream pointer. - // - // If the pointer is non-null, then it is always safe to access the memory - // allocated by the allocator on the returned stream. This condition is not - // required though, as streams could be synchronized by other means. - // - // TODO(cheshire): clean up the interface, it might be cleaner to explicitly - // pass the stream to Compiler. - virtual Stream *GetStream() const { return nullptr; } + // Returns a stream pointer on which it is always safe to access memory + // allocated by this allocator. It is not necessary to use the returned stream + // though, as clients may have additional information letting them safely use + // a different stream. + virtual port::StatusOr GetStream(int device_ordinal) = 0; protected: const Platform* platform_; @@ -251,12 +248,21 @@ class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator { bool AllowsAsynchronousDeallocation() const override; + // Gets-or-creates a stream for a given `device_ordinal` from an appropriate + // stream executor. + port::StatusOr GetStream(int device_ordinal) override; + private: - port::StatusOr GetStreamExecutor(int device_ordinal); + port::StatusOr GetStreamExecutor(int device_ordinal) const; // Available stream executors. Each stream executor has a different device // ordinal. std::vector stream_executors_; + + absl::Mutex mutex_; + + // Cache of streams for GetStream. + std::map streams_ GUARDED_BY(mutex_); }; template diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc index dae2403915a..ded59d290c6 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.cc +++ b/tensorflow/stream_executor/stream_executor_pimpl.cc @@ -899,7 +899,7 @@ port::Status StreamExecutorMemoryAllocator::Deallocate(int device_ordinal, } port::StatusOr -StreamExecutorMemoryAllocator::GetStreamExecutor(int device_ordinal) { +StreamExecutorMemoryAllocator::GetStreamExecutor(int device_ordinal) const { if (device_ordinal < 0) { return tensorflow::errors::InvalidArgument(absl::StrFormat( "device ordinal value (%d) must be non-negative", device_ordinal)); @@ -918,4 +918,24 @@ bool StreamExecutorMemoryAllocator::AllowsAsynchronousDeallocation() const { return false; } +port::StatusOr StreamExecutorMemoryAllocator::GetStream( + int device_ordinal) { + CHECK(!AllowsAsynchronousDeallocation()) + << "The logic below only works for synchronous allocators"; + TF_ASSIGN_OR_RETURN(StreamExecutor * executor, + GetStreamExecutor(device_ordinal)); + Stream *out = [&] { + absl::MutexLock lock(&mutex_); + if (!streams_.count(device_ordinal)) { + auto p = streams_.emplace(std::piecewise_construct, + std::forward_as_tuple(device_ordinal), + std::forward_as_tuple(executor)); + p.first->second.Init(); + return &p.first->second; + } + return &streams_.at(device_ordinal); + }(); + return out; +} + } // namespace stream_executor diff --git a/tensorflow/stream_executor/tf_allocator_adapter.cc b/tensorflow/stream_executor/tf_allocator_adapter.cc index b62f2218d44..0b2d66f7e29 100644 --- a/tensorflow/stream_executor/tf_allocator_adapter.cc +++ b/tensorflow/stream_executor/tf_allocator_adapter.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/stream_executor/tf_allocator_adapter.h" +#include "absl/synchronization/mutex.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/stream_executor/lib/error.h" #include "tensorflow/stream_executor/stream.h" @@ -58,4 +59,9 @@ port::Status TfAllocatorAdapter::Deallocate(int device_ordinal, return port::Status::OK(); } +port::StatusOr TfAllocatorAdapter::GetStream(int device_ordinal) { + CHECK_EQ(stream_->parent()->device_ordinal(), device_ordinal); + return stream_; +} + } // namespace stream_executor diff --git a/tensorflow/stream_executor/tf_allocator_adapter.h b/tensorflow/stream_executor/tf_allocator_adapter.h index d0db676065d..9b66f57d087 100644 --- a/tensorflow/stream_executor/tf_allocator_adapter.h +++ b/tensorflow/stream_executor/tf_allocator_adapter.h @@ -54,7 +54,7 @@ class TfAllocatorAdapter : public DeviceMemoryAllocator { // (This attribute has no effect on CPU.) bool AllowsAsynchronousDeallocation() const override { return true; } - Stream *GetStream() const override { return stream_; } + port::StatusOr GetStream(int device_ordinal) override; private: tensorflow::Allocator *wrapped_; @@ -101,6 +101,10 @@ class MultiDeviceAdapter : public DeviceMemoryAllocator { // (This attribute has no effect on CPU.) bool AllowsAsynchronousDeallocation() const override { return true; } + port::StatusOr GetStream(int device_ordinal) override { + return per_device_allocators_[device_ordinal].GetStream(device_ordinal); + } + private: std::vector per_device_allocators_; // The wrapped TF allocators backing per_device_allocators_