[SE] Use absl mutex instead of TF mutex.
Fix layering violation. PiperOrigin-RevId: 245851033
This commit is contained in:
parent
9904676451
commit
481748b5c9
@ -93,6 +93,7 @@ cc_library(
|
|||||||
"//tensorflow/stream_executor/platform",
|
"//tensorflow/stream_executor/platform",
|
||||||
"@com_google_absl//absl/base:core_headers",
|
"@com_google_absl//absl/base:core_headers",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/synchronization",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
],
|
],
|
||||||
@ -140,6 +141,7 @@ cc_library(
|
|||||||
"//tensorflow/stream_executor/platform",
|
"//tensorflow/stream_executor/platform",
|
||||||
"@com_google_absl//absl/base:core_headers",
|
"@com_google_absl//absl/base:core_headers",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/synchronization",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
],
|
],
|
||||||
@ -152,6 +154,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/stream_executor/platform",
|
"//tensorflow/stream_executor/platform",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/synchronization",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -199,6 +202,7 @@ cc_library(
|
|||||||
"//tensorflow/stream_executor/lib",
|
"//tensorflow/stream_executor/lib",
|
||||||
"//tensorflow/stream_executor/platform",
|
"//tensorflow/stream_executor/platform",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/synchronization",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
],
|
],
|
||||||
@ -231,6 +235,7 @@ cc_library(
|
|||||||
"//tensorflow/stream_executor/platform",
|
"//tensorflow/stream_executor/platform",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/synchronization",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -259,6 +264,7 @@ cc_library(
|
|||||||
"//tensorflow/stream_executor/platform",
|
"//tensorflow/stream_executor/platform",
|
||||||
"@com_google_absl//absl/base:core_headers",
|
"@com_google_absl//absl/base:core_headers",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/synchronization",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -295,6 +301,7 @@ cc_library(
|
|||||||
":stream_header",
|
":stream_header",
|
||||||
"//tensorflow/stream_executor/lib",
|
"//tensorflow/stream_executor/lib",
|
||||||
"//tensorflow/stream_executor/platform",
|
"//tensorflow/stream_executor/platform",
|
||||||
|
"@com_google_absl//absl/synchronization",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -311,6 +318,7 @@ cc_library(
|
|||||||
"//tensorflow/stream_executor/platform",
|
"//tensorflow/stream_executor/platform",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
"@com_google_absl//absl/synchronization",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -399,6 +407,7 @@ cc_library(
|
|||||||
"@com_google_absl//absl/base:core_headers",
|
"@com_google_absl//absl/base:core_headers",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
"@com_google_absl//absl/synchronization",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
],
|
],
|
||||||
@ -440,10 +449,10 @@ cc_library(
|
|||||||
":plugin",
|
":plugin",
|
||||||
":stream_executor_headers",
|
":stream_executor_headers",
|
||||||
"//tensorflow/stream_executor/lib",
|
"//tensorflow/stream_executor/lib",
|
||||||
"//tensorflow/stream_executor/platform",
|
|
||||||
"@com_google_absl//absl/base:core_headers",
|
"@com_google_absl//absl/base:core_headers",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
"@com_google_absl//absl/synchronization",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -536,6 +545,7 @@ cc_library(
|
|||||||
"//tensorflow/stream_executor/platform",
|
"//tensorflow/stream_executor/platform",
|
||||||
"@com_google_absl//absl/base:core_headers",
|
"@com_google_absl//absl/base:core_headers",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/synchronization",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
],
|
],
|
||||||
@ -612,6 +622,7 @@ cc_library(
|
|||||||
"//tensorflow/stream_executor/platform",
|
"//tensorflow/stream_executor/platform",
|
||||||
"@com_google_absl//absl/base:core_headers",
|
"@com_google_absl//absl/base:core_headers",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/synchronization",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
],
|
],
|
||||||
|
@ -59,9 +59,11 @@ cc_library(
|
|||||||
"//tensorflow/stream_executor/platform",
|
"//tensorflow/stream_executor/platform",
|
||||||
],
|
],
|
||||||
) + tf_additional_cuda_platform_deps() + [
|
) + tf_additional_cuda_platform_deps() + [
|
||||||
|
"@com_google_absl//absl/base:core_headers",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
"@com_google_absl//absl/synchronization",
|
||||||
],
|
],
|
||||||
alwayslink = True, # Registers itself with the MultiPlatformManager.
|
alwayslink = True, # Registers itself with the MultiPlatformManager.
|
||||||
)
|
)
|
||||||
@ -112,6 +114,7 @@ cc_library(
|
|||||||
"//tensorflow:using_cuda_clang_with_dynamic_build": ["cudart_stub"],
|
"//tensorflow:using_cuda_clang_with_dynamic_build": ["cudart_stub"],
|
||||||
"//conditions:default": ["//tensorflow/core:cuda"],
|
"//conditions:default": ["//tensorflow/core:cuda"],
|
||||||
}) + [
|
}) + [
|
||||||
|
"@com_google_absl//absl/base:core_headers",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
"@com_google_absl//absl/synchronization",
|
"@com_google_absl//absl/synchronization",
|
||||||
@ -223,7 +226,10 @@ cc_library(
|
|||||||
] + if_static(
|
] + if_static(
|
||||||
["@local_config_cuda//cuda:cublas"],
|
["@local_config_cuda//cuda:cublas"],
|
||||||
[":cublas_stub"],
|
[":cublas_stub"],
|
||||||
)) + ["@com_google_absl//absl/strings:str_format"],
|
)) + [
|
||||||
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
"@com_google_absl//absl/synchronization",
|
||||||
|
],
|
||||||
alwayslink = True,
|
alwayslink = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -306,7 +312,7 @@ cc_library(
|
|||||||
]) + tf_additional_cudnn_plugin_deps() + if_cuda_is_configured(if_static(
|
]) + tf_additional_cudnn_plugin_deps() + if_cuda_is_configured(if_static(
|
||||||
["@local_config_cuda//cuda:cudnn"],
|
["@local_config_cuda//cuda:cudnn"],
|
||||||
[":cudnn_stub"],
|
[":cudnn_stub"],
|
||||||
)),
|
)) + ["@com_google_absl//absl/synchronization"],
|
||||||
alwayslink = True,
|
alwayslink = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -402,7 +402,7 @@ template <typename FuncT, typename... Args>
|
|||||||
bool CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
|
bool CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
|
||||||
bool pointer_mode_host, bool err_on_failure,
|
bool pointer_mode_host, bool err_on_failure,
|
||||||
bool use_tensor_op_math, Args... args) {
|
bool use_tensor_op_math, Args... args) {
|
||||||
mutex_lock lock(mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
|
|
||||||
CHECK(blas_ != nullptr);
|
CHECK(blas_ != nullptr);
|
||||||
if (!SetStream(stream)) {
|
if (!SetStream(stream)) {
|
||||||
|
@ -20,9 +20,9 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_
|
#ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_
|
||||||
#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_
|
#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_
|
||||||
|
|
||||||
|
#include "absl/synchronization/mutex.h"
|
||||||
#include "tensorflow/stream_executor/blas.h"
|
#include "tensorflow/stream_executor/blas.h"
|
||||||
#include "tensorflow/stream_executor/host_or_device_scalar.h"
|
#include "tensorflow/stream_executor/host_or_device_scalar.h"
|
||||||
#include "tensorflow/stream_executor/platform/mutex.h"
|
|
||||||
#include "tensorflow/stream_executor/platform/port.h"
|
#include "tensorflow/stream_executor/platform/port.h"
|
||||||
#include "tensorflow/stream_executor/platform/thread_annotations.h"
|
#include "tensorflow/stream_executor/platform/thread_annotations.h"
|
||||||
#include "tensorflow/stream_executor/plugin_registry.h"
|
#include "tensorflow/stream_executor/plugin_registry.h"
|
||||||
@ -142,8 +142,8 @@ class CUDABlas : public blas::BlasSupport {
|
|||||||
const T &beta, DeviceMemory<T> *y, int incy,
|
const T &beta, DeviceMemory<T> *y, int incy,
|
||||||
blas::ProfileResult *output_profile_result);
|
blas::ProfileResult *output_profile_result);
|
||||||
|
|
||||||
// mutex that guards the cuBLAS handle for this device.
|
// Guards the cuBLAS handle for this device.
|
||||||
mutex mu_;
|
absl::Mutex mu_;
|
||||||
|
|
||||||
// GpuExecutor which instantiated this CUDABlas.
|
// GpuExecutor which instantiated this CUDABlas.
|
||||||
// Immutable post-initialization.
|
// Immutable post-initialization.
|
||||||
|
@ -133,8 +133,8 @@ class CudnnHandle {
|
|||||||
public:
|
public:
|
||||||
// Takes ownership of the executor context and the lock to access cuDNN
|
// Takes ownership of the executor context and the lock to access cuDNN
|
||||||
// using handle.
|
// using handle.
|
||||||
CudnnHandle(gpu::ScopedActivateExecutorContext context, mutex_lock lock,
|
CudnnHandle(gpu::ScopedActivateExecutorContext context,
|
||||||
cudnnHandle_t handle)
|
std::unique_ptr<absl::MutexLock> lock, cudnnHandle_t handle)
|
||||||
: context_(std::move(context)), lock_(std::move(lock)), handle_(handle) {}
|
: context_(std::move(context)), lock_(std::move(lock)), handle_(handle) {}
|
||||||
|
|
||||||
// Returns cuDNN handle. To be passed directly to cuDNN APIs, don't keep
|
// Returns cuDNN handle. To be passed directly to cuDNN APIs, don't keep
|
||||||
@ -143,7 +143,7 @@ class CudnnHandle {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
gpu::ScopedActivateExecutorContext context_;
|
gpu::ScopedActivateExecutorContext context_;
|
||||||
mutex_lock lock_;
|
std::unique_ptr<absl::MutexLock> lock_;
|
||||||
cudnnHandle_t handle_; // Not owned.
|
cudnnHandle_t handle_; // Not owned.
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -160,7 +160,7 @@ class CudnnAccess {
|
|||||||
explicit CudnnAccess(cudnnHandle_t handle) : handle_(handle) {}
|
explicit CudnnAccess(cudnnHandle_t handle) : handle_(handle) {}
|
||||||
|
|
||||||
~CudnnAccess() {
|
~CudnnAccess() {
|
||||||
mutex_lock lock(mutex_);
|
absl::MutexLock lock(&mutex_);
|
||||||
cudnnDestroy(handle_);
|
cudnnDestroy(handle_);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -181,7 +181,8 @@ class CudnnAccess {
|
|||||||
// therefore a bad idea (performance wise) to call any cuDNN APIs that
|
// therefore a bad idea (performance wise) to call any cuDNN APIs that
|
||||||
// enqueue work in the stream.
|
// enqueue work in the stream.
|
||||||
CudnnHandle GetHandle(GpuExecutor* executor, Stream* stream) {
|
CudnnHandle GetHandle(GpuExecutor* executor, Stream* stream) {
|
||||||
mutex_lock lock(mutex_);
|
auto lock = absl::make_unique<absl::MutexLock>(&mutex_);
|
||||||
|
mutex_.AssertHeld();
|
||||||
gpu::ScopedActivateExecutorContext context(executor);
|
gpu::ScopedActivateExecutorContext context(executor);
|
||||||
CUstream cu_stream = stream ? AsGpuStreamValue(stream) : cudaStreamLegacy;
|
CUstream cu_stream = stream ? AsGpuStreamValue(stream) : cudaStreamLegacy;
|
||||||
const auto status = cudnnSetStream(handle_, cu_stream);
|
const auto status = cudnnSetStream(handle_, cu_stream);
|
||||||
@ -191,7 +192,7 @@ class CudnnAccess {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
// Guards the enqueueing of cuDNN operations via the handle_ below.
|
// Guards the enqueueing of cuDNN operations via the handle_ below.
|
||||||
mutex mutex_;
|
absl::Mutex mutex_;
|
||||||
|
|
||||||
// cuDNN library handle.
|
// cuDNN library handle.
|
||||||
cudnnHandle_t handle_ GUARDED_BY(mutex_); // Owned.
|
cudnnHandle_t handle_ GUARDED_BY(mutex_); // Owned.
|
||||||
|
@ -22,7 +22,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
|
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
|
||||||
#include "tensorflow/stream_executor/dnn.h"
|
#include "tensorflow/stream_executor/dnn.h"
|
||||||
#include "tensorflow/stream_executor/lib/status.h"
|
#include "tensorflow/stream_executor/lib/status.h"
|
||||||
#include "tensorflow/stream_executor/platform/mutex.h"
|
|
||||||
#include "tensorflow/stream_executor/platform/thread_annotations.h"
|
#include "tensorflow/stream_executor/platform/thread_annotations.h"
|
||||||
#include "tensorflow/stream_executor/plugin_registry.h"
|
#include "tensorflow/stream_executor/plugin_registry.h"
|
||||||
#include "tensorflow/stream_executor/temporary_device_memory.h"
|
#include "tensorflow/stream_executor/temporary_device_memory.h"
|
||||||
|
@ -23,11 +23,13 @@ limitations under the License.
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "absl/base/casts.h"
|
#include "absl/base/casts.h"
|
||||||
|
#include "absl/base/const_init.h"
|
||||||
#include "absl/container/inlined_vector.h"
|
#include "absl/container/inlined_vector.h"
|
||||||
#include "absl/debugging/leak_check.h"
|
#include "absl/debugging/leak_check.h"
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
|
#include "absl/synchronization/mutex.h"
|
||||||
#include "absl/synchronization/notification.h"
|
#include "absl/synchronization/notification.h"
|
||||||
#include "cuda/include/cuda_runtime_api.h"
|
#include "cuda/include/cuda_runtime_api.h"
|
||||||
#include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
|
#include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
|
||||||
@ -38,7 +40,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/stream_executor/lib/static_threadlocal.h"
|
#include "tensorflow/stream_executor/lib/static_threadlocal.h"
|
||||||
#include "tensorflow/stream_executor/lib/threadpool.h"
|
#include "tensorflow/stream_executor/lib/threadpool.h"
|
||||||
#include "tensorflow/stream_executor/platform/logging.h"
|
#include "tensorflow/stream_executor/platform/logging.h"
|
||||||
#include "tensorflow/stream_executor/platform/mutex.h"
|
|
||||||
#include "tensorflow/stream_executor/platform/port.h"
|
#include "tensorflow/stream_executor/platform/port.h"
|
||||||
|
|
||||||
bool FLAGS_gpuexec_cuda_driver_inject_init_error = false;
|
bool FLAGS_gpuexec_cuda_driver_inject_init_error = false;
|
||||||
@ -65,14 +66,14 @@ class CreatedContexts {
|
|||||||
public:
|
public:
|
||||||
// Returns whether context is a member of the live set.
|
// Returns whether context is a member of the live set.
|
||||||
static bool Has(CUcontext context) {
|
static bool Has(CUcontext context) {
|
||||||
tf_shared_lock lock(mu_);
|
absl::ReaderMutexLock lock(&mu_);
|
||||||
return Live()->find(context) != Live()->end();
|
return Live()->find(context) != Live()->end();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Adds context to the live set, or returns it if it's already present.
|
// Adds context to the live set, or returns it if it's already present.
|
||||||
static GpuContext* Add(CUcontext context) {
|
static GpuContext* Add(CUcontext context) {
|
||||||
CHECK(context != nullptr);
|
CHECK(context != nullptr);
|
||||||
mutex_lock lock(mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
auto insert_result = Live()->insert(std::make_pair(context, nullptr));
|
auto insert_result = Live()->insert(std::make_pair(context, nullptr));
|
||||||
auto it = insert_result.first;
|
auto it = insert_result.first;
|
||||||
if (insert_result.second) {
|
if (insert_result.second) {
|
||||||
@ -85,7 +86,7 @@ class CreatedContexts {
|
|||||||
// Removes context from the live set.
|
// Removes context from the live set.
|
||||||
static void Remove(CUcontext context) {
|
static void Remove(CUcontext context) {
|
||||||
CHECK(context != nullptr);
|
CHECK(context != nullptr);
|
||||||
mutex_lock lock(mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
auto it = Live()->find(context);
|
auto it = Live()->find(context);
|
||||||
CHECK(it != Live()->end()) << context;
|
CHECK(it != Live()->end()) << context;
|
||||||
Live()->erase(it);
|
Live()->erase(it);
|
||||||
@ -100,11 +101,11 @@ class CreatedContexts {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Lock that guards access-to/mutation-of the live set.
|
// Lock that guards access-to/mutation-of the live set.
|
||||||
static mutex mu_;
|
static absl::Mutex mu_;
|
||||||
static int64 next_id_;
|
static int64 next_id_;
|
||||||
};
|
};
|
||||||
|
|
||||||
/* static */ mutex CreatedContexts::mu_{LINKER_INITIALIZED};
|
/* static */ absl::Mutex CreatedContexts::mu_{absl::kConstInit};
|
||||||
/* static */ int64 CreatedContexts::next_id_ = 1; // 0 means "no context"
|
/* static */ int64 CreatedContexts::next_id_ = 1; // 0 means "no context"
|
||||||
|
|
||||||
// Formats CUresult to output prettified values into a log stream.
|
// Formats CUresult to output prettified values into a log stream.
|
||||||
@ -322,17 +323,10 @@ static port::Status InternalInit() {
|
|||||||
/* static */ port::Status GpuDriver::Init() {
|
/* static */ port::Status GpuDriver::Init() {
|
||||||
// Cached return value from calling InternalInit(), as cuInit need only be
|
// Cached return value from calling InternalInit(), as cuInit need only be
|
||||||
// called once, but GpuDriver::Init may be called many times.
|
// called once, but GpuDriver::Init may be called many times.
|
||||||
static port::Status init_retval;
|
static port::Status* init_retval = [] {
|
||||||
static bool set = false;
|
return new port::Status(InternalInit());
|
||||||
static mutex* init_mu = new mutex;
|
}();
|
||||||
|
return *init_retval;
|
||||||
mutex_lock lock(*init_mu);
|
|
||||||
if (!set) {
|
|
||||||
init_retval = InternalInit();
|
|
||||||
set = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
return init_retval;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */ port::Status GpuDriver::GetDevice(int device_ordinal,
|
/* static */ port::Status GpuDriver::GetDevice(int device_ordinal,
|
||||||
|
@ -272,7 +272,7 @@ bool GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec,
|
|||||||
VLOG(3) << "GetKernel on kernel " << kernel << " : " << kernel->name();
|
VLOG(3) << "GetKernel on kernel " << kernel << " : " << kernel->name();
|
||||||
|
|
||||||
if (spec.has_cuda_cubin_in_memory()) {
|
if (spec.has_cuda_cubin_in_memory()) {
|
||||||
mutex_lock lock{in_memory_modules_mu_};
|
absl::MutexLock lock{&in_memory_modules_mu_};
|
||||||
kernelname = &spec.cuda_cubin_in_memory().kernelname();
|
kernelname = &spec.cuda_cubin_in_memory().kernelname();
|
||||||
const char *cubin = spec.cuda_cubin_in_memory().bytes();
|
const char *cubin = spec.cuda_cubin_in_memory().bytes();
|
||||||
if (!LoadModuleFromCuBin(cubin, &module)) {
|
if (!LoadModuleFromCuBin(cubin, &module)) {
|
||||||
@ -295,7 +295,7 @@ bool GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
mutex_lock lock{in_memory_modules_mu_};
|
absl::MutexLock lock{&in_memory_modules_mu_};
|
||||||
if (!LoadModuleFromPtx(ptx, &module)) {
|
if (!LoadModuleFromPtx(ptx, &module)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -343,7 +343,7 @@ bool GpuExecutor::UnloadGpuBinary(const void* gpu_binary) {
|
|||||||
void GpuExecutor::UnloadKernel(const KernelBase* kernel) {
|
void GpuExecutor::UnloadKernel(const KernelBase* kernel) {
|
||||||
VLOG(3) << "Unloading kernel " << kernel << " : " << kernel->name();
|
VLOG(3) << "Unloading kernel " << kernel << " : " << kernel->name();
|
||||||
|
|
||||||
mutex_lock lock{in_memory_modules_mu_};
|
absl::MutexLock lock{&in_memory_modules_mu_};
|
||||||
auto gpu_binary_it = kernel_to_gpu_binary_.find(kernel);
|
auto gpu_binary_it = kernel_to_gpu_binary_.find(kernel);
|
||||||
if (kernel_to_gpu_binary_.end() == gpu_binary_it) {
|
if (kernel_to_gpu_binary_.end() == gpu_binary_it) {
|
||||||
VLOG(3) << "Kernel " << kernel << " : " << kernel->name()
|
VLOG(3) << "Kernel " << kernel << " : " << kernel->name()
|
||||||
@ -362,7 +362,7 @@ bool GpuExecutor::LoadModule(const MultiModuleLoaderSpec& spec,
|
|||||||
// ModuleHandle::id().
|
// ModuleHandle::id().
|
||||||
CUmodule cu_module;
|
CUmodule cu_module;
|
||||||
if (spec.has_cuda_cubin_in_memory()) {
|
if (spec.has_cuda_cubin_in_memory()) {
|
||||||
mutex_lock lock{in_memory_modules_mu_};
|
absl::MutexLock lock{&in_memory_modules_mu_};
|
||||||
if (!LoadModuleFromCuBin(
|
if (!LoadModuleFromCuBin(
|
||||||
reinterpret_cast<const char *>(spec.cuda_cubin_in_memory().data()),
|
reinterpret_cast<const char *>(spec.cuda_cubin_in_memory().data()),
|
||||||
&cu_module)) {
|
&cu_module)) {
|
||||||
@ -380,7 +380,7 @@ bool GpuExecutor::LoadModule(const MultiModuleLoaderSpec& spec,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
mutex_lock lock{in_memory_modules_mu_};
|
absl::MutexLock lock{&in_memory_modules_mu_};
|
||||||
if (!LoadModuleFromPtx(spec.cuda_ptx_in_memory(), &cu_module)) {
|
if (!LoadModuleFromPtx(spec.cuda_ptx_in_memory(), &cu_module)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -394,7 +394,7 @@ bool GpuExecutor::LoadModule(const MultiModuleLoaderSpec& spec,
|
|||||||
|
|
||||||
bool GpuExecutor::UnloadModule(ModuleHandle module_handle) {
|
bool GpuExecutor::UnloadModule(ModuleHandle module_handle) {
|
||||||
const char *gpu_binary = reinterpret_cast<const char *>(module_handle.id());
|
const char *gpu_binary = reinterpret_cast<const char *>(module_handle.id());
|
||||||
mutex_lock lock{in_memory_modules_mu_};
|
absl::MutexLock lock{&in_memory_modules_mu_};
|
||||||
return UnloadGpuBinary(gpu_binary);
|
return UnloadGpuBinary(gpu_binary);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -428,7 +428,7 @@ bool GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims,
|
|||||||
// whether we've done an occupancy check on this kernel before isn't free
|
// whether we've done an occupancy check on this kernel before isn't free
|
||||||
// (because we have to synchronize), so we only do this at -v 2+.
|
// (because we have to synchronize), so we only do this at -v 2+.
|
||||||
if (VLOG_IS_ON(2)) {
|
if (VLOG_IS_ON(2)) {
|
||||||
mutex_lock lock(launched_kernels_mu_);
|
absl::MutexLock lock(&launched_kernels_mu_);
|
||||||
if (!launched_kernels_.count(cufunc)) {
|
if (!launched_kernels_.count(cufunc)) {
|
||||||
VlogOccupancyInfo(kernel, thread_dims, block_dims);
|
VlogOccupancyInfo(kernel, thread_dims, block_dims);
|
||||||
// TODO(rspringer): Remove elements from launched_kernels_...if we ever
|
// TODO(rspringer): Remove elements from launched_kernels_...if we ever
|
||||||
@ -892,7 +892,7 @@ bool GpuExecutor::GetSymbol(const string& symbol_name,
|
|||||||
};
|
};
|
||||||
|
|
||||||
{ // give limited scope to mutex_lock
|
{ // give limited scope to mutex_lock
|
||||||
mutex_lock lock{in_memory_modules_mu_};
|
absl::MutexLock lock{&in_memory_modules_mu_};
|
||||||
if (static_cast<bool>(module_handle)) {
|
if (static_cast<bool>(module_handle)) {
|
||||||
auto it = gpu_binary_to_module_.find(module_handle.id());
|
auto it = gpu_binary_to_module_.find(module_handle.id());
|
||||||
CHECK(it != gpu_binary_to_module_.end());
|
CHECK(it != gpu_binary_to_module_.end());
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/stream_executor/cuda/cuda_platform.h"
|
#include "tensorflow/stream_executor/cuda/cuda_platform.h"
|
||||||
|
|
||||||
|
#include "absl/base/const_init.h"
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
@ -75,13 +76,8 @@ CudaPlatform::~CudaPlatform() {}
|
|||||||
void CudaPlatform::InspectNumaNodes() {
|
void CudaPlatform::InspectNumaNodes() {
|
||||||
// To get NUMA node information, we need to create all executors, so we can
|
// To get NUMA node information, we need to create all executors, so we can
|
||||||
// examine their device descriptions to see their bus assignments.
|
// examine their device descriptions to see their bus assignments.
|
||||||
static bool initialized = false;
|
static std::once_flag once;
|
||||||
static mutex numa_mutex(LINKER_INITIALIZED);
|
std::call_once(once, [&] {
|
||||||
mutex_lock lock(numa_mutex);
|
|
||||||
if (initialized) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
StreamExecutorConfig config;
|
StreamExecutorConfig config;
|
||||||
for (int i = 0; i < VisibleDeviceCount(); i++) {
|
for (int i = 0; i < VisibleDeviceCount(); i++) {
|
||||||
config.ordinal = i;
|
config.ordinal = i;
|
||||||
@ -94,11 +90,11 @@ void CudaPlatform::InspectNumaNodes() {
|
|||||||
} else {
|
} else {
|
||||||
min_numa_node_ =
|
min_numa_node_ =
|
||||||
std::min(min_numa_node_, exec->GetDeviceDescription().numa_node());
|
std::min(min_numa_node_, exec->GetDeviceDescription().numa_node());
|
||||||
limit_numa_node_ = std::max(limit_numa_node_,
|
limit_numa_node_ = std::max(
|
||||||
exec->GetDeviceDescription().numa_node() + 1);
|
limit_numa_node_, exec->GetDeviceDescription().numa_node() + 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
initialized = true;
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
int CudaPlatform::BusCount() {
|
int CudaPlatform::BusCount() {
|
||||||
|
@ -24,7 +24,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||||
#include "tensorflow/stream_executor/multi_platform_manager.h"
|
#include "tensorflow/stream_executor/multi_platform_manager.h"
|
||||||
#include "tensorflow/stream_executor/platform.h"
|
#include "tensorflow/stream_executor/platform.h"
|
||||||
#include "tensorflow/stream_executor/platform/mutex.h"
|
|
||||||
#include "tensorflow/stream_executor/platform/port.h"
|
#include "tensorflow/stream_executor/platform/port.h"
|
||||||
#include "tensorflow/stream_executor/platform/thread_annotations.h"
|
#include "tensorflow/stream_executor/platform/thread_annotations.h"
|
||||||
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
||||||
|
@ -71,7 +71,7 @@ GpuRng::~GpuRng() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool GpuRng::Init() {
|
bool GpuRng::Init() {
|
||||||
mutex_lock lock(mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
CHECK(rng_ == nullptr);
|
CHECK(rng_ == nullptr);
|
||||||
|
|
||||||
cuda::ScopedActivateExecutorContext sac(parent_);
|
cuda::ScopedActivateExecutorContext sac(parent_);
|
||||||
@ -106,7 +106,7 @@ constexpr bool ComplexIsConsecutiveFloats() {
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
bool GpuRng::DoPopulateRandUniformInternal(Stream* stream, DeviceMemory<T>* v) {
|
bool GpuRng::DoPopulateRandUniformInternal(Stream* stream, DeviceMemory<T>* v) {
|
||||||
mutex_lock lock(mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
static_assert(ComplexIsConsecutiveFloats(),
|
static_assert(ComplexIsConsecutiveFloats(),
|
||||||
"std::complex values are not stored as consecutive values");
|
"std::complex values are not stored as consecutive values");
|
||||||
|
|
||||||
@ -164,7 +164,7 @@ bool GpuRng::DoPopulateRandGaussianInternal(Stream* stream, ElemT mean,
|
|||||||
ElemT stddev,
|
ElemT stddev,
|
||||||
DeviceMemory<ElemT>* v,
|
DeviceMemory<ElemT>* v,
|
||||||
FuncT func) {
|
FuncT func) {
|
||||||
mutex_lock lock(mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
|
|
||||||
if (!SetStream(stream)) {
|
if (!SetStream(stream)) {
|
||||||
return false;
|
return false;
|
||||||
@ -197,7 +197,7 @@ bool GpuRng::DoPopulateRandGaussian(Stream* stream, double mean, double stddev,
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool GpuRng::SetSeed(Stream* stream, const uint8* seed, uint64 seed_bytes) {
|
bool GpuRng::SetSeed(Stream* stream, const uint8* seed, uint64 seed_bytes) {
|
||||||
mutex_lock lock(mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
CHECK(rng_ != nullptr);
|
CHECK(rng_ != nullptr);
|
||||||
|
|
||||||
if (!CheckSeed(seed, seed_bytes)) {
|
if (!CheckSeed(seed, seed_bytes)) {
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/stream_executor/executor_cache.h"
|
#include "tensorflow/stream_executor/executor_cache.h"
|
||||||
|
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
|
#include "absl/synchronization/mutex.h"
|
||||||
|
|
||||||
namespace stream_executor {
|
namespace stream_executor {
|
||||||
|
|
||||||
@ -32,7 +33,7 @@ port::StatusOr<StreamExecutor*> ExecutorCache::GetOrCreate(
|
|||||||
|
|
||||||
Entry* entry = nullptr;
|
Entry* entry = nullptr;
|
||||||
{
|
{
|
||||||
mutex_lock lock{mutex_};
|
absl::MutexLock lock{&mutex_};
|
||||||
entry = &cache_[config.ordinal];
|
entry = &cache_[config.ordinal];
|
||||||
// Release the map lock; the address of 'entry' is stable because
|
// Release the map lock; the address of 'entry' is stable because
|
||||||
// std::map guarantees reference stability.
|
// std::map guarantees reference stability.
|
||||||
@ -41,7 +42,7 @@ port::StatusOr<StreamExecutor*> ExecutorCache::GetOrCreate(
|
|||||||
// Acquire the per-Entry mutex without holding the map mutex. Initializing
|
// Acquire the per-Entry mutex without holding the map mutex. Initializing
|
||||||
// an Executor may be expensive, so we want to allow concurrent
|
// an Executor may be expensive, so we want to allow concurrent
|
||||||
// initialization of different entries.
|
// initialization of different entries.
|
||||||
mutex_lock lock{entry->configurations_mutex};
|
absl::MutexLock lock{&entry->configurations_mutex};
|
||||||
for (const auto& iter : entry->configurations) {
|
for (const auto& iter : entry->configurations) {
|
||||||
if (iter.first.plugin_config == config.plugin_config &&
|
if (iter.first.plugin_config == config.plugin_config &&
|
||||||
iter.first.device_options == config.device_options) {
|
iter.first.device_options == config.device_options) {
|
||||||
@ -66,7 +67,7 @@ port::StatusOr<StreamExecutor*> ExecutorCache::Get(
|
|||||||
const StreamExecutorConfig& config) {
|
const StreamExecutorConfig& config) {
|
||||||
Entry* entry = nullptr;
|
Entry* entry = nullptr;
|
||||||
{
|
{
|
||||||
tf_shared_lock lock{mutex_};
|
absl::ReaderMutexLock lock{&mutex_};
|
||||||
auto it = cache_.find(config.ordinal);
|
auto it = cache_.find(config.ordinal);
|
||||||
if (it != cache_.end()) {
|
if (it != cache_.end()) {
|
||||||
entry = &it->second;
|
entry = &it->second;
|
||||||
@ -77,7 +78,7 @@ port::StatusOr<StreamExecutor*> ExecutorCache::Get(
|
|||||||
config.ordinal));
|
config.ordinal));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
tf_shared_lock lock{entry->configurations_mutex};
|
absl::ReaderMutexLock lock{&entry->configurations_mutex};
|
||||||
if (entry->configurations.empty()) {
|
if (entry->configurations.empty()) {
|
||||||
return port::Status(
|
return port::Status(
|
||||||
port::error::NOT_FOUND,
|
port::error::NOT_FOUND,
|
||||||
@ -96,12 +97,12 @@ port::StatusOr<StreamExecutor*> ExecutorCache::Get(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void ExecutorCache::DestroyAllExecutors() {
|
void ExecutorCache::DestroyAllExecutors() {
|
||||||
mutex_lock lock{mutex_};
|
absl::MutexLock lock{&mutex_};
|
||||||
cache_.clear();
|
cache_.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
ExecutorCache::Entry::~Entry() {
|
ExecutorCache::Entry::~Entry() {
|
||||||
mutex_lock lock{configurations_mutex};
|
absl::MutexLock lock{&configurations_mutex};
|
||||||
configurations.clear();
|
configurations.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -19,9 +19,9 @@ limitations under the License.
|
|||||||
#include <functional>
|
#include <functional>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
|
||||||
|
#include "absl/synchronization/mutex.h"
|
||||||
#include "tensorflow/stream_executor/lib/status.h"
|
#include "tensorflow/stream_executor/lib/status.h"
|
||||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||||
#include "tensorflow/stream_executor/platform/mutex.h"
|
|
||||||
#include "tensorflow/stream_executor/stream_executor_pimpl.h"
|
#include "tensorflow/stream_executor/stream_executor_pimpl.h"
|
||||||
|
|
||||||
namespace stream_executor {
|
namespace stream_executor {
|
||||||
@ -54,11 +54,11 @@ class ExecutorCache {
|
|||||||
struct Entry {
|
struct Entry {
|
||||||
~Entry();
|
~Entry();
|
||||||
|
|
||||||
// Mutex that locks the contents of each entry. The 'mutex_' of the
|
// Mutex that guards the contents of each entry. The 'mutex_' of the
|
||||||
// ExecutorCache class protects both the 'cache_' and the existence of each
|
// ExecutorCache class protects both the 'cache_' and the existence of each
|
||||||
// Entry, but not the Entry's contents. 'configurations_mutex' protects the
|
// Entry, but not the Entry's contents. 'configurations_mutex' protects the
|
||||||
// contents of the entry after 'mutex_' has been dropped.
|
// contents of the entry after 'mutex_' has been dropped.
|
||||||
mutex configurations_mutex;
|
absl::Mutex configurations_mutex;
|
||||||
|
|
||||||
// Vector of cached {config, executor} pairs.
|
// Vector of cached {config, executor} pairs.
|
||||||
std::vector<
|
std::vector<
|
||||||
@ -69,7 +69,7 @@ class ExecutorCache {
|
|||||||
// Maps ordinal number to a list of cached executors for that ordinal.
|
// Maps ordinal number to a list of cached executors for that ordinal.
|
||||||
// We key off of ordinal (instead of just looking up all fields in the
|
// We key off of ordinal (instead of just looking up all fields in the
|
||||||
// StreamExecutorConfig) for a slight improvement in lookup time.
|
// StreamExecutorConfig) for a slight improvement in lookup time.
|
||||||
mutex mutex_;
|
absl::Mutex mutex_;
|
||||||
std::map<int, Entry> cache_ GUARDED_BY(mutex_);
|
std::map<int, Entry> cache_ GUARDED_BY(mutex_);
|
||||||
|
|
||||||
SE_DISALLOW_COPY_AND_ASSIGN(ExecutorCache);
|
SE_DISALLOW_COPY_AND_ASSIGN(ExecutorCache);
|
||||||
|
@ -98,6 +98,7 @@ cc_library(
|
|||||||
"//tensorflow/stream_executor/lib",
|
"//tensorflow/stream_executor/lib",
|
||||||
"//tensorflow/stream_executor/platform",
|
"//tensorflow/stream_executor/platform",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/synchronization",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -126,6 +127,7 @@ cc_library(
|
|||||||
"//tensorflow/stream_executor:plugin_registry",
|
"//tensorflow/stream_executor:plugin_registry",
|
||||||
"//tensorflow/stream_executor:rng",
|
"//tensorflow/stream_executor:rng",
|
||||||
"//tensorflow/stream_executor/platform",
|
"//tensorflow/stream_executor/platform",
|
||||||
|
"@com_google_absl//absl/synchronization",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -26,12 +26,12 @@ limitations under the License.
|
|||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "absl/synchronization/mutex.h"
|
||||||
#include "tensorflow/stream_executor/event.h"
|
#include "tensorflow/stream_executor/event.h"
|
||||||
#include "tensorflow/stream_executor/gpu/gpu_kernel.h"
|
#include "tensorflow/stream_executor/gpu/gpu_kernel.h"
|
||||||
#include "tensorflow/stream_executor/lib/status.h"
|
#include "tensorflow/stream_executor/lib/status.h"
|
||||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||||
#include "tensorflow/stream_executor/platform.h"
|
#include "tensorflow/stream_executor/platform.h"
|
||||||
#include "tensorflow/stream_executor/platform/mutex.h"
|
|
||||||
#include "tensorflow/stream_executor/platform/port.h"
|
#include "tensorflow/stream_executor/platform/port.h"
|
||||||
#include "tensorflow/stream_executor/platform/thread_annotations.h"
|
#include "tensorflow/stream_executor/platform/thread_annotations.h"
|
||||||
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
||||||
@ -287,7 +287,7 @@ class GpuExecutor : public internal::StreamExecutorInterface {
|
|||||||
EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_);
|
EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_);
|
||||||
|
|
||||||
// Guards the on-disk-module mapping.
|
// Guards the on-disk-module mapping.
|
||||||
mutex disk_modules_mu_;
|
absl::Mutex disk_modules_mu_;
|
||||||
|
|
||||||
// Mapping from filename to GPUModuleHandle, if it was already retrieved.
|
// Mapping from filename to GPUModuleHandle, if it was already retrieved.
|
||||||
// Multiple GPUFunctionHandle are usually obtained from a single
|
// Multiple GPUFunctionHandle are usually obtained from a single
|
||||||
@ -296,7 +296,7 @@ class GpuExecutor : public internal::StreamExecutorInterface {
|
|||||||
std::map<string, GpuModuleHandle> disk_modules_ GUARDED_BY(disk_modules_mu_);
|
std::map<string, GpuModuleHandle> disk_modules_ GUARDED_BY(disk_modules_mu_);
|
||||||
|
|
||||||
// Guards the in-memory-module mapping.
|
// Guards the in-memory-module mapping.
|
||||||
mutex in_memory_modules_mu_;
|
absl::Mutex in_memory_modules_mu_;
|
||||||
|
|
||||||
std::map<const char*, GpuModuleHandle> in_memory_modules_
|
std::map<const char*, GpuModuleHandle> in_memory_modules_
|
||||||
GUARDED_BY(in_memory_modules_mu_);
|
GUARDED_BY(in_memory_modules_mu_);
|
||||||
@ -309,7 +309,7 @@ class GpuExecutor : public internal::StreamExecutorInterface {
|
|||||||
gpu_binary_to_module_ GUARDED_BY(in_memory_modules_mu_);
|
gpu_binary_to_module_ GUARDED_BY(in_memory_modules_mu_);
|
||||||
|
|
||||||
// Guards the launched kernel set.
|
// Guards the launched kernel set.
|
||||||
mutex launched_kernels_mu_;
|
absl::Mutex launched_kernels_mu_;
|
||||||
|
|
||||||
// Keeps track of the set of launched kernels. Currently used to suppress the
|
// Keeps track of the set of launched kernels. Currently used to suppress the
|
||||||
// occupancy check on subsequent launches.
|
// occupancy check on subsequent launches.
|
||||||
|
@ -16,14 +16,13 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_GPU_GPU_RNG_H_
|
#ifndef TENSORFLOW_STREAM_EXECUTOR_GPU_GPU_RNG_H_
|
||||||
#define TENSORFLOW_STREAM_EXECUTOR_GPU_GPU_RNG_H_
|
#define TENSORFLOW_STREAM_EXECUTOR_GPU_GPU_RNG_H_
|
||||||
|
|
||||||
#include "tensorflow/stream_executor/platform/mutex.h"
|
#include "absl/synchronization/mutex.h"
|
||||||
|
#include "tensorflow/stream_executor/gpu/gpu_types.h"
|
||||||
#include "tensorflow/stream_executor/platform/port.h"
|
#include "tensorflow/stream_executor/platform/port.h"
|
||||||
#include "tensorflow/stream_executor/platform/thread_annotations.h"
|
#include "tensorflow/stream_executor/platform/thread_annotations.h"
|
||||||
#include "tensorflow/stream_executor/plugin_registry.h"
|
#include "tensorflow/stream_executor/plugin_registry.h"
|
||||||
#include "tensorflow/stream_executor/rng.h"
|
#include "tensorflow/stream_executor/rng.h"
|
||||||
|
|
||||||
#include "tensorflow/stream_executor/gpu/gpu_types.h"
|
|
||||||
|
|
||||||
namespace stream_executor {
|
namespace stream_executor {
|
||||||
|
|
||||||
class Stream;
|
class Stream;
|
||||||
@ -83,8 +82,8 @@ class GpuRng : public rng::RngSupport {
|
|||||||
// with random number generation.
|
// with random number generation.
|
||||||
bool SetStream(Stream* stream) EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
bool SetStream(Stream* stream) EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
|
|
||||||
// mutex that guards the gpu rng library handle for this device.
|
// Guards the gpu rng library handle for this device.
|
||||||
mutex mu_;
|
absl::Mutex mu_;
|
||||||
|
|
||||||
// GpuExecutor which instantiated this GpuRng.
|
// GpuExecutor which instantiated this GpuRng.
|
||||||
// Immutable post-initialization.
|
// Immutable post-initialization.
|
||||||
|
@ -52,6 +52,7 @@ cc_library(
|
|||||||
"//tensorflow/stream_executor/lib",
|
"//tensorflow/stream_executor/lib",
|
||||||
"//tensorflow/stream_executor/platform",
|
"//tensorflow/stream_executor/platform",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
"@com_google_absl//absl/synchronization",
|
||||||
],
|
],
|
||||||
alwayslink = True, # Registers itself with the MultiPlatformManager.
|
alwayslink = True, # Registers itself with the MultiPlatformManager.
|
||||||
)
|
)
|
||||||
@ -67,6 +68,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/stream_executor:kernel",
|
"//tensorflow/stream_executor:kernel",
|
||||||
"//tensorflow/stream_executor/lib",
|
"//tensorflow/stream_executor/lib",
|
||||||
|
"@com_google_absl//absl/synchronization",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -27,7 +27,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||||
#include "tensorflow/stream_executor/multi_platform_manager.h"
|
#include "tensorflow/stream_executor/multi_platform_manager.h"
|
||||||
#include "tensorflow/stream_executor/platform.h"
|
#include "tensorflow/stream_executor/platform.h"
|
||||||
#include "tensorflow/stream_executor/platform/mutex.h"
|
|
||||||
#include "tensorflow/stream_executor/platform/port.h"
|
#include "tensorflow/stream_executor/platform/port.h"
|
||||||
#include "tensorflow/stream_executor/platform/thread_annotations.h"
|
#include "tensorflow/stream_executor/platform/thread_annotations.h"
|
||||||
#include "tensorflow/stream_executor/stream_executor_pimpl.h"
|
#include "tensorflow/stream_executor/stream_executor_pimpl.h"
|
||||||
|
@ -38,15 +38,15 @@ bool HostStream::EnqueueTask(std::function<void()> task) {
|
|||||||
// should guarantee that all tasks are destroyed.
|
// should guarantee that all tasks are destroyed.
|
||||||
task = std::function<void()>();
|
task = std::function<void()>();
|
||||||
{
|
{
|
||||||
mutex_lock lock(stream->mu_);
|
absl::MutexLock lock(&stream->mu_);
|
||||||
--stream->pending_tasks_;
|
--stream->pending_tasks_;
|
||||||
}
|
}
|
||||||
stream->completion_condition_.notify_all();
|
stream->completion_condition_.SignalAll();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
{
|
{
|
||||||
mutex_lock lock(mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
++pending_tasks_;
|
++pending_tasks_;
|
||||||
}
|
}
|
||||||
host_executor_->Schedule(NotifiedTask{this, std::move(task)});
|
host_executor_->Schedule(NotifiedTask{this, std::move(task)});
|
||||||
@ -54,9 +54,9 @@ bool HostStream::EnqueueTask(std::function<void()> task) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void HostStream::BlockUntilDone() {
|
void HostStream::BlockUntilDone() {
|
||||||
mutex_lock lock(mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
while (pending_tasks_ != 0) {
|
while (pending_tasks_ != 0) {
|
||||||
completion_condition_.wait(lock);
|
completion_condition_.Wait(&mu_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include <functional>
|
#include <functional>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
|
#include "absl/synchronization/mutex.h"
|
||||||
#include "tensorflow/stream_executor/lib/threadpool.h"
|
#include "tensorflow/stream_executor/lib/threadpool.h"
|
||||||
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
||||||
|
|
||||||
@ -45,9 +46,9 @@ class HostStream : public internal::StreamInterface {
|
|||||||
static const int kExecutorThreads = 1;
|
static const int kExecutorThreads = 1;
|
||||||
std::unique_ptr<port::ThreadPool> host_executor_;
|
std::unique_ptr<port::ThreadPool> host_executor_;
|
||||||
|
|
||||||
mutex mu_;
|
absl::Mutex mu_;
|
||||||
int pending_tasks_ GUARDED_BY(mu_) = 0;
|
int pending_tasks_ GUARDED_BY(mu_) = 0;
|
||||||
condition_variable completion_condition_;
|
absl::CondVar completion_condition_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace host
|
} // namespace host
|
||||||
|
@ -94,7 +94,7 @@ const char *CudaPtxInMemory::default_text() const {
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
mutex_lock lock(mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
|
|
||||||
auto ptx = ptx_by_compute_capability_.begin()->second;
|
auto ptx = ptx_by_compute_capability_.begin()->second;
|
||||||
// Check if there is an entry in decompressed ptx table.
|
// Check if there is an entry in decompressed ptx table.
|
||||||
@ -128,7 +128,7 @@ const char *CudaPtxInMemory::text(int compute_capability_major,
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
mutex_lock lock(mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
|
|
||||||
// Check if there is an entry in decompressed ptx table.
|
// Check if there is an entry in decompressed ptx table.
|
||||||
auto decompressed_ptx_iter = decompressed_ptx_.find(ptx_iter->second);
|
auto decompressed_ptx_iter = decompressed_ptx_.find(ptx_iter->second);
|
||||||
|
@ -47,13 +47,13 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_STREAM_EXECUTOR_KERNEL_SPEC_H_
|
#define TENSORFLOW_STREAM_EXECUTOR_KERNEL_SPEC_H_
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "tensorflow/stream_executor/platform/port.h"
|
|
||||||
|
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "absl/synchronization/mutex.h"
|
||||||
#include "tensorflow/stream_executor/platform/logging.h"
|
#include "tensorflow/stream_executor/platform/logging.h"
|
||||||
#include "tensorflow/stream_executor/platform/mutex.h"
|
|
||||||
#include "tensorflow/stream_executor/platform/port.h"
|
#include "tensorflow/stream_executor/platform/port.h"
|
||||||
|
|
||||||
namespace stream_executor {
|
namespace stream_executor {
|
||||||
@ -206,7 +206,7 @@ class CudaPtxInMemory : public KernelLoaderSpec {
|
|||||||
// Stores all decompressed ptx strings, with original ptx string as keys.
|
// Stores all decompressed ptx strings, with original ptx string as keys.
|
||||||
// It is marked as mutable for lazy decompression.
|
// It is marked as mutable for lazy decompression.
|
||||||
mutable std::map<const char *, string> decompressed_ptx_;
|
mutable std::map<const char *, string> decompressed_ptx_;
|
||||||
mutable mutex mu_;
|
mutable absl::Mutex mu_;
|
||||||
|
|
||||||
// Defines the minimum compute capability possible. Used when PTX has no
|
// Defines the minimum compute capability possible. Used when PTX has no
|
||||||
// compute capability specified (in the single-PTX constructor).
|
// compute capability specified (in the single-PTX constructor).
|
||||||
|
@ -16,7 +16,6 @@ cc_library(
|
|||||||
name = "platform",
|
name = "platform",
|
||||||
textual_hdrs = [
|
textual_hdrs = [
|
||||||
"logging.h",
|
"logging.h",
|
||||||
"mutex.h",
|
|
||||||
"platform.h",
|
"platform.h",
|
||||||
"port.h",
|
"port.h",
|
||||||
"thread_annotations.h",
|
"thread_annotations.h",
|
||||||
|
@ -6,7 +6,6 @@ cc_library(
|
|||||||
name = "platform",
|
name = "platform",
|
||||||
textual_hdrs = [
|
textual_hdrs = [
|
||||||
"initialize.h",
|
"initialize.h",
|
||||||
"mutex.h",
|
|
||||||
],
|
],
|
||||||
deps = ["//tensorflow/core:lib"],
|
deps = ["//tensorflow/core:lib"],
|
||||||
)
|
)
|
||||||
|
@ -20,13 +20,13 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_STREAM_EXECUTOR_DSO_LOADER_H_
|
#define TENSORFLOW_STREAM_EXECUTOR_DSO_LOADER_H_
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "tensorflow/stream_executor/platform/port.h"
|
|
||||||
|
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "absl/synchronization/mutex.h"
|
||||||
#include "tensorflow/stream_executor/lib/status.h"
|
#include "tensorflow/stream_executor/lib/status.h"
|
||||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||||
#include "tensorflow/stream_executor/platform.h"
|
#include "tensorflow/stream_executor/platform.h"
|
||||||
#include "tensorflow/stream_executor/platform/mutex.h"
|
#include "tensorflow/stream_executor/platform/port.h"
|
||||||
|
|
||||||
namespace stream_executor {
|
namespace stream_executor {
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
@ -1,40 +0,0 @@
|
|||||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_PLATFORM_DEFAULT_MUTEX_H_
|
|
||||||
#define TENSORFLOW_STREAM_EXECUTOR_PLATFORM_DEFAULT_MUTEX_H_
|
|
||||||
|
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
|
||||||
#include "tensorflow/core/platform/types.h"
|
|
||||||
|
|
||||||
namespace stream_executor {
|
|
||||||
|
|
||||||
#undef mutex_lock
|
|
||||||
#undef tf_shared_lock
|
|
||||||
|
|
||||||
using tensorflow::ConditionResult;
|
|
||||||
using tensorflow::WaitForMilliseconds;
|
|
||||||
using tensorflow::condition_variable;
|
|
||||||
using tensorflow::mutex;
|
|
||||||
using tensorflow::mutex_lock;
|
|
||||||
using tensorflow::tf_shared_lock;
|
|
||||||
|
|
||||||
#define mutex_lock(x) static_assert(0, "mutex_lock_decl_missing_var_name");
|
|
||||||
#define tf_shared_lock(x) \
|
|
||||||
static_assert(0, "tf_shared_lock_decl_missing_var_name");
|
|
||||||
|
|
||||||
} // namespace stream_executor
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_STREAM_EXECUTOR_PLATFORM_DEFAULT_MUTEX_H_
|
|
@ -1,27 +0,0 @@
|
|||||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_PLATFORM_MUTEX_H_
|
|
||||||
#define TENSORFLOW_STREAM_EXECUTOR_PLATFORM_MUTEX_H_
|
|
||||||
|
|
||||||
#include "tensorflow/stream_executor/platform/platform.h"
|
|
||||||
|
|
||||||
#if defined(PLATFORM_GOOGLE)
|
|
||||||
#include "tensorflow/stream_executor/platform/google/mutex.h"
|
|
||||||
#else
|
|
||||||
#include "tensorflow/stream_executor/platform/default/mutex.h"
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_STREAM_EXECUTOR_PLATFORM_MUTEX_H_
|
|
@ -19,7 +19,6 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_STREAM_EXECUTOR_PLATFORM_PORT_H_
|
#define TENSORFLOW_STREAM_EXECUTOR_PLATFORM_PORT_H_
|
||||||
|
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
namespace stream_executor {
|
namespace stream_executor {
|
||||||
@ -38,9 +37,6 @@ using tensorflow::uint64;
|
|||||||
using std::string;
|
using std::string;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
using tensorflow::LinkerInitialized;
|
|
||||||
using tensorflow::LINKER_INITIALIZED;
|
|
||||||
|
|
||||||
#define SE_FALLTHROUGH_INTENDED TF_FALLTHROUGH_INTENDED
|
#define SE_FALLTHROUGH_INTENDED TF_FALLTHROUGH_INTENDED
|
||||||
|
|
||||||
} // namespace stream_executor
|
} // namespace stream_executor
|
||||||
|
@ -15,8 +15,10 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/stream_executor/plugin_registry.h"
|
#include "tensorflow/stream_executor/plugin_registry.h"
|
||||||
|
|
||||||
|
#include "absl/base/const_init.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
|
#include "absl/synchronization/mutex.h"
|
||||||
#include "tensorflow/stream_executor/lib/error.h"
|
#include "tensorflow/stream_executor/lib/error.h"
|
||||||
#include "tensorflow/stream_executor/multi_platform_manager.h"
|
#include "tensorflow/stream_executor/multi_platform_manager.h"
|
||||||
|
|
||||||
@ -44,9 +46,9 @@ string PluginKindString(PluginKind plugin_kind) {
|
|||||||
PluginRegistry::DefaultFactories::DefaultFactories() :
|
PluginRegistry::DefaultFactories::DefaultFactories() :
|
||||||
blas(kNullPlugin), dnn(kNullPlugin), fft(kNullPlugin), rng(kNullPlugin) { }
|
blas(kNullPlugin), dnn(kNullPlugin), fft(kNullPlugin), rng(kNullPlugin) { }
|
||||||
|
|
||||||
static mutex& GetPluginRegistryMutex() {
|
static absl::Mutex& GetPluginRegistryMutex() {
|
||||||
static mutex* mu = new mutex;
|
static absl::Mutex mu(absl::kConstInit);
|
||||||
return *mu;
|
return mu;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */ PluginRegistry* PluginRegistry::instance_ = nullptr;
|
/* static */ PluginRegistry* PluginRegistry::instance_ = nullptr;
|
||||||
@ -54,7 +56,7 @@ static mutex& GetPluginRegistryMutex() {
|
|||||||
PluginRegistry::PluginRegistry() {}
|
PluginRegistry::PluginRegistry() {}
|
||||||
|
|
||||||
/* static */ PluginRegistry* PluginRegistry::Instance() {
|
/* static */ PluginRegistry* PluginRegistry::Instance() {
|
||||||
mutex_lock lock{GetPluginRegistryMutex()};
|
absl::MutexLock lock{&GetPluginRegistryMutex()};
|
||||||
if (instance_ == nullptr) {
|
if (instance_ == nullptr) {
|
||||||
instance_ = new PluginRegistry();
|
instance_ = new PluginRegistry();
|
||||||
}
|
}
|
||||||
@ -70,7 +72,7 @@ template <typename FACTORY_TYPE>
|
|||||||
port::Status PluginRegistry::RegisterFactoryInternal(
|
port::Status PluginRegistry::RegisterFactoryInternal(
|
||||||
PluginId plugin_id, const string& plugin_name, FACTORY_TYPE factory,
|
PluginId plugin_id, const string& plugin_name, FACTORY_TYPE factory,
|
||||||
std::map<PluginId, FACTORY_TYPE>* factories) {
|
std::map<PluginId, FACTORY_TYPE>* factories) {
|
||||||
mutex_lock lock{GetPluginRegistryMutex()};
|
absl::MutexLock lock{&GetPluginRegistryMutex()};
|
||||||
|
|
||||||
if (factories->find(plugin_id) != factories->end()) {
|
if (factories->find(plugin_id) != factories->end()) {
|
||||||
return port::Status(
|
return port::Status(
|
||||||
|
@ -25,7 +25,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/stream_executor/lib/status.h"
|
#include "tensorflow/stream_executor/lib/status.h"
|
||||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||||
#include "tensorflow/stream_executor/platform.h"
|
#include "tensorflow/stream_executor/platform.h"
|
||||||
#include "tensorflow/stream_executor/platform/mutex.h"
|
|
||||||
#include "tensorflow/stream_executor/plugin.h"
|
#include "tensorflow/stream_executor/plugin.h"
|
||||||
#include "tensorflow/stream_executor/rng.h"
|
#include "tensorflow/stream_executor/rng.h"
|
||||||
|
|
||||||
|
@ -382,7 +382,7 @@ template <typename FuncT, typename... Args>
|
|||||||
bool ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream,
|
bool ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream,
|
||||||
bool pointer_mode_host, bool err_on_failure,
|
bool pointer_mode_host, bool err_on_failure,
|
||||||
Args... args) {
|
Args... args) {
|
||||||
mutex_lock lock{mu_};
|
absl::MutexLock lock{&mu_};
|
||||||
|
|
||||||
CHECK(blas_ != nullptr);
|
CHECK(blas_ != nullptr);
|
||||||
if (!SetStream(stream)) {
|
if (!SetStream(stream)) {
|
||||||
|
@ -20,8 +20,8 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_
|
#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_
|
||||||
#define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_
|
#define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_
|
||||||
|
|
||||||
|
#include "absl/synchronization/mutex.h"
|
||||||
#include "tensorflow/stream_executor/blas.h"
|
#include "tensorflow/stream_executor/blas.h"
|
||||||
#include "tensorflow/stream_executor/platform/mutex.h"
|
|
||||||
#include "tensorflow/stream_executor/platform/port.h"
|
#include "tensorflow/stream_executor/platform/port.h"
|
||||||
#include "tensorflow/stream_executor/platform/thread_annotations.h"
|
#include "tensorflow/stream_executor/platform/thread_annotations.h"
|
||||||
#include "tensorflow/stream_executor/plugin_registry.h"
|
#include "tensorflow/stream_executor/plugin_registry.h"
|
||||||
@ -141,7 +141,7 @@ class ROCMBlas : public blas::BlasSupport {
|
|||||||
blas::ProfileResult *output_profile_result);
|
blas::ProfileResult *output_profile_result);
|
||||||
|
|
||||||
// mutex that guards the rocBLAS handle for this device.
|
// mutex that guards the rocBLAS handle for this device.
|
||||||
mutex mu_;
|
absl::Mutex mu_;
|
||||||
|
|
||||||
// GpuExecutor which instantiated this ROCMBlas.
|
// GpuExecutor which instantiated this ROCMBlas.
|
||||||
// Immutable post-initialization.
|
// Immutable post-initialization.
|
||||||
|
@ -98,8 +98,8 @@ class MIOpenHandle {
|
|||||||
public:
|
public:
|
||||||
// Takes ownership of the executor context and the lock to access MIOpen
|
// Takes ownership of the executor context and the lock to access MIOpen
|
||||||
// using handle.
|
// using handle.
|
||||||
MIOpenHandle(gpu::ScopedActivateExecutorContext context, mutex_lock lock,
|
MIOpenHandle(gpu::ScopedActivateExecutorContext context,
|
||||||
miopenHandle_t handle)
|
std::unique_ptr<absl::MutexLock> lock, miopenHandle_t handle)
|
||||||
: context_(std::move(context)), lock_(std::move(lock)), handle_(handle) {}
|
: context_(std::move(context)), lock_(std::move(lock)), handle_(handle) {}
|
||||||
|
|
||||||
// Returns MIOpen handle. To be passed directly to MIOpen APIs, don't keep
|
// Returns MIOpen handle. To be passed directly to MIOpen APIs, don't keep
|
||||||
@ -108,7 +108,7 @@ class MIOpenHandle {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
gpu::ScopedActivateExecutorContext context_;
|
gpu::ScopedActivateExecutorContext context_;
|
||||||
mutex_lock lock_;
|
std::unique_ptr<absl::MutexLock> lock_;
|
||||||
miopenHandle_t handle_; // Not owned.
|
miopenHandle_t handle_; // Not owned.
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -318,7 +318,7 @@ class CachedFusionPlans {
|
|||||||
miopenFusionPlanDescriptor_t* fusion_plan,
|
miopenFusionPlanDescriptor_t* fusion_plan,
|
||||||
miopenFusionDirection_t fusion_direction,
|
miopenFusionDirection_t fusion_direction,
|
||||||
miopenTensorDescriptor_t input_descriptor) {
|
miopenTensorDescriptor_t input_descriptor) {
|
||||||
mutex_lock lock{cached_plans_mutex};
|
absl::MutexLock lock{&cached_plans_mutex};
|
||||||
|
|
||||||
bool found_cached_plan = false;
|
bool found_cached_plan = false;
|
||||||
|
|
||||||
@ -342,7 +342,7 @@ class CachedFusionPlans {
|
|||||||
|
|
||||||
// Need to figure out the right place to call this routine.
|
// Need to figure out the right place to call this routine.
|
||||||
static void Clear() {
|
static void Clear() {
|
||||||
mutex_lock lock{cached_plans_mutex};
|
absl::MutexLock lock{&cached_plans_mutex};
|
||||||
|
|
||||||
for (auto it : cached_plans) {
|
for (auto it : cached_plans) {
|
||||||
auto status = wrap::miopenDestroyFusionPlan(it.second);
|
auto status = wrap::miopenDestroyFusionPlan(it.second);
|
||||||
@ -359,19 +359,19 @@ class CachedFusionPlans {
|
|||||||
|
|
||||||
// Is the Fusion plan corresponding to this hash unsupported.
|
// Is the Fusion plan corresponding to this hash unsupported.
|
||||||
static bool IsUnsupportedFusionPlan(uint64 hash) {
|
static bool IsUnsupportedFusionPlan(uint64 hash) {
|
||||||
mutex_lock lock{cached_plans_mutex};
|
absl::MutexLock lock{&cached_plans_mutex};
|
||||||
return unsupported_plans.count(hash) > 0;
|
return unsupported_plans.count(hash) > 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mark the given hash value as corresponding to an unsupported fusion plan.
|
// Mark the given hash value as corresponding to an unsupported fusion plan.
|
||||||
static void MarkFusionPlanUnsupported(uint64 hash) {
|
static void MarkFusionPlanUnsupported(uint64 hash) {
|
||||||
mutex_lock lock{cached_plans_mutex};
|
absl::MutexLock lock{&cached_plans_mutex};
|
||||||
unsupported_plans.insert(hash);
|
unsupported_plans.insert(hash);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Mutex to guard access to all data within this class.
|
// Mutex to guard access to all data within this class.
|
||||||
static mutex cached_plans_mutex;
|
static absl::Mutex cached_plans_mutex;
|
||||||
|
|
||||||
// Map of hash-value to MIOpen Fusion plan descriptors.
|
// Map of hash-value to MIOpen Fusion plan descriptors.
|
||||||
// Need to be able share this across more than one stream and hence static.
|
// Need to be able share this across more than one stream and hence static.
|
||||||
@ -382,7 +382,7 @@ class CachedFusionPlans {
|
|||||||
static std::set<uint64> unsupported_plans;
|
static std::set<uint64> unsupported_plans;
|
||||||
};
|
};
|
||||||
|
|
||||||
mutex CachedFusionPlans::cached_plans_mutex;
|
absl::Mutex CachedFusionPlans::cached_plans_mutex;
|
||||||
std::map<uint64, miopenFusionPlanDescriptor_t> CachedFusionPlans::cached_plans;
|
std::map<uint64, miopenFusionPlanDescriptor_t> CachedFusionPlans::cached_plans;
|
||||||
std::set<uint64> CachedFusionPlans::unsupported_plans;
|
std::set<uint64> CachedFusionPlans::unsupported_plans;
|
||||||
|
|
||||||
@ -449,7 +449,7 @@ class MIOpenAccess {
|
|||||||
explicit MIOpenAccess(miopenHandle_t handle) : handle_(handle) {}
|
explicit MIOpenAccess(miopenHandle_t handle) : handle_(handle) {}
|
||||||
|
|
||||||
~MIOpenAccess() {
|
~MIOpenAccess() {
|
||||||
mutex_lock lock(mutex_);
|
absl::MutexLock lock(&mutex_);
|
||||||
wrap::miopenDestroy(handle_);
|
wrap::miopenDestroy(handle_);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -468,7 +468,8 @@ class MIOpenAccess {
|
|||||||
// therefore a bad idea (performance wise) to call any MIOpen APIs that
|
// therefore a bad idea (performance wise) to call any MIOpen APIs that
|
||||||
// enqueue work in the stream.
|
// enqueue work in the stream.
|
||||||
MIOpenHandle GetHandle(GpuExecutor* executor, Stream* stream) {
|
MIOpenHandle GetHandle(GpuExecutor* executor, Stream* stream) {
|
||||||
mutex_lock lock(mutex_);
|
auto lock = absl::make_unique<absl::MutexLock>(&mutex_);
|
||||||
|
mutex_.AssertHeld();
|
||||||
gpu::ScopedActivateExecutorContext context(executor);
|
gpu::ScopedActivateExecutorContext context(executor);
|
||||||
hipStream_t hip_stream = stream ? AsGpuStreamValue(stream) : nullptr;
|
hipStream_t hip_stream = stream ? AsGpuStreamValue(stream) : nullptr;
|
||||||
auto status = wrap::miopenSetStream(handle_, hip_stream);
|
auto status = wrap::miopenSetStream(handle_, hip_stream);
|
||||||
@ -478,7 +479,7 @@ class MIOpenAccess {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
// Guards the enqueueing of MIOpen operations via the handle_ below.
|
// Guards the enqueueing of MIOpen operations via the handle_ below.
|
||||||
mutex mutex_;
|
absl::Mutex mutex_;
|
||||||
|
|
||||||
// MIOpen library handle.
|
// MIOpen library handle.
|
||||||
miopenHandle_t handle_ GUARDED_BY(mutex_); // Owned.
|
miopenHandle_t handle_ GUARDED_BY(mutex_); // Owned.
|
||||||
|
@ -19,9 +19,9 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DNN_H_
|
#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DNN_H_
|
||||||
#define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DNN_H_
|
#define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DNN_H_
|
||||||
|
|
||||||
|
#include "absl/synchronization/mutex.h"
|
||||||
#include "tensorflow/stream_executor/dnn.h"
|
#include "tensorflow/stream_executor/dnn.h"
|
||||||
#include "tensorflow/stream_executor/lib/status.h"
|
#include "tensorflow/stream_executor/lib/status.h"
|
||||||
#include "tensorflow/stream_executor/platform/mutex.h"
|
|
||||||
#include "tensorflow/stream_executor/platform/thread_annotations.h"
|
#include "tensorflow/stream_executor/platform/thread_annotations.h"
|
||||||
#include "tensorflow/stream_executor/plugin_registry.h"
|
#include "tensorflow/stream_executor/plugin_registry.h"
|
||||||
#include "tensorflow/stream_executor/temporary_device_memory.h"
|
#include "tensorflow/stream_executor/temporary_device_memory.h"
|
||||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
#include "absl/container/inlined_vector.h"
|
#include "absl/container/inlined_vector.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
|
#include "absl/synchronization/mutex.h"
|
||||||
#include "absl/synchronization/notification.h"
|
#include "absl/synchronization/notification.h"
|
||||||
#include "tensorflow/stream_executor/gpu/gpu_diagnostics.h"
|
#include "tensorflow/stream_executor/gpu/gpu_diagnostics.h"
|
||||||
#include "tensorflow/stream_executor/gpu/gpu_driver.h"
|
#include "tensorflow/stream_executor/gpu/gpu_driver.h"
|
||||||
@ -34,7 +35,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/stream_executor/lib/static_threadlocal.h"
|
#include "tensorflow/stream_executor/lib/static_threadlocal.h"
|
||||||
#include "tensorflow/stream_executor/lib/threadpool.h"
|
#include "tensorflow/stream_executor/lib/threadpool.h"
|
||||||
#include "tensorflow/stream_executor/platform/logging.h"
|
#include "tensorflow/stream_executor/platform/logging.h"
|
||||||
#include "tensorflow/stream_executor/platform/mutex.h"
|
|
||||||
#include "tensorflow/stream_executor/platform/port.h"
|
#include "tensorflow/stream_executor/platform/port.h"
|
||||||
#include "tensorflow/stream_executor/rocm/rocm_driver_wrapper.h"
|
#include "tensorflow/stream_executor/rocm/rocm_driver_wrapper.h"
|
||||||
|
|
||||||
@ -305,17 +305,10 @@ static port::Status InternalInit() {
|
|||||||
/* static */ port::Status GpuDriver::Init() {
|
/* static */ port::Status GpuDriver::Init() {
|
||||||
// Cached return value from calling InternalInit(), as hipInit need only be
|
// Cached return value from calling InternalInit(), as hipInit need only be
|
||||||
// called once, but GpuDriver::Init may be called many times.
|
// called once, but GpuDriver::Init may be called many times.
|
||||||
static port::Status init_retval;
|
static port::Status* init_retval = [&] {
|
||||||
static bool set = false;
|
init_retval = new Status(InternalInit());
|
||||||
static mutex* init_mu = new mutex;
|
}();
|
||||||
|
return *init_retval;
|
||||||
mutex_lock lock(*init_mu);
|
|
||||||
if (!set) {
|
|
||||||
init_retval = InternalInit();
|
|
||||||
set = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
return init_retval;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */ port::Status GpuDriver::GetDevice(int device_ordinal,
|
/* static */ port::Status GpuDriver::GetDevice(int device_ordinal,
|
||||||
|
@ -114,7 +114,7 @@ GpuExecutor::~GpuExecutor() {
|
|||||||
}
|
}
|
||||||
bool GpuExecutor::UnloadModule(ModuleHandle module_handle) {
|
bool GpuExecutor::UnloadModule(ModuleHandle module_handle) {
|
||||||
const char* gpu_binary = reinterpret_cast<const char*>(module_handle.id());
|
const char* gpu_binary = reinterpret_cast<const char*>(module_handle.id());
|
||||||
mutex_lock lock{in_memory_modules_mu_};
|
absl::MutexLock lock{&in_memory_modules_mu_};
|
||||||
return UnloadGpuBinary(gpu_binary);
|
return UnloadGpuBinary(gpu_binary);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -236,7 +236,7 @@ bool GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec,
|
|||||||
kernelname = &spec.cuda_cubin_in_memory().kernelname();
|
kernelname = &spec.cuda_cubin_in_memory().kernelname();
|
||||||
|
|
||||||
const char* hsaco = spec.cuda_cubin_in_memory().bytes();
|
const char* hsaco = spec.cuda_cubin_in_memory().bytes();
|
||||||
mutex_lock lock{in_memory_modules_mu_};
|
absl::MutexLock lock{&in_memory_modules_mu_};
|
||||||
module = in_memory_modules_[hsaco];
|
module = in_memory_modules_[hsaco];
|
||||||
|
|
||||||
if (module == nullptr) {
|
if (module == nullptr) {
|
||||||
@ -294,7 +294,7 @@ bool GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims,
|
|||||||
// whether we've done an occupancy check on this kernel before isn't free
|
// whether we've done an occupancy check on this kernel before isn't free
|
||||||
// (because we have to synchronize), so we only do this at -v 2+.
|
// (because we have to synchronize), so we only do this at -v 2+.
|
||||||
if (VLOG_IS_ON(2)) {
|
if (VLOG_IS_ON(2)) {
|
||||||
mutex_lock lock(launched_kernels_mu_);
|
absl::MutexLock lock(&launched_kernels_mu_);
|
||||||
if (!launched_kernels_.count(hipfunc)) {
|
if (!launched_kernels_.count(hipfunc)) {
|
||||||
VlogOccupancyInfo(kernel, thread_dims, block_dims);
|
VlogOccupancyInfo(kernel, thread_dims, block_dims);
|
||||||
// TODO(rspringer): Remove elements from launched_kernels_...if we ever
|
// TODO(rspringer): Remove elements from launched_kernels_...if we ever
|
||||||
@ -366,7 +366,7 @@ bool GpuExecutor::LoadModule(const MultiModuleLoaderSpec& spec,
|
|||||||
hipModule_t hip_module = nullptr;
|
hipModule_t hip_module = nullptr;
|
||||||
// TODO(ROCm): Need generic term instead of cubin/cuda/ptx
|
// TODO(ROCm): Need generic term instead of cubin/cuda/ptx
|
||||||
if (spec.has_cuda_cubin_in_memory()) {
|
if (spec.has_cuda_cubin_in_memory()) {
|
||||||
mutex_lock lock{in_memory_modules_mu_};
|
absl::MutexLock lock{&in_memory_modules_mu_};
|
||||||
if (!LoadModuleFromHsaco(
|
if (!LoadModuleFromHsaco(
|
||||||
reinterpret_cast<const char*>(spec.cuda_cubin_in_memory().data()),
|
reinterpret_cast<const char*>(spec.cuda_cubin_in_memory().data()),
|
||||||
&hip_module)) {
|
&hip_module)) {
|
||||||
@ -765,8 +765,8 @@ bool GpuExecutor::DeviceMemoryUsage(int64* free, int64* total) const {
|
|||||||
bool GpuExecutor::GetSymbol(const string& symbol_name,
|
bool GpuExecutor::GetSymbol(const string& symbol_name,
|
||||||
ModuleHandle module_handle, void** mem,
|
ModuleHandle module_handle, void** mem,
|
||||||
size_t* bytes) {
|
size_t* bytes) {
|
||||||
{ // give limited scope to mutex_lock
|
{ // give limited scope to lock
|
||||||
mutex_lock lock{disk_modules_mu_};
|
absl::MutexLock lock{&disk_modules_mu_};
|
||||||
for (auto& it : disk_modules_) {
|
for (auto& it : disk_modules_) {
|
||||||
if (GpuDriver::GetModuleSymbol(context_, it.second, symbol_name.c_str(),
|
if (GpuDriver::GetModuleSymbol(context_, it.second, symbol_name.c_str(),
|
||||||
reinterpret_cast<hipDeviceptr_t*>(mem),
|
reinterpret_cast<hipDeviceptr_t*>(mem),
|
||||||
@ -776,8 +776,8 @@ bool GpuExecutor::GetSymbol(const string& symbol_name,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
{ // give limited scope to mutex_lock
|
{ // give limited scope to lock
|
||||||
mutex_lock lock{in_memory_modules_mu_};
|
absl::MutexLock lock{&in_memory_modules_mu_};
|
||||||
for (auto& it : in_memory_modules_) {
|
for (auto& it : in_memory_modules_) {
|
||||||
if (GpuDriver::GetModuleSymbol(context_, it.second, symbol_name.c_str(),
|
if (GpuDriver::GetModuleSymbol(context_, it.second, symbol_name.c_str(),
|
||||||
reinterpret_cast<hipDeviceptr_t*>(mem),
|
reinterpret_cast<hipDeviceptr_t*>(mem),
|
||||||
@ -787,8 +787,8 @@ bool GpuExecutor::GetSymbol(const string& symbol_name,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
{ // give limited scope to mutex_lock
|
{ // give limited scope to lock
|
||||||
mutex_lock lock{in_memory_modules_mu_};
|
absl::MutexLock lock{&in_memory_modules_mu_};
|
||||||
if (static_cast<bool>(module_handle)) {
|
if (static_cast<bool>(module_handle)) {
|
||||||
auto it = gpu_binary_to_module_.find(module_handle.id());
|
auto it = gpu_binary_to_module_.find(module_handle.id());
|
||||||
CHECK(it != gpu_binary_to_module_.end());
|
CHECK(it != gpu_binary_to_module_.end());
|
||||||
|
@ -38,13 +38,8 @@ ROCmPlatform::~ROCmPlatform() {}
|
|||||||
void ROCmPlatform::InspectNumaNodes() {
|
void ROCmPlatform::InspectNumaNodes() {
|
||||||
// To get NUMA node information, we need to create all executors, so we can
|
// To get NUMA node information, we need to create all executors, so we can
|
||||||
// examine their device descriptions to see their bus assignments.
|
// examine their device descriptions to see their bus assignments.
|
||||||
static bool initialized = false;
|
std::once_flag once;
|
||||||
static mutex numa_mutex(LINKER_INITIALIZED);
|
std::call_once(once, [&] {
|
||||||
mutex_lock lock(numa_mutex);
|
|
||||||
if (initialized) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
StreamExecutorConfig config;
|
StreamExecutorConfig config;
|
||||||
for (int i = 0; i < VisibleDeviceCount(); i++) {
|
for (int i = 0; i < VisibleDeviceCount(); i++) {
|
||||||
config.ordinal = i;
|
config.ordinal = i;
|
||||||
@ -57,11 +52,11 @@ void ROCmPlatform::InspectNumaNodes() {
|
|||||||
} else {
|
} else {
|
||||||
min_numa_node_ =
|
min_numa_node_ =
|
||||||
std::min(min_numa_node_, exec->GetDeviceDescription().numa_node());
|
std::min(min_numa_node_, exec->GetDeviceDescription().numa_node());
|
||||||
limit_numa_node_ = std::max(limit_numa_node_,
|
limit_numa_node_ = std::max(
|
||||||
exec->GetDeviceDescription().numa_node() + 1);
|
limit_numa_node_, exec->GetDeviceDescription().numa_node() + 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
initialized = true;
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
int ROCmPlatform::BusCount() {
|
int ROCmPlatform::BusCount() {
|
||||||
|
@ -19,11 +19,11 @@ limitations under the License.
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/synchronization/mutex.h"
|
||||||
#include "tensorflow/stream_executor/executor_cache.h"
|
#include "tensorflow/stream_executor/executor_cache.h"
|
||||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||||
#include "tensorflow/stream_executor/multi_platform_manager.h"
|
#include "tensorflow/stream_executor/multi_platform_manager.h"
|
||||||
#include "tensorflow/stream_executor/platform.h"
|
#include "tensorflow/stream_executor/platform.h"
|
||||||
#include "tensorflow/stream_executor/platform/mutex.h"
|
|
||||||
#include "tensorflow/stream_executor/platform/port.h"
|
#include "tensorflow/stream_executor/platform/port.h"
|
||||||
#include "tensorflow/stream_executor/platform/thread_annotations.h"
|
#include "tensorflow/stream_executor/platform/thread_annotations.h"
|
||||||
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
||||||
@ -90,7 +90,7 @@ class ROCmPlatform : public Platform {
|
|||||||
string name_;
|
string name_;
|
||||||
|
|
||||||
// mutex that guards internal state.
|
// mutex that guards internal state.
|
||||||
mutable mutex mu_;
|
mutable absl::Mutex mu_;
|
||||||
|
|
||||||
// Cache of created executors.
|
// Cache of created executors.
|
||||||
ExecutorCache executor_cache_;
|
ExecutorCache executor_cache_;
|
||||||
|
@ -126,7 +126,7 @@ GpuRng::~GpuRng() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool GpuRng::Init() {
|
bool GpuRng::Init() {
|
||||||
mutex_lock lock{mu_};
|
absl::MutexLock lock{&mu_};
|
||||||
CHECK(rng_ == nullptr);
|
CHECK(rng_ == nullptr);
|
||||||
|
|
||||||
hiprandStatus_t ret =
|
hiprandStatus_t ret =
|
||||||
@ -161,7 +161,7 @@ constexpr bool ComplexIsConsecutiveFloats() {
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
bool GpuRng::DoPopulateRandUniformInternal(Stream* stream, DeviceMemory<T>* v) {
|
bool GpuRng::DoPopulateRandUniformInternal(Stream* stream, DeviceMemory<T>* v) {
|
||||||
mutex_lock lock{mu_};
|
absl::MutexLock lock{&mu_};
|
||||||
static_assert(ComplexIsConsecutiveFloats(),
|
static_assert(ComplexIsConsecutiveFloats(),
|
||||||
"std::complex values are not stored as consecutive values");
|
"std::complex values are not stored as consecutive values");
|
||||||
|
|
||||||
@ -220,7 +220,7 @@ bool GpuRng::DoPopulateRandGaussianInternal(Stream* stream, ElemT mean,
|
|||||||
ElemT stddev,
|
ElemT stddev,
|
||||||
DeviceMemory<ElemT>* v,
|
DeviceMemory<ElemT>* v,
|
||||||
FuncT func) {
|
FuncT func) {
|
||||||
mutex_lock lock{mu_};
|
absl::MutexLock lock{&mu_};
|
||||||
|
|
||||||
if (!SetStream(stream)) {
|
if (!SetStream(stream)) {
|
||||||
return false;
|
return false;
|
||||||
@ -252,7 +252,7 @@ bool GpuRng::DoPopulateRandGaussian(Stream* stream, double mean, double stddev,
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool GpuRng::SetSeed(Stream* stream, const uint8* seed, uint64 seed_bytes) {
|
bool GpuRng::SetSeed(Stream* stream, const uint8* seed, uint64 seed_bytes) {
|
||||||
mutex_lock lock{mu_};
|
absl::MutexLock lock{&mu_};
|
||||||
CHECK(rng_ != nullptr);
|
CHECK(rng_ != nullptr);
|
||||||
|
|
||||||
if (!CheckSeed(seed, seed_bytes)) {
|
if (!CheckSeed(seed, seed_bytes)) {
|
||||||
|
@ -290,7 +290,7 @@ port::Status Stream::RefreshStatus() {
|
|||||||
Stream &Stream::Init() {
|
Stream &Stream::Init() {
|
||||||
VLOG_CALL();
|
VLOG_CALL();
|
||||||
|
|
||||||
mutex_lock lock(mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
CHECK_EQ(false, allocated_)
|
CHECK_EQ(false, allocated_)
|
||||||
<< "stream appears to already have been initialized";
|
<< "stream appears to already have been initialized";
|
||||||
CHECK(!ok_) << "stream should be in !ok() state pre-initialization";
|
CHECK(!ok_) << "stream should be in !ok() state pre-initialization";
|
||||||
@ -1736,7 +1736,7 @@ Stream &Stream::ThenMemcpyH2DQuantized(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Stream *Stream::GetOrCreateSubStream() {
|
Stream *Stream::GetOrCreateSubStream() {
|
||||||
mutex_lock lock(mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
|
|
||||||
// Look for the first reusable sub_stream that is ok, dropping !ok sub_streams
|
// Look for the first reusable sub_stream that is ok, dropping !ok sub_streams
|
||||||
// we encounter along the way.
|
// we encounter along the way.
|
||||||
@ -1783,7 +1783,7 @@ Stream *Stream::GetOrCreateSubStream() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Stream::ReturnSubStream(Stream *sub_stream) {
|
void Stream::ReturnSubStream(Stream *sub_stream) {
|
||||||
mutex_lock lock(mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
|
|
||||||
// Look for the sub-stream.
|
// Look for the sub-stream.
|
||||||
for (int64 index = 0; index < sub_streams_.size(); ++index) {
|
for (int64 index = 0; index < sub_streams_.size(); ++index) {
|
||||||
@ -5285,7 +5285,7 @@ void Stream::CheckStatus(port::Status status) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
LOG(ERROR) << status;
|
LOG(ERROR) << status;
|
||||||
mutex_lock lock(mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
ok_ = false;
|
ok_ = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||||||
#include <functional>
|
#include <functional>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
|
#include "absl/synchronization/mutex.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
#include "tensorflow/stream_executor/blas.h"
|
#include "tensorflow/stream_executor/blas.h"
|
||||||
#include "tensorflow/stream_executor/device_memory.h"
|
#include "tensorflow/stream_executor/device_memory.h"
|
||||||
@ -35,7 +36,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/stream_executor/kernel.h"
|
#include "tensorflow/stream_executor/kernel.h"
|
||||||
#include "tensorflow/stream_executor/launch_dim.h"
|
#include "tensorflow/stream_executor/launch_dim.h"
|
||||||
#include "tensorflow/stream_executor/lib/array_slice.h"
|
#include "tensorflow/stream_executor/lib/array_slice.h"
|
||||||
#include "tensorflow/stream_executor/platform/mutex.h"
|
|
||||||
#include "tensorflow/stream_executor/platform/port.h"
|
#include "tensorflow/stream_executor/platform/port.h"
|
||||||
#include "tensorflow/stream_executor/platform/thread_annotations.h"
|
#include "tensorflow/stream_executor/platform/thread_annotations.h"
|
||||||
#include "tensorflow/stream_executor/temporary_memory_manager.h"
|
#include "tensorflow/stream_executor/temporary_memory_manager.h"
|
||||||
@ -1964,7 +1964,7 @@ class Stream {
|
|||||||
friend class ocl::CLBlas; // for parent_.
|
friend class ocl::CLBlas; // for parent_.
|
||||||
|
|
||||||
bool InErrorState() const LOCKS_EXCLUDED(mu_) {
|
bool InErrorState() const LOCKS_EXCLUDED(mu_) {
|
||||||
tf_shared_lock lock(mu_);
|
absl::ReaderMutexLock lock(&mu_);
|
||||||
return !ok_;
|
return !ok_;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1974,7 +1974,7 @@ class Stream {
|
|||||||
if (operation_retcode) {
|
if (operation_retcode) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
mutex_lock lock(mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
ok_ = false;
|
ok_ = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1998,7 +1998,7 @@ class Stream {
|
|||||||
|
|
||||||
// mutex that guards the allocation / error state flags.
|
// mutex that guards the allocation / error state flags.
|
||||||
// Mutable so that it can be obtained via const reader lock.
|
// Mutable so that it can be obtained via const reader lock.
|
||||||
mutable mutex mu_;
|
mutable absl::Mutex mu_;
|
||||||
|
|
||||||
// Whether Init() was successfully called to allocate this stream on the
|
// Whether Init() was successfully called to allocate this stream on the
|
||||||
// underlying platform. It simply flips from 0 to 1 with a sanity check.
|
// underlying platform. It simply flips from 0 to 1 with a sanity check.
|
||||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
#include "absl/base/const_init.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
#include "absl/synchronization/notification.h"
|
#include "absl/synchronization/notification.h"
|
||||||
@ -92,7 +93,7 @@ class ScopedTracer {
|
|||||||
void Trace(CallbackT callback, TraceArgsT... args) {
|
void Trace(CallbackT callback, TraceArgsT... args) {
|
||||||
{
|
{
|
||||||
// Instance tracers held in a block to limit the lock lifetime.
|
// Instance tracers held in a block to limit the lock lifetime.
|
||||||
tf_shared_lock lock{stream_exec_->mu_};
|
absl::ReaderMutexLock lock{&stream_exec_->mu_};
|
||||||
for (TraceListener *listener : stream_exec_->listeners_) {
|
for (TraceListener *listener : stream_exec_->listeners_) {
|
||||||
(listener->*callback)(correlation_id_,
|
(listener->*callback)(correlation_id_,
|
||||||
std::forward<TraceArgsT>(args)...);
|
std::forward<TraceArgsT>(args)...);
|
||||||
@ -121,7 +122,7 @@ MakeScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call,
|
|||||||
auto tracer = MakeScopedTracer(this, &LOC ## Begin, \
|
auto tracer = MakeScopedTracer(this, &LOC ## Begin, \
|
||||||
&LOC ## Complete, ## __VA_ARGS__);
|
&LOC ## Complete, ## __VA_ARGS__);
|
||||||
|
|
||||||
/* static */ mutex StreamExecutor::static_mu_{LINKER_INITIALIZED};
|
/* static */ absl::Mutex StreamExecutor::static_mu_{absl::kConstInit};
|
||||||
|
|
||||||
// Get per-device memory limit in bytes. Returns 0 if
|
// Get per-device memory limit in bytes. Returns 0 if
|
||||||
// TF_PER_DEVICE_MEMORY_LIMIT_MB environment variable is not set.
|
// TF_PER_DEVICE_MEMORY_LIMIT_MB environment variable is not set.
|
||||||
@ -217,7 +218,7 @@ void StreamExecutor::Deallocate(DeviceMemoryBase *mem) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void StreamExecutor::GetMemAllocs(std::map<void *, AllocRecord> *records_out) {
|
void StreamExecutor::GetMemAllocs(std::map<void *, AllocRecord> *records_out) {
|
||||||
tf_shared_lock lock(mu_);
|
absl::ReaderMutexLock lock(&mu_);
|
||||||
*records_out = mem_allocs_;
|
*records_out = mem_allocs_;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -247,7 +248,7 @@ port::Status StreamExecutor::SetDeviceSharedMemoryConfig(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const DeviceDescription &StreamExecutor::GetDeviceDescription() const {
|
const DeviceDescription &StreamExecutor::GetDeviceDescription() const {
|
||||||
mutex_lock lock(mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
if (device_description_ != nullptr) {
|
if (device_description_ != nullptr) {
|
||||||
return *device_description_;
|
return *device_description_;
|
||||||
}
|
}
|
||||||
@ -393,7 +394,7 @@ StreamExecutor::createRnnStateTensorDescriptor(int num_layer, int batch_size,
|
|||||||
}
|
}
|
||||||
|
|
||||||
dnn::DnnSupport *StreamExecutor::AsDnn() {
|
dnn::DnnSupport *StreamExecutor::AsDnn() {
|
||||||
mutex_lock lock(mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
if (dnn_ != nullptr) {
|
if (dnn_ != nullptr) {
|
||||||
return dnn_.get();
|
return dnn_.get();
|
||||||
}
|
}
|
||||||
@ -403,7 +404,7 @@ dnn::DnnSupport *StreamExecutor::AsDnn() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
blas::BlasSupport *StreamExecutor::AsBlas() {
|
blas::BlasSupport *StreamExecutor::AsBlas() {
|
||||||
mutex_lock lock(mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
if (blas_ != nullptr) {
|
if (blas_ != nullptr) {
|
||||||
return blas_.get();
|
return blas_.get();
|
||||||
}
|
}
|
||||||
@ -413,7 +414,7 @@ blas::BlasSupport *StreamExecutor::AsBlas() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fft::FftSupport *StreamExecutor::AsFft() {
|
fft::FftSupport *StreamExecutor::AsFft() {
|
||||||
mutex_lock lock(mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
if (fft_ != nullptr) {
|
if (fft_ != nullptr) {
|
||||||
return fft_.get();
|
return fft_.get();
|
||||||
}
|
}
|
||||||
@ -423,7 +424,7 @@ fft::FftSupport *StreamExecutor::AsFft() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
rng::RngSupport *StreamExecutor::AsRng() {
|
rng::RngSupport *StreamExecutor::AsRng() {
|
||||||
mutex_lock lock(mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
if (rng_ != nullptr) {
|
if (rng_ != nullptr) {
|
||||||
return rng_.get();
|
return rng_.get();
|
||||||
}
|
}
|
||||||
@ -781,7 +782,7 @@ void StreamExecutor::EnqueueOnBackgroundThread(std::function<void()> task) {
|
|||||||
|
|
||||||
void StreamExecutor::CreateAllocRecord(void *opaque, uint64 bytes) {
|
void StreamExecutor::CreateAllocRecord(void *opaque, uint64 bytes) {
|
||||||
if (FLAGS_check_device_leaks && opaque != nullptr && bytes != 0) {
|
if (FLAGS_check_device_leaks && opaque != nullptr && bytes != 0) {
|
||||||
mutex_lock lock(mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
mem_allocs_[opaque] = AllocRecord{
|
mem_allocs_[opaque] = AllocRecord{
|
||||||
bytes, ""};
|
bytes, ""};
|
||||||
mem_alloc_bytes_ += bytes;
|
mem_alloc_bytes_ += bytes;
|
||||||
@ -790,7 +791,7 @@ void StreamExecutor::CreateAllocRecord(void *opaque, uint64 bytes) {
|
|||||||
|
|
||||||
void StreamExecutor::EraseAllocRecord(void *opaque) {
|
void StreamExecutor::EraseAllocRecord(void *opaque) {
|
||||||
if (FLAGS_check_device_leaks && opaque != nullptr) {
|
if (FLAGS_check_device_leaks && opaque != nullptr) {
|
||||||
mutex_lock lock(mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
if (mem_allocs_.find(opaque) == mem_allocs_.end()) {
|
if (mem_allocs_.find(opaque) == mem_allocs_.end()) {
|
||||||
LOG(ERROR) << "Deallocating unknown pointer: " << opaque;
|
LOG(ERROR) << "Deallocating unknown pointer: " << opaque;
|
||||||
} else {
|
} else {
|
||||||
@ -804,7 +805,7 @@ void StreamExecutor::EnableTracing(bool enabled) { tracing_enabled_ = enabled; }
|
|||||||
|
|
||||||
void StreamExecutor::RegisterTraceListener(TraceListener *listener) {
|
void StreamExecutor::RegisterTraceListener(TraceListener *listener) {
|
||||||
{
|
{
|
||||||
mutex_lock lock(mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
if (listeners_.find(listener) != listeners_.end()) {
|
if (listeners_.find(listener) != listeners_.end()) {
|
||||||
LOG(INFO) << "Attempt to register already-registered listener, "
|
LOG(INFO) << "Attempt to register already-registered listener, "
|
||||||
<< listener;
|
<< listener;
|
||||||
@ -818,7 +819,7 @@ void StreamExecutor::RegisterTraceListener(TraceListener *listener) {
|
|||||||
|
|
||||||
bool StreamExecutor::UnregisterTraceListener(TraceListener *listener) {
|
bool StreamExecutor::UnregisterTraceListener(TraceListener *listener) {
|
||||||
{
|
{
|
||||||
mutex_lock lock(mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
if (listeners_.find(listener) == listeners_.end()) {
|
if (listeners_.find(listener) == listeners_.end()) {
|
||||||
LOG(INFO) << "Attempt to unregister unknown listener, " << listener;
|
LOG(INFO) << "Attempt to unregister unknown listener, " << listener;
|
||||||
return false;
|
return false;
|
||||||
@ -839,7 +840,7 @@ void StreamExecutor::SubmitTrace(TraceCallT trace_call, ArgsT &&... args) {
|
|||||||
if (tracing_enabled_) {
|
if (tracing_enabled_) {
|
||||||
{
|
{
|
||||||
// instance tracers held in a block to limit the lock lifetime.
|
// instance tracers held in a block to limit the lock lifetime.
|
||||||
tf_shared_lock lock(mu_);
|
absl::ReaderMutexLock lock(&mu_);
|
||||||
for (TraceListener *listener : listeners_) {
|
for (TraceListener *listener : listeners_) {
|
||||||
(listener->*trace_call)(std::forward<ArgsT>(args)...);
|
(listener->*trace_call)(std::forward<ArgsT>(args)...);
|
||||||
}
|
}
|
||||||
|
@ -23,13 +23,13 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/base/macros.h"
|
#include "absl/base/macros.h"
|
||||||
|
#include "absl/synchronization/mutex.h"
|
||||||
#include "absl/types/optional.h"
|
#include "absl/types/optional.h"
|
||||||
#include "tensorflow/stream_executor/lib/status.h"
|
#include "tensorflow/stream_executor/lib/status.h"
|
||||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||||
#include "tensorflow/stream_executor/lib/threadpool.h"
|
#include "tensorflow/stream_executor/lib/threadpool.h"
|
||||||
#include "tensorflow/stream_executor/platform.h"
|
#include "tensorflow/stream_executor/platform.h"
|
||||||
#include "tensorflow/stream_executor/platform/logging.h"
|
#include "tensorflow/stream_executor/platform/logging.h"
|
||||||
#include "tensorflow/stream_executor/platform/mutex.h"
|
|
||||||
#include "tensorflow/stream_executor/platform/port.h"
|
#include "tensorflow/stream_executor/platform/port.h"
|
||||||
#include "tensorflow/stream_executor/platform/thread_annotations.h"
|
#include "tensorflow/stream_executor/platform/thread_annotations.h"
|
||||||
#include "tensorflow/stream_executor/rng.h"
|
#include "tensorflow/stream_executor/rng.h"
|
||||||
@ -619,13 +619,13 @@ class StreamExecutor {
|
|||||||
void SubmitTrace(TraceCallT trace_call, ArgsT&&... args);
|
void SubmitTrace(TraceCallT trace_call, ArgsT&&... args);
|
||||||
|
|
||||||
// Reader/writer lock for class-static StreamExecutor members.
|
// Reader/writer lock for class-static StreamExecutor members.
|
||||||
static mutex static_mu_;
|
static absl::Mutex static_mu_;
|
||||||
|
|
||||||
// Reader/writer lock for mutable data structures on this StreamExecutor.
|
// Reader/writer lock for mutable data structures on this StreamExecutor.
|
||||||
//
|
//
|
||||||
// Mutable so that caching functions (like DeviceDescription, AsBlas, etc.)
|
// Mutable so that caching functions (like DeviceDescription, AsBlas, etc.)
|
||||||
// can acquire the lock on their first (mutating) call as well.
|
// can acquire the lock on their first (mutating) call as well.
|
||||||
mutable mutex mu_;
|
mutable absl::Mutex mu_;
|
||||||
|
|
||||||
// Reference to the platform that created this executor.
|
// Reference to the platform that created this executor.
|
||||||
const Platform *platform_;
|
const Platform *platform_;
|
||||||
|
@ -25,7 +25,7 @@ namespace stream_executor {
|
|||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
||||||
void TemporaryMemoryManager::ForceDeallocateAll() {
|
void TemporaryMemoryManager::ForceDeallocateAll() {
|
||||||
mutex_lock lock(mutex_);
|
absl::MutexLock lock(&mutex_);
|
||||||
VLOG(1) << "force-deallocating " << records_.size() << " remaining records";
|
VLOG(1) << "force-deallocating " << records_.size() << " remaining records";
|
||||||
for (auto it = records_.begin(); it != records_.end(); ++it) {
|
for (auto it = records_.begin(); it != records_.end(); ++it) {
|
||||||
DeviceMemoryBase device_memory = it->first;
|
DeviceMemoryBase device_memory = it->first;
|
||||||
@ -35,7 +35,7 @@ void TemporaryMemoryManager::ForceDeallocateAll() {
|
|||||||
|
|
||||||
void TemporaryMemoryManager::MarkFinalized(
|
void TemporaryMemoryManager::MarkFinalized(
|
||||||
const DeviceMemoryBase& device_memory, uint64 generation, bool must_exist) {
|
const DeviceMemoryBase& device_memory, uint64 generation, bool must_exist) {
|
||||||
mutex_lock lock(mutex_);
|
absl::MutexLock lock(&mutex_);
|
||||||
auto it = records_.find(device_memory);
|
auto it = records_.find(device_memory);
|
||||||
if (it == records_.end()) {
|
if (it == records_.end()) {
|
||||||
if (must_exist) {
|
if (must_exist) {
|
||||||
@ -48,7 +48,7 @@ void TemporaryMemoryManager::MarkFinalized(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void TemporaryMemoryManager::DeallocateFinalizedTemporaries() {
|
void TemporaryMemoryManager::DeallocateFinalizedTemporaries() {
|
||||||
mutex_lock lock(mutex_);
|
absl::MutexLock lock(&mutex_);
|
||||||
int deallocated_count = 0;
|
int deallocated_count = 0;
|
||||||
for (auto it = records_.begin(); it != records_.end();) {
|
for (auto it = records_.begin(); it != records_.end();) {
|
||||||
if (it->second.finalized) {
|
if (it->second.finalized) {
|
||||||
@ -65,7 +65,7 @@ void TemporaryMemoryManager::DeallocateFinalizedTemporaries() {
|
|||||||
|
|
||||||
bool TemporaryMemoryManager::IsFinalized(const DeviceMemoryBase& device_memory,
|
bool TemporaryMemoryManager::IsFinalized(const DeviceMemoryBase& device_memory,
|
||||||
uint64 allocation_generation) const {
|
uint64 allocation_generation) const {
|
||||||
mutex_lock lock(mutex_);
|
absl::MutexLock lock(&mutex_);
|
||||||
auto it = records_.find(device_memory);
|
auto it = records_.find(device_memory);
|
||||||
if (it == records_.end()) {
|
if (it == records_.end()) {
|
||||||
return true; // If there's no record present it's vacuously finalized.
|
return true; // If there's no record present it's vacuously finalized.
|
||||||
@ -81,7 +81,7 @@ bool TemporaryMemoryManager::IsFinalized(const DeviceMemoryBase& device_memory,
|
|||||||
|
|
||||||
bool TemporaryMemoryManager::HasAllocated(const DeviceMemoryBase& device_memory,
|
bool TemporaryMemoryManager::HasAllocated(const DeviceMemoryBase& device_memory,
|
||||||
uint64 generation) const {
|
uint64 generation) const {
|
||||||
mutex_lock lock(mutex_);
|
absl::MutexLock lock(&mutex_);
|
||||||
auto it = records_.find(device_memory);
|
auto it = records_.find(device_memory);
|
||||||
if (it == records_.end()) {
|
if (it == records_.end()) {
|
||||||
return false;
|
return false;
|
||||||
@ -106,7 +106,7 @@ TemporaryMemoryManager::AllocateArrayBase(uint64 element_count,
|
|||||||
// Add the record before instantiating the device memory instance so we can
|
// Add the record before instantiating the device memory instance so we can
|
||||||
// check the allocation invariant at TemporaryDeviceMemory construction time.
|
// check the allocation invariant at TemporaryDeviceMemory construction time.
|
||||||
{
|
{
|
||||||
mutex_lock lock(mutex_);
|
absl::MutexLock lock(&mutex_);
|
||||||
generation = ++generation_;
|
generation = ++generation_;
|
||||||
DCHECK(records_.find(device_memory) == records_.end());
|
DCHECK(records_.find(device_memory) == records_.end());
|
||||||
records_[device_memory] = {generation,
|
records_[device_memory] = {generation,
|
||||||
|
@ -24,10 +24,10 @@ limitations under the License.
|
|||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
|
#include "absl/synchronization/mutex.h"
|
||||||
#include "tensorflow/stream_executor/device_memory.h"
|
#include "tensorflow/stream_executor/device_memory.h"
|
||||||
#include "tensorflow/stream_executor/lib/status.h"
|
#include "tensorflow/stream_executor/lib/status.h"
|
||||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||||
#include "tensorflow/stream_executor/platform/mutex.h"
|
|
||||||
#include "tensorflow/stream_executor/platform/thread_annotations.h"
|
#include "tensorflow/stream_executor/platform/thread_annotations.h"
|
||||||
#include "tensorflow/stream_executor/temporary_device_memory.h"
|
#include "tensorflow/stream_executor/temporary_device_memory.h"
|
||||||
|
|
||||||
@ -108,7 +108,7 @@ class TemporaryMemoryManager {
|
|||||||
uint64 element_count, uint64 element_size);
|
uint64 element_count, uint64 element_size);
|
||||||
|
|
||||||
// Mutex to guard temporary record state.
|
// Mutex to guard temporary record state.
|
||||||
mutable mutex mutex_;
|
mutable absl::Mutex mutex_;
|
||||||
|
|
||||||
// Mapping from device memory to the current (live) temporary memory record.
|
// Mapping from device memory to the current (live) temporary memory record.
|
||||||
//
|
//
|
||||||
|
Loading…
Reference in New Issue
Block a user