From 481748b5c92d217dd1b10725ce259ba8c7086df0 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Mon, 29 Apr 2019 16:51:22 -0700 Subject: [PATCH] [SE] Use absl mutex instead of TF mutex. Fix layering violation. PiperOrigin-RevId: 245851033 --- tensorflow/stream_executor/BUILD | 13 +++++- tensorflow/stream_executor/cuda/BUILD | 10 ++++- tensorflow/stream_executor/cuda/cuda_blas.cc | 2 +- tensorflow/stream_executor/cuda/cuda_blas.h | 6 +-- tensorflow/stream_executor/cuda/cuda_dnn.cc | 13 +++--- tensorflow/stream_executor/cuda/cuda_dnn.h | 1 - .../stream_executor/cuda/cuda_driver.cc | 28 +++++-------- .../stream_executor/cuda/cuda_gpu_executor.cc | 16 +++---- .../stream_executor/cuda/cuda_platform.cc | 42 +++++++++---------- .../stream_executor/cuda/cuda_platform.h | 1 - tensorflow/stream_executor/cuda/cuda_rng.cc | 8 ++-- tensorflow/stream_executor/executor_cache.cc | 13 +++--- tensorflow/stream_executor/executor_cache.h | 8 ++-- tensorflow/stream_executor/gpu/BUILD | 2 + tensorflow/stream_executor/gpu/gpu_executor.h | 8 ++-- tensorflow/stream_executor/gpu/gpu_rng.h | 9 ++-- tensorflow/stream_executor/host/BUILD | 2 + .../stream_executor/host/host_platform.h | 1 - .../stream_executor/host/host_stream.cc | 10 ++--- tensorflow/stream_executor/host/host_stream.h | 5 ++- tensorflow/stream_executor/kernel_spec.cc | 4 +- tensorflow/stream_executor/kernel_spec.h | 6 +-- tensorflow/stream_executor/platform/BUILD | 1 - .../stream_executor/platform/default/BUILD | 1 - .../platform/default/dso_loader.h | 4 +- .../stream_executor/platform/default/mutex.h | 40 ------------------ tensorflow/stream_executor/platform/mutex.h | 27 ------------ tensorflow/stream_executor/platform/port.h | 4 -- tensorflow/stream_executor/plugin_registry.cc | 12 +++--- tensorflow/stream_executor/plugin_registry.h | 1 - tensorflow/stream_executor/rocm/rocm_blas.cc | 2 +- tensorflow/stream_executor/rocm/rocm_blas.h | 4 +- tensorflow/stream_executor/rocm/rocm_dnn.cc | 25 +++++------ tensorflow/stream_executor/rocm/rocm_dnn.h | 2 +- .../stream_executor/rocm/rocm_driver.cc | 17 +++----- .../stream_executor/rocm/rocm_gpu_executor.cc | 20 ++++----- .../stream_executor/rocm/rocm_platform.cc | 41 ++++++++---------- .../stream_executor/rocm/rocm_platform.h | 4 +- tensorflow/stream_executor/rocm/rocm_rng.cc | 8 ++-- tensorflow/stream_executor/stream.cc | 8 ++-- tensorflow/stream_executor/stream.h | 8 ++-- .../stream_executor/stream_executor_pimpl.cc | 27 ++++++------ .../stream_executor/stream_executor_pimpl.h | 6 +-- .../temporary_memory_manager.cc | 12 +++--- .../temporary_memory_manager.h | 4 +- 45 files changed, 207 insertions(+), 279 deletions(-) delete mode 100644 tensorflow/stream_executor/platform/default/mutex.h delete mode 100644 tensorflow/stream_executor/platform/mutex.h diff --git a/tensorflow/stream_executor/BUILD b/tensorflow/stream_executor/BUILD index 67840a1c10a..b12a661cd3c 100644 --- a/tensorflow/stream_executor/BUILD +++ b/tensorflow/stream_executor/BUILD @@ -93,6 +93,7 @@ cc_library( "//tensorflow/stream_executor/platform", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], @@ -140,6 +141,7 @@ cc_library( "//tensorflow/stream_executor/platform", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], @@ -152,6 +154,7 @@ cc_library( deps = [ "//tensorflow/stream_executor/platform", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", ], ) @@ -199,6 +202,7 @@ cc_library( "//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/platform", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], @@ -231,6 +235,7 @@ cc_library( "//tensorflow/stream_executor/platform", "//third_party/eigen3", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", ], ) @@ -259,6 +264,7 @@ cc_library( "//tensorflow/stream_executor/platform", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", ], ) @@ -295,6 +301,7 @@ cc_library( ":stream_header", "//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/platform", + "@com_google_absl//absl/synchronization", ], ) @@ -311,6 +318,7 @@ cc_library( "//tensorflow/stream_executor/platform", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", ], ) @@ -399,6 +407,7 @@ cc_library( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], @@ -440,10 +449,10 @@ cc_library( ":plugin", ":stream_executor_headers", "//tensorflow/stream_executor/lib", - "//tensorflow/stream_executor/platform", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", ], ) @@ -536,6 +545,7 @@ cc_library( "//tensorflow/stream_executor/platform", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], @@ -612,6 +622,7 @@ cc_library( "//tensorflow/stream_executor/platform", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], diff --git a/tensorflow/stream_executor/cuda/BUILD b/tensorflow/stream_executor/cuda/BUILD index 8bad0786fc6..fa456bf76d3 100644 --- a/tensorflow/stream_executor/cuda/BUILD +++ b/tensorflow/stream_executor/cuda/BUILD @@ -59,9 +59,11 @@ cc_library( "//tensorflow/stream_executor/platform", ], ) + tf_additional_cuda_platform_deps() + [ + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", ], alwayslink = True, # Registers itself with the MultiPlatformManager. ) @@ -112,6 +114,7 @@ cc_library( "//tensorflow:using_cuda_clang_with_dynamic_build": ["cudart_stub"], "//conditions:default": ["//tensorflow/core:cuda"], }) + [ + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", @@ -223,7 +226,10 @@ cc_library( ] + if_static( ["@local_config_cuda//cuda:cublas"], [":cublas_stub"], - )) + ["@com_google_absl//absl/strings:str_format"], + )) + [ + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + ], alwayslink = True, ) @@ -306,7 +312,7 @@ cc_library( ]) + tf_additional_cudnn_plugin_deps() + if_cuda_is_configured(if_static( ["@local_config_cuda//cuda:cudnn"], [":cudnn_stub"], - )), + )) + ["@com_google_absl//absl/synchronization"], alwayslink = True, ) diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index 34bbd2f0d8f..661d84512a2 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -402,7 +402,7 @@ template bool CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream, bool pointer_mode_host, bool err_on_failure, bool use_tensor_op_math, Args... args) { - mutex_lock lock(mu_); + absl::MutexLock lock(&mu_); CHECK(blas_ != nullptr); if (!SetStream(stream)) { diff --git a/tensorflow/stream_executor/cuda/cuda_blas.h b/tensorflow/stream_executor/cuda/cuda_blas.h index 63d03056d91..d0cd15fe4d2 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.h +++ b/tensorflow/stream_executor/cuda/cuda_blas.h @@ -20,9 +20,9 @@ limitations under the License. #ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_ #define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_ +#include "absl/synchronization/mutex.h" #include "tensorflow/stream_executor/blas.h" #include "tensorflow/stream_executor/host_or_device_scalar.h" -#include "tensorflow/stream_executor/platform/mutex.h" #include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/thread_annotations.h" #include "tensorflow/stream_executor/plugin_registry.h" @@ -142,8 +142,8 @@ class CUDABlas : public blas::BlasSupport { const T &beta, DeviceMemory *y, int incy, blas::ProfileResult *output_profile_result); - // mutex that guards the cuBLAS handle for this device. - mutex mu_; + // Guards the cuBLAS handle for this device. + absl::Mutex mu_; // GpuExecutor which instantiated this CUDABlas. // Immutable post-initialization. diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 14f4be25352..12da285bb6a 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -133,8 +133,8 @@ class CudnnHandle { public: // Takes ownership of the executor context and the lock to access cuDNN // using handle. - CudnnHandle(gpu::ScopedActivateExecutorContext context, mutex_lock lock, - cudnnHandle_t handle) + CudnnHandle(gpu::ScopedActivateExecutorContext context, + std::unique_ptr lock, cudnnHandle_t handle) : context_(std::move(context)), lock_(std::move(lock)), handle_(handle) {} // Returns cuDNN handle. To be passed directly to cuDNN APIs, don't keep @@ -143,7 +143,7 @@ class CudnnHandle { private: gpu::ScopedActivateExecutorContext context_; - mutex_lock lock_; + std::unique_ptr lock_; cudnnHandle_t handle_; // Not owned. }; @@ -160,7 +160,7 @@ class CudnnAccess { explicit CudnnAccess(cudnnHandle_t handle) : handle_(handle) {} ~CudnnAccess() { - mutex_lock lock(mutex_); + absl::MutexLock lock(&mutex_); cudnnDestroy(handle_); } @@ -181,7 +181,8 @@ class CudnnAccess { // therefore a bad idea (performance wise) to call any cuDNN APIs that // enqueue work in the stream. CudnnHandle GetHandle(GpuExecutor* executor, Stream* stream) { - mutex_lock lock(mutex_); + auto lock = absl::make_unique(&mutex_); + mutex_.AssertHeld(); gpu::ScopedActivateExecutorContext context(executor); CUstream cu_stream = stream ? AsGpuStreamValue(stream) : cudaStreamLegacy; const auto status = cudnnSetStream(handle_, cu_stream); @@ -191,7 +192,7 @@ class CudnnAccess { private: // Guards the enqueueing of cuDNN operations via the handle_ below. - mutex mutex_; + absl::Mutex mutex_; // cuDNN library handle. cudnnHandle_t handle_ GUARDED_BY(mutex_); // Owned. diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h index 3a49469651c..80fc1ae1794 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.h +++ b/tensorflow/stream_executor/cuda/cuda_dnn.h @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/stream_executor/cuda/cuda_activation.h" #include "tensorflow/stream_executor/dnn.h" #include "tensorflow/stream_executor/lib/status.h" -#include "tensorflow/stream_executor/platform/mutex.h" #include "tensorflow/stream_executor/platform/thread_annotations.h" #include "tensorflow/stream_executor/plugin_registry.h" #include "tensorflow/stream_executor/temporary_device_memory.h" diff --git a/tensorflow/stream_executor/cuda/cuda_driver.cc b/tensorflow/stream_executor/cuda/cuda_driver.cc index eb6e124176f..e4c5c72854e 100644 --- a/tensorflow/stream_executor/cuda/cuda_driver.cc +++ b/tensorflow/stream_executor/cuda/cuda_driver.cc @@ -23,11 +23,13 @@ limitations under the License. #include #include "absl/base/casts.h" +#include "absl/base/const_init.h" #include "absl/container/inlined_vector.h" #include "absl/debugging/leak_check.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" #include "absl/synchronization/notification.h" #include "cuda/include/cuda_runtime_api.h" #include "tensorflow/stream_executor/cuda/cuda_diagnostics.h" @@ -38,7 +40,6 @@ limitations under the License. #include "tensorflow/stream_executor/lib/static_threadlocal.h" #include "tensorflow/stream_executor/lib/threadpool.h" #include "tensorflow/stream_executor/platform/logging.h" -#include "tensorflow/stream_executor/platform/mutex.h" #include "tensorflow/stream_executor/platform/port.h" bool FLAGS_gpuexec_cuda_driver_inject_init_error = false; @@ -65,14 +66,14 @@ class CreatedContexts { public: // Returns whether context is a member of the live set. static bool Has(CUcontext context) { - tf_shared_lock lock(mu_); + absl::ReaderMutexLock lock(&mu_); return Live()->find(context) != Live()->end(); } // Adds context to the live set, or returns it if it's already present. static GpuContext* Add(CUcontext context) { CHECK(context != nullptr); - mutex_lock lock(mu_); + absl::MutexLock lock(&mu_); auto insert_result = Live()->insert(std::make_pair(context, nullptr)); auto it = insert_result.first; if (insert_result.second) { @@ -85,7 +86,7 @@ class CreatedContexts { // Removes context from the live set. static void Remove(CUcontext context) { CHECK(context != nullptr); - mutex_lock lock(mu_); + absl::MutexLock lock(&mu_); auto it = Live()->find(context); CHECK(it != Live()->end()) << context; Live()->erase(it); @@ -100,11 +101,11 @@ class CreatedContexts { } // Lock that guards access-to/mutation-of the live set. - static mutex mu_; + static absl::Mutex mu_; static int64 next_id_; }; -/* static */ mutex CreatedContexts::mu_{LINKER_INITIALIZED}; +/* static */ absl::Mutex CreatedContexts::mu_{absl::kConstInit}; /* static */ int64 CreatedContexts::next_id_ = 1; // 0 means "no context" // Formats CUresult to output prettified values into a log stream. @@ -322,17 +323,10 @@ static port::Status InternalInit() { /* static */ port::Status GpuDriver::Init() { // Cached return value from calling InternalInit(), as cuInit need only be // called once, but GpuDriver::Init may be called many times. - static port::Status init_retval; - static bool set = false; - static mutex* init_mu = new mutex; - - mutex_lock lock(*init_mu); - if (!set) { - init_retval = InternalInit(); - set = true; - } - - return init_retval; + static port::Status* init_retval = [] { + return new port::Status(InternalInit()); + }(); + return *init_retval; } /* static */ port::Status GpuDriver::GetDevice(int device_ordinal, diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc index 6056f6df1f0..cdd5ae7e57f 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc @@ -272,7 +272,7 @@ bool GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, VLOG(3) << "GetKernel on kernel " << kernel << " : " << kernel->name(); if (spec.has_cuda_cubin_in_memory()) { - mutex_lock lock{in_memory_modules_mu_}; + absl::MutexLock lock{&in_memory_modules_mu_}; kernelname = &spec.cuda_cubin_in_memory().kernelname(); const char *cubin = spec.cuda_cubin_in_memory().bytes(); if (!LoadModuleFromCuBin(cubin, &module)) { @@ -295,7 +295,7 @@ bool GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, return false; } - mutex_lock lock{in_memory_modules_mu_}; + absl::MutexLock lock{&in_memory_modules_mu_}; if (!LoadModuleFromPtx(ptx, &module)) { return false; } @@ -343,7 +343,7 @@ bool GpuExecutor::UnloadGpuBinary(const void* gpu_binary) { void GpuExecutor::UnloadKernel(const KernelBase* kernel) { VLOG(3) << "Unloading kernel " << kernel << " : " << kernel->name(); - mutex_lock lock{in_memory_modules_mu_}; + absl::MutexLock lock{&in_memory_modules_mu_}; auto gpu_binary_it = kernel_to_gpu_binary_.find(kernel); if (kernel_to_gpu_binary_.end() == gpu_binary_it) { VLOG(3) << "Kernel " << kernel << " : " << kernel->name() @@ -362,7 +362,7 @@ bool GpuExecutor::LoadModule(const MultiModuleLoaderSpec& spec, // ModuleHandle::id(). CUmodule cu_module; if (spec.has_cuda_cubin_in_memory()) { - mutex_lock lock{in_memory_modules_mu_}; + absl::MutexLock lock{&in_memory_modules_mu_}; if (!LoadModuleFromCuBin( reinterpret_cast(spec.cuda_cubin_in_memory().data()), &cu_module)) { @@ -380,7 +380,7 @@ bool GpuExecutor::LoadModule(const MultiModuleLoaderSpec& spec, return false; } - mutex_lock lock{in_memory_modules_mu_}; + absl::MutexLock lock{&in_memory_modules_mu_}; if (!LoadModuleFromPtx(spec.cuda_ptx_in_memory(), &cu_module)) { return false; } @@ -394,7 +394,7 @@ bool GpuExecutor::LoadModule(const MultiModuleLoaderSpec& spec, bool GpuExecutor::UnloadModule(ModuleHandle module_handle) { const char *gpu_binary = reinterpret_cast(module_handle.id()); - mutex_lock lock{in_memory_modules_mu_}; + absl::MutexLock lock{&in_memory_modules_mu_}; return UnloadGpuBinary(gpu_binary); } @@ -428,7 +428,7 @@ bool GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, // whether we've done an occupancy check on this kernel before isn't free // (because we have to synchronize), so we only do this at -v 2+. if (VLOG_IS_ON(2)) { - mutex_lock lock(launched_kernels_mu_); + absl::MutexLock lock(&launched_kernels_mu_); if (!launched_kernels_.count(cufunc)) { VlogOccupancyInfo(kernel, thread_dims, block_dims); // TODO(rspringer): Remove elements from launched_kernels_...if we ever @@ -892,7 +892,7 @@ bool GpuExecutor::GetSymbol(const string& symbol_name, }; { // give limited scope to mutex_lock - mutex_lock lock{in_memory_modules_mu_}; + absl::MutexLock lock{&in_memory_modules_mu_}; if (static_cast(module_handle)) { auto it = gpu_binary_to_module_.find(module_handle.id()); CHECK(it != gpu_binary_to_module_.end()); diff --git a/tensorflow/stream_executor/cuda/cuda_platform.cc b/tensorflow/stream_executor/cuda/cuda_platform.cc index 8a5cc48118c..8ca5185bcf0 100644 --- a/tensorflow/stream_executor/cuda/cuda_platform.cc +++ b/tensorflow/stream_executor/cuda/cuda_platform.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/stream_executor/cuda/cuda_platform.h" +#include "absl/base/const_init.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -75,30 +76,25 @@ CudaPlatform::~CudaPlatform() {} void CudaPlatform::InspectNumaNodes() { // To get NUMA node information, we need to create all executors, so we can // examine their device descriptions to see their bus assignments. - static bool initialized = false; - static mutex numa_mutex(LINKER_INITIALIZED); - mutex_lock lock(numa_mutex); - if (initialized) { - return; - } - - StreamExecutorConfig config; - for (int i = 0; i < VisibleDeviceCount(); i++) { - config.ordinal = i; - StreamExecutor* exec = GetExecutor(config).ValueOrDie(); - if (i == 0) { - // NUMA nodes may not start at 0, so set the minimum node based on the - // first executor we see. - min_numa_node_ = exec->GetDeviceDescription().numa_node(); - limit_numa_node_ = min_numa_node_ + 1; - } else { - min_numa_node_ = - std::min(min_numa_node_, exec->GetDeviceDescription().numa_node()); - limit_numa_node_ = std::max(limit_numa_node_, - exec->GetDeviceDescription().numa_node() + 1); + static std::once_flag once; + std::call_once(once, [&] { + StreamExecutorConfig config; + for (int i = 0; i < VisibleDeviceCount(); i++) { + config.ordinal = i; + StreamExecutor* exec = GetExecutor(config).ValueOrDie(); + if (i == 0) { + // NUMA nodes may not start at 0, so set the minimum node based on the + // first executor we see. + min_numa_node_ = exec->GetDeviceDescription().numa_node(); + limit_numa_node_ = min_numa_node_ + 1; + } else { + min_numa_node_ = + std::min(min_numa_node_, exec->GetDeviceDescription().numa_node()); + limit_numa_node_ = std::max( + limit_numa_node_, exec->GetDeviceDescription().numa_node() + 1); + } } - } - initialized = true; + }); } int CudaPlatform::BusCount() { diff --git a/tensorflow/stream_executor/cuda/cuda_platform.h b/tensorflow/stream_executor/cuda/cuda_platform.h index b92f0d86597..b452fd38fb5 100644 --- a/tensorflow/stream_executor/cuda/cuda_platform.h +++ b/tensorflow/stream_executor/cuda/cuda_platform.h @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/multi_platform_manager.h" #include "tensorflow/stream_executor/platform.h" -#include "tensorflow/stream_executor/platform/mutex.h" #include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/thread_annotations.h" #include "tensorflow/stream_executor/stream_executor_internal.h" diff --git a/tensorflow/stream_executor/cuda/cuda_rng.cc b/tensorflow/stream_executor/cuda/cuda_rng.cc index 36eef0a07ac..b86c3122fb1 100644 --- a/tensorflow/stream_executor/cuda/cuda_rng.cc +++ b/tensorflow/stream_executor/cuda/cuda_rng.cc @@ -71,7 +71,7 @@ GpuRng::~GpuRng() { } bool GpuRng::Init() { - mutex_lock lock(mu_); + absl::MutexLock lock(&mu_); CHECK(rng_ == nullptr); cuda::ScopedActivateExecutorContext sac(parent_); @@ -106,7 +106,7 @@ constexpr bool ComplexIsConsecutiveFloats() { template bool GpuRng::DoPopulateRandUniformInternal(Stream* stream, DeviceMemory* v) { - mutex_lock lock(mu_); + absl::MutexLock lock(&mu_); static_assert(ComplexIsConsecutiveFloats(), "std::complex values are not stored as consecutive values"); @@ -164,7 +164,7 @@ bool GpuRng::DoPopulateRandGaussianInternal(Stream* stream, ElemT mean, ElemT stddev, DeviceMemory* v, FuncT func) { - mutex_lock lock(mu_); + absl::MutexLock lock(&mu_); if (!SetStream(stream)) { return false; @@ -197,7 +197,7 @@ bool GpuRng::DoPopulateRandGaussian(Stream* stream, double mean, double stddev, } bool GpuRng::SetSeed(Stream* stream, const uint8* seed, uint64 seed_bytes) { - mutex_lock lock(mu_); + absl::MutexLock lock(&mu_); CHECK(rng_ != nullptr); if (!CheckSeed(seed, seed_bytes)) { diff --git a/tensorflow/stream_executor/executor_cache.cc b/tensorflow/stream_executor/executor_cache.cc index c62a66453fc..2dd0303f42f 100644 --- a/tensorflow/stream_executor/executor_cache.cc +++ b/tensorflow/stream_executor/executor_cache.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/stream_executor/executor_cache.h" #include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" namespace stream_executor { @@ -32,7 +33,7 @@ port::StatusOr ExecutorCache::GetOrCreate( Entry* entry = nullptr; { - mutex_lock lock{mutex_}; + absl::MutexLock lock{&mutex_}; entry = &cache_[config.ordinal]; // Release the map lock; the address of 'entry' is stable because // std::map guarantees reference stability. @@ -41,7 +42,7 @@ port::StatusOr ExecutorCache::GetOrCreate( // Acquire the per-Entry mutex without holding the map mutex. Initializing // an Executor may be expensive, so we want to allow concurrent // initialization of different entries. - mutex_lock lock{entry->configurations_mutex}; + absl::MutexLock lock{&entry->configurations_mutex}; for (const auto& iter : entry->configurations) { if (iter.first.plugin_config == config.plugin_config && iter.first.device_options == config.device_options) { @@ -66,7 +67,7 @@ port::StatusOr ExecutorCache::Get( const StreamExecutorConfig& config) { Entry* entry = nullptr; { - tf_shared_lock lock{mutex_}; + absl::ReaderMutexLock lock{&mutex_}; auto it = cache_.find(config.ordinal); if (it != cache_.end()) { entry = &it->second; @@ -77,7 +78,7 @@ port::StatusOr ExecutorCache::Get( config.ordinal)); } } - tf_shared_lock lock{entry->configurations_mutex}; + absl::ReaderMutexLock lock{&entry->configurations_mutex}; if (entry->configurations.empty()) { return port::Status( port::error::NOT_FOUND, @@ -96,12 +97,12 @@ port::StatusOr ExecutorCache::Get( } void ExecutorCache::DestroyAllExecutors() { - mutex_lock lock{mutex_}; + absl::MutexLock lock{&mutex_}; cache_.clear(); } ExecutorCache::Entry::~Entry() { - mutex_lock lock{configurations_mutex}; + absl::MutexLock lock{&configurations_mutex}; configurations.clear(); } diff --git a/tensorflow/stream_executor/executor_cache.h b/tensorflow/stream_executor/executor_cache.h index bbeeaed787c..0100f6d9ac0 100644 --- a/tensorflow/stream_executor/executor_cache.h +++ b/tensorflow/stream_executor/executor_cache.h @@ -19,9 +19,9 @@ limitations under the License. #include #include +#include "absl/synchronization/mutex.h" #include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/statusor.h" -#include "tensorflow/stream_executor/platform/mutex.h" #include "tensorflow/stream_executor/stream_executor_pimpl.h" namespace stream_executor { @@ -54,11 +54,11 @@ class ExecutorCache { struct Entry { ~Entry(); - // Mutex that locks the contents of each entry. The 'mutex_' of the + // Mutex that guards the contents of each entry. The 'mutex_' of the // ExecutorCache class protects both the 'cache_' and the existence of each // Entry, but not the Entry's contents. 'configurations_mutex' protects the // contents of the entry after 'mutex_' has been dropped. - mutex configurations_mutex; + absl::Mutex configurations_mutex; // Vector of cached {config, executor} pairs. std::vector< @@ -69,7 +69,7 @@ class ExecutorCache { // Maps ordinal number to a list of cached executors for that ordinal. // We key off of ordinal (instead of just looking up all fields in the // StreamExecutorConfig) for a slight improvement in lookup time. - mutex mutex_; + absl::Mutex mutex_; std::map cache_ GUARDED_BY(mutex_); SE_DISALLOW_COPY_AND_ASSIGN(ExecutorCache); diff --git a/tensorflow/stream_executor/gpu/BUILD b/tensorflow/stream_executor/gpu/BUILD index e6812389e3e..2dd21f49ff1 100644 --- a/tensorflow/stream_executor/gpu/BUILD +++ b/tensorflow/stream_executor/gpu/BUILD @@ -98,6 +98,7 @@ cc_library( "//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/platform", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", ], ) @@ -126,6 +127,7 @@ cc_library( "//tensorflow/stream_executor:plugin_registry", "//tensorflow/stream_executor:rng", "//tensorflow/stream_executor/platform", + "@com_google_absl//absl/synchronization", ], ) diff --git a/tensorflow/stream_executor/gpu/gpu_executor.h b/tensorflow/stream_executor/gpu/gpu_executor.h index 23df3923aaa..2149f136877 100644 --- a/tensorflow/stream_executor/gpu/gpu_executor.h +++ b/tensorflow/stream_executor/gpu/gpu_executor.h @@ -26,12 +26,12 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "tensorflow/stream_executor/event.h" #include "tensorflow/stream_executor/gpu/gpu_kernel.h" #include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/platform.h" -#include "tensorflow/stream_executor/platform/mutex.h" #include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/thread_annotations.h" #include "tensorflow/stream_executor/stream_executor_internal.h" @@ -287,7 +287,7 @@ class GpuExecutor : public internal::StreamExecutorInterface { EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_); // Guards the on-disk-module mapping. - mutex disk_modules_mu_; + absl::Mutex disk_modules_mu_; // Mapping from filename to GPUModuleHandle, if it was already retrieved. // Multiple GPUFunctionHandle are usually obtained from a single @@ -296,7 +296,7 @@ class GpuExecutor : public internal::StreamExecutorInterface { std::map disk_modules_ GUARDED_BY(disk_modules_mu_); // Guards the in-memory-module mapping. - mutex in_memory_modules_mu_; + absl::Mutex in_memory_modules_mu_; std::map in_memory_modules_ GUARDED_BY(in_memory_modules_mu_); @@ -309,7 +309,7 @@ class GpuExecutor : public internal::StreamExecutorInterface { gpu_binary_to_module_ GUARDED_BY(in_memory_modules_mu_); // Guards the launched kernel set. - mutex launched_kernels_mu_; + absl::Mutex launched_kernels_mu_; // Keeps track of the set of launched kernels. Currently used to suppress the // occupancy check on subsequent launches. diff --git a/tensorflow/stream_executor/gpu/gpu_rng.h b/tensorflow/stream_executor/gpu/gpu_rng.h index d4bf1e19630..f256e54b673 100644 --- a/tensorflow/stream_executor/gpu/gpu_rng.h +++ b/tensorflow/stream_executor/gpu/gpu_rng.h @@ -16,14 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_STREAM_EXECUTOR_GPU_GPU_RNG_H_ #define TENSORFLOW_STREAM_EXECUTOR_GPU_GPU_RNG_H_ -#include "tensorflow/stream_executor/platform/mutex.h" +#include "absl/synchronization/mutex.h" +#include "tensorflow/stream_executor/gpu/gpu_types.h" #include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/thread_annotations.h" #include "tensorflow/stream_executor/plugin_registry.h" #include "tensorflow/stream_executor/rng.h" -#include "tensorflow/stream_executor/gpu/gpu_types.h" - namespace stream_executor { class Stream; @@ -83,8 +82,8 @@ class GpuRng : public rng::RngSupport { // with random number generation. bool SetStream(Stream* stream) EXCLUSIVE_LOCKS_REQUIRED(mu_); - // mutex that guards the gpu rng library handle for this device. - mutex mu_; + // Guards the gpu rng library handle for this device. + absl::Mutex mu_; // GpuExecutor which instantiated this GpuRng. // Immutable post-initialization. diff --git a/tensorflow/stream_executor/host/BUILD b/tensorflow/stream_executor/host/BUILD index 19902d86c2f..3abcd348e9b 100644 --- a/tensorflow/stream_executor/host/BUILD +++ b/tensorflow/stream_executor/host/BUILD @@ -52,6 +52,7 @@ cc_library( "//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/platform", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", ], alwayslink = True, # Registers itself with the MultiPlatformManager. ) @@ -67,6 +68,7 @@ cc_library( deps = [ "//tensorflow/stream_executor:kernel", "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/synchronization", ], ) diff --git a/tensorflow/stream_executor/host/host_platform.h b/tensorflow/stream_executor/host/host_platform.h index 61ed4c01fda..1a5c276ce67 100644 --- a/tensorflow/stream_executor/host/host_platform.h +++ b/tensorflow/stream_executor/host/host_platform.h @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/multi_platform_manager.h" #include "tensorflow/stream_executor/platform.h" -#include "tensorflow/stream_executor/platform/mutex.h" #include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/thread_annotations.h" #include "tensorflow/stream_executor/stream_executor_pimpl.h" diff --git a/tensorflow/stream_executor/host/host_stream.cc b/tensorflow/stream_executor/host/host_stream.cc index bfbfb56cd79..7bb23f2ffa1 100644 --- a/tensorflow/stream_executor/host/host_stream.cc +++ b/tensorflow/stream_executor/host/host_stream.cc @@ -38,15 +38,15 @@ bool HostStream::EnqueueTask(std::function task) { // should guarantee that all tasks are destroyed. task = std::function(); { - mutex_lock lock(stream->mu_); + absl::MutexLock lock(&stream->mu_); --stream->pending_tasks_; } - stream->completion_condition_.notify_all(); + stream->completion_condition_.SignalAll(); } }; { - mutex_lock lock(mu_); + absl::MutexLock lock(&mu_); ++pending_tasks_; } host_executor_->Schedule(NotifiedTask{this, std::move(task)}); @@ -54,9 +54,9 @@ bool HostStream::EnqueueTask(std::function task) { } void HostStream::BlockUntilDone() { - mutex_lock lock(mu_); + absl::MutexLock lock(&mu_); while (pending_tasks_ != 0) { - completion_condition_.wait(lock); + completion_condition_.Wait(&mu_); } } diff --git a/tensorflow/stream_executor/host/host_stream.h b/tensorflow/stream_executor/host/host_stream.h index be88f074cf6..a15173e5408 100644 --- a/tensorflow/stream_executor/host/host_stream.h +++ b/tensorflow/stream_executor/host/host_stream.h @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/synchronization/mutex.h" #include "tensorflow/stream_executor/lib/threadpool.h" #include "tensorflow/stream_executor/stream_executor_internal.h" @@ -45,9 +46,9 @@ class HostStream : public internal::StreamInterface { static const int kExecutorThreads = 1; std::unique_ptr host_executor_; - mutex mu_; + absl::Mutex mu_; int pending_tasks_ GUARDED_BY(mu_) = 0; - condition_variable completion_condition_; + absl::CondVar completion_condition_; }; } // namespace host diff --git a/tensorflow/stream_executor/kernel_spec.cc b/tensorflow/stream_executor/kernel_spec.cc index 2e090af7169..d7e00205103 100644 --- a/tensorflow/stream_executor/kernel_spec.cc +++ b/tensorflow/stream_executor/kernel_spec.cc @@ -94,7 +94,7 @@ const char *CudaPtxInMemory::default_text() const { return nullptr; } - mutex_lock lock(mu_); + absl::MutexLock lock(&mu_); auto ptx = ptx_by_compute_capability_.begin()->second; // Check if there is an entry in decompressed ptx table. @@ -128,7 +128,7 @@ const char *CudaPtxInMemory::text(int compute_capability_major, return nullptr; } - mutex_lock lock(mu_); + absl::MutexLock lock(&mu_); // Check if there is an entry in decompressed ptx table. auto decompressed_ptx_iter = decompressed_ptx_.find(ptx_iter->second); diff --git a/tensorflow/stream_executor/kernel_spec.h b/tensorflow/stream_executor/kernel_spec.h index 04b2eab084c..7199f60e4ca 100644 --- a/tensorflow/stream_executor/kernel_spec.h +++ b/tensorflow/stream_executor/kernel_spec.h @@ -47,13 +47,13 @@ limitations under the License. #define TENSORFLOW_STREAM_EXECUTOR_KERNEL_SPEC_H_ #include + #include #include -#include "tensorflow/stream_executor/platform/port.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "tensorflow/stream_executor/platform/logging.h" -#include "tensorflow/stream_executor/platform/mutex.h" #include "tensorflow/stream_executor/platform/port.h" namespace stream_executor { @@ -206,7 +206,7 @@ class CudaPtxInMemory : public KernelLoaderSpec { // Stores all decompressed ptx strings, with original ptx string as keys. // It is marked as mutable for lazy decompression. mutable std::map decompressed_ptx_; - mutable mutex mu_; + mutable absl::Mutex mu_; // Defines the minimum compute capability possible. Used when PTX has no // compute capability specified (in the single-PTX constructor). diff --git a/tensorflow/stream_executor/platform/BUILD b/tensorflow/stream_executor/platform/BUILD index 702b2cdfe0d..ee093b4ac58 100644 --- a/tensorflow/stream_executor/platform/BUILD +++ b/tensorflow/stream_executor/platform/BUILD @@ -16,7 +16,6 @@ cc_library( name = "platform", textual_hdrs = [ "logging.h", - "mutex.h", "platform.h", "port.h", "thread_annotations.h", diff --git a/tensorflow/stream_executor/platform/default/BUILD b/tensorflow/stream_executor/platform/default/BUILD index f1ae7d86ff7..1be09d8b143 100644 --- a/tensorflow/stream_executor/platform/default/BUILD +++ b/tensorflow/stream_executor/platform/default/BUILD @@ -6,7 +6,6 @@ cc_library( name = "platform", textual_hdrs = [ "initialize.h", - "mutex.h", ], deps = ["//tensorflow/core:lib"], ) diff --git a/tensorflow/stream_executor/platform/default/dso_loader.h b/tensorflow/stream_executor/platform/default/dso_loader.h index 4c868226ae5..d8d0af0aba2 100644 --- a/tensorflow/stream_executor/platform/default/dso_loader.h +++ b/tensorflow/stream_executor/platform/default/dso_loader.h @@ -20,13 +20,13 @@ limitations under the License. #define TENSORFLOW_STREAM_EXECUTOR_DSO_LOADER_H_ #include -#include "tensorflow/stream_executor/platform/port.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/platform.h" -#include "tensorflow/stream_executor/platform/mutex.h" +#include "tensorflow/stream_executor/platform/port.h" namespace stream_executor { namespace internal { diff --git a/tensorflow/stream_executor/platform/default/mutex.h b/tensorflow/stream_executor/platform/default/mutex.h deleted file mode 100644 index 2f8f0636ba7..00000000000 --- a/tensorflow/stream_executor/platform/default/mutex.h +++ /dev/null @@ -1,40 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_STREAM_EXECUTOR_PLATFORM_DEFAULT_MUTEX_H_ -#define TENSORFLOW_STREAM_EXECUTOR_PLATFORM_DEFAULT_MUTEX_H_ - -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/types.h" - -namespace stream_executor { - -#undef mutex_lock -#undef tf_shared_lock - -using tensorflow::ConditionResult; -using tensorflow::WaitForMilliseconds; -using tensorflow::condition_variable; -using tensorflow::mutex; -using tensorflow::mutex_lock; -using tensorflow::tf_shared_lock; - -#define mutex_lock(x) static_assert(0, "mutex_lock_decl_missing_var_name"); -#define tf_shared_lock(x) \ - static_assert(0, "tf_shared_lock_decl_missing_var_name"); - -} // namespace stream_executor - -#endif // TENSORFLOW_STREAM_EXECUTOR_PLATFORM_DEFAULT_MUTEX_H_ diff --git a/tensorflow/stream_executor/platform/mutex.h b/tensorflow/stream_executor/platform/mutex.h deleted file mode 100644 index fa6c8c017c3..00000000000 --- a/tensorflow/stream_executor/platform/mutex.h +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_STREAM_EXECUTOR_PLATFORM_MUTEX_H_ -#define TENSORFLOW_STREAM_EXECUTOR_PLATFORM_MUTEX_H_ - -#include "tensorflow/stream_executor/platform/platform.h" - -#if defined(PLATFORM_GOOGLE) -#include "tensorflow/stream_executor/platform/google/mutex.h" -#else -#include "tensorflow/stream_executor/platform/default/mutex.h" -#endif - -#endif // TENSORFLOW_STREAM_EXECUTOR_PLATFORM_MUTEX_H_ diff --git a/tensorflow/stream_executor/platform/port.h b/tensorflow/stream_executor/platform/port.h index 57ad965ef11..26c14bdf14f 100644 --- a/tensorflow/stream_executor/platform/port.h +++ b/tensorflow/stream_executor/platform/port.h @@ -19,7 +19,6 @@ limitations under the License. #define TENSORFLOW_STREAM_EXECUTOR_PLATFORM_PORT_H_ #include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" namespace stream_executor { @@ -38,9 +37,6 @@ using tensorflow::uint64; using std::string; #endif -using tensorflow::LinkerInitialized; -using tensorflow::LINKER_INITIALIZED; - #define SE_FALLTHROUGH_INTENDED TF_FALLTHROUGH_INTENDED } // namespace stream_executor diff --git a/tensorflow/stream_executor/plugin_registry.cc b/tensorflow/stream_executor/plugin_registry.cc index 53cbd22af13..1e6a2d4f2a9 100644 --- a/tensorflow/stream_executor/plugin_registry.cc +++ b/tensorflow/stream_executor/plugin_registry.cc @@ -15,8 +15,10 @@ limitations under the License. #include "tensorflow/stream_executor/plugin_registry.h" +#include "absl/base/const_init.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" #include "tensorflow/stream_executor/lib/error.h" #include "tensorflow/stream_executor/multi_platform_manager.h" @@ -44,9 +46,9 @@ string PluginKindString(PluginKind plugin_kind) { PluginRegistry::DefaultFactories::DefaultFactories() : blas(kNullPlugin), dnn(kNullPlugin), fft(kNullPlugin), rng(kNullPlugin) { } -static mutex& GetPluginRegistryMutex() { - static mutex* mu = new mutex; - return *mu; +static absl::Mutex& GetPluginRegistryMutex() { + static absl::Mutex mu(absl::kConstInit); + return mu; } /* static */ PluginRegistry* PluginRegistry::instance_ = nullptr; @@ -54,7 +56,7 @@ static mutex& GetPluginRegistryMutex() { PluginRegistry::PluginRegistry() {} /* static */ PluginRegistry* PluginRegistry::Instance() { - mutex_lock lock{GetPluginRegistryMutex()}; + absl::MutexLock lock{&GetPluginRegistryMutex()}; if (instance_ == nullptr) { instance_ = new PluginRegistry(); } @@ -70,7 +72,7 @@ template port::Status PluginRegistry::RegisterFactoryInternal( PluginId plugin_id, const string& plugin_name, FACTORY_TYPE factory, std::map* factories) { - mutex_lock lock{GetPluginRegistryMutex()}; + absl::MutexLock lock{&GetPluginRegistryMutex()}; if (factories->find(plugin_id) != factories->end()) { return port::Status( diff --git a/tensorflow/stream_executor/plugin_registry.h b/tensorflow/stream_executor/plugin_registry.h index 3065b5cb770..e0321169815 100644 --- a/tensorflow/stream_executor/plugin_registry.h +++ b/tensorflow/stream_executor/plugin_registry.h @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/platform.h" -#include "tensorflow/stream_executor/platform/mutex.h" #include "tensorflow/stream_executor/plugin.h" #include "tensorflow/stream_executor/rng.h" diff --git a/tensorflow/stream_executor/rocm/rocm_blas.cc b/tensorflow/stream_executor/rocm/rocm_blas.cc index 5d5fba9d424..f3003e0a987 100644 --- a/tensorflow/stream_executor/rocm/rocm_blas.cc +++ b/tensorflow/stream_executor/rocm/rocm_blas.cc @@ -382,7 +382,7 @@ template bool ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream, bool pointer_mode_host, bool err_on_failure, Args... args) { - mutex_lock lock{mu_}; + absl::MutexLock lock{&mu_}; CHECK(blas_ != nullptr); if (!SetStream(stream)) { diff --git a/tensorflow/stream_executor/rocm/rocm_blas.h b/tensorflow/stream_executor/rocm/rocm_blas.h index 75c68481cee..a9743314aa8 100644 --- a/tensorflow/stream_executor/rocm/rocm_blas.h +++ b/tensorflow/stream_executor/rocm/rocm_blas.h @@ -20,8 +20,8 @@ limitations under the License. #ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_ #define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_ +#include "absl/synchronization/mutex.h" #include "tensorflow/stream_executor/blas.h" -#include "tensorflow/stream_executor/platform/mutex.h" #include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/thread_annotations.h" #include "tensorflow/stream_executor/plugin_registry.h" @@ -141,7 +141,7 @@ class ROCMBlas : public blas::BlasSupport { blas::ProfileResult *output_profile_result); // mutex that guards the rocBLAS handle for this device. - mutex mu_; + absl::Mutex mu_; // GpuExecutor which instantiated this ROCMBlas. // Immutable post-initialization. diff --git a/tensorflow/stream_executor/rocm/rocm_dnn.cc b/tensorflow/stream_executor/rocm/rocm_dnn.cc index e1a2e453152..25a68810d23 100644 --- a/tensorflow/stream_executor/rocm/rocm_dnn.cc +++ b/tensorflow/stream_executor/rocm/rocm_dnn.cc @@ -98,8 +98,8 @@ class MIOpenHandle { public: // Takes ownership of the executor context and the lock to access MIOpen // using handle. - MIOpenHandle(gpu::ScopedActivateExecutorContext context, mutex_lock lock, - miopenHandle_t handle) + MIOpenHandle(gpu::ScopedActivateExecutorContext context, + std::unique_ptr lock, miopenHandle_t handle) : context_(std::move(context)), lock_(std::move(lock)), handle_(handle) {} // Returns MIOpen handle. To be passed directly to MIOpen APIs, don't keep @@ -108,7 +108,7 @@ class MIOpenHandle { private: gpu::ScopedActivateExecutorContext context_; - mutex_lock lock_; + std::unique_ptr lock_; miopenHandle_t handle_; // Not owned. }; @@ -318,7 +318,7 @@ class CachedFusionPlans { miopenFusionPlanDescriptor_t* fusion_plan, miopenFusionDirection_t fusion_direction, miopenTensorDescriptor_t input_descriptor) { - mutex_lock lock{cached_plans_mutex}; + absl::MutexLock lock{&cached_plans_mutex}; bool found_cached_plan = false; @@ -342,7 +342,7 @@ class CachedFusionPlans { // Need to figure out the right place to call this routine. static void Clear() { - mutex_lock lock{cached_plans_mutex}; + absl::MutexLock lock{&cached_plans_mutex}; for (auto it : cached_plans) { auto status = wrap::miopenDestroyFusionPlan(it.second); @@ -359,19 +359,19 @@ class CachedFusionPlans { // Is the Fusion plan corresponding to this hash unsupported. static bool IsUnsupportedFusionPlan(uint64 hash) { - mutex_lock lock{cached_plans_mutex}; + absl::MutexLock lock{&cached_plans_mutex}; return unsupported_plans.count(hash) > 0; } // Mark the given hash value as corresponding to an unsupported fusion plan. static void MarkFusionPlanUnsupported(uint64 hash) { - mutex_lock lock{cached_plans_mutex}; + absl::MutexLock lock{&cached_plans_mutex}; unsupported_plans.insert(hash); } private: // Mutex to guard access to all data within this class. - static mutex cached_plans_mutex; + static absl::Mutex cached_plans_mutex; // Map of hash-value to MIOpen Fusion plan descriptors. // Need to be able share this across more than one stream and hence static. @@ -382,7 +382,7 @@ class CachedFusionPlans { static std::set unsupported_plans; }; -mutex CachedFusionPlans::cached_plans_mutex; +absl::Mutex CachedFusionPlans::cached_plans_mutex; std::map CachedFusionPlans::cached_plans; std::set CachedFusionPlans::unsupported_plans; @@ -449,7 +449,7 @@ class MIOpenAccess { explicit MIOpenAccess(miopenHandle_t handle) : handle_(handle) {} ~MIOpenAccess() { - mutex_lock lock(mutex_); + absl::MutexLock lock(&mutex_); wrap::miopenDestroy(handle_); } @@ -468,7 +468,8 @@ class MIOpenAccess { // therefore a bad idea (performance wise) to call any MIOpen APIs that // enqueue work in the stream. MIOpenHandle GetHandle(GpuExecutor* executor, Stream* stream) { - mutex_lock lock(mutex_); + auto lock = absl::make_unique(&mutex_); + mutex_.AssertHeld(); gpu::ScopedActivateExecutorContext context(executor); hipStream_t hip_stream = stream ? AsGpuStreamValue(stream) : nullptr; auto status = wrap::miopenSetStream(handle_, hip_stream); @@ -478,7 +479,7 @@ class MIOpenAccess { private: // Guards the enqueueing of MIOpen operations via the handle_ below. - mutex mutex_; + absl::Mutex mutex_; // MIOpen library handle. miopenHandle_t handle_ GUARDED_BY(mutex_); // Owned. diff --git a/tensorflow/stream_executor/rocm/rocm_dnn.h b/tensorflow/stream_executor/rocm/rocm_dnn.h index 643f9b3a73c..125acf7b474 100644 --- a/tensorflow/stream_executor/rocm/rocm_dnn.h +++ b/tensorflow/stream_executor/rocm/rocm_dnn.h @@ -19,9 +19,9 @@ limitations under the License. #ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DNN_H_ #define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DNN_H_ +#include "absl/synchronization/mutex.h" #include "tensorflow/stream_executor/dnn.h" #include "tensorflow/stream_executor/lib/status.h" -#include "tensorflow/stream_executor/platform/mutex.h" #include "tensorflow/stream_executor/platform/thread_annotations.h" #include "tensorflow/stream_executor/plugin_registry.h" #include "tensorflow/stream_executor/temporary_device_memory.h" diff --git a/tensorflow/stream_executor/rocm/rocm_driver.cc b/tensorflow/stream_executor/rocm/rocm_driver.cc index 5f958e619db..8547c9f07f5 100644 --- a/tensorflow/stream_executor/rocm/rocm_driver.cc +++ b/tensorflow/stream_executor/rocm/rocm_driver.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" #include "absl/synchronization/notification.h" #include "tensorflow/stream_executor/gpu/gpu_diagnostics.h" #include "tensorflow/stream_executor/gpu/gpu_driver.h" @@ -34,7 +35,6 @@ limitations under the License. #include "tensorflow/stream_executor/lib/static_threadlocal.h" #include "tensorflow/stream_executor/lib/threadpool.h" #include "tensorflow/stream_executor/platform/logging.h" -#include "tensorflow/stream_executor/platform/mutex.h" #include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/rocm/rocm_driver_wrapper.h" @@ -305,17 +305,10 @@ static port::Status InternalInit() { /* static */ port::Status GpuDriver::Init() { // Cached return value from calling InternalInit(), as hipInit need only be // called once, but GpuDriver::Init may be called many times. - static port::Status init_retval; - static bool set = false; - static mutex* init_mu = new mutex; - - mutex_lock lock(*init_mu); - if (!set) { - init_retval = InternalInit(); - set = true; - } - - return init_retval; + static port::Status* init_retval = [&] { + init_retval = new Status(InternalInit()); + }(); + return *init_retval; } /* static */ port::Status GpuDriver::GetDevice(int device_ordinal, diff --git a/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc b/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc index b8940e04a5b..61eb19003ab 100644 --- a/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc +++ b/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc @@ -114,7 +114,7 @@ GpuExecutor::~GpuExecutor() { } bool GpuExecutor::UnloadModule(ModuleHandle module_handle) { const char* gpu_binary = reinterpret_cast(module_handle.id()); - mutex_lock lock{in_memory_modules_mu_}; + absl::MutexLock lock{&in_memory_modules_mu_}; return UnloadGpuBinary(gpu_binary); } @@ -236,7 +236,7 @@ bool GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, kernelname = &spec.cuda_cubin_in_memory().kernelname(); const char* hsaco = spec.cuda_cubin_in_memory().bytes(); - mutex_lock lock{in_memory_modules_mu_}; + absl::MutexLock lock{&in_memory_modules_mu_}; module = in_memory_modules_[hsaco]; if (module == nullptr) { @@ -294,7 +294,7 @@ bool GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, // whether we've done an occupancy check on this kernel before isn't free // (because we have to synchronize), so we only do this at -v 2+. if (VLOG_IS_ON(2)) { - mutex_lock lock(launched_kernels_mu_); + absl::MutexLock lock(&launched_kernels_mu_); if (!launched_kernels_.count(hipfunc)) { VlogOccupancyInfo(kernel, thread_dims, block_dims); // TODO(rspringer): Remove elements from launched_kernels_...if we ever @@ -366,7 +366,7 @@ bool GpuExecutor::LoadModule(const MultiModuleLoaderSpec& spec, hipModule_t hip_module = nullptr; // TODO(ROCm): Need generic term instead of cubin/cuda/ptx if (spec.has_cuda_cubin_in_memory()) { - mutex_lock lock{in_memory_modules_mu_}; + absl::MutexLock lock{&in_memory_modules_mu_}; if (!LoadModuleFromHsaco( reinterpret_cast(spec.cuda_cubin_in_memory().data()), &hip_module)) { @@ -765,8 +765,8 @@ bool GpuExecutor::DeviceMemoryUsage(int64* free, int64* total) const { bool GpuExecutor::GetSymbol(const string& symbol_name, ModuleHandle module_handle, void** mem, size_t* bytes) { - { // give limited scope to mutex_lock - mutex_lock lock{disk_modules_mu_}; + { // give limited scope to lock + absl::MutexLock lock{&disk_modules_mu_}; for (auto& it : disk_modules_) { if (GpuDriver::GetModuleSymbol(context_, it.second, symbol_name.c_str(), reinterpret_cast(mem), @@ -776,8 +776,8 @@ bool GpuExecutor::GetSymbol(const string& symbol_name, } } - { // give limited scope to mutex_lock - mutex_lock lock{in_memory_modules_mu_}; + { // give limited scope to lock + absl::MutexLock lock{&in_memory_modules_mu_}; for (auto& it : in_memory_modules_) { if (GpuDriver::GetModuleSymbol(context_, it.second, symbol_name.c_str(), reinterpret_cast(mem), @@ -787,8 +787,8 @@ bool GpuExecutor::GetSymbol(const string& symbol_name, } } - { // give limited scope to mutex_lock - mutex_lock lock{in_memory_modules_mu_}; + { // give limited scope to lock + absl::MutexLock lock{&in_memory_modules_mu_}; if (static_cast(module_handle)) { auto it = gpu_binary_to_module_.find(module_handle.id()); CHECK(it != gpu_binary_to_module_.end()); diff --git a/tensorflow/stream_executor/rocm/rocm_platform.cc b/tensorflow/stream_executor/rocm/rocm_platform.cc index fa286f971d5..156ec8de0fc 100644 --- a/tensorflow/stream_executor/rocm/rocm_platform.cc +++ b/tensorflow/stream_executor/rocm/rocm_platform.cc @@ -38,30 +38,25 @@ ROCmPlatform::~ROCmPlatform() {} void ROCmPlatform::InspectNumaNodes() { // To get NUMA node information, we need to create all executors, so we can // examine their device descriptions to see their bus assignments. - static bool initialized = false; - static mutex numa_mutex(LINKER_INITIALIZED); - mutex_lock lock(numa_mutex); - if (initialized) { - return; - } - - StreamExecutorConfig config; - for (int i = 0; i < VisibleDeviceCount(); i++) { - config.ordinal = i; - StreamExecutor* exec = GetExecutor(config).ValueOrDie(); - if (i == 0) { - // NUMA nodes may not start at 0, so set the minimum node based on the - // first executor we see. - min_numa_node_ = exec->GetDeviceDescription().numa_node(); - limit_numa_node_ = min_numa_node_ + 1; - } else { - min_numa_node_ = - std::min(min_numa_node_, exec->GetDeviceDescription().numa_node()); - limit_numa_node_ = std::max(limit_numa_node_, - exec->GetDeviceDescription().numa_node() + 1); + std::once_flag once; + std::call_once(once, [&] { + StreamExecutorConfig config; + for (int i = 0; i < VisibleDeviceCount(); i++) { + config.ordinal = i; + StreamExecutor* exec = GetExecutor(config).ValueOrDie(); + if (i == 0) { + // NUMA nodes may not start at 0, so set the minimum node based on the + // first executor we see. + min_numa_node_ = exec->GetDeviceDescription().numa_node(); + limit_numa_node_ = min_numa_node_ + 1; + } else { + min_numa_node_ = + std::min(min_numa_node_, exec->GetDeviceDescription().numa_node()); + limit_numa_node_ = std::max( + limit_numa_node_, exec->GetDeviceDescription().numa_node() + 1); + } } - } - initialized = true; + }); } int ROCmPlatform::BusCount() { diff --git a/tensorflow/stream_executor/rocm/rocm_platform.h b/tensorflow/stream_executor/rocm/rocm_platform.h index ec9577cb301..c4838eebeb2 100644 --- a/tensorflow/stream_executor/rocm/rocm_platform.h +++ b/tensorflow/stream_executor/rocm/rocm_platform.h @@ -19,11 +19,11 @@ limitations under the License. #include #include +#include "absl/synchronization/mutex.h" #include "tensorflow/stream_executor/executor_cache.h" #include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/multi_platform_manager.h" #include "tensorflow/stream_executor/platform.h" -#include "tensorflow/stream_executor/platform/mutex.h" #include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/thread_annotations.h" #include "tensorflow/stream_executor/stream_executor_internal.h" @@ -90,7 +90,7 @@ class ROCmPlatform : public Platform { string name_; // mutex that guards internal state. - mutable mutex mu_; + mutable absl::Mutex mu_; // Cache of created executors. ExecutorCache executor_cache_; diff --git a/tensorflow/stream_executor/rocm/rocm_rng.cc b/tensorflow/stream_executor/rocm/rocm_rng.cc index 99bfc49d10f..38f4f8bb0c6 100644 --- a/tensorflow/stream_executor/rocm/rocm_rng.cc +++ b/tensorflow/stream_executor/rocm/rocm_rng.cc @@ -126,7 +126,7 @@ GpuRng::~GpuRng() { } bool GpuRng::Init() { - mutex_lock lock{mu_}; + absl::MutexLock lock{&mu_}; CHECK(rng_ == nullptr); hiprandStatus_t ret = @@ -161,7 +161,7 @@ constexpr bool ComplexIsConsecutiveFloats() { template bool GpuRng::DoPopulateRandUniformInternal(Stream* stream, DeviceMemory* v) { - mutex_lock lock{mu_}; + absl::MutexLock lock{&mu_}; static_assert(ComplexIsConsecutiveFloats(), "std::complex values are not stored as consecutive values"); @@ -220,7 +220,7 @@ bool GpuRng::DoPopulateRandGaussianInternal(Stream* stream, ElemT mean, ElemT stddev, DeviceMemory* v, FuncT func) { - mutex_lock lock{mu_}; + absl::MutexLock lock{&mu_}; if (!SetStream(stream)) { return false; @@ -252,7 +252,7 @@ bool GpuRng::DoPopulateRandGaussian(Stream* stream, double mean, double stddev, } bool GpuRng::SetSeed(Stream* stream, const uint8* seed, uint64 seed_bytes) { - mutex_lock lock{mu_}; + absl::MutexLock lock{&mu_}; CHECK(rng_ != nullptr); if (!CheckSeed(seed, seed_bytes)) { diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index 2577d3825fb..8f3021929bc 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -290,7 +290,7 @@ port::Status Stream::RefreshStatus() { Stream &Stream::Init() { VLOG_CALL(); - mutex_lock lock(mu_); + absl::MutexLock lock(&mu_); CHECK_EQ(false, allocated_) << "stream appears to already have been initialized"; CHECK(!ok_) << "stream should be in !ok() state pre-initialization"; @@ -1736,7 +1736,7 @@ Stream &Stream::ThenMemcpyH2DQuantized( } Stream *Stream::GetOrCreateSubStream() { - mutex_lock lock(mu_); + absl::MutexLock lock(&mu_); // Look for the first reusable sub_stream that is ok, dropping !ok sub_streams // we encounter along the way. @@ -1783,7 +1783,7 @@ Stream *Stream::GetOrCreateSubStream() { } void Stream::ReturnSubStream(Stream *sub_stream) { - mutex_lock lock(mu_); + absl::MutexLock lock(&mu_); // Look for the sub-stream. for (int64 index = 0; index < sub_streams_.size(); ++index) { @@ -5285,7 +5285,7 @@ void Stream::CheckStatus(port::Status status) { return; } LOG(ERROR) << status; - mutex_lock lock(mu_); + absl::MutexLock lock(&mu_); ok_ = false; } diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h index 3e67d55922d..cdd3464719b 100644 --- a/tensorflow/stream_executor/stream.h +++ b/tensorflow/stream_executor/stream.h @@ -25,6 +25,7 @@ limitations under the License. #include #include +#include "absl/synchronization/mutex.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/stream_executor/blas.h" #include "tensorflow/stream_executor/device_memory.h" @@ -35,7 +36,6 @@ limitations under the License. #include "tensorflow/stream_executor/kernel.h" #include "tensorflow/stream_executor/launch_dim.h" #include "tensorflow/stream_executor/lib/array_slice.h" -#include "tensorflow/stream_executor/platform/mutex.h" #include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/thread_annotations.h" #include "tensorflow/stream_executor/temporary_memory_manager.h" @@ -1964,7 +1964,7 @@ class Stream { friend class ocl::CLBlas; // for parent_. bool InErrorState() const LOCKS_EXCLUDED(mu_) { - tf_shared_lock lock(mu_); + absl::ReaderMutexLock lock(&mu_); return !ok_; } @@ -1974,7 +1974,7 @@ class Stream { if (operation_retcode) { return; } - mutex_lock lock(mu_); + absl::MutexLock lock(&mu_); ok_ = false; } @@ -1998,7 +1998,7 @@ class Stream { // mutex that guards the allocation / error state flags. // Mutable so that it can be obtained via const reader lock. - mutable mutex mu_; + mutable absl::Mutex mu_; // Whether Init() was successfully called to allocate this stream on the // underlying platform. It simply flips from 0 to 1 with a sanity check. diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc index e708643c9fe..97d77c04ea7 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.cc +++ b/tensorflow/stream_executor/stream_executor_pimpl.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/base/const_init.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/synchronization/notification.h" @@ -92,7 +93,7 @@ class ScopedTracer { void Trace(CallbackT callback, TraceArgsT... args) { { // Instance tracers held in a block to limit the lock lifetime. - tf_shared_lock lock{stream_exec_->mu_}; + absl::ReaderMutexLock lock{&stream_exec_->mu_}; for (TraceListener *listener : stream_exec_->listeners_) { (listener->*callback)(correlation_id_, std::forward(args)...); @@ -121,7 +122,7 @@ MakeScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call, auto tracer = MakeScopedTracer(this, &LOC ## Begin, \ &LOC ## Complete, ## __VA_ARGS__); -/* static */ mutex StreamExecutor::static_mu_{LINKER_INITIALIZED}; +/* static */ absl::Mutex StreamExecutor::static_mu_{absl::kConstInit}; // Get per-device memory limit in bytes. Returns 0 if // TF_PER_DEVICE_MEMORY_LIMIT_MB environment variable is not set. @@ -217,7 +218,7 @@ void StreamExecutor::Deallocate(DeviceMemoryBase *mem) { } void StreamExecutor::GetMemAllocs(std::map *records_out) { - tf_shared_lock lock(mu_); + absl::ReaderMutexLock lock(&mu_); *records_out = mem_allocs_; } @@ -247,7 +248,7 @@ port::Status StreamExecutor::SetDeviceSharedMemoryConfig( } const DeviceDescription &StreamExecutor::GetDeviceDescription() const { - mutex_lock lock(mu_); + absl::MutexLock lock(&mu_); if (device_description_ != nullptr) { return *device_description_; } @@ -393,7 +394,7 @@ StreamExecutor::createRnnStateTensorDescriptor(int num_layer, int batch_size, } dnn::DnnSupport *StreamExecutor::AsDnn() { - mutex_lock lock(mu_); + absl::MutexLock lock(&mu_); if (dnn_ != nullptr) { return dnn_.get(); } @@ -403,7 +404,7 @@ dnn::DnnSupport *StreamExecutor::AsDnn() { } blas::BlasSupport *StreamExecutor::AsBlas() { - mutex_lock lock(mu_); + absl::MutexLock lock(&mu_); if (blas_ != nullptr) { return blas_.get(); } @@ -413,7 +414,7 @@ blas::BlasSupport *StreamExecutor::AsBlas() { } fft::FftSupport *StreamExecutor::AsFft() { - mutex_lock lock(mu_); + absl::MutexLock lock(&mu_); if (fft_ != nullptr) { return fft_.get(); } @@ -423,7 +424,7 @@ fft::FftSupport *StreamExecutor::AsFft() { } rng::RngSupport *StreamExecutor::AsRng() { - mutex_lock lock(mu_); + absl::MutexLock lock(&mu_); if (rng_ != nullptr) { return rng_.get(); } @@ -781,7 +782,7 @@ void StreamExecutor::EnqueueOnBackgroundThread(std::function task) { void StreamExecutor::CreateAllocRecord(void *opaque, uint64 bytes) { if (FLAGS_check_device_leaks && opaque != nullptr && bytes != 0) { - mutex_lock lock(mu_); + absl::MutexLock lock(&mu_); mem_allocs_[opaque] = AllocRecord{ bytes, ""}; mem_alloc_bytes_ += bytes; @@ -790,7 +791,7 @@ void StreamExecutor::CreateAllocRecord(void *opaque, uint64 bytes) { void StreamExecutor::EraseAllocRecord(void *opaque) { if (FLAGS_check_device_leaks && opaque != nullptr) { - mutex_lock lock(mu_); + absl::MutexLock lock(&mu_); if (mem_allocs_.find(opaque) == mem_allocs_.end()) { LOG(ERROR) << "Deallocating unknown pointer: " << opaque; } else { @@ -804,7 +805,7 @@ void StreamExecutor::EnableTracing(bool enabled) { tracing_enabled_ = enabled; } void StreamExecutor::RegisterTraceListener(TraceListener *listener) { { - mutex_lock lock(mu_); + absl::MutexLock lock(&mu_); if (listeners_.find(listener) != listeners_.end()) { LOG(INFO) << "Attempt to register already-registered listener, " << listener; @@ -818,7 +819,7 @@ void StreamExecutor::RegisterTraceListener(TraceListener *listener) { bool StreamExecutor::UnregisterTraceListener(TraceListener *listener) { { - mutex_lock lock(mu_); + absl::MutexLock lock(&mu_); if (listeners_.find(listener) == listeners_.end()) { LOG(INFO) << "Attempt to unregister unknown listener, " << listener; return false; @@ -839,7 +840,7 @@ void StreamExecutor::SubmitTrace(TraceCallT trace_call, ArgsT &&... args) { if (tracing_enabled_) { { // instance tracers held in a block to limit the lock lifetime. - tf_shared_lock lock(mu_); + absl::ReaderMutexLock lock(&mu_); for (TraceListener *listener : listeners_) { (listener->*trace_call)(std::forward(args)...); } diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h index 1a3ac970f12..4c3aeda0db4 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.h +++ b/tensorflow/stream_executor/stream_executor_pimpl.h @@ -23,13 +23,13 @@ limitations under the License. #include #include "absl/base/macros.h" +#include "absl/synchronization/mutex.h" #include "absl/types/optional.h" #include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/lib/threadpool.h" #include "tensorflow/stream_executor/platform.h" #include "tensorflow/stream_executor/platform/logging.h" -#include "tensorflow/stream_executor/platform/mutex.h" #include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/thread_annotations.h" #include "tensorflow/stream_executor/rng.h" @@ -619,13 +619,13 @@ class StreamExecutor { void SubmitTrace(TraceCallT trace_call, ArgsT&&... args); // Reader/writer lock for class-static StreamExecutor members. - static mutex static_mu_; + static absl::Mutex static_mu_; // Reader/writer lock for mutable data structures on this StreamExecutor. // // Mutable so that caching functions (like DeviceDescription, AsBlas, etc.) // can acquire the lock on their first (mutating) call as well. - mutable mutex mu_; + mutable absl::Mutex mu_; // Reference to the platform that created this executor. const Platform *platform_; diff --git a/tensorflow/stream_executor/temporary_memory_manager.cc b/tensorflow/stream_executor/temporary_memory_manager.cc index f48e2b9c5e8..080e020499c 100644 --- a/tensorflow/stream_executor/temporary_memory_manager.cc +++ b/tensorflow/stream_executor/temporary_memory_manager.cc @@ -25,7 +25,7 @@ namespace stream_executor { namespace internal { void TemporaryMemoryManager::ForceDeallocateAll() { - mutex_lock lock(mutex_); + absl::MutexLock lock(&mutex_); VLOG(1) << "force-deallocating " << records_.size() << " remaining records"; for (auto it = records_.begin(); it != records_.end(); ++it) { DeviceMemoryBase device_memory = it->first; @@ -35,7 +35,7 @@ void TemporaryMemoryManager::ForceDeallocateAll() { void TemporaryMemoryManager::MarkFinalized( const DeviceMemoryBase& device_memory, uint64 generation, bool must_exist) { - mutex_lock lock(mutex_); + absl::MutexLock lock(&mutex_); auto it = records_.find(device_memory); if (it == records_.end()) { if (must_exist) { @@ -48,7 +48,7 @@ void TemporaryMemoryManager::MarkFinalized( } void TemporaryMemoryManager::DeallocateFinalizedTemporaries() { - mutex_lock lock(mutex_); + absl::MutexLock lock(&mutex_); int deallocated_count = 0; for (auto it = records_.begin(); it != records_.end();) { if (it->second.finalized) { @@ -65,7 +65,7 @@ void TemporaryMemoryManager::DeallocateFinalizedTemporaries() { bool TemporaryMemoryManager::IsFinalized(const DeviceMemoryBase& device_memory, uint64 allocation_generation) const { - mutex_lock lock(mutex_); + absl::MutexLock lock(&mutex_); auto it = records_.find(device_memory); if (it == records_.end()) { return true; // If there's no record present it's vacuously finalized. @@ -81,7 +81,7 @@ bool TemporaryMemoryManager::IsFinalized(const DeviceMemoryBase& device_memory, bool TemporaryMemoryManager::HasAllocated(const DeviceMemoryBase& device_memory, uint64 generation) const { - mutex_lock lock(mutex_); + absl::MutexLock lock(&mutex_); auto it = records_.find(device_memory); if (it == records_.end()) { return false; @@ -106,7 +106,7 @@ TemporaryMemoryManager::AllocateArrayBase(uint64 element_count, // Add the record before instantiating the device memory instance so we can // check the allocation invariant at TemporaryDeviceMemory construction time. { - mutex_lock lock(mutex_); + absl::MutexLock lock(&mutex_); generation = ++generation_; DCHECK(records_.find(device_memory) == records_.end()); records_[device_memory] = {generation, diff --git a/tensorflow/stream_executor/temporary_memory_manager.h b/tensorflow/stream_executor/temporary_memory_manager.h index faf13380dc2..12d4d4b8de2 100644 --- a/tensorflow/stream_executor/temporary_memory_manager.h +++ b/tensorflow/stream_executor/temporary_memory_manager.h @@ -24,10 +24,10 @@ limitations under the License. #include #include +#include "absl/synchronization/mutex.h" #include "tensorflow/stream_executor/device_memory.h" #include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/statusor.h" -#include "tensorflow/stream_executor/platform/mutex.h" #include "tensorflow/stream_executor/platform/thread_annotations.h" #include "tensorflow/stream_executor/temporary_device_memory.h" @@ -108,7 +108,7 @@ class TemporaryMemoryManager { uint64 element_count, uint64 element_size); // Mutex to guard temporary record state. - mutable mutex mutex_; + mutable absl::Mutex mutex_; // Mapping from device memory to the current (live) temporary memory record. //