[SE] Use absl mutex instead of TF mutex.

Fix layering violation.

PiperOrigin-RevId: 245851033
This commit is contained in:
Justin Lebar 2019-04-29 16:51:22 -07:00 committed by TensorFlower Gardener
parent 9904676451
commit 481748b5c9
45 changed files with 207 additions and 279 deletions

View File

@ -93,6 +93,7 @@ cc_library(
"//tensorflow/stream_executor/platform", "//tensorflow/stream_executor/platform",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
], ],
@ -140,6 +141,7 @@ cc_library(
"//tensorflow/stream_executor/platform", "//tensorflow/stream_executor/platform",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
], ],
@ -152,6 +154,7 @@ cc_library(
deps = [ deps = [
"//tensorflow/stream_executor/platform", "//tensorflow/stream_executor/platform",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
], ],
) )
@ -199,6 +202,7 @@ cc_library(
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"//tensorflow/stream_executor/platform", "//tensorflow/stream_executor/platform",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
], ],
@ -231,6 +235,7 @@ cc_library(
"//tensorflow/stream_executor/platform", "//tensorflow/stream_executor/platform",
"//third_party/eigen3", "//third_party/eigen3",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
], ],
) )
@ -259,6 +264,7 @@ cc_library(
"//tensorflow/stream_executor/platform", "//tensorflow/stream_executor/platform",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
], ],
) )
@ -295,6 +301,7 @@ cc_library(
":stream_header", ":stream_header",
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"//tensorflow/stream_executor/platform", "//tensorflow/stream_executor/platform",
"@com_google_absl//absl/synchronization",
], ],
) )
@ -311,6 +318,7 @@ cc_library(
"//tensorflow/stream_executor/platform", "//tensorflow/stream_executor/platform",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
], ],
) )
@ -399,6 +407,7 @@ cc_library(
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
], ],
@ -440,10 +449,10 @@ cc_library(
":plugin", ":plugin",
":stream_executor_headers", ":stream_executor_headers",
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"//tensorflow/stream_executor/platform",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
], ],
) )
@ -536,6 +545,7 @@ cc_library(
"//tensorflow/stream_executor/platform", "//tensorflow/stream_executor/platform",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
], ],
@ -612,6 +622,7 @@ cc_library(
"//tensorflow/stream_executor/platform", "//tensorflow/stream_executor/platform",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
], ],

View File

@ -59,9 +59,11 @@ cc_library(
"//tensorflow/stream_executor/platform", "//tensorflow/stream_executor/platform",
], ],
) + tf_additional_cuda_platform_deps() + [ ) + tf_additional_cuda_platform_deps() + [
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
], ],
alwayslink = True, # Registers itself with the MultiPlatformManager. alwayslink = True, # Registers itself with the MultiPlatformManager.
) )
@ -112,6 +114,7 @@ cc_library(
"//tensorflow:using_cuda_clang_with_dynamic_build": ["cudart_stub"], "//tensorflow:using_cuda_clang_with_dynamic_build": ["cudart_stub"],
"//conditions:default": ["//tensorflow/core:cuda"], "//conditions:default": ["//tensorflow/core:cuda"],
}) + [ }) + [
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization", "@com_google_absl//absl/synchronization",
@ -223,7 +226,10 @@ cc_library(
] + if_static( ] + if_static(
["@local_config_cuda//cuda:cublas"], ["@local_config_cuda//cuda:cublas"],
[":cublas_stub"], [":cublas_stub"],
)) + ["@com_google_absl//absl/strings:str_format"], )) + [
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
],
alwayslink = True, alwayslink = True,
) )
@ -306,7 +312,7 @@ cc_library(
]) + tf_additional_cudnn_plugin_deps() + if_cuda_is_configured(if_static( ]) + tf_additional_cudnn_plugin_deps() + if_cuda_is_configured(if_static(
["@local_config_cuda//cuda:cudnn"], ["@local_config_cuda//cuda:cudnn"],
[":cudnn_stub"], [":cudnn_stub"],
)), )) + ["@com_google_absl//absl/synchronization"],
alwayslink = True, alwayslink = True,
) )

View File

@ -402,7 +402,7 @@ template <typename FuncT, typename... Args>
bool CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream, bool CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
bool pointer_mode_host, bool err_on_failure, bool pointer_mode_host, bool err_on_failure,
bool use_tensor_op_math, Args... args) { bool use_tensor_op_math, Args... args) {
mutex_lock lock(mu_); absl::MutexLock lock(&mu_);
CHECK(blas_ != nullptr); CHECK(blas_ != nullptr);
if (!SetStream(stream)) { if (!SetStream(stream)) {

View File

@ -20,9 +20,9 @@ limitations under the License.
#ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_ #ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_
#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_ #define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_
#include "absl/synchronization/mutex.h"
#include "tensorflow/stream_executor/blas.h" #include "tensorflow/stream_executor/blas.h"
#include "tensorflow/stream_executor/host_or_device_scalar.h" #include "tensorflow/stream_executor/host_or_device_scalar.h"
#include "tensorflow/stream_executor/platform/mutex.h"
#include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/platform/thread_annotations.h" #include "tensorflow/stream_executor/platform/thread_annotations.h"
#include "tensorflow/stream_executor/plugin_registry.h" #include "tensorflow/stream_executor/plugin_registry.h"
@ -142,8 +142,8 @@ class CUDABlas : public blas::BlasSupport {
const T &beta, DeviceMemory<T> *y, int incy, const T &beta, DeviceMemory<T> *y, int incy,
blas::ProfileResult *output_profile_result); blas::ProfileResult *output_profile_result);
// mutex that guards the cuBLAS handle for this device. // Guards the cuBLAS handle for this device.
mutex mu_; absl::Mutex mu_;
// GpuExecutor which instantiated this CUDABlas. // GpuExecutor which instantiated this CUDABlas.
// Immutable post-initialization. // Immutable post-initialization.

View File

@ -133,8 +133,8 @@ class CudnnHandle {
public: public:
// Takes ownership of the executor context and the lock to access cuDNN // Takes ownership of the executor context and the lock to access cuDNN
// using handle. // using handle.
CudnnHandle(gpu::ScopedActivateExecutorContext context, mutex_lock lock, CudnnHandle(gpu::ScopedActivateExecutorContext context,
cudnnHandle_t handle) std::unique_ptr<absl::MutexLock> lock, cudnnHandle_t handle)
: context_(std::move(context)), lock_(std::move(lock)), handle_(handle) {} : context_(std::move(context)), lock_(std::move(lock)), handle_(handle) {}
// Returns cuDNN handle. To be passed directly to cuDNN APIs, don't keep // Returns cuDNN handle. To be passed directly to cuDNN APIs, don't keep
@ -143,7 +143,7 @@ class CudnnHandle {
private: private:
gpu::ScopedActivateExecutorContext context_; gpu::ScopedActivateExecutorContext context_;
mutex_lock lock_; std::unique_ptr<absl::MutexLock> lock_;
cudnnHandle_t handle_; // Not owned. cudnnHandle_t handle_; // Not owned.
}; };
@ -160,7 +160,7 @@ class CudnnAccess {
explicit CudnnAccess(cudnnHandle_t handle) : handle_(handle) {} explicit CudnnAccess(cudnnHandle_t handle) : handle_(handle) {}
~CudnnAccess() { ~CudnnAccess() {
mutex_lock lock(mutex_); absl::MutexLock lock(&mutex_);
cudnnDestroy(handle_); cudnnDestroy(handle_);
} }
@ -181,7 +181,8 @@ class CudnnAccess {
// therefore a bad idea (performance wise) to call any cuDNN APIs that // therefore a bad idea (performance wise) to call any cuDNN APIs that
// enqueue work in the stream. // enqueue work in the stream.
CudnnHandle GetHandle(GpuExecutor* executor, Stream* stream) { CudnnHandle GetHandle(GpuExecutor* executor, Stream* stream) {
mutex_lock lock(mutex_); auto lock = absl::make_unique<absl::MutexLock>(&mutex_);
mutex_.AssertHeld();
gpu::ScopedActivateExecutorContext context(executor); gpu::ScopedActivateExecutorContext context(executor);
CUstream cu_stream = stream ? AsGpuStreamValue(stream) : cudaStreamLegacy; CUstream cu_stream = stream ? AsGpuStreamValue(stream) : cudaStreamLegacy;
const auto status = cudnnSetStream(handle_, cu_stream); const auto status = cudnnSetStream(handle_, cu_stream);
@ -191,7 +192,7 @@ class CudnnAccess {
private: private:
// Guards the enqueueing of cuDNN operations via the handle_ below. // Guards the enqueueing of cuDNN operations via the handle_ below.
mutex mutex_; absl::Mutex mutex_;
// cuDNN library handle. // cuDNN library handle.
cudnnHandle_t handle_ GUARDED_BY(mutex_); // Owned. cudnnHandle_t handle_ GUARDED_BY(mutex_); // Owned.

View File

@ -22,7 +22,6 @@ limitations under the License.
#include "tensorflow/stream_executor/cuda/cuda_activation.h" #include "tensorflow/stream_executor/cuda/cuda_activation.h"
#include "tensorflow/stream_executor/dnn.h" #include "tensorflow/stream_executor/dnn.h"
#include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/platform/mutex.h"
#include "tensorflow/stream_executor/platform/thread_annotations.h" #include "tensorflow/stream_executor/platform/thread_annotations.h"
#include "tensorflow/stream_executor/plugin_registry.h" #include "tensorflow/stream_executor/plugin_registry.h"
#include "tensorflow/stream_executor/temporary_device_memory.h" #include "tensorflow/stream_executor/temporary_device_memory.h"

View File

@ -23,11 +23,13 @@ limitations under the License.
#include <utility> #include <utility>
#include "absl/base/casts.h" #include "absl/base/casts.h"
#include "absl/base/const_init.h"
#include "absl/container/inlined_vector.h" #include "absl/container/inlined_vector.h"
#include "absl/debugging/leak_check.h" #include "absl/debugging/leak_check.h"
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "absl/synchronization/mutex.h"
#include "absl/synchronization/notification.h" #include "absl/synchronization/notification.h"
#include "cuda/include/cuda_runtime_api.h" #include "cuda/include/cuda_runtime_api.h"
#include "tensorflow/stream_executor/cuda/cuda_diagnostics.h" #include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
@ -38,7 +40,6 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/static_threadlocal.h" #include "tensorflow/stream_executor/lib/static_threadlocal.h"
#include "tensorflow/stream_executor/lib/threadpool.h" #include "tensorflow/stream_executor/lib/threadpool.h"
#include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/platform/logging.h"
#include "tensorflow/stream_executor/platform/mutex.h"
#include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/port.h"
bool FLAGS_gpuexec_cuda_driver_inject_init_error = false; bool FLAGS_gpuexec_cuda_driver_inject_init_error = false;
@ -65,14 +66,14 @@ class CreatedContexts {
public: public:
// Returns whether context is a member of the live set. // Returns whether context is a member of the live set.
static bool Has(CUcontext context) { static bool Has(CUcontext context) {
tf_shared_lock lock(mu_); absl::ReaderMutexLock lock(&mu_);
return Live()->find(context) != Live()->end(); return Live()->find(context) != Live()->end();
} }
// Adds context to the live set, or returns it if it's already present. // Adds context to the live set, or returns it if it's already present.
static GpuContext* Add(CUcontext context) { static GpuContext* Add(CUcontext context) {
CHECK(context != nullptr); CHECK(context != nullptr);
mutex_lock lock(mu_); absl::MutexLock lock(&mu_);
auto insert_result = Live()->insert(std::make_pair(context, nullptr)); auto insert_result = Live()->insert(std::make_pair(context, nullptr));
auto it = insert_result.first; auto it = insert_result.first;
if (insert_result.second) { if (insert_result.second) {
@ -85,7 +86,7 @@ class CreatedContexts {
// Removes context from the live set. // Removes context from the live set.
static void Remove(CUcontext context) { static void Remove(CUcontext context) {
CHECK(context != nullptr); CHECK(context != nullptr);
mutex_lock lock(mu_); absl::MutexLock lock(&mu_);
auto it = Live()->find(context); auto it = Live()->find(context);
CHECK(it != Live()->end()) << context; CHECK(it != Live()->end()) << context;
Live()->erase(it); Live()->erase(it);
@ -100,11 +101,11 @@ class CreatedContexts {
} }
// Lock that guards access-to/mutation-of the live set. // Lock that guards access-to/mutation-of the live set.
static mutex mu_; static absl::Mutex mu_;
static int64 next_id_; static int64 next_id_;
}; };
/* static */ mutex CreatedContexts::mu_{LINKER_INITIALIZED}; /* static */ absl::Mutex CreatedContexts::mu_{absl::kConstInit};
/* static */ int64 CreatedContexts::next_id_ = 1; // 0 means "no context" /* static */ int64 CreatedContexts::next_id_ = 1; // 0 means "no context"
// Formats CUresult to output prettified values into a log stream. // Formats CUresult to output prettified values into a log stream.
@ -322,17 +323,10 @@ static port::Status InternalInit() {
/* static */ port::Status GpuDriver::Init() { /* static */ port::Status GpuDriver::Init() {
// Cached return value from calling InternalInit(), as cuInit need only be // Cached return value from calling InternalInit(), as cuInit need only be
// called once, but GpuDriver::Init may be called many times. // called once, but GpuDriver::Init may be called many times.
static port::Status init_retval; static port::Status* init_retval = [] {
static bool set = false; return new port::Status(InternalInit());
static mutex* init_mu = new mutex; }();
return *init_retval;
mutex_lock lock(*init_mu);
if (!set) {
init_retval = InternalInit();
set = true;
}
return init_retval;
} }
/* static */ port::Status GpuDriver::GetDevice(int device_ordinal, /* static */ port::Status GpuDriver::GetDevice(int device_ordinal,

View File

@ -272,7 +272,7 @@ bool GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec,
VLOG(3) << "GetKernel on kernel " << kernel << " : " << kernel->name(); VLOG(3) << "GetKernel on kernel " << kernel << " : " << kernel->name();
if (spec.has_cuda_cubin_in_memory()) { if (spec.has_cuda_cubin_in_memory()) {
mutex_lock lock{in_memory_modules_mu_}; absl::MutexLock lock{&in_memory_modules_mu_};
kernelname = &spec.cuda_cubin_in_memory().kernelname(); kernelname = &spec.cuda_cubin_in_memory().kernelname();
const char *cubin = spec.cuda_cubin_in_memory().bytes(); const char *cubin = spec.cuda_cubin_in_memory().bytes();
if (!LoadModuleFromCuBin(cubin, &module)) { if (!LoadModuleFromCuBin(cubin, &module)) {
@ -295,7 +295,7 @@ bool GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec,
return false; return false;
} }
mutex_lock lock{in_memory_modules_mu_}; absl::MutexLock lock{&in_memory_modules_mu_};
if (!LoadModuleFromPtx(ptx, &module)) { if (!LoadModuleFromPtx(ptx, &module)) {
return false; return false;
} }
@ -343,7 +343,7 @@ bool GpuExecutor::UnloadGpuBinary(const void* gpu_binary) {
void GpuExecutor::UnloadKernel(const KernelBase* kernel) { void GpuExecutor::UnloadKernel(const KernelBase* kernel) {
VLOG(3) << "Unloading kernel " << kernel << " : " << kernel->name(); VLOG(3) << "Unloading kernel " << kernel << " : " << kernel->name();
mutex_lock lock{in_memory_modules_mu_}; absl::MutexLock lock{&in_memory_modules_mu_};
auto gpu_binary_it = kernel_to_gpu_binary_.find(kernel); auto gpu_binary_it = kernel_to_gpu_binary_.find(kernel);
if (kernel_to_gpu_binary_.end() == gpu_binary_it) { if (kernel_to_gpu_binary_.end() == gpu_binary_it) {
VLOG(3) << "Kernel " << kernel << " : " << kernel->name() VLOG(3) << "Kernel " << kernel << " : " << kernel->name()
@ -362,7 +362,7 @@ bool GpuExecutor::LoadModule(const MultiModuleLoaderSpec& spec,
// ModuleHandle::id(). // ModuleHandle::id().
CUmodule cu_module; CUmodule cu_module;
if (spec.has_cuda_cubin_in_memory()) { if (spec.has_cuda_cubin_in_memory()) {
mutex_lock lock{in_memory_modules_mu_}; absl::MutexLock lock{&in_memory_modules_mu_};
if (!LoadModuleFromCuBin( if (!LoadModuleFromCuBin(
reinterpret_cast<const char *>(spec.cuda_cubin_in_memory().data()), reinterpret_cast<const char *>(spec.cuda_cubin_in_memory().data()),
&cu_module)) { &cu_module)) {
@ -380,7 +380,7 @@ bool GpuExecutor::LoadModule(const MultiModuleLoaderSpec& spec,
return false; return false;
} }
mutex_lock lock{in_memory_modules_mu_}; absl::MutexLock lock{&in_memory_modules_mu_};
if (!LoadModuleFromPtx(spec.cuda_ptx_in_memory(), &cu_module)) { if (!LoadModuleFromPtx(spec.cuda_ptx_in_memory(), &cu_module)) {
return false; return false;
} }
@ -394,7 +394,7 @@ bool GpuExecutor::LoadModule(const MultiModuleLoaderSpec& spec,
bool GpuExecutor::UnloadModule(ModuleHandle module_handle) { bool GpuExecutor::UnloadModule(ModuleHandle module_handle) {
const char *gpu_binary = reinterpret_cast<const char *>(module_handle.id()); const char *gpu_binary = reinterpret_cast<const char *>(module_handle.id());
mutex_lock lock{in_memory_modules_mu_}; absl::MutexLock lock{&in_memory_modules_mu_};
return UnloadGpuBinary(gpu_binary); return UnloadGpuBinary(gpu_binary);
} }
@ -428,7 +428,7 @@ bool GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims,
// whether we've done an occupancy check on this kernel before isn't free // whether we've done an occupancy check on this kernel before isn't free
// (because we have to synchronize), so we only do this at -v 2+. // (because we have to synchronize), so we only do this at -v 2+.
if (VLOG_IS_ON(2)) { if (VLOG_IS_ON(2)) {
mutex_lock lock(launched_kernels_mu_); absl::MutexLock lock(&launched_kernels_mu_);
if (!launched_kernels_.count(cufunc)) { if (!launched_kernels_.count(cufunc)) {
VlogOccupancyInfo(kernel, thread_dims, block_dims); VlogOccupancyInfo(kernel, thread_dims, block_dims);
// TODO(rspringer): Remove elements from launched_kernels_...if we ever // TODO(rspringer): Remove elements from launched_kernels_...if we ever
@ -892,7 +892,7 @@ bool GpuExecutor::GetSymbol(const string& symbol_name,
}; };
{ // give limited scope to mutex_lock { // give limited scope to mutex_lock
mutex_lock lock{in_memory_modules_mu_}; absl::MutexLock lock{&in_memory_modules_mu_};
if (static_cast<bool>(module_handle)) { if (static_cast<bool>(module_handle)) {
auto it = gpu_binary_to_module_.find(module_handle.id()); auto it = gpu_binary_to_module_.find(module_handle.id());
CHECK(it != gpu_binary_to_module_.end()); CHECK(it != gpu_binary_to_module_.end());

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/stream_executor/cuda/cuda_platform.h" #include "tensorflow/stream_executor/cuda/cuda_platform.h"
#include "absl/base/const_init.h"
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
@ -75,13 +76,8 @@ CudaPlatform::~CudaPlatform() {}
void CudaPlatform::InspectNumaNodes() { void CudaPlatform::InspectNumaNodes() {
// To get NUMA node information, we need to create all executors, so we can // To get NUMA node information, we need to create all executors, so we can
// examine their device descriptions to see their bus assignments. // examine their device descriptions to see their bus assignments.
static bool initialized = false; static std::once_flag once;
static mutex numa_mutex(LINKER_INITIALIZED); std::call_once(once, [&] {
mutex_lock lock(numa_mutex);
if (initialized) {
return;
}
StreamExecutorConfig config; StreamExecutorConfig config;
for (int i = 0; i < VisibleDeviceCount(); i++) { for (int i = 0; i < VisibleDeviceCount(); i++) {
config.ordinal = i; config.ordinal = i;
@ -94,11 +90,11 @@ void CudaPlatform::InspectNumaNodes() {
} else { } else {
min_numa_node_ = min_numa_node_ =
std::min(min_numa_node_, exec->GetDeviceDescription().numa_node()); std::min(min_numa_node_, exec->GetDeviceDescription().numa_node());
limit_numa_node_ = std::max(limit_numa_node_, limit_numa_node_ = std::max(
exec->GetDeviceDescription().numa_node() + 1); limit_numa_node_, exec->GetDeviceDescription().numa_node() + 1);
} }
} }
initialized = true; });
} }
int CudaPlatform::BusCount() { int CudaPlatform::BusCount() {

View File

@ -24,7 +24,6 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/multi_platform_manager.h" #include "tensorflow/stream_executor/multi_platform_manager.h"
#include "tensorflow/stream_executor/platform.h" #include "tensorflow/stream_executor/platform.h"
#include "tensorflow/stream_executor/platform/mutex.h"
#include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/platform/thread_annotations.h" #include "tensorflow/stream_executor/platform/thread_annotations.h"
#include "tensorflow/stream_executor/stream_executor_internal.h" #include "tensorflow/stream_executor/stream_executor_internal.h"

View File

@ -71,7 +71,7 @@ GpuRng::~GpuRng() {
} }
bool GpuRng::Init() { bool GpuRng::Init() {
mutex_lock lock(mu_); absl::MutexLock lock(&mu_);
CHECK(rng_ == nullptr); CHECK(rng_ == nullptr);
cuda::ScopedActivateExecutorContext sac(parent_); cuda::ScopedActivateExecutorContext sac(parent_);
@ -106,7 +106,7 @@ constexpr bool ComplexIsConsecutiveFloats() {
template <typename T> template <typename T>
bool GpuRng::DoPopulateRandUniformInternal(Stream* stream, DeviceMemory<T>* v) { bool GpuRng::DoPopulateRandUniformInternal(Stream* stream, DeviceMemory<T>* v) {
mutex_lock lock(mu_); absl::MutexLock lock(&mu_);
static_assert(ComplexIsConsecutiveFloats(), static_assert(ComplexIsConsecutiveFloats(),
"std::complex values are not stored as consecutive values"); "std::complex values are not stored as consecutive values");
@ -164,7 +164,7 @@ bool GpuRng::DoPopulateRandGaussianInternal(Stream* stream, ElemT mean,
ElemT stddev, ElemT stddev,
DeviceMemory<ElemT>* v, DeviceMemory<ElemT>* v,
FuncT func) { FuncT func) {
mutex_lock lock(mu_); absl::MutexLock lock(&mu_);
if (!SetStream(stream)) { if (!SetStream(stream)) {
return false; return false;
@ -197,7 +197,7 @@ bool GpuRng::DoPopulateRandGaussian(Stream* stream, double mean, double stddev,
} }
bool GpuRng::SetSeed(Stream* stream, const uint8* seed, uint64 seed_bytes) { bool GpuRng::SetSeed(Stream* stream, const uint8* seed, uint64 seed_bytes) {
mutex_lock lock(mu_); absl::MutexLock lock(&mu_);
CHECK(rng_ != nullptr); CHECK(rng_ != nullptr);
if (!CheckSeed(seed, seed_bytes)) { if (!CheckSeed(seed, seed_bytes)) {

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/stream_executor/executor_cache.h" #include "tensorflow/stream_executor/executor_cache.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "absl/synchronization/mutex.h"
namespace stream_executor { namespace stream_executor {
@ -32,7 +33,7 @@ port::StatusOr<StreamExecutor*> ExecutorCache::GetOrCreate(
Entry* entry = nullptr; Entry* entry = nullptr;
{ {
mutex_lock lock{mutex_}; absl::MutexLock lock{&mutex_};
entry = &cache_[config.ordinal]; entry = &cache_[config.ordinal];
// Release the map lock; the address of 'entry' is stable because // Release the map lock; the address of 'entry' is stable because
// std::map guarantees reference stability. // std::map guarantees reference stability.
@ -41,7 +42,7 @@ port::StatusOr<StreamExecutor*> ExecutorCache::GetOrCreate(
// Acquire the per-Entry mutex without holding the map mutex. Initializing // Acquire the per-Entry mutex without holding the map mutex. Initializing
// an Executor may be expensive, so we want to allow concurrent // an Executor may be expensive, so we want to allow concurrent
// initialization of different entries. // initialization of different entries.
mutex_lock lock{entry->configurations_mutex}; absl::MutexLock lock{&entry->configurations_mutex};
for (const auto& iter : entry->configurations) { for (const auto& iter : entry->configurations) {
if (iter.first.plugin_config == config.plugin_config && if (iter.first.plugin_config == config.plugin_config &&
iter.first.device_options == config.device_options) { iter.first.device_options == config.device_options) {
@ -66,7 +67,7 @@ port::StatusOr<StreamExecutor*> ExecutorCache::Get(
const StreamExecutorConfig& config) { const StreamExecutorConfig& config) {
Entry* entry = nullptr; Entry* entry = nullptr;
{ {
tf_shared_lock lock{mutex_}; absl::ReaderMutexLock lock{&mutex_};
auto it = cache_.find(config.ordinal); auto it = cache_.find(config.ordinal);
if (it != cache_.end()) { if (it != cache_.end()) {
entry = &it->second; entry = &it->second;
@ -77,7 +78,7 @@ port::StatusOr<StreamExecutor*> ExecutorCache::Get(
config.ordinal)); config.ordinal));
} }
} }
tf_shared_lock lock{entry->configurations_mutex}; absl::ReaderMutexLock lock{&entry->configurations_mutex};
if (entry->configurations.empty()) { if (entry->configurations.empty()) {
return port::Status( return port::Status(
port::error::NOT_FOUND, port::error::NOT_FOUND,
@ -96,12 +97,12 @@ port::StatusOr<StreamExecutor*> ExecutorCache::Get(
} }
void ExecutorCache::DestroyAllExecutors() { void ExecutorCache::DestroyAllExecutors() {
mutex_lock lock{mutex_}; absl::MutexLock lock{&mutex_};
cache_.clear(); cache_.clear();
} }
ExecutorCache::Entry::~Entry() { ExecutorCache::Entry::~Entry() {
mutex_lock lock{configurations_mutex}; absl::MutexLock lock{&configurations_mutex};
configurations.clear(); configurations.clear();
} }

View File

@ -19,9 +19,9 @@ limitations under the License.
#include <functional> #include <functional>
#include <map> #include <map>
#include "absl/synchronization/mutex.h"
#include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/platform/mutex.h"
#include "tensorflow/stream_executor/stream_executor_pimpl.h" #include "tensorflow/stream_executor/stream_executor_pimpl.h"
namespace stream_executor { namespace stream_executor {
@ -54,11 +54,11 @@ class ExecutorCache {
struct Entry { struct Entry {
~Entry(); ~Entry();
// Mutex that locks the contents of each entry. The 'mutex_' of the // Mutex that guards the contents of each entry. The 'mutex_' of the
// ExecutorCache class protects both the 'cache_' and the existence of each // ExecutorCache class protects both the 'cache_' and the existence of each
// Entry, but not the Entry's contents. 'configurations_mutex' protects the // Entry, but not the Entry's contents. 'configurations_mutex' protects the
// contents of the entry after 'mutex_' has been dropped. // contents of the entry after 'mutex_' has been dropped.
mutex configurations_mutex; absl::Mutex configurations_mutex;
// Vector of cached {config, executor} pairs. // Vector of cached {config, executor} pairs.
std::vector< std::vector<
@ -69,7 +69,7 @@ class ExecutorCache {
// Maps ordinal number to a list of cached executors for that ordinal. // Maps ordinal number to a list of cached executors for that ordinal.
// We key off of ordinal (instead of just looking up all fields in the // We key off of ordinal (instead of just looking up all fields in the
// StreamExecutorConfig) for a slight improvement in lookup time. // StreamExecutorConfig) for a slight improvement in lookup time.
mutex mutex_; absl::Mutex mutex_;
std::map<int, Entry> cache_ GUARDED_BY(mutex_); std::map<int, Entry> cache_ GUARDED_BY(mutex_);
SE_DISALLOW_COPY_AND_ASSIGN(ExecutorCache); SE_DISALLOW_COPY_AND_ASSIGN(ExecutorCache);

View File

@ -98,6 +98,7 @@ cc_library(
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"//tensorflow/stream_executor/platform", "//tensorflow/stream_executor/platform",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
], ],
) )
@ -126,6 +127,7 @@ cc_library(
"//tensorflow/stream_executor:plugin_registry", "//tensorflow/stream_executor:plugin_registry",
"//tensorflow/stream_executor:rng", "//tensorflow/stream_executor:rng",
"//tensorflow/stream_executor/platform", "//tensorflow/stream_executor/platform",
"@com_google_absl//absl/synchronization",
], ],
) )

View File

@ -26,12 +26,12 @@ limitations under the License.
#include <unordered_map> #include <unordered_map>
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "tensorflow/stream_executor/event.h" #include "tensorflow/stream_executor/event.h"
#include "tensorflow/stream_executor/gpu/gpu_kernel.h" #include "tensorflow/stream_executor/gpu/gpu_kernel.h"
#include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/platform.h" #include "tensorflow/stream_executor/platform.h"
#include "tensorflow/stream_executor/platform/mutex.h"
#include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/platform/thread_annotations.h" #include "tensorflow/stream_executor/platform/thread_annotations.h"
#include "tensorflow/stream_executor/stream_executor_internal.h" #include "tensorflow/stream_executor/stream_executor_internal.h"
@ -287,7 +287,7 @@ class GpuExecutor : public internal::StreamExecutorInterface {
EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_); EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_);
// Guards the on-disk-module mapping. // Guards the on-disk-module mapping.
mutex disk_modules_mu_; absl::Mutex disk_modules_mu_;
// Mapping from filename to GPUModuleHandle, if it was already retrieved. // Mapping from filename to GPUModuleHandle, if it was already retrieved.
// Multiple GPUFunctionHandle are usually obtained from a single // Multiple GPUFunctionHandle are usually obtained from a single
@ -296,7 +296,7 @@ class GpuExecutor : public internal::StreamExecutorInterface {
std::map<string, GpuModuleHandle> disk_modules_ GUARDED_BY(disk_modules_mu_); std::map<string, GpuModuleHandle> disk_modules_ GUARDED_BY(disk_modules_mu_);
// Guards the in-memory-module mapping. // Guards the in-memory-module mapping.
mutex in_memory_modules_mu_; absl::Mutex in_memory_modules_mu_;
std::map<const char*, GpuModuleHandle> in_memory_modules_ std::map<const char*, GpuModuleHandle> in_memory_modules_
GUARDED_BY(in_memory_modules_mu_); GUARDED_BY(in_memory_modules_mu_);
@ -309,7 +309,7 @@ class GpuExecutor : public internal::StreamExecutorInterface {
gpu_binary_to_module_ GUARDED_BY(in_memory_modules_mu_); gpu_binary_to_module_ GUARDED_BY(in_memory_modules_mu_);
// Guards the launched kernel set. // Guards the launched kernel set.
mutex launched_kernels_mu_; absl::Mutex launched_kernels_mu_;
// Keeps track of the set of launched kernels. Currently used to suppress the // Keeps track of the set of launched kernels. Currently used to suppress the
// occupancy check on subsequent launches. // occupancy check on subsequent launches.

View File

@ -16,14 +16,13 @@ limitations under the License.
#ifndef TENSORFLOW_STREAM_EXECUTOR_GPU_GPU_RNG_H_ #ifndef TENSORFLOW_STREAM_EXECUTOR_GPU_GPU_RNG_H_
#define TENSORFLOW_STREAM_EXECUTOR_GPU_GPU_RNG_H_ #define TENSORFLOW_STREAM_EXECUTOR_GPU_GPU_RNG_H_
#include "tensorflow/stream_executor/platform/mutex.h" #include "absl/synchronization/mutex.h"
#include "tensorflow/stream_executor/gpu/gpu_types.h"
#include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/platform/thread_annotations.h" #include "tensorflow/stream_executor/platform/thread_annotations.h"
#include "tensorflow/stream_executor/plugin_registry.h" #include "tensorflow/stream_executor/plugin_registry.h"
#include "tensorflow/stream_executor/rng.h" #include "tensorflow/stream_executor/rng.h"
#include "tensorflow/stream_executor/gpu/gpu_types.h"
namespace stream_executor { namespace stream_executor {
class Stream; class Stream;
@ -83,8 +82,8 @@ class GpuRng : public rng::RngSupport {
// with random number generation. // with random number generation.
bool SetStream(Stream* stream) EXCLUSIVE_LOCKS_REQUIRED(mu_); bool SetStream(Stream* stream) EXCLUSIVE_LOCKS_REQUIRED(mu_);
// mutex that guards the gpu rng library handle for this device. // Guards the gpu rng library handle for this device.
mutex mu_; absl::Mutex mu_;
// GpuExecutor which instantiated this GpuRng. // GpuExecutor which instantiated this GpuRng.
// Immutable post-initialization. // Immutable post-initialization.

View File

@ -52,6 +52,7 @@ cc_library(
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"//tensorflow/stream_executor/platform", "//tensorflow/stream_executor/platform",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
], ],
alwayslink = True, # Registers itself with the MultiPlatformManager. alwayslink = True, # Registers itself with the MultiPlatformManager.
) )
@ -67,6 +68,7 @@ cc_library(
deps = [ deps = [
"//tensorflow/stream_executor:kernel", "//tensorflow/stream_executor:kernel",
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"@com_google_absl//absl/synchronization",
], ],
) )

View File

@ -27,7 +27,6 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/multi_platform_manager.h" #include "tensorflow/stream_executor/multi_platform_manager.h"
#include "tensorflow/stream_executor/platform.h" #include "tensorflow/stream_executor/platform.h"
#include "tensorflow/stream_executor/platform/mutex.h"
#include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/platform/thread_annotations.h" #include "tensorflow/stream_executor/platform/thread_annotations.h"
#include "tensorflow/stream_executor/stream_executor_pimpl.h" #include "tensorflow/stream_executor/stream_executor_pimpl.h"

View File

@ -38,15 +38,15 @@ bool HostStream::EnqueueTask(std::function<void()> task) {
// should guarantee that all tasks are destroyed. // should guarantee that all tasks are destroyed.
task = std::function<void()>(); task = std::function<void()>();
{ {
mutex_lock lock(stream->mu_); absl::MutexLock lock(&stream->mu_);
--stream->pending_tasks_; --stream->pending_tasks_;
} }
stream->completion_condition_.notify_all(); stream->completion_condition_.SignalAll();
} }
}; };
{ {
mutex_lock lock(mu_); absl::MutexLock lock(&mu_);
++pending_tasks_; ++pending_tasks_;
} }
host_executor_->Schedule(NotifiedTask{this, std::move(task)}); host_executor_->Schedule(NotifiedTask{this, std::move(task)});
@ -54,9 +54,9 @@ bool HostStream::EnqueueTask(std::function<void()> task) {
} }
void HostStream::BlockUntilDone() { void HostStream::BlockUntilDone() {
mutex_lock lock(mu_); absl::MutexLock lock(&mu_);
while (pending_tasks_ != 0) { while (pending_tasks_ != 0) {
completion_condition_.wait(lock); completion_condition_.Wait(&mu_);
} }
} }

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <functional> #include <functional>
#include <memory> #include <memory>
#include "absl/synchronization/mutex.h"
#include "tensorflow/stream_executor/lib/threadpool.h" #include "tensorflow/stream_executor/lib/threadpool.h"
#include "tensorflow/stream_executor/stream_executor_internal.h" #include "tensorflow/stream_executor/stream_executor_internal.h"
@ -45,9 +46,9 @@ class HostStream : public internal::StreamInterface {
static const int kExecutorThreads = 1; static const int kExecutorThreads = 1;
std::unique_ptr<port::ThreadPool> host_executor_; std::unique_ptr<port::ThreadPool> host_executor_;
mutex mu_; absl::Mutex mu_;
int pending_tasks_ GUARDED_BY(mu_) = 0; int pending_tasks_ GUARDED_BY(mu_) = 0;
condition_variable completion_condition_; absl::CondVar completion_condition_;
}; };
} // namespace host } // namespace host

View File

@ -94,7 +94,7 @@ const char *CudaPtxInMemory::default_text() const {
return nullptr; return nullptr;
} }
mutex_lock lock(mu_); absl::MutexLock lock(&mu_);
auto ptx = ptx_by_compute_capability_.begin()->second; auto ptx = ptx_by_compute_capability_.begin()->second;
// Check if there is an entry in decompressed ptx table. // Check if there is an entry in decompressed ptx table.
@ -128,7 +128,7 @@ const char *CudaPtxInMemory::text(int compute_capability_major,
return nullptr; return nullptr;
} }
mutex_lock lock(mu_); absl::MutexLock lock(&mu_);
// Check if there is an entry in decompressed ptx table. // Check if there is an entry in decompressed ptx table.
auto decompressed_ptx_iter = decompressed_ptx_.find(ptx_iter->second); auto decompressed_ptx_iter = decompressed_ptx_.find(ptx_iter->second);

View File

@ -47,13 +47,13 @@ limitations under the License.
#define TENSORFLOW_STREAM_EXECUTOR_KERNEL_SPEC_H_ #define TENSORFLOW_STREAM_EXECUTOR_KERNEL_SPEC_H_
#include <stddef.h> #include <stddef.h>
#include <map> #include <map>
#include <memory> #include <memory>
#include "tensorflow/stream_executor/platform/port.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/platform/logging.h"
#include "tensorflow/stream_executor/platform/mutex.h"
#include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/port.h"
namespace stream_executor { namespace stream_executor {
@ -206,7 +206,7 @@ class CudaPtxInMemory : public KernelLoaderSpec {
// Stores all decompressed ptx strings, with original ptx string as keys. // Stores all decompressed ptx strings, with original ptx string as keys.
// It is marked as mutable for lazy decompression. // It is marked as mutable for lazy decompression.
mutable std::map<const char *, string> decompressed_ptx_; mutable std::map<const char *, string> decompressed_ptx_;
mutable mutex mu_; mutable absl::Mutex mu_;
// Defines the minimum compute capability possible. Used when PTX has no // Defines the minimum compute capability possible. Used when PTX has no
// compute capability specified (in the single-PTX constructor). // compute capability specified (in the single-PTX constructor).

View File

@ -16,7 +16,6 @@ cc_library(
name = "platform", name = "platform",
textual_hdrs = [ textual_hdrs = [
"logging.h", "logging.h",
"mutex.h",
"platform.h", "platform.h",
"port.h", "port.h",
"thread_annotations.h", "thread_annotations.h",

View File

@ -6,7 +6,6 @@ cc_library(
name = "platform", name = "platform",
textual_hdrs = [ textual_hdrs = [
"initialize.h", "initialize.h",
"mutex.h",
], ],
deps = ["//tensorflow/core:lib"], deps = ["//tensorflow/core:lib"],
) )

View File

@ -20,13 +20,13 @@ limitations under the License.
#define TENSORFLOW_STREAM_EXECUTOR_DSO_LOADER_H_ #define TENSORFLOW_STREAM_EXECUTOR_DSO_LOADER_H_
#include <vector> #include <vector>
#include "tensorflow/stream_executor/platform/port.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/platform.h" #include "tensorflow/stream_executor/platform.h"
#include "tensorflow/stream_executor/platform/mutex.h" #include "tensorflow/stream_executor/platform/port.h"
namespace stream_executor { namespace stream_executor {
namespace internal { namespace internal {

View File

@ -1,40 +0,0 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_STREAM_EXECUTOR_PLATFORM_DEFAULT_MUTEX_H_
#define TENSORFLOW_STREAM_EXECUTOR_PLATFORM_DEFAULT_MUTEX_H_
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
namespace stream_executor {
#undef mutex_lock
#undef tf_shared_lock
using tensorflow::ConditionResult;
using tensorflow::WaitForMilliseconds;
using tensorflow::condition_variable;
using tensorflow::mutex;
using tensorflow::mutex_lock;
using tensorflow::tf_shared_lock;
#define mutex_lock(x) static_assert(0, "mutex_lock_decl_missing_var_name");
#define tf_shared_lock(x) \
static_assert(0, "tf_shared_lock_decl_missing_var_name");
} // namespace stream_executor
#endif // TENSORFLOW_STREAM_EXECUTOR_PLATFORM_DEFAULT_MUTEX_H_

View File

@ -1,27 +0,0 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_STREAM_EXECUTOR_PLATFORM_MUTEX_H_
#define TENSORFLOW_STREAM_EXECUTOR_PLATFORM_MUTEX_H_
#include "tensorflow/stream_executor/platform/platform.h"
#if defined(PLATFORM_GOOGLE)
#include "tensorflow/stream_executor/platform/google/mutex.h"
#else
#include "tensorflow/stream_executor/platform/default/mutex.h"
#endif
#endif // TENSORFLOW_STREAM_EXECUTOR_PLATFORM_MUTEX_H_

View File

@ -19,7 +19,6 @@ limitations under the License.
#define TENSORFLOW_STREAM_EXECUTOR_PLATFORM_PORT_H_ #define TENSORFLOW_STREAM_EXECUTOR_PLATFORM_PORT_H_
#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
namespace stream_executor { namespace stream_executor {
@ -38,9 +37,6 @@ using tensorflow::uint64;
using std::string; using std::string;
#endif #endif
using tensorflow::LinkerInitialized;
using tensorflow::LINKER_INITIALIZED;
#define SE_FALLTHROUGH_INTENDED TF_FALLTHROUGH_INTENDED #define SE_FALLTHROUGH_INTENDED TF_FALLTHROUGH_INTENDED
} // namespace stream_executor } // namespace stream_executor

View File

@ -15,8 +15,10 @@ limitations under the License.
#include "tensorflow/stream_executor/plugin_registry.h" #include "tensorflow/stream_executor/plugin_registry.h"
#include "absl/base/const_init.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "absl/synchronization/mutex.h"
#include "tensorflow/stream_executor/lib/error.h" #include "tensorflow/stream_executor/lib/error.h"
#include "tensorflow/stream_executor/multi_platform_manager.h" #include "tensorflow/stream_executor/multi_platform_manager.h"
@ -44,9 +46,9 @@ string PluginKindString(PluginKind plugin_kind) {
PluginRegistry::DefaultFactories::DefaultFactories() : PluginRegistry::DefaultFactories::DefaultFactories() :
blas(kNullPlugin), dnn(kNullPlugin), fft(kNullPlugin), rng(kNullPlugin) { } blas(kNullPlugin), dnn(kNullPlugin), fft(kNullPlugin), rng(kNullPlugin) { }
static mutex& GetPluginRegistryMutex() { static absl::Mutex& GetPluginRegistryMutex() {
static mutex* mu = new mutex; static absl::Mutex mu(absl::kConstInit);
return *mu; return mu;
} }
/* static */ PluginRegistry* PluginRegistry::instance_ = nullptr; /* static */ PluginRegistry* PluginRegistry::instance_ = nullptr;
@ -54,7 +56,7 @@ static mutex& GetPluginRegistryMutex() {
PluginRegistry::PluginRegistry() {} PluginRegistry::PluginRegistry() {}
/* static */ PluginRegistry* PluginRegistry::Instance() { /* static */ PluginRegistry* PluginRegistry::Instance() {
mutex_lock lock{GetPluginRegistryMutex()}; absl::MutexLock lock{&GetPluginRegistryMutex()};
if (instance_ == nullptr) { if (instance_ == nullptr) {
instance_ = new PluginRegistry(); instance_ = new PluginRegistry();
} }
@ -70,7 +72,7 @@ template <typename FACTORY_TYPE>
port::Status PluginRegistry::RegisterFactoryInternal( port::Status PluginRegistry::RegisterFactoryInternal(
PluginId plugin_id, const string& plugin_name, FACTORY_TYPE factory, PluginId plugin_id, const string& plugin_name, FACTORY_TYPE factory,
std::map<PluginId, FACTORY_TYPE>* factories) { std::map<PluginId, FACTORY_TYPE>* factories) {
mutex_lock lock{GetPluginRegistryMutex()}; absl::MutexLock lock{&GetPluginRegistryMutex()};
if (factories->find(plugin_id) != factories->end()) { if (factories->find(plugin_id) != factories->end()) {
return port::Status( return port::Status(

View File

@ -25,7 +25,6 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/platform.h" #include "tensorflow/stream_executor/platform.h"
#include "tensorflow/stream_executor/platform/mutex.h"
#include "tensorflow/stream_executor/plugin.h" #include "tensorflow/stream_executor/plugin.h"
#include "tensorflow/stream_executor/rng.h" #include "tensorflow/stream_executor/rng.h"

View File

@ -382,7 +382,7 @@ template <typename FuncT, typename... Args>
bool ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream, bool ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream,
bool pointer_mode_host, bool err_on_failure, bool pointer_mode_host, bool err_on_failure,
Args... args) { Args... args) {
mutex_lock lock{mu_}; absl::MutexLock lock{&mu_};
CHECK(blas_ != nullptr); CHECK(blas_ != nullptr);
if (!SetStream(stream)) { if (!SetStream(stream)) {

View File

@ -20,8 +20,8 @@ limitations under the License.
#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_ #ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_
#define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_ #define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_
#include "absl/synchronization/mutex.h"
#include "tensorflow/stream_executor/blas.h" #include "tensorflow/stream_executor/blas.h"
#include "tensorflow/stream_executor/platform/mutex.h"
#include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/platform/thread_annotations.h" #include "tensorflow/stream_executor/platform/thread_annotations.h"
#include "tensorflow/stream_executor/plugin_registry.h" #include "tensorflow/stream_executor/plugin_registry.h"
@ -141,7 +141,7 @@ class ROCMBlas : public blas::BlasSupport {
blas::ProfileResult *output_profile_result); blas::ProfileResult *output_profile_result);
// mutex that guards the rocBLAS handle for this device. // mutex that guards the rocBLAS handle for this device.
mutex mu_; absl::Mutex mu_;
// GpuExecutor which instantiated this ROCMBlas. // GpuExecutor which instantiated this ROCMBlas.
// Immutable post-initialization. // Immutable post-initialization.

View File

@ -98,8 +98,8 @@ class MIOpenHandle {
public: public:
// Takes ownership of the executor context and the lock to access MIOpen // Takes ownership of the executor context and the lock to access MIOpen
// using handle. // using handle.
MIOpenHandle(gpu::ScopedActivateExecutorContext context, mutex_lock lock, MIOpenHandle(gpu::ScopedActivateExecutorContext context,
miopenHandle_t handle) std::unique_ptr<absl::MutexLock> lock, miopenHandle_t handle)
: context_(std::move(context)), lock_(std::move(lock)), handle_(handle) {} : context_(std::move(context)), lock_(std::move(lock)), handle_(handle) {}
// Returns MIOpen handle. To be passed directly to MIOpen APIs, don't keep // Returns MIOpen handle. To be passed directly to MIOpen APIs, don't keep
@ -108,7 +108,7 @@ class MIOpenHandle {
private: private:
gpu::ScopedActivateExecutorContext context_; gpu::ScopedActivateExecutorContext context_;
mutex_lock lock_; std::unique_ptr<absl::MutexLock> lock_;
miopenHandle_t handle_; // Not owned. miopenHandle_t handle_; // Not owned.
}; };
@ -318,7 +318,7 @@ class CachedFusionPlans {
miopenFusionPlanDescriptor_t* fusion_plan, miopenFusionPlanDescriptor_t* fusion_plan,
miopenFusionDirection_t fusion_direction, miopenFusionDirection_t fusion_direction,
miopenTensorDescriptor_t input_descriptor) { miopenTensorDescriptor_t input_descriptor) {
mutex_lock lock{cached_plans_mutex}; absl::MutexLock lock{&cached_plans_mutex};
bool found_cached_plan = false; bool found_cached_plan = false;
@ -342,7 +342,7 @@ class CachedFusionPlans {
// Need to figure out the right place to call this routine. // Need to figure out the right place to call this routine.
static void Clear() { static void Clear() {
mutex_lock lock{cached_plans_mutex}; absl::MutexLock lock{&cached_plans_mutex};
for (auto it : cached_plans) { for (auto it : cached_plans) {
auto status = wrap::miopenDestroyFusionPlan(it.second); auto status = wrap::miopenDestroyFusionPlan(it.second);
@ -359,19 +359,19 @@ class CachedFusionPlans {
// Is the Fusion plan corresponding to this hash unsupported. // Is the Fusion plan corresponding to this hash unsupported.
static bool IsUnsupportedFusionPlan(uint64 hash) { static bool IsUnsupportedFusionPlan(uint64 hash) {
mutex_lock lock{cached_plans_mutex}; absl::MutexLock lock{&cached_plans_mutex};
return unsupported_plans.count(hash) > 0; return unsupported_plans.count(hash) > 0;
} }
// Mark the given hash value as corresponding to an unsupported fusion plan. // Mark the given hash value as corresponding to an unsupported fusion plan.
static void MarkFusionPlanUnsupported(uint64 hash) { static void MarkFusionPlanUnsupported(uint64 hash) {
mutex_lock lock{cached_plans_mutex}; absl::MutexLock lock{&cached_plans_mutex};
unsupported_plans.insert(hash); unsupported_plans.insert(hash);
} }
private: private:
// Mutex to guard access to all data within this class. // Mutex to guard access to all data within this class.
static mutex cached_plans_mutex; static absl::Mutex cached_plans_mutex;
// Map of hash-value to MIOpen Fusion plan descriptors. // Map of hash-value to MIOpen Fusion plan descriptors.
// Need to be able share this across more than one stream and hence static. // Need to be able share this across more than one stream and hence static.
@ -382,7 +382,7 @@ class CachedFusionPlans {
static std::set<uint64> unsupported_plans; static std::set<uint64> unsupported_plans;
}; };
mutex CachedFusionPlans::cached_plans_mutex; absl::Mutex CachedFusionPlans::cached_plans_mutex;
std::map<uint64, miopenFusionPlanDescriptor_t> CachedFusionPlans::cached_plans; std::map<uint64, miopenFusionPlanDescriptor_t> CachedFusionPlans::cached_plans;
std::set<uint64> CachedFusionPlans::unsupported_plans; std::set<uint64> CachedFusionPlans::unsupported_plans;
@ -449,7 +449,7 @@ class MIOpenAccess {
explicit MIOpenAccess(miopenHandle_t handle) : handle_(handle) {} explicit MIOpenAccess(miopenHandle_t handle) : handle_(handle) {}
~MIOpenAccess() { ~MIOpenAccess() {
mutex_lock lock(mutex_); absl::MutexLock lock(&mutex_);
wrap::miopenDestroy(handle_); wrap::miopenDestroy(handle_);
} }
@ -468,7 +468,8 @@ class MIOpenAccess {
// therefore a bad idea (performance wise) to call any MIOpen APIs that // therefore a bad idea (performance wise) to call any MIOpen APIs that
// enqueue work in the stream. // enqueue work in the stream.
MIOpenHandle GetHandle(GpuExecutor* executor, Stream* stream) { MIOpenHandle GetHandle(GpuExecutor* executor, Stream* stream) {
mutex_lock lock(mutex_); auto lock = absl::make_unique<absl::MutexLock>(&mutex_);
mutex_.AssertHeld();
gpu::ScopedActivateExecutorContext context(executor); gpu::ScopedActivateExecutorContext context(executor);
hipStream_t hip_stream = stream ? AsGpuStreamValue(stream) : nullptr; hipStream_t hip_stream = stream ? AsGpuStreamValue(stream) : nullptr;
auto status = wrap::miopenSetStream(handle_, hip_stream); auto status = wrap::miopenSetStream(handle_, hip_stream);
@ -478,7 +479,7 @@ class MIOpenAccess {
private: private:
// Guards the enqueueing of MIOpen operations via the handle_ below. // Guards the enqueueing of MIOpen operations via the handle_ below.
mutex mutex_; absl::Mutex mutex_;
// MIOpen library handle. // MIOpen library handle.
miopenHandle_t handle_ GUARDED_BY(mutex_); // Owned. miopenHandle_t handle_ GUARDED_BY(mutex_); // Owned.

View File

@ -19,9 +19,9 @@ limitations under the License.
#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DNN_H_ #ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DNN_H_
#define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DNN_H_ #define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DNN_H_
#include "absl/synchronization/mutex.h"
#include "tensorflow/stream_executor/dnn.h" #include "tensorflow/stream_executor/dnn.h"
#include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/platform/mutex.h"
#include "tensorflow/stream_executor/platform/thread_annotations.h" #include "tensorflow/stream_executor/platform/thread_annotations.h"
#include "tensorflow/stream_executor/plugin_registry.h" #include "tensorflow/stream_executor/plugin_registry.h"
#include "tensorflow/stream_executor/temporary_device_memory.h" #include "tensorflow/stream_executor/temporary_device_memory.h"

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "absl/container/inlined_vector.h" #include "absl/container/inlined_vector.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "absl/synchronization/mutex.h"
#include "absl/synchronization/notification.h" #include "absl/synchronization/notification.h"
#include "tensorflow/stream_executor/gpu/gpu_diagnostics.h" #include "tensorflow/stream_executor/gpu/gpu_diagnostics.h"
#include "tensorflow/stream_executor/gpu/gpu_driver.h" #include "tensorflow/stream_executor/gpu/gpu_driver.h"
@ -34,7 +35,6 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/static_threadlocal.h" #include "tensorflow/stream_executor/lib/static_threadlocal.h"
#include "tensorflow/stream_executor/lib/threadpool.h" #include "tensorflow/stream_executor/lib/threadpool.h"
#include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/platform/logging.h"
#include "tensorflow/stream_executor/platform/mutex.h"
#include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/rocm/rocm_driver_wrapper.h" #include "tensorflow/stream_executor/rocm/rocm_driver_wrapper.h"
@ -305,17 +305,10 @@ static port::Status InternalInit() {
/* static */ port::Status GpuDriver::Init() { /* static */ port::Status GpuDriver::Init() {
// Cached return value from calling InternalInit(), as hipInit need only be // Cached return value from calling InternalInit(), as hipInit need only be
// called once, but GpuDriver::Init may be called many times. // called once, but GpuDriver::Init may be called many times.
static port::Status init_retval; static port::Status* init_retval = [&] {
static bool set = false; init_retval = new Status(InternalInit());
static mutex* init_mu = new mutex; }();
return *init_retval;
mutex_lock lock(*init_mu);
if (!set) {
init_retval = InternalInit();
set = true;
}
return init_retval;
} }
/* static */ port::Status GpuDriver::GetDevice(int device_ordinal, /* static */ port::Status GpuDriver::GetDevice(int device_ordinal,

View File

@ -114,7 +114,7 @@ GpuExecutor::~GpuExecutor() {
} }
bool GpuExecutor::UnloadModule(ModuleHandle module_handle) { bool GpuExecutor::UnloadModule(ModuleHandle module_handle) {
const char* gpu_binary = reinterpret_cast<const char*>(module_handle.id()); const char* gpu_binary = reinterpret_cast<const char*>(module_handle.id());
mutex_lock lock{in_memory_modules_mu_}; absl::MutexLock lock{&in_memory_modules_mu_};
return UnloadGpuBinary(gpu_binary); return UnloadGpuBinary(gpu_binary);
} }
@ -236,7 +236,7 @@ bool GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec,
kernelname = &spec.cuda_cubin_in_memory().kernelname(); kernelname = &spec.cuda_cubin_in_memory().kernelname();
const char* hsaco = spec.cuda_cubin_in_memory().bytes(); const char* hsaco = spec.cuda_cubin_in_memory().bytes();
mutex_lock lock{in_memory_modules_mu_}; absl::MutexLock lock{&in_memory_modules_mu_};
module = in_memory_modules_[hsaco]; module = in_memory_modules_[hsaco];
if (module == nullptr) { if (module == nullptr) {
@ -294,7 +294,7 @@ bool GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims,
// whether we've done an occupancy check on this kernel before isn't free // whether we've done an occupancy check on this kernel before isn't free
// (because we have to synchronize), so we only do this at -v 2+. // (because we have to synchronize), so we only do this at -v 2+.
if (VLOG_IS_ON(2)) { if (VLOG_IS_ON(2)) {
mutex_lock lock(launched_kernels_mu_); absl::MutexLock lock(&launched_kernels_mu_);
if (!launched_kernels_.count(hipfunc)) { if (!launched_kernels_.count(hipfunc)) {
VlogOccupancyInfo(kernel, thread_dims, block_dims); VlogOccupancyInfo(kernel, thread_dims, block_dims);
// TODO(rspringer): Remove elements from launched_kernels_...if we ever // TODO(rspringer): Remove elements from launched_kernels_...if we ever
@ -366,7 +366,7 @@ bool GpuExecutor::LoadModule(const MultiModuleLoaderSpec& spec,
hipModule_t hip_module = nullptr; hipModule_t hip_module = nullptr;
// TODO(ROCm): Need generic term instead of cubin/cuda/ptx // TODO(ROCm): Need generic term instead of cubin/cuda/ptx
if (spec.has_cuda_cubin_in_memory()) { if (spec.has_cuda_cubin_in_memory()) {
mutex_lock lock{in_memory_modules_mu_}; absl::MutexLock lock{&in_memory_modules_mu_};
if (!LoadModuleFromHsaco( if (!LoadModuleFromHsaco(
reinterpret_cast<const char*>(spec.cuda_cubin_in_memory().data()), reinterpret_cast<const char*>(spec.cuda_cubin_in_memory().data()),
&hip_module)) { &hip_module)) {
@ -765,8 +765,8 @@ bool GpuExecutor::DeviceMemoryUsage(int64* free, int64* total) const {
bool GpuExecutor::GetSymbol(const string& symbol_name, bool GpuExecutor::GetSymbol(const string& symbol_name,
ModuleHandle module_handle, void** mem, ModuleHandle module_handle, void** mem,
size_t* bytes) { size_t* bytes) {
{ // give limited scope to mutex_lock { // give limited scope to lock
mutex_lock lock{disk_modules_mu_}; absl::MutexLock lock{&disk_modules_mu_};
for (auto& it : disk_modules_) { for (auto& it : disk_modules_) {
if (GpuDriver::GetModuleSymbol(context_, it.second, symbol_name.c_str(), if (GpuDriver::GetModuleSymbol(context_, it.second, symbol_name.c_str(),
reinterpret_cast<hipDeviceptr_t*>(mem), reinterpret_cast<hipDeviceptr_t*>(mem),
@ -776,8 +776,8 @@ bool GpuExecutor::GetSymbol(const string& symbol_name,
} }
} }
{ // give limited scope to mutex_lock { // give limited scope to lock
mutex_lock lock{in_memory_modules_mu_}; absl::MutexLock lock{&in_memory_modules_mu_};
for (auto& it : in_memory_modules_) { for (auto& it : in_memory_modules_) {
if (GpuDriver::GetModuleSymbol(context_, it.second, symbol_name.c_str(), if (GpuDriver::GetModuleSymbol(context_, it.second, symbol_name.c_str(),
reinterpret_cast<hipDeviceptr_t*>(mem), reinterpret_cast<hipDeviceptr_t*>(mem),
@ -787,8 +787,8 @@ bool GpuExecutor::GetSymbol(const string& symbol_name,
} }
} }
{ // give limited scope to mutex_lock { // give limited scope to lock
mutex_lock lock{in_memory_modules_mu_}; absl::MutexLock lock{&in_memory_modules_mu_};
if (static_cast<bool>(module_handle)) { if (static_cast<bool>(module_handle)) {
auto it = gpu_binary_to_module_.find(module_handle.id()); auto it = gpu_binary_to_module_.find(module_handle.id());
CHECK(it != gpu_binary_to_module_.end()); CHECK(it != gpu_binary_to_module_.end());

View File

@ -38,13 +38,8 @@ ROCmPlatform::~ROCmPlatform() {}
void ROCmPlatform::InspectNumaNodes() { void ROCmPlatform::InspectNumaNodes() {
// To get NUMA node information, we need to create all executors, so we can // To get NUMA node information, we need to create all executors, so we can
// examine their device descriptions to see their bus assignments. // examine their device descriptions to see their bus assignments.
static bool initialized = false; std::once_flag once;
static mutex numa_mutex(LINKER_INITIALIZED); std::call_once(once, [&] {
mutex_lock lock(numa_mutex);
if (initialized) {
return;
}
StreamExecutorConfig config; StreamExecutorConfig config;
for (int i = 0; i < VisibleDeviceCount(); i++) { for (int i = 0; i < VisibleDeviceCount(); i++) {
config.ordinal = i; config.ordinal = i;
@ -57,11 +52,11 @@ void ROCmPlatform::InspectNumaNodes() {
} else { } else {
min_numa_node_ = min_numa_node_ =
std::min(min_numa_node_, exec->GetDeviceDescription().numa_node()); std::min(min_numa_node_, exec->GetDeviceDescription().numa_node());
limit_numa_node_ = std::max(limit_numa_node_, limit_numa_node_ = std::max(
exec->GetDeviceDescription().numa_node() + 1); limit_numa_node_, exec->GetDeviceDescription().numa_node() + 1);
} }
} }
initialized = true; });
} }
int ROCmPlatform::BusCount() { int ROCmPlatform::BusCount() {

View File

@ -19,11 +19,11 @@ limitations under the License.
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "absl/synchronization/mutex.h"
#include "tensorflow/stream_executor/executor_cache.h" #include "tensorflow/stream_executor/executor_cache.h"
#include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/multi_platform_manager.h" #include "tensorflow/stream_executor/multi_platform_manager.h"
#include "tensorflow/stream_executor/platform.h" #include "tensorflow/stream_executor/platform.h"
#include "tensorflow/stream_executor/platform/mutex.h"
#include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/platform/thread_annotations.h" #include "tensorflow/stream_executor/platform/thread_annotations.h"
#include "tensorflow/stream_executor/stream_executor_internal.h" #include "tensorflow/stream_executor/stream_executor_internal.h"
@ -90,7 +90,7 @@ class ROCmPlatform : public Platform {
string name_; string name_;
// mutex that guards internal state. // mutex that guards internal state.
mutable mutex mu_; mutable absl::Mutex mu_;
// Cache of created executors. // Cache of created executors.
ExecutorCache executor_cache_; ExecutorCache executor_cache_;

View File

@ -126,7 +126,7 @@ GpuRng::~GpuRng() {
} }
bool GpuRng::Init() { bool GpuRng::Init() {
mutex_lock lock{mu_}; absl::MutexLock lock{&mu_};
CHECK(rng_ == nullptr); CHECK(rng_ == nullptr);
hiprandStatus_t ret = hiprandStatus_t ret =
@ -161,7 +161,7 @@ constexpr bool ComplexIsConsecutiveFloats() {
template <typename T> template <typename T>
bool GpuRng::DoPopulateRandUniformInternal(Stream* stream, DeviceMemory<T>* v) { bool GpuRng::DoPopulateRandUniformInternal(Stream* stream, DeviceMemory<T>* v) {
mutex_lock lock{mu_}; absl::MutexLock lock{&mu_};
static_assert(ComplexIsConsecutiveFloats(), static_assert(ComplexIsConsecutiveFloats(),
"std::complex values are not stored as consecutive values"); "std::complex values are not stored as consecutive values");
@ -220,7 +220,7 @@ bool GpuRng::DoPopulateRandGaussianInternal(Stream* stream, ElemT mean,
ElemT stddev, ElemT stddev,
DeviceMemory<ElemT>* v, DeviceMemory<ElemT>* v,
FuncT func) { FuncT func) {
mutex_lock lock{mu_}; absl::MutexLock lock{&mu_};
if (!SetStream(stream)) { if (!SetStream(stream)) {
return false; return false;
@ -252,7 +252,7 @@ bool GpuRng::DoPopulateRandGaussian(Stream* stream, double mean, double stddev,
} }
bool GpuRng::SetSeed(Stream* stream, const uint8* seed, uint64 seed_bytes) { bool GpuRng::SetSeed(Stream* stream, const uint8* seed, uint64 seed_bytes) {
mutex_lock lock{mu_}; absl::MutexLock lock{&mu_};
CHECK(rng_ != nullptr); CHECK(rng_ != nullptr);
if (!CheckSeed(seed, seed_bytes)) { if (!CheckSeed(seed, seed_bytes)) {

View File

@ -290,7 +290,7 @@ port::Status Stream::RefreshStatus() {
Stream &Stream::Init() { Stream &Stream::Init() {
VLOG_CALL(); VLOG_CALL();
mutex_lock lock(mu_); absl::MutexLock lock(&mu_);
CHECK_EQ(false, allocated_) CHECK_EQ(false, allocated_)
<< "stream appears to already have been initialized"; << "stream appears to already have been initialized";
CHECK(!ok_) << "stream should be in !ok() state pre-initialization"; CHECK(!ok_) << "stream should be in !ok() state pre-initialization";
@ -1736,7 +1736,7 @@ Stream &Stream::ThenMemcpyH2DQuantized(
} }
Stream *Stream::GetOrCreateSubStream() { Stream *Stream::GetOrCreateSubStream() {
mutex_lock lock(mu_); absl::MutexLock lock(&mu_);
// Look for the first reusable sub_stream that is ok, dropping !ok sub_streams // Look for the first reusable sub_stream that is ok, dropping !ok sub_streams
// we encounter along the way. // we encounter along the way.
@ -1783,7 +1783,7 @@ Stream *Stream::GetOrCreateSubStream() {
} }
void Stream::ReturnSubStream(Stream *sub_stream) { void Stream::ReturnSubStream(Stream *sub_stream) {
mutex_lock lock(mu_); absl::MutexLock lock(&mu_);
// Look for the sub-stream. // Look for the sub-stream.
for (int64 index = 0; index < sub_streams_.size(); ++index) { for (int64 index = 0; index < sub_streams_.size(); ++index) {
@ -5285,7 +5285,7 @@ void Stream::CheckStatus(port::Status status) {
return; return;
} }
LOG(ERROR) << status; LOG(ERROR) << status;
mutex_lock lock(mu_); absl::MutexLock lock(&mu_);
ok_ = false; ok_ = false;
} }

View File

@ -25,6 +25,7 @@ limitations under the License.
#include <functional> #include <functional>
#include <memory> #include <memory>
#include "absl/synchronization/mutex.h"
#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/macros.h"
#include "tensorflow/stream_executor/blas.h" #include "tensorflow/stream_executor/blas.h"
#include "tensorflow/stream_executor/device_memory.h" #include "tensorflow/stream_executor/device_memory.h"
@ -35,7 +36,6 @@ limitations under the License.
#include "tensorflow/stream_executor/kernel.h" #include "tensorflow/stream_executor/kernel.h"
#include "tensorflow/stream_executor/launch_dim.h" #include "tensorflow/stream_executor/launch_dim.h"
#include "tensorflow/stream_executor/lib/array_slice.h" #include "tensorflow/stream_executor/lib/array_slice.h"
#include "tensorflow/stream_executor/platform/mutex.h"
#include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/platform/thread_annotations.h" #include "tensorflow/stream_executor/platform/thread_annotations.h"
#include "tensorflow/stream_executor/temporary_memory_manager.h" #include "tensorflow/stream_executor/temporary_memory_manager.h"
@ -1964,7 +1964,7 @@ class Stream {
friend class ocl::CLBlas; // for parent_. friend class ocl::CLBlas; // for parent_.
bool InErrorState() const LOCKS_EXCLUDED(mu_) { bool InErrorState() const LOCKS_EXCLUDED(mu_) {
tf_shared_lock lock(mu_); absl::ReaderMutexLock lock(&mu_);
return !ok_; return !ok_;
} }
@ -1974,7 +1974,7 @@ class Stream {
if (operation_retcode) { if (operation_retcode) {
return; return;
} }
mutex_lock lock(mu_); absl::MutexLock lock(&mu_);
ok_ = false; ok_ = false;
} }
@ -1998,7 +1998,7 @@ class Stream {
// mutex that guards the allocation / error state flags. // mutex that guards the allocation / error state flags.
// Mutable so that it can be obtained via const reader lock. // Mutable so that it can be obtained via const reader lock.
mutable mutex mu_; mutable absl::Mutex mu_;
// Whether Init() was successfully called to allocate this stream on the // Whether Init() was successfully called to allocate this stream on the
// underlying platform. It simply flips from 0 to 1 with a sanity check. // underlying platform. It simply flips from 0 to 1 with a sanity check.

View File

@ -23,6 +23,7 @@ limitations under the License.
#include <memory> #include <memory>
#include <utility> #include <utility>
#include "absl/base/const_init.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "absl/synchronization/notification.h" #include "absl/synchronization/notification.h"
@ -92,7 +93,7 @@ class ScopedTracer {
void Trace(CallbackT callback, TraceArgsT... args) { void Trace(CallbackT callback, TraceArgsT... args) {
{ {
// Instance tracers held in a block to limit the lock lifetime. // Instance tracers held in a block to limit the lock lifetime.
tf_shared_lock lock{stream_exec_->mu_}; absl::ReaderMutexLock lock{&stream_exec_->mu_};
for (TraceListener *listener : stream_exec_->listeners_) { for (TraceListener *listener : stream_exec_->listeners_) {
(listener->*callback)(correlation_id_, (listener->*callback)(correlation_id_,
std::forward<TraceArgsT>(args)...); std::forward<TraceArgsT>(args)...);
@ -121,7 +122,7 @@ MakeScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call,
auto tracer = MakeScopedTracer(this, &LOC ## Begin, \ auto tracer = MakeScopedTracer(this, &LOC ## Begin, \
&LOC ## Complete, ## __VA_ARGS__); &LOC ## Complete, ## __VA_ARGS__);
/* static */ mutex StreamExecutor::static_mu_{LINKER_INITIALIZED}; /* static */ absl::Mutex StreamExecutor::static_mu_{absl::kConstInit};
// Get per-device memory limit in bytes. Returns 0 if // Get per-device memory limit in bytes. Returns 0 if
// TF_PER_DEVICE_MEMORY_LIMIT_MB environment variable is not set. // TF_PER_DEVICE_MEMORY_LIMIT_MB environment variable is not set.
@ -217,7 +218,7 @@ void StreamExecutor::Deallocate(DeviceMemoryBase *mem) {
} }
void StreamExecutor::GetMemAllocs(std::map<void *, AllocRecord> *records_out) { void StreamExecutor::GetMemAllocs(std::map<void *, AllocRecord> *records_out) {
tf_shared_lock lock(mu_); absl::ReaderMutexLock lock(&mu_);
*records_out = mem_allocs_; *records_out = mem_allocs_;
} }
@ -247,7 +248,7 @@ port::Status StreamExecutor::SetDeviceSharedMemoryConfig(
} }
const DeviceDescription &StreamExecutor::GetDeviceDescription() const { const DeviceDescription &StreamExecutor::GetDeviceDescription() const {
mutex_lock lock(mu_); absl::MutexLock lock(&mu_);
if (device_description_ != nullptr) { if (device_description_ != nullptr) {
return *device_description_; return *device_description_;
} }
@ -393,7 +394,7 @@ StreamExecutor::createRnnStateTensorDescriptor(int num_layer, int batch_size,
} }
dnn::DnnSupport *StreamExecutor::AsDnn() { dnn::DnnSupport *StreamExecutor::AsDnn() {
mutex_lock lock(mu_); absl::MutexLock lock(&mu_);
if (dnn_ != nullptr) { if (dnn_ != nullptr) {
return dnn_.get(); return dnn_.get();
} }
@ -403,7 +404,7 @@ dnn::DnnSupport *StreamExecutor::AsDnn() {
} }
blas::BlasSupport *StreamExecutor::AsBlas() { blas::BlasSupport *StreamExecutor::AsBlas() {
mutex_lock lock(mu_); absl::MutexLock lock(&mu_);
if (blas_ != nullptr) { if (blas_ != nullptr) {
return blas_.get(); return blas_.get();
} }
@ -413,7 +414,7 @@ blas::BlasSupport *StreamExecutor::AsBlas() {
} }
fft::FftSupport *StreamExecutor::AsFft() { fft::FftSupport *StreamExecutor::AsFft() {
mutex_lock lock(mu_); absl::MutexLock lock(&mu_);
if (fft_ != nullptr) { if (fft_ != nullptr) {
return fft_.get(); return fft_.get();
} }
@ -423,7 +424,7 @@ fft::FftSupport *StreamExecutor::AsFft() {
} }
rng::RngSupport *StreamExecutor::AsRng() { rng::RngSupport *StreamExecutor::AsRng() {
mutex_lock lock(mu_); absl::MutexLock lock(&mu_);
if (rng_ != nullptr) { if (rng_ != nullptr) {
return rng_.get(); return rng_.get();
} }
@ -781,7 +782,7 @@ void StreamExecutor::EnqueueOnBackgroundThread(std::function<void()> task) {
void StreamExecutor::CreateAllocRecord(void *opaque, uint64 bytes) { void StreamExecutor::CreateAllocRecord(void *opaque, uint64 bytes) {
if (FLAGS_check_device_leaks && opaque != nullptr && bytes != 0) { if (FLAGS_check_device_leaks && opaque != nullptr && bytes != 0) {
mutex_lock lock(mu_); absl::MutexLock lock(&mu_);
mem_allocs_[opaque] = AllocRecord{ mem_allocs_[opaque] = AllocRecord{
bytes, ""}; bytes, ""};
mem_alloc_bytes_ += bytes; mem_alloc_bytes_ += bytes;
@ -790,7 +791,7 @@ void StreamExecutor::CreateAllocRecord(void *opaque, uint64 bytes) {
void StreamExecutor::EraseAllocRecord(void *opaque) { void StreamExecutor::EraseAllocRecord(void *opaque) {
if (FLAGS_check_device_leaks && opaque != nullptr) { if (FLAGS_check_device_leaks && opaque != nullptr) {
mutex_lock lock(mu_); absl::MutexLock lock(&mu_);
if (mem_allocs_.find(opaque) == mem_allocs_.end()) { if (mem_allocs_.find(opaque) == mem_allocs_.end()) {
LOG(ERROR) << "Deallocating unknown pointer: " << opaque; LOG(ERROR) << "Deallocating unknown pointer: " << opaque;
} else { } else {
@ -804,7 +805,7 @@ void StreamExecutor::EnableTracing(bool enabled) { tracing_enabled_ = enabled; }
void StreamExecutor::RegisterTraceListener(TraceListener *listener) { void StreamExecutor::RegisterTraceListener(TraceListener *listener) {
{ {
mutex_lock lock(mu_); absl::MutexLock lock(&mu_);
if (listeners_.find(listener) != listeners_.end()) { if (listeners_.find(listener) != listeners_.end()) {
LOG(INFO) << "Attempt to register already-registered listener, " LOG(INFO) << "Attempt to register already-registered listener, "
<< listener; << listener;
@ -818,7 +819,7 @@ void StreamExecutor::RegisterTraceListener(TraceListener *listener) {
bool StreamExecutor::UnregisterTraceListener(TraceListener *listener) { bool StreamExecutor::UnregisterTraceListener(TraceListener *listener) {
{ {
mutex_lock lock(mu_); absl::MutexLock lock(&mu_);
if (listeners_.find(listener) == listeners_.end()) { if (listeners_.find(listener) == listeners_.end()) {
LOG(INFO) << "Attempt to unregister unknown listener, " << listener; LOG(INFO) << "Attempt to unregister unknown listener, " << listener;
return false; return false;
@ -839,7 +840,7 @@ void StreamExecutor::SubmitTrace(TraceCallT trace_call, ArgsT &&... args) {
if (tracing_enabled_) { if (tracing_enabled_) {
{ {
// instance tracers held in a block to limit the lock lifetime. // instance tracers held in a block to limit the lock lifetime.
tf_shared_lock lock(mu_); absl::ReaderMutexLock lock(&mu_);
for (TraceListener *listener : listeners_) { for (TraceListener *listener : listeners_) {
(listener->*trace_call)(std::forward<ArgsT>(args)...); (listener->*trace_call)(std::forward<ArgsT>(args)...);
} }

View File

@ -23,13 +23,13 @@ limitations under the License.
#include <vector> #include <vector>
#include "absl/base/macros.h" #include "absl/base/macros.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/lib/threadpool.h" #include "tensorflow/stream_executor/lib/threadpool.h"
#include "tensorflow/stream_executor/platform.h" #include "tensorflow/stream_executor/platform.h"
#include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/platform/logging.h"
#include "tensorflow/stream_executor/platform/mutex.h"
#include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/platform/thread_annotations.h" #include "tensorflow/stream_executor/platform/thread_annotations.h"
#include "tensorflow/stream_executor/rng.h" #include "tensorflow/stream_executor/rng.h"
@ -619,13 +619,13 @@ class StreamExecutor {
void SubmitTrace(TraceCallT trace_call, ArgsT&&... args); void SubmitTrace(TraceCallT trace_call, ArgsT&&... args);
// Reader/writer lock for class-static StreamExecutor members. // Reader/writer lock for class-static StreamExecutor members.
static mutex static_mu_; static absl::Mutex static_mu_;
// Reader/writer lock for mutable data structures on this StreamExecutor. // Reader/writer lock for mutable data structures on this StreamExecutor.
// //
// Mutable so that caching functions (like DeviceDescription, AsBlas, etc.) // Mutable so that caching functions (like DeviceDescription, AsBlas, etc.)
// can acquire the lock on their first (mutating) call as well. // can acquire the lock on their first (mutating) call as well.
mutable mutex mu_; mutable absl::Mutex mu_;
// Reference to the platform that created this executor. // Reference to the platform that created this executor.
const Platform *platform_; const Platform *platform_;

View File

@ -25,7 +25,7 @@ namespace stream_executor {
namespace internal { namespace internal {
void TemporaryMemoryManager::ForceDeallocateAll() { void TemporaryMemoryManager::ForceDeallocateAll() {
mutex_lock lock(mutex_); absl::MutexLock lock(&mutex_);
VLOG(1) << "force-deallocating " << records_.size() << " remaining records"; VLOG(1) << "force-deallocating " << records_.size() << " remaining records";
for (auto it = records_.begin(); it != records_.end(); ++it) { for (auto it = records_.begin(); it != records_.end(); ++it) {
DeviceMemoryBase device_memory = it->first; DeviceMemoryBase device_memory = it->first;
@ -35,7 +35,7 @@ void TemporaryMemoryManager::ForceDeallocateAll() {
void TemporaryMemoryManager::MarkFinalized( void TemporaryMemoryManager::MarkFinalized(
const DeviceMemoryBase& device_memory, uint64 generation, bool must_exist) { const DeviceMemoryBase& device_memory, uint64 generation, bool must_exist) {
mutex_lock lock(mutex_); absl::MutexLock lock(&mutex_);
auto it = records_.find(device_memory); auto it = records_.find(device_memory);
if (it == records_.end()) { if (it == records_.end()) {
if (must_exist) { if (must_exist) {
@ -48,7 +48,7 @@ void TemporaryMemoryManager::MarkFinalized(
} }
void TemporaryMemoryManager::DeallocateFinalizedTemporaries() { void TemporaryMemoryManager::DeallocateFinalizedTemporaries() {
mutex_lock lock(mutex_); absl::MutexLock lock(&mutex_);
int deallocated_count = 0; int deallocated_count = 0;
for (auto it = records_.begin(); it != records_.end();) { for (auto it = records_.begin(); it != records_.end();) {
if (it->second.finalized) { if (it->second.finalized) {
@ -65,7 +65,7 @@ void TemporaryMemoryManager::DeallocateFinalizedTemporaries() {
bool TemporaryMemoryManager::IsFinalized(const DeviceMemoryBase& device_memory, bool TemporaryMemoryManager::IsFinalized(const DeviceMemoryBase& device_memory,
uint64 allocation_generation) const { uint64 allocation_generation) const {
mutex_lock lock(mutex_); absl::MutexLock lock(&mutex_);
auto it = records_.find(device_memory); auto it = records_.find(device_memory);
if (it == records_.end()) { if (it == records_.end()) {
return true; // If there's no record present it's vacuously finalized. return true; // If there's no record present it's vacuously finalized.
@ -81,7 +81,7 @@ bool TemporaryMemoryManager::IsFinalized(const DeviceMemoryBase& device_memory,
bool TemporaryMemoryManager::HasAllocated(const DeviceMemoryBase& device_memory, bool TemporaryMemoryManager::HasAllocated(const DeviceMemoryBase& device_memory,
uint64 generation) const { uint64 generation) const {
mutex_lock lock(mutex_); absl::MutexLock lock(&mutex_);
auto it = records_.find(device_memory); auto it = records_.find(device_memory);
if (it == records_.end()) { if (it == records_.end()) {
return false; return false;
@ -106,7 +106,7 @@ TemporaryMemoryManager::AllocateArrayBase(uint64 element_count,
// Add the record before instantiating the device memory instance so we can // Add the record before instantiating the device memory instance so we can
// check the allocation invariant at TemporaryDeviceMemory construction time. // check the allocation invariant at TemporaryDeviceMemory construction time.
{ {
mutex_lock lock(mutex_); absl::MutexLock lock(&mutex_);
generation = ++generation_; generation = ++generation_;
DCHECK(records_.find(device_memory) == records_.end()); DCHECK(records_.find(device_memory) == records_.end());
records_[device_memory] = {generation, records_[device_memory] = {generation,

View File

@ -24,10 +24,10 @@ limitations under the License.
#include <map> #include <map>
#include <memory> #include <memory>
#include "absl/synchronization/mutex.h"
#include "tensorflow/stream_executor/device_memory.h" #include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/platform/mutex.h"
#include "tensorflow/stream_executor/platform/thread_annotations.h" #include "tensorflow/stream_executor/platform/thread_annotations.h"
#include "tensorflow/stream_executor/temporary_device_memory.h" #include "tensorflow/stream_executor/temporary_device_memory.h"
@ -108,7 +108,7 @@ class TemporaryMemoryManager {
uint64 element_count, uint64 element_size); uint64 element_count, uint64 element_size);
// Mutex to guard temporary record state. // Mutex to guard temporary record state.
mutable mutex mutex_; mutable absl::Mutex mutex_;
// Mapping from device memory to the current (live) temporary memory record. // Mapping from device memory to the current (live) temporary memory record.
// //