PR #25676: [ROCm] Adding ROCm support for code in the tensorflow/core/common_runtime/gpu directory
Imported from GitHub PR #25676 This PR contains commits to add ROCm support for code within the directory `tensorflow/core/common_runtime/gpu` Prior to this PR the gpu implementation specific code within that dir is for CUDA only, and this PR adds the ROCm equivalent. Some files/variables have been renamed from *cuda* to *gpu* to make code more generic The PR is broken down into several commits for ease-of-review. I can squash them into one prior to the merge if need be. @tatianashp. Adding you to the cc-list just as an FYI @whchung Copybara import of the project: - c62c032d92e17a775db791ad03ae3502ea6d8e8e adding ROCm support in tensorflow/core/common_runtime/gpu... by Deven Desai <deven.desai.amd@gmail.com> - e94e3f408d9956dfa3aa7849a4ff41c584a6e0f0 adding ROCm support in tensorflow/core/common_runtime/gpu... by Deven Desai <deven.desai.amd@gmail.com> - 96e522ddf21fa675308a38b91d0a95efd771033e adding ROCm support in tensorflow/core/common_runtime/gpu... by Deven Desai <deven.desai.amd@gmail.com> - 929bbc2d4c620e31b2f7786600f4d7a271de0dba renaming cuda_host_allocator.h to gpu_host_allocator.h by Deven Desai <deven.desai.amd@gmail.com> - 4e255e01d26f74e003c11a894ebf39fcaf64ca49 misc changes by Deven Desai <deven.desai.amd@gmail.com> - a11b8117f1b27ba7f1abd43c4849e29d21c610d2 changes requested in the code review by Deven Desai <deven.desai.amd@gmail.com> - 07b49644d27aeac2e8646609640206faf5908ca5 using string instead int std::string + removing the #defi... by Deven Desai <deven.desai.amd@gmail.com> - 46bc499856a769cd00e9ddcdf5c922e3c17f0603 Merge 07b49644d27aeac2e8646609640206faf5908ca5 into afab5... by Deven Desai <36858332+deven-amd@users.noreply.github.com> COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/tensorflow/pull/25676 from ROCmSoftwarePlatform:google_upstream_common_runtime_changes 07b49644d27aeac2e8646609640206faf5908ca5 PiperOrigin-RevId: 237437557
This commit is contained in:
parent
bc05464062
commit
d280d3d9b9
@ -250,10 +250,9 @@ Status GdrMemoryManager::Init() {
|
|||||||
LOG(INFO) << "Instrumenting CPU allocator(s)";
|
LOG(INFO) << "Instrumenting CPU allocator(s)";
|
||||||
|
|
||||||
for (int numa_idx = 0; numa_idx < port::NUMANumNodes(); ++numa_idx) {
|
for (int numa_idx = 0; numa_idx < port::NUMANumNodes(); ++numa_idx) {
|
||||||
GPUProcessState::singleton()->AddCUDAHostAllocVisitor(numa_idx,
|
GPUProcessState::singleton()->AddGpuHostAllocVisitor(numa_idx,
|
||||||
alloc_visitor);
|
alloc_visitor);
|
||||||
GPUProcessState::singleton()->AddCUDAHostFreeVisitor(numa_idx,
|
GPUProcessState::singleton()->AddGpuHostFreeVisitor(numa_idx, free_visitor);
|
||||||
free_visitor);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (IsGDRAvailable()) {
|
if (IsGDRAvailable()) {
|
||||||
|
@ -1086,7 +1086,7 @@ void RdmaTensorResponse::RecvHandler(Rendezvous::ParsedKey parsed,
|
|||||||
// The tensor must be copied from GPU to CPU, because either:
|
// The tensor must be copied from GPU to CPU, because either:
|
||||||
// 1. The tensor is located on a non GDR compatible GPU.
|
// 1. The tensor is located on a non GDR compatible GPU.
|
||||||
// 2. The tensor's meta-data has changed.
|
// 2. The tensor's meta-data has changed.
|
||||||
Allocator* alloc = GPUProcessState::singleton()->GetCUDAHostAllocator(0);
|
Allocator* alloc = GPUProcessState::singleton()->GetGpuHostAllocator(0);
|
||||||
copy = Tensor(alloc, in.dtype(), in.shape());
|
copy = Tensor(alloc, in.dtype(), in.shape());
|
||||||
CountCopies(rm_.name_, (void*)DMAHelper::base(&in),
|
CountCopies(rm_.name_, (void*)DMAHelper::base(&in),
|
||||||
(void*)DMAHelper::base(©), in.TotalBytes(), true);
|
(void*)DMAHelper::base(©), in.TotalBytes(), true);
|
||||||
@ -1543,7 +1543,7 @@ bool RdmaTensorRequest::AllocateTensors() {
|
|||||||
if (mr_ == nullptr) {
|
if (mr_ == nullptr) {
|
||||||
// Can't RDMA directly to result. Use a proxy.
|
// Can't RDMA directly to result. Use a proxy.
|
||||||
proxy_tensor_ =
|
proxy_tensor_ =
|
||||||
new Tensor(GPUProcessState::singleton()->GetCUDAHostAllocator(0),
|
new Tensor(GPUProcessState::singleton()->GetGpuHostAllocator(0),
|
||||||
result_tensor_->dtype(), result_tensor_->shape());
|
result_tensor_->dtype(), result_tensor_->shape());
|
||||||
rdma_addr_ = DMAHelper::base(proxy_tensor_);
|
rdma_addr_ = DMAHelper::base(proxy_tensor_);
|
||||||
mr_ =
|
mr_ =
|
||||||
|
@ -277,8 +277,8 @@ void RdmaMgr::InitAllocators() {
|
|||||||
ProcessState::singleton()->AddCPUFreeVisitor(free_visitor);
|
ProcessState::singleton()->AddCPUFreeVisitor(free_visitor);
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
GPUProcessState::singleton()->AddCUDAHostAllocVisitor(0, alloc_visitor);
|
GPUProcessState::singleton()->AddGpuHostAllocVisitor(0, alloc_visitor);
|
||||||
GPUProcessState::singleton()->AddCUDAHostFreeVisitor(0, free_visitor);
|
GPUProcessState::singleton()->AddGpuHostFreeVisitor(0, free_visitor);
|
||||||
|
|
||||||
if (IsGDRAvailable()) {
|
if (IsGDRAvailable()) {
|
||||||
// Note we don't free allocated GPU memory so there is no free visitor
|
// Note we don't free allocated GPU memory so there is no free visitor
|
||||||
|
@ -1705,6 +1705,7 @@ filegroup(
|
|||||||
"platform/**/logger.cc",
|
"platform/**/logger.cc",
|
||||||
"platform/default/test_benchmark.*",
|
"platform/default/test_benchmark.*",
|
||||||
"platform/cuda.h",
|
"platform/cuda.h",
|
||||||
|
"platform/rocm.h",
|
||||||
"platform/google/**/*",
|
"platform/google/**/*",
|
||||||
"platform/hadoop/**/*",
|
"platform/hadoop/**/*",
|
||||||
"platform/gif.h",
|
"platform/gif.h",
|
||||||
@ -2259,6 +2260,7 @@ LIB_INTERNAL_PRIVATE_HEADERS = ["framework/resource_handle.h"] + glob(
|
|||||||
"platform/jpeg.h",
|
"platform/jpeg.h",
|
||||||
"platform/png.h",
|
"platform/png.h",
|
||||||
"platform/**/cuda.h",
|
"platform/**/cuda.h",
|
||||||
|
"platform/**/rocm.h",
|
||||||
"platform/**/stream_executor.h",
|
"platform/**/stream_executor.h",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -2371,6 +2373,7 @@ cc_library(
|
|||||||
"**/*test*",
|
"**/*test*",
|
||||||
"platform/**/cuda.h",
|
"platform/**/cuda.h",
|
||||||
"platform/**/cuda_libdevice_path.cc",
|
"platform/**/cuda_libdevice_path.cc",
|
||||||
|
"platform/**/rocm.h",
|
||||||
"platform/**/stream_executor.h",
|
"platform/**/stream_executor.h",
|
||||||
"platform/**/env_time.cc",
|
"platform/**/env_time.cc",
|
||||||
"platform/**/device_tracer.cc",
|
"platform/**/device_tracer.cc",
|
||||||
@ -2866,6 +2869,7 @@ tf_cuda_library(
|
|||||||
srcs = ["platform/stream_executor.h"],
|
srcs = ["platform/stream_executor.h"],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"platform/cuda.h",
|
"platform/cuda.h",
|
||||||
|
"platform/rocm.h",
|
||||||
"platform/stream_executor.h",
|
"platform/stream_executor.h",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
@ -3301,7 +3305,7 @@ cc_library(
|
|||||||
)
|
)
|
||||||
|
|
||||||
GPU_RUNTIME_HEADERS = [
|
GPU_RUNTIME_HEADERS = [
|
||||||
"common_runtime/gpu/cuda_host_allocator.h",
|
"common_runtime/gpu/gpu_host_allocator.h",
|
||||||
"common_runtime/gpu/gpu_bfc_allocator.h",
|
"common_runtime/gpu/gpu_bfc_allocator.h",
|
||||||
"common_runtime/gpu/gpu_cudamalloc_allocator.h",
|
"common_runtime/gpu/gpu_cudamalloc_allocator.h",
|
||||||
"common_runtime/gpu/gpu_debug_allocator.h",
|
"common_runtime/gpu/gpu_debug_allocator.h",
|
||||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h"
|
#include "tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h"
|
||||||
|
|
||||||
@ -580,4 +580,4 @@ TEST_F(GPUBFCAllocatorPrivateMethodsTest, ForceAllowGrowth) {
|
|||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h"
|
#include "tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h"
|
||||||
|
|
||||||
@ -249,4 +249,4 @@ TEST(GPUDebugAllocatorTest, AllocatedVsRequested) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
@ -15,7 +15,11 @@ limitations under the License.
|
|||||||
|
|
||||||
// TODO(opensource): Use a more generic sounding preprocessor name than
|
// TODO(opensource): Use a more generic sounding preprocessor name than
|
||||||
// GOOGLE_CUDA
|
// GOOGLE_CUDA
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
|
#if TENSORFLOW_USE_ROCM
|
||||||
|
#include "rocm/include/hip/hip_runtime.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
@ -55,7 +59,11 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/strings/numbers.h"
|
#include "tensorflow/core/lib/strings/numbers.h"
|
||||||
#include "tensorflow/core/lib/strings/str_util.h"
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
|
#if GOOGLE_CUDA
|
||||||
#include "tensorflow/core/platform/cuda.h"
|
#include "tensorflow/core/platform/cuda.h"
|
||||||
|
#elif TENSORFLOW_USE_ROCM
|
||||||
|
#include "tensorflow/core/platform/rocm.h"
|
||||||
|
#endif
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
#include "tensorflow/core/platform/stream_executor.h"
|
#include "tensorflow/core/platform/stream_executor.h"
|
||||||
@ -67,18 +75,36 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/util/stream_executor_util.h"
|
#include "tensorflow/core/util/stream_executor_util.h"
|
||||||
|
|
||||||
#if !defined(PLATFORM_GOOGLE)
|
#if !defined(PLATFORM_GOOGLE)
|
||||||
|
#if GOOGLE_CUDA
|
||||||
#include "cuda/cuda_config.h"
|
#include "cuda/cuda_config.h"
|
||||||
#endif
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
|
typedef cudaStream_t gpuStream_t;
|
||||||
|
typedef cudaDeviceProp gpuDeviceProp_t;
|
||||||
|
#define EIGEN_GPU_SCRATCH_SIZE (Eigen::kGpuScratchSize)
|
||||||
|
using se::cuda::ScopedActivateExecutorContext;
|
||||||
|
|
||||||
|
#elif TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
|
typedef hipStream_t gpuStream_t;
|
||||||
|
typedef hipDeviceProp_t gpuDeviceProp_t;
|
||||||
|
#define EIGEN_GPU_SCRATCH_SIZE (Eigen::kGpuScratchSize)
|
||||||
|
using se::rocm::ScopedActivateExecutorContext;
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
// Eigen Ops directly allocate memory only for temporary buffers used
|
// Eigen Ops directly allocate memory only for temporary buffers used
|
||||||
// during OpKernel::Compute(). The recommended way of allocating such
|
// during OpKernel::Compute(). The recommended way of allocating such
|
||||||
// memory is via OpKernelContext::allocate_temp(). However, Eigen Ops
|
// memory is via OpKernelContext::allocate_temp(). However, Eigen Ops
|
||||||
// don't have access to OpKernelContext, instead they get access to
|
// don't have access to OpKernelContext, instead they get access to
|
||||||
// memory directly through the device allocator. As an Open Source
|
// memory directly through the device allocator. As an Open Source
|
||||||
// project, Eigen assumes allocator semantics similar to those of the
|
// project, Eigen assumes allocator semantics similar to those of the
|
||||||
// CUDA memory allocator, and may not work correctly due to race
|
// CUDA or ROCm memory allocator, and may not work correctly due to race
|
||||||
// conditions if used with some other allocator. For safety, we need
|
// conditions if used with some other allocator. For safety, we need
|
||||||
// to delay deallocation calls out of Eigen until all events on the
|
// to delay deallocation calls out of Eigen until all events on the
|
||||||
// corresponding stream have completed. The following two classes
|
// corresponding stream have completed. The following two classes
|
||||||
@ -91,7 +117,7 @@ class EigenGpuStreamDevice : public ::Eigen::StreamInterface {
|
|||||||
Eigen::initializeDeviceProp();
|
Eigen::initializeDeviceProp();
|
||||||
}
|
}
|
||||||
~EigenGpuStreamDevice() override {}
|
~EigenGpuStreamDevice() override {}
|
||||||
void Reinitialize(OpKernelContext* context, const cudaStream_t* cuda_stream,
|
void Reinitialize(OpKernelContext* context, const gpuStream_t* gpu_stream,
|
||||||
TfGpuId tf_gpu_id, ::tensorflow::Allocator* alloc,
|
TfGpuId tf_gpu_id, ::tensorflow::Allocator* alloc,
|
||||||
char* scratch) {
|
char* scratch) {
|
||||||
if (LogMemory::IsEnabled()) {
|
if (LogMemory::IsEnabled()) {
|
||||||
@ -102,15 +128,15 @@ class EigenGpuStreamDevice : public ::Eigen::StreamInterface {
|
|||||||
scratch_ = scratch;
|
scratch_ = scratch;
|
||||||
semaphore_ =
|
semaphore_ =
|
||||||
reinterpret_cast<unsigned int*>(scratch + Eigen::kGpuScratchSize);
|
reinterpret_cast<unsigned int*>(scratch + Eigen::kGpuScratchSize);
|
||||||
stream_ = cuda_stream;
|
stream_ = gpu_stream;
|
||||||
allocator_ = alloc;
|
allocator_ = alloc;
|
||||||
PlatformGpuId platform_gpu_id;
|
PlatformGpuId platform_gpu_id;
|
||||||
TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id));
|
TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id));
|
||||||
device_prop_ = &Eigen::m_deviceProperties[platform_gpu_id.value()];
|
device_prop_ = &Eigen::m_deviceProperties[platform_gpu_id.value()];
|
||||||
}
|
}
|
||||||
|
|
||||||
const cudaStream_t& stream() const override { return *stream_; }
|
const gpuStream_t& stream() const override { return *stream_; }
|
||||||
const cudaDeviceProp& deviceProperties() const override {
|
const gpuDeviceProp_t& deviceProperties() const override {
|
||||||
return *device_prop_;
|
return *device_prop_;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -140,8 +166,13 @@ class EigenGpuStreamDevice : public ::Eigen::StreamInterface {
|
|||||||
}
|
}
|
||||||
AsyncFreeData* afData =
|
AsyncFreeData* afData =
|
||||||
new AsyncFreeData(allocator_, buffer, operation_, step_id_);
|
new AsyncFreeData(allocator_, buffer, operation_, step_id_);
|
||||||
|
#if GOOGLE_CUDA
|
||||||
cudaError_t err = cudaStreamAddCallback(*stream_, asyncFree, afData, 0);
|
cudaError_t err = cudaStreamAddCallback(*stream_, asyncFree, afData, 0);
|
||||||
CHECK_EQ(err, cudaSuccess);
|
CHECK_EQ(err, cudaSuccess);
|
||||||
|
#elif TENSORFLOW_USE_ROCM
|
||||||
|
hipError_t err = hipStreamAddCallback(*stream_, asyncFree, afData, 0);
|
||||||
|
CHECK_EQ(err, hipSuccess);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return a pointer to a per stream scratchpad of 1024 bytes residing
|
// Return a pointer to a per stream scratchpad of 1024 bytes residing
|
||||||
@ -165,8 +196,12 @@ class EigenGpuStreamDevice : public ::Eigen::StreamInterface {
|
|||||||
const int64 step_id_;
|
const int64 step_id_;
|
||||||
};
|
};
|
||||||
|
|
||||||
static void CUDART_CB asyncFree(cudaStream_t stream, cudaError_t status,
|
#if GOOGLE_CUDA
|
||||||
|
static void CUDART_CB asyncFree(gpuStream_t stream, cudaError_t status,
|
||||||
void* userData) {
|
void* userData) {
|
||||||
|
#elif TENSORFLOW_USE_ROCM
|
||||||
|
static void asyncFree(gpuStream_t stream, hipError_t status, void* userData) {
|
||||||
|
#endif
|
||||||
AsyncFreeData* data = static_cast<AsyncFreeData*>(userData);
|
AsyncFreeData* data = static_cast<AsyncFreeData*>(userData);
|
||||||
if (LogMemory::IsEnabled()) {
|
if (LogMemory::IsEnabled()) {
|
||||||
LogMemory::RecordRawDeallocation(data->operation_, data->step_id_,
|
LogMemory::RecordRawDeallocation(data->operation_, data->step_id_,
|
||||||
@ -178,8 +213,8 @@ class EigenGpuStreamDevice : public ::Eigen::StreamInterface {
|
|||||||
|
|
||||||
string operation_;
|
string operation_;
|
||||||
int64 step_id_;
|
int64 step_id_;
|
||||||
const cudaStream_t* stream_; // Not owned.
|
const gpuStream_t* stream_; // Not owned.
|
||||||
const cudaDeviceProp* device_prop_; // Not owned.
|
const gpuDeviceProp_t* device_prop_; // Not owned.
|
||||||
::tensorflow::Allocator* allocator_; // Not owned.
|
::tensorflow::Allocator* allocator_; // Not owned.
|
||||||
mutable char* scratch_;
|
mutable char* scratch_;
|
||||||
mutable unsigned int* semaphore_;
|
mutable unsigned int* semaphore_;
|
||||||
@ -454,7 +489,7 @@ Status BaseGPUDevice::FillContextMap(const Graph* graph,
|
|||||||
void BaseGPUDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
|
void BaseGPUDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
|
||||||
// NOTE(tucker): We need to discriminate between Eigen GPU
|
// NOTE(tucker): We need to discriminate between Eigen GPU
|
||||||
// operations and all others. If an operation is Eigen
|
// operations and all others. If an operation is Eigen
|
||||||
// implemented (or otherwise tries to launch a cuda kernel
|
// implemented (or otherwise tries to launch a GPU kernel
|
||||||
// directly), we need to establish a stacked-scoped environment
|
// directly), we need to establish a stacked-scoped environment
|
||||||
// that directs it to execute on the proper device. Otherwise we
|
// that directs it to execute on the proper device. Otherwise we
|
||||||
// expect the Op to use StreamExecutor directly and correctly. The
|
// expect the Op to use StreamExecutor directly and correctly. The
|
||||||
@ -530,7 +565,7 @@ void BaseGPUDevice::ComputeHelper(OpKernel* op_kernel,
|
|||||||
DCHECK(kernel_tracker_);
|
DCHECK(kernel_tracker_);
|
||||||
kernel_tracker_->PauseWhilePendingExceeds(pending_cap_);
|
kernel_tracker_->PauseWhilePendingExceeds(pending_cap_);
|
||||||
}
|
}
|
||||||
se::cuda::ScopedActivateExecutorContext scoped_activation{stream->parent()};
|
ScopedActivateExecutorContext scoped_activation{stream->parent()};
|
||||||
op_kernel->Compute(context);
|
op_kernel->Compute(context);
|
||||||
if (context->status().ok()) {
|
if (context->status().ok()) {
|
||||||
if (sync_every_op_) {
|
if (sync_every_op_) {
|
||||||
@ -596,7 +631,7 @@ void BaseGPUDevice::ComputeAsync(AsyncOpKernel* op_kernel,
|
|||||||
// activity is simple enough that its overhead is negligible.
|
// activity is simple enough that its overhead is negligible.
|
||||||
tracing::ScopedActivity activity(op_kernel->name(), op_kernel->type_string(),
|
tracing::ScopedActivity activity(op_kernel->name(), op_kernel->type_string(),
|
||||||
op_kernel->IsExpensive());
|
op_kernel->IsExpensive());
|
||||||
se::cuda::ScopedActivateExecutorContext scoped_activation{stream->parent()};
|
ScopedActivateExecutorContext scoped_activation{stream->parent()};
|
||||||
op_kernel->ComputeAsync(context, done);
|
op_kernel->ComputeAsync(context, done);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -715,10 +750,10 @@ class ConcretePerOpGpuDevice : public PerOpGpuDevice {
|
|||||||
public:
|
public:
|
||||||
ConcretePerOpGpuDevice() : device_(&stream_device_) {}
|
ConcretePerOpGpuDevice() : device_(&stream_device_) {}
|
||||||
|
|
||||||
void Reinitialize(OpKernelContext* context, const cudaStream_t* cuda_stream,
|
void Reinitialize(OpKernelContext* context, const gpuStream_t* gpu_stream,
|
||||||
TfGpuId tf_gpu_id, Allocator* base_allocator,
|
TfGpuId tf_gpu_id, Allocator* base_allocator,
|
||||||
char* scratch) {
|
char* scratch) {
|
||||||
stream_device_.Reinitialize(context, cuda_stream, tf_gpu_id, base_allocator,
|
stream_device_.Reinitialize(context, gpu_stream, tf_gpu_id, base_allocator,
|
||||||
scratch);
|
scratch);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -898,9 +933,9 @@ void BaseGPUDevice::ReinitializeDevice(OpKernelContext* context,
|
|||||||
ConcretePerOpGpuDevice* concrete_device =
|
ConcretePerOpGpuDevice* concrete_device =
|
||||||
static_cast<ConcretePerOpGpuDevice*>(device);
|
static_cast<ConcretePerOpGpuDevice*>(device);
|
||||||
DCHECK(concrete_device);
|
DCHECK(concrete_device);
|
||||||
const cudaStream_t* cuda_stream = reinterpret_cast<const cudaStream_t*>(
|
const gpuStream_t* gpu_stream = reinterpret_cast<const gpuStream_t*>(
|
||||||
streams_[stream_id]->compute->implementation()->GpuStreamMemberHack());
|
streams_[stream_id]->compute->implementation()->GpuStreamMemberHack());
|
||||||
concrete_device->Reinitialize(context, cuda_stream, tf_gpu_id_, allocator,
|
concrete_device->Reinitialize(context, gpu_stream, tf_gpu_id_, allocator,
|
||||||
scratch_[stream_id]);
|
scratch_[stream_id]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -977,14 +1012,24 @@ Status BaseGPUDeviceFactory::CreateDevices(
|
|||||||
if (!valid_platform_gpu_ids.empty()) {
|
if (!valid_platform_gpu_ids.empty()) {
|
||||||
// Save the original device.
|
// Save the original device.
|
||||||
int original_device = 0;
|
int original_device = 0;
|
||||||
|
#if GOOGLE_CUDA
|
||||||
cudaError_t err = cudaGetDevice(&original_device);
|
cudaError_t err = cudaGetDevice(&original_device);
|
||||||
if (err != cudaSuccess) {
|
if (err != cudaSuccess) {
|
||||||
return errors::Internal("cudaGetDevice() failed. Status: ",
|
return errors::Internal("cudaGetDevice() failed. Status: ",
|
||||||
cudaGetErrorString(err));
|
cudaGetErrorString(err));
|
||||||
}
|
}
|
||||||
|
#elif TENSORFLOW_USE_ROCM
|
||||||
|
hipError_t err = hipGetDevice(&original_device);
|
||||||
|
if (err != hipSuccess) {
|
||||||
|
return errors::Internal("hipGetDevice() failed. Status: ",
|
||||||
|
hipGetErrorString(err));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
// Force to implicitly initialize CUDA runtime on each valid GPU before
|
// Force to implicitly initialize CUDA runtime on each valid GPU before
|
||||||
// CreateGPUDevice().
|
// CreateGPUDevice().
|
||||||
for (PlatformGpuId platform_gpu_id : valid_platform_gpu_ids) {
|
for (PlatformGpuId platform_gpu_id : valid_platform_gpu_ids) {
|
||||||
|
#if GOOGLE_CUDA
|
||||||
err = cudaSetDevice(platform_gpu_id.value());
|
err = cudaSetDevice(platform_gpu_id.value());
|
||||||
if (err != cudaSuccess) {
|
if (err != cudaSuccess) {
|
||||||
return errors::Internal(
|
return errors::Internal(
|
||||||
@ -997,13 +1042,35 @@ Status BaseGPUDeviceFactory::CreateDevices(
|
|||||||
platform_gpu_id.value(),
|
platform_gpu_id.value(),
|
||||||
" failed. Status: ", cudaGetErrorString(err));
|
" failed. Status: ", cudaGetErrorString(err));
|
||||||
}
|
}
|
||||||
|
#elif TENSORFLOW_USE_ROCM
|
||||||
|
err = hipSetDevice(platform_gpu_id.value());
|
||||||
|
if (err != hipSuccess) {
|
||||||
|
return errors::Internal(
|
||||||
|
"hipSetDevice() on GPU:", platform_gpu_id.value(),
|
||||||
|
" failed. Status: ", hipGetErrorString(err));
|
||||||
|
}
|
||||||
|
err = hipFree(nullptr);
|
||||||
|
if (err != hipSuccess) {
|
||||||
|
return errors::Internal("ROCm runtime implicit initialization on GPU:",
|
||||||
|
platform_gpu_id.value(),
|
||||||
|
" failed. Status: ", hipGetErrorString(err));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
// Reset to the original device.
|
// Reset to the original device.
|
||||||
|
#if GOOGLE_CUDA
|
||||||
err = cudaSetDevice(original_device);
|
err = cudaSetDevice(original_device);
|
||||||
if (err != cudaSuccess) {
|
if (err != cudaSuccess) {
|
||||||
return errors::Internal("cudaSetDevice() on GPU:", original_device,
|
return errors::Internal("cudaSetDevice() on GPU:", original_device,
|
||||||
" failed. Status: ", cudaGetErrorString(err));
|
" failed. Status: ", cudaGetErrorString(err));
|
||||||
}
|
}
|
||||||
|
#elif TENSORFLOW_USE_ROCM
|
||||||
|
err = hipSetDevice(original_device);
|
||||||
|
if (err != hipSuccess) {
|
||||||
|
return errors::Internal("hipSetDevice() on GPU:", original_device,
|
||||||
|
" failed. Status: ", hipGetErrorString(err));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<InterconnectMap> interconnect_maps;
|
std::vector<InterconnectMap> interconnect_maps;
|
||||||
@ -1093,6 +1160,7 @@ Status BaseGPUDeviceFactory::CreateDevices(
|
|||||||
|
|
||||||
static string GetShortDeviceDescription(PlatformGpuId platform_gpu_id,
|
static string GetShortDeviceDescription(PlatformGpuId platform_gpu_id,
|
||||||
const se::DeviceDescription& desc) {
|
const se::DeviceDescription& desc) {
|
||||||
|
#if GOOGLE_CUDA
|
||||||
int cc_major;
|
int cc_major;
|
||||||
int cc_minor;
|
int cc_minor;
|
||||||
if (!desc.cuda_compute_capability(&cc_major, &cc_minor)) {
|
if (!desc.cuda_compute_capability(&cc_major, &cc_minor)) {
|
||||||
@ -1104,6 +1172,11 @@ static string GetShortDeviceDescription(PlatformGpuId platform_gpu_id,
|
|||||||
desc.name(), ", pci bus id: ", desc.pci_bus_id(),
|
desc.name(), ", pci bus id: ", desc.pci_bus_id(),
|
||||||
", compute capability: ", cc_major, ".", cc_minor);
|
", compute capability: ", cc_major, ".", cc_minor);
|
||||||
// LINT.ThenChange(//tensorflow/python/platform/test.py)
|
// LINT.ThenChange(//tensorflow/python/platform/test.py)
|
||||||
|
#elif TENSORFLOW_USE_ROCM
|
||||||
|
return strings::StrCat("device: ", platform_gpu_id.value(),
|
||||||
|
", name: ", desc.name(),
|
||||||
|
", pci bus id: ", desc.pci_bus_id());
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
Status BaseGPUDeviceFactory::CreateGPUDevice(
|
Status BaseGPUDeviceFactory::CreateGPUDevice(
|
||||||
@ -1329,6 +1402,7 @@ static int GetMinGPUMultiprocessorCount(
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
struct CudaVersion {
|
struct CudaVersion {
|
||||||
// Initialize from version_name in the form of "3.5"
|
// Initialize from version_name in the form of "3.5"
|
||||||
explicit CudaVersion(const std::string& version_name) {
|
explicit CudaVersion(const std::string& version_name) {
|
||||||
@ -1380,6 +1454,15 @@ std::vector<CudaVersion> GetSupportedCudaComputeCapabilities() {
|
|||||||
#endif
|
#endif
|
||||||
return cuda_caps;
|
return cuda_caps;
|
||||||
}
|
}
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
#if TENSORFLOW_USE_ROCM
|
||||||
|
std::vector<int> supported_amdgpu_isa_versions = {803, 900, 906};
|
||||||
|
|
||||||
|
std::vector<int> GetSupportedAMDGPUISAVersions() {
|
||||||
|
return supported_amdgpu_isa_versions;
|
||||||
|
}
|
||||||
|
#endif // TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
Status EnablePeerAccess(se::Platform* platform,
|
Status EnablePeerAccess(se::Platform* platform,
|
||||||
const std::vector<PlatformGpuId>& visible_gpu_order) {
|
const std::vector<PlatformGpuId>& visible_gpu_order) {
|
||||||
@ -1457,6 +1540,7 @@ Status BaseGPUDeviceFactory::GetValidDeviceIds(
|
|||||||
total_bytes = 0;
|
total_bytes = 0;
|
||||||
}
|
}
|
||||||
const auto& description = stream_exec->GetDeviceDescription();
|
const auto& description = stream_exec->GetDeviceDescription();
|
||||||
|
#if GOOGLE_CUDA
|
||||||
int cc_major;
|
int cc_major;
|
||||||
int cc_minor;
|
int cc_minor;
|
||||||
if (!description.cuda_compute_capability(&cc_major, &cc_minor)) {
|
if (!description.cuda_compute_capability(&cc_major, &cc_minor)) {
|
||||||
@ -1471,6 +1555,21 @@ Status BaseGPUDeviceFactory::GetValidDeviceIds(
|
|||||||
<< "\npciBusID: " << description.pci_bus_id() << "\ntotalMemory: "
|
<< "\npciBusID: " << description.pci_bus_id() << "\ntotalMemory: "
|
||||||
<< strings::HumanReadableNumBytes(total_bytes)
|
<< strings::HumanReadableNumBytes(total_bytes)
|
||||||
<< " freeMemory: " << strings::HumanReadableNumBytes(free_bytes);
|
<< " freeMemory: " << strings::HumanReadableNumBytes(free_bytes);
|
||||||
|
#elif TENSORFLOW_USE_ROCM
|
||||||
|
int isa_version;
|
||||||
|
if (!description.rocm_amdgpu_isa_version(&isa_version)) {
|
||||||
|
// Logs internally on failure.
|
||||||
|
isa_version = 0;
|
||||||
|
}
|
||||||
|
LOG(INFO) << "Found device " << i << " with properties: "
|
||||||
|
<< "\nname: " << description.name() << "\nAMDGPU ISA: gfx"
|
||||||
|
<< isa_version << "\nmemoryClockRate (GHz) "
|
||||||
|
<< description.clock_rate_ghz() << "\npciBusID "
|
||||||
|
<< description.pci_bus_id() << "\nTotal memory: "
|
||||||
|
<< strings::HumanReadableNumBytes(total_bytes)
|
||||||
|
<< "\nFree memory: "
|
||||||
|
<< strings::HumanReadableNumBytes(free_bytes);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
// Checking peering and shows matrix if more than one gpu found.
|
// Checking peering and shows matrix if more than one gpu found.
|
||||||
if (new_gpu_found && visible_gpu_order.size() > 1) {
|
if (new_gpu_found && visible_gpu_order.size() > 1) {
|
||||||
@ -1478,6 +1577,7 @@ Status BaseGPUDeviceFactory::GetValidDeviceIds(
|
|||||||
TF_RETURN_IF_ERROR(EnablePeerAccess(gpu_manager, visible_gpu_order));
|
TF_RETURN_IF_ERROR(EnablePeerAccess(gpu_manager, visible_gpu_order));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
auto cuda_supported_capabilities = GetSupportedCudaComputeCapabilities();
|
auto cuda_supported_capabilities = GetSupportedCudaComputeCapabilities();
|
||||||
if (cuda_supported_capabilities.empty()) {
|
if (cuda_supported_capabilities.empty()) {
|
||||||
return errors::FailedPrecondition(
|
return errors::FailedPrecondition(
|
||||||
@ -1485,6 +1585,15 @@ Status BaseGPUDeviceFactory::GetValidDeviceIds(
|
|||||||
}
|
}
|
||||||
CudaVersion min_supported_capability = *std::min_element(
|
CudaVersion min_supported_capability = *std::min_element(
|
||||||
cuda_supported_capabilities.begin(), cuda_supported_capabilities.end());
|
cuda_supported_capabilities.begin(), cuda_supported_capabilities.end());
|
||||||
|
#elif TENSORFLOW_USE_ROCM
|
||||||
|
auto rocm_supported_isas = GetSupportedAMDGPUISAVersions();
|
||||||
|
if (rocm_supported_isas.empty()) {
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"No supported rocm capabilities in binary.");
|
||||||
|
}
|
||||||
|
int min_supported_isa =
|
||||||
|
*std::min_element(rocm_supported_isas.begin(), rocm_supported_isas.end());
|
||||||
|
#endif
|
||||||
|
|
||||||
int min_gpu_core_count =
|
int min_gpu_core_count =
|
||||||
GetMinGPUMultiprocessorCount(gpu_manager, visible_gpu_order);
|
GetMinGPUMultiprocessorCount(gpu_manager, visible_gpu_order);
|
||||||
@ -1502,6 +1611,8 @@ Status BaseGPUDeviceFactory::GetValidDeviceIds(
|
|||||||
}
|
}
|
||||||
se::StreamExecutor* se = exec_status.ValueOrDie();
|
se::StreamExecutor* se = exec_status.ValueOrDie();
|
||||||
const se::DeviceDescription& desc = se->GetDeviceDescription();
|
const se::DeviceDescription& desc = se->GetDeviceDescription();
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
CudaVersion device_capability;
|
CudaVersion device_capability;
|
||||||
if (!desc.cuda_compute_capability(&device_capability.major_part,
|
if (!desc.cuda_compute_capability(&device_capability.major_part,
|
||||||
&device_capability.minor_part)) {
|
&device_capability.minor_part)) {
|
||||||
@ -1522,6 +1633,23 @@ Status BaseGPUDeviceFactory::GetValidDeviceIds(
|
|||||||
<< min_supported_capability << ".";
|
<< min_supported_capability << ".";
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
#elif TENSORFLOW_USE_ROCM
|
||||||
|
int device_isa;
|
||||||
|
if (!desc.rocm_amdgpu_isa_version(&device_isa)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// Only GPUs with no less than the minimum supported compute capability is
|
||||||
|
// accepted.
|
||||||
|
if (device_isa < min_supported_isa) {
|
||||||
|
LOG(INFO) << "Ignoring visible gpu device "
|
||||||
|
<< "(" << GetShortDeviceDescription(visible_gpu_id, desc)
|
||||||
|
<< ") "
|
||||||
|
<< "with AMDGPU ISA gfx" << device_isa
|
||||||
|
<< ". The minimum required AMDGPU ISA is gfx"
|
||||||
|
<< min_supported_isa << ".";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
// Filter out slow GPUs. By default, GPUs with a lower multiprocessor
|
// Filter out slow GPUs. By default, GPUs with a lower multiprocessor
|
||||||
// count than the fastest GPU are filtered out, unless they have 8 or more
|
// count than the fastest GPU are filtered out, unless they have 8 or more
|
||||||
@ -1531,7 +1659,7 @@ Status BaseGPUDeviceFactory::GetValidDeviceIds(
|
|||||||
LOG(INFO) << "Ignoring visible gpu device "
|
LOG(INFO) << "Ignoring visible gpu device "
|
||||||
<< "(" << GetShortDeviceDescription(visible_gpu_id, desc)
|
<< "(" << GetShortDeviceDescription(visible_gpu_id, desc)
|
||||||
<< ") "
|
<< ") "
|
||||||
<< "with Cuda multiprocessor count: " << desc.core_count()
|
<< "with core count: " << desc.core_count()
|
||||||
<< ". The minimum required count is " << min_gpu_core_count
|
<< ". The minimum required count is " << min_gpu_core_count
|
||||||
<< ". You can adjust this requirement with the env var "
|
<< ". You can adjust this requirement with the env var "
|
||||||
"TF_MIN_GPU_MULTIPROCESSOR_COUNT.";
|
"TF_MIN_GPU_MULTIPROCESSOR_COUNT.";
|
||||||
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#if !GOOGLE_CUDA
|
#if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM
|
||||||
#error This file must only be included when building with Cuda support
|
#error This file must only be included when building with Cuda or ROCm support
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_DEVICE_H_
|
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_DEVICE_H_
|
||||||
@ -98,7 +98,7 @@ class BaseGPUDevice : public LocalDevice {
|
|||||||
Allocator* allocator) override;
|
Allocator* allocator) override;
|
||||||
|
|
||||||
// Returns the platform GPU id of this device within the native driver system;
|
// Returns the platform GPU id of this device within the native driver system;
|
||||||
// e.g., for CUDA this is the ordinal of the GPU within the system.
|
// e.g., for CUDA and ROCm this is the ordinal of the GPU within the system.
|
||||||
int gpu_id() const {
|
int gpu_id() const {
|
||||||
PlatformGpuId platform_gpu_id;
|
PlatformGpuId platform_gpu_id;
|
||||||
TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf_gpu_id_, &platform_gpu_id));
|
TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf_gpu_id_, &platform_gpu_id));
|
||||||
@ -311,8 +311,8 @@ class BaseGPUDeviceFactory : public DeviceFactory {
|
|||||||
// Returns into 'ids' the list of valid platform GPU ids, in the order that
|
// Returns into 'ids' the list of valid platform GPU ids, in the order that
|
||||||
// they should map to TF GPU ids "/device:GPU:0", "/device:GPU:1", etc,
|
// they should map to TF GPU ids "/device:GPU:0", "/device:GPU:1", etc,
|
||||||
// based upon 'visible_gpu_order' which was generated by parsing
|
// based upon 'visible_gpu_order' which was generated by parsing
|
||||||
// GPUOptions::visible_device_list which is a comma-separated list of CUDA GPU
|
// GPUOptions::visible_device_list which is a comma-separated list of CUDA or
|
||||||
// ids.
|
// ROCm GPU ids.
|
||||||
Status GetValidDeviceIds(const std::vector<PlatformGpuId>& visible_gpu_order,
|
Status GetValidDeviceIds(const std::vector<PlatformGpuId>& visible_gpu_order,
|
||||||
std::vector<PlatformGpuId>* ids);
|
std::vector<PlatformGpuId>* ids);
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
@ -45,7 +45,7 @@ class GPUDevice : public BaseGPUDevice {
|
|||||||
if (attr.on_host()) {
|
if (attr.on_host()) {
|
||||||
if (attr.gpu_compatible() || force_gpu_compatible_) {
|
if (attr.gpu_compatible() || force_gpu_compatible_) {
|
||||||
GPUProcessState* ps = GPUProcessState::singleton();
|
GPUProcessState* ps = GPUProcessState::singleton();
|
||||||
return ps->GetCUDAHostAllocator(0);
|
return ps->GetGpuHostAllocator(0);
|
||||||
} else {
|
} else {
|
||||||
return cpu_allocator_;
|
return cpu_allocator_;
|
||||||
}
|
}
|
||||||
@ -94,7 +94,7 @@ class GPUCompatibleCPUDevice : public ThreadPoolDevice {
|
|||||||
Allocator* GetAllocator(AllocatorAttributes attr) override {
|
Allocator* GetAllocator(AllocatorAttributes attr) override {
|
||||||
GPUProcessState* ps = GPUProcessState::singleton();
|
GPUProcessState* ps = GPUProcessState::singleton();
|
||||||
if (attr.gpu_compatible() || force_gpu_compatible_) {
|
if (attr.gpu_compatible() || force_gpu_compatible_) {
|
||||||
return ps->GetCUDAHostAllocator(numa_node_);
|
return ps->GetGpuHostAllocator(numa_node_);
|
||||||
} else {
|
} else {
|
||||||
// Call the parent's implementation.
|
// Call the parent's implementation.
|
||||||
return ThreadPoolDevice::GetAllocator(attr);
|
return ThreadPoolDevice::GetAllocator(attr);
|
||||||
@ -136,4 +136,4 @@ REGISTER_LOCAL_DEVICE_FACTORY("CPU", GPUCompatibleCPUDeviceFactory, 70);
|
|||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
@ -13,20 +13,20 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_CUDA_HOST_ALLOCATOR_H_
|
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_HOST_ALLOCATOR_H_
|
||||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_CUDA_HOST_ALLOCATOR_H_
|
#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_HOST_ALLOCATOR_H_
|
||||||
|
|
||||||
#include "tensorflow/core/framework/allocator.h"
|
#include "tensorflow/core/framework/allocator.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
#include "tensorflow/core/platform/stream_executor.h"
|
#include "tensorflow/core/platform/stream_executor.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
// Allocator for pinned CPU RAM that is made known to CUDA for the
|
// Allocator for pinned CPU RAM that is made known to GPU for the
|
||||||
// purpose of efficient DMA with a GPU.
|
// purpose of efficient DMA with a GPU.
|
||||||
class CUDAHostAllocator : public SubAllocator {
|
class GpuHostAllocator : public SubAllocator {
|
||||||
public:
|
public:
|
||||||
// Note: stream_exec cannot be null.
|
// Note: stream_exec cannot be null.
|
||||||
explicit CUDAHostAllocator(se::StreamExecutor* stream_exec, int numa_node,
|
explicit GpuHostAllocator(se::StreamExecutor* stream_exec, int numa_node,
|
||||||
const std::vector<Visitor>& alloc_visitors,
|
const std::vector<Visitor>& alloc_visitors,
|
||||||
const std::vector<Visitor>& free_visitors)
|
const std::vector<Visitor>& free_visitors)
|
||||||
: SubAllocator(alloc_visitors, free_visitors),
|
: SubAllocator(alloc_visitors, free_visitors),
|
||||||
@ -34,7 +34,7 @@ class CUDAHostAllocator : public SubAllocator {
|
|||||||
numa_node_(numa_node) {
|
numa_node_(numa_node) {
|
||||||
CHECK(stream_exec_ != nullptr);
|
CHECK(stream_exec_ != nullptr);
|
||||||
}
|
}
|
||||||
~CUDAHostAllocator() override {}
|
~GpuHostAllocator() override {}
|
||||||
|
|
||||||
void* Alloc(size_t alignment, size_t num_bytes) override {
|
void* Alloc(size_t alignment, size_t num_bytes) override {
|
||||||
void* ptr = nullptr;
|
void* ptr = nullptr;
|
||||||
@ -61,8 +61,8 @@ class CUDAHostAllocator : public SubAllocator {
|
|||||||
se::StreamExecutor* stream_exec_; // not owned, non-null
|
se::StreamExecutor* stream_exec_; // not owned, non-null
|
||||||
const int numa_node_;
|
const int numa_node_;
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(CUDAHostAllocator);
|
TF_DISALLOW_COPY_AND_ASSIGN(GpuHostAllocator);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_CUDA_HOST_ALLOCATOR_H_
|
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_HOST_ALLOCATOR_H_
|
@ -29,17 +29,27 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
Status ValidateGPUMachineManager() {
|
Status ValidateGPUMachineManager() {
|
||||||
return se::MultiPlatformManager::PlatformWithName("CUDA").status();
|
return se::MultiPlatformManager::PlatformWithName(GpuPlatformName()).status();
|
||||||
}
|
}
|
||||||
|
|
||||||
se::Platform* GPUMachineManager() {
|
se::Platform* GPUMachineManager() {
|
||||||
auto result = se::MultiPlatformManager::PlatformWithName("CUDA");
|
auto result = se::MultiPlatformManager::PlatformWithName(GpuPlatformName());
|
||||||
if (!result.ok()) {
|
if (!result.ok()) {
|
||||||
LOG(FATAL) << "Could not find Platform with name CUDA";
|
LOG(FATAL) << "Could not find Platform with name " << GpuPlatformName();
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
return result.ValueOrDie();
|
return result.ValueOrDie();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
string GpuPlatformName() {
|
||||||
|
#if TENSORFLOW_USE_ROCM
|
||||||
|
return "ROCM";
|
||||||
|
#else
|
||||||
|
// This function will return "CUDA" even when building TF without GPU support
|
||||||
|
// This is done to preserve existing functionality
|
||||||
|
return "CUDA";
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_INIT_H_
|
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_INIT_H_
|
||||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_INIT_H_
|
#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_INIT_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
namespace stream_executor {
|
namespace stream_executor {
|
||||||
@ -24,7 +25,7 @@ class Platform;
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
// Initializes the CUDA platform and returns OK if the CUDA
|
// Initializes the GPU platform and returns OK if the GPU
|
||||||
// platform could be initialized.
|
// platform could be initialized.
|
||||||
Status ValidateGPUMachineManager();
|
Status ValidateGPUMachineManager();
|
||||||
|
|
||||||
@ -34,6 +35,11 @@ Status ValidateGPUMachineManager();
|
|||||||
// in the process (e.g., ValidateGPUMachineManager() returns OK).
|
// in the process (e.g., ValidateGPUMachineManager() returns OK).
|
||||||
stream_executor::Platform* GPUMachineManager();
|
stream_executor::Platform* GPUMachineManager();
|
||||||
|
|
||||||
|
// Returns the string describing the name of the GPU platform in use.
|
||||||
|
// This value is "CUDA" by default, and
|
||||||
|
// "ROCM" when TF is built with `--config==rocm`
|
||||||
|
string GpuPlatformName();
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_INIT_H_
|
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_INIT_H_
|
||||||
|
@ -18,10 +18,10 @@ limitations under the License.
|
|||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/gpu/cuda_host_allocator.h"
|
|
||||||
#include "tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h"
|
#include "tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h"
|
||||||
#include "tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h"
|
#include "tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h"
|
||||||
#include "tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h"
|
#include "tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h"
|
||||||
|
#include "tensorflow/core/common_runtime/gpu/gpu_host_allocator.h"
|
||||||
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
|
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
|
||||||
#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
|
#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
|
||||||
#include "tensorflow/core/common_runtime/gpu/gpu_id_utils.h"
|
#include "tensorflow/core/common_runtime/gpu/gpu_id_utils.h"
|
||||||
@ -81,7 +81,7 @@ Allocator* GPUProcessState::GetGPUAllocator(const GPUOptions& options,
|
|||||||
TfGpuId tf_gpu_id,
|
TfGpuId tf_gpu_id,
|
||||||
size_t total_bytes) {
|
size_t total_bytes) {
|
||||||
CHECK(process_state_);
|
CHECK(process_state_);
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
const string& allocator_type = options.allocator_type();
|
const string& allocator_type = options.allocator_type();
|
||||||
mutex_lock lock(mu_);
|
mutex_lock lock(mu_);
|
||||||
GpuIdUtil::CheckValidTfGpuId(tf_gpu_id);
|
GpuIdUtil::CheckValidTfGpuId(tf_gpu_id);
|
||||||
@ -155,14 +155,15 @@ Allocator* GPUProcessState::GetGPUAllocator(const GPUOptions& options,
|
|||||||
return allocator_parts.allocator.get();
|
return allocator_parts.allocator.get();
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
LOG(FATAL) << "GPUAllocator unavailable. Not compiled with --config=cuda.";
|
LOG(FATAL) << "GPUAllocator unavailable. Not compiled with --config=cuda or "
|
||||||
|
"--config=rocm.";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
}
|
}
|
||||||
|
|
||||||
SharedCounter* GPUProcessState::GPUAllocatorCounter(TfGpuId tf_gpu_id) {
|
SharedCounter* GPUProcessState::GPUAllocatorCounter(TfGpuId tf_gpu_id) {
|
||||||
DCHECK(process_state_);
|
DCHECK(process_state_);
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
GpuIdUtil::CheckValidTfGpuId(tf_gpu_id);
|
GpuIdUtil::CheckValidTfGpuId(tf_gpu_id);
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
if (tf_gpu_id.value() >= static_cast<int64>(gpu_allocators_.size())) {
|
if (tf_gpu_id.value() >= static_cast<int64>(gpu_allocators_.size())) {
|
||||||
@ -173,37 +174,37 @@ SharedCounter* GPUProcessState::GPUAllocatorCounter(TfGpuId tf_gpu_id) {
|
|||||||
return allocator_parts.counter.get();
|
return allocator_parts.counter.get();
|
||||||
#else
|
#else
|
||||||
return nullptr;
|
return nullptr;
|
||||||
#endif
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
}
|
}
|
||||||
|
|
||||||
Allocator* GPUProcessState::GetCUDAHostAllocator(int numa_node) {
|
Allocator* GPUProcessState::GetGpuHostAllocator(int numa_node) {
|
||||||
CHECK(process_state_);
|
CHECK(process_state_);
|
||||||
if (!HasGPUDevice() ||
|
if (!HasGPUDevice() ||
|
||||||
!process_state_->ProcessState::FLAGS_brain_mem_reg_cuda_dma) {
|
!process_state_->ProcessState::FLAGS_brain_mem_reg_gpu_dma) {
|
||||||
return process_state_->GetCPUAllocator(numa_node);
|
return process_state_->GetCPUAllocator(numa_node);
|
||||||
}
|
}
|
||||||
if (numa_node == port::kNUMANoAffinity) {
|
if (numa_node == port::kNUMANoAffinity) {
|
||||||
numa_node = 0;
|
numa_node = 0;
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
// Here we optimize the most common use case where cuda_host_allocators_
|
// Here we optimize the most common use case where gpu_host_allocators_
|
||||||
// and cuda_al_ have already been populated and since we're only reading
|
// have already been populated and since we're only reading
|
||||||
// these vectors, we can get by with a shared lock. In the slower case,
|
// these vectors, we can get by with a shared lock. In the slower case,
|
||||||
// we take a unique lock and populate these vectors.
|
// we take a unique lock and populate these vectors.
|
||||||
tf_shared_lock lock(mu_);
|
tf_shared_lock lock(mu_);
|
||||||
|
|
||||||
if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types &&
|
if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types &&
|
||||||
!cuda_host_allocators_.empty() &&
|
!gpu_host_allocators_.empty() &&
|
||||||
cuda_host_allocators_[0].recording_allocator != nullptr) {
|
gpu_host_allocators_[0].recording_allocator != nullptr) {
|
||||||
return cuda_host_allocators_[0].recording_allocator.get();
|
return gpu_host_allocators_[0].recording_allocator.get();
|
||||||
}
|
}
|
||||||
if (static_cast<int>(cuda_host_allocators_.size()) > numa_node) {
|
if (static_cast<int>(gpu_host_allocators_.size()) > numa_node) {
|
||||||
return cuda_host_allocators_[0].allocator.get();
|
return gpu_host_allocators_[0].allocator.get();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mutex_lock lock(mu_);
|
mutex_lock lock(mu_);
|
||||||
// Find the first valid StreamExecutor to request CUDA host memory
|
// Find the first valid StreamExecutor to request CUDA or ROCm host memory
|
||||||
// through, since any will work.
|
// through, since any will work.
|
||||||
//
|
//
|
||||||
// This search isn't super clean, and it would be nice to use a
|
// This search isn't super clean, and it would be nice to use a
|
||||||
@ -220,39 +221,39 @@ Allocator* GPUProcessState::GetCUDAHostAllocator(int numa_node) {
|
|||||||
|
|
||||||
CHECK_NE(nullptr, se);
|
CHECK_NE(nullptr, se);
|
||||||
|
|
||||||
while (static_cast<int>(cuda_host_allocators_.size()) <= numa_node) {
|
while (static_cast<int>(gpu_host_allocators_.size()) <= numa_node) {
|
||||||
while (cuda_host_alloc_visitors_.size() <= numa_node) {
|
while (gpu_host_alloc_visitors_.size() <= numa_node) {
|
||||||
cuda_host_alloc_visitors_.push_back({});
|
gpu_host_alloc_visitors_.push_back({});
|
||||||
}
|
}
|
||||||
while (cuda_host_free_visitors_.size() <= numa_node) {
|
while (gpu_host_free_visitors_.size() <= numa_node) {
|
||||||
cuda_host_free_visitors_.push_back({});
|
gpu_host_free_visitors_.push_back({});
|
||||||
}
|
}
|
||||||
SubAllocator* sub_allocator = new CUDAHostAllocator(
|
SubAllocator* sub_allocator =
|
||||||
se, numa_node, cuda_host_alloc_visitors_[numa_node],
|
new GpuHostAllocator(se, numa_node, gpu_host_alloc_visitors_[numa_node],
|
||||||
cuda_host_free_visitors_[numa_node]);
|
gpu_host_free_visitors_[numa_node]);
|
||||||
// TODO(zheng-xq): evaluate whether 64GB by default is the best choice.
|
// TODO(zheng-xq): evaluate whether 64GB by default is the best choice.
|
||||||
int64 cuda_host_mem_limit_in_mb = -1;
|
int64 gpu_host_mem_limit_in_mb = -1;
|
||||||
Status status = ReadInt64FromEnvVar("TF_CUDA_HOST_MEM_LIMIT_IN_MB",
|
Status status = ReadInt64FromEnvVar("TF_GPU_HOST_MEM_LIMIT_IN_MB",
|
||||||
1LL << 16 /*64GB max by default*/,
|
1LL << 16 /*64GB max by default*/,
|
||||||
&cuda_host_mem_limit_in_mb);
|
&gpu_host_mem_limit_in_mb);
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
LOG(ERROR) << "GetCUDAHostAllocator: " << status.error_message();
|
LOG(ERROR) << "GetGpuHostAllocator: " << status.error_message();
|
||||||
}
|
}
|
||||||
int64 cuda_host_mem_limit = cuda_host_mem_limit_in_mb * (1LL << 20);
|
int64 gpu_host_mem_limit = gpu_host_mem_limit_in_mb * (1LL << 20);
|
||||||
Allocator* allocator =
|
Allocator* allocator =
|
||||||
new BFCAllocator(sub_allocator, cuda_host_mem_limit,
|
new BFCAllocator(sub_allocator, gpu_host_mem_limit,
|
||||||
true /*allow_growth*/, "cuda_host_bfc" /*name*/);
|
true /*allow_growth*/, "gpu_host_bfc" /*name*/);
|
||||||
|
|
||||||
if (LogMemory::IsEnabled() && !allocator->TracksAllocationSizes()) {
|
if (LogMemory::IsEnabled() && !allocator->TracksAllocationSizes()) {
|
||||||
// Wrap the allocator to track allocation ids for better logging
|
// Wrap the allocator to track allocation ids for better logging
|
||||||
// at the cost of performance.
|
// at the cost of performance.
|
||||||
allocator = new TrackingAllocator(allocator, true);
|
allocator = new TrackingAllocator(allocator, true);
|
||||||
}
|
}
|
||||||
cuda_host_allocators_.push_back({std::unique_ptr<Allocator>(allocator),
|
gpu_host_allocators_.push_back({std::unique_ptr<Allocator>(allocator),
|
||||||
std::unique_ptr<SharedCounter>(nullptr),
|
std::unique_ptr<SharedCounter>(nullptr),
|
||||||
sub_allocator,
|
sub_allocator,
|
||||||
std::unique_ptr<Allocator>(nullptr)});
|
std::unique_ptr<Allocator>(nullptr)});
|
||||||
AllocatorParts& allocator_parts = cuda_host_allocators_.back();
|
AllocatorParts& allocator_parts = gpu_host_allocators_.back();
|
||||||
if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
|
if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
|
||||||
ProcessState::MemDesc md;
|
ProcessState::MemDesc md;
|
||||||
md.loc = ProcessState::MemDesc::CPU;
|
md.loc = ProcessState::MemDesc::CPU;
|
||||||
@ -266,15 +267,15 @@ Allocator* GPUProcessState::GetCUDAHostAllocator(int numa_node) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
|
if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
|
||||||
return cuda_host_allocators_[0].recording_allocator.get();
|
return gpu_host_allocators_[0].recording_allocator.get();
|
||||||
} else {
|
} else {
|
||||||
return cuda_host_allocators_[0].allocator.get();
|
return gpu_host_allocators_[0].allocator.get();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void GPUProcessState::AddGPUAllocVisitor(int bus_id,
|
void GPUProcessState::AddGPUAllocVisitor(int bus_id,
|
||||||
const SubAllocator::Visitor& visitor) {
|
const SubAllocator::Visitor& visitor) {
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
mutex_lock lock(mu_);
|
mutex_lock lock(mu_);
|
||||||
CHECK(gpu_allocators_.empty()) // Crash OK
|
CHECK(gpu_allocators_.empty()) // Crash OK
|
||||||
<< "AddGPUAllocVisitor must be called before "
|
<< "AddGPUAllocVisitor must be called before "
|
||||||
@ -284,35 +285,35 @@ void GPUProcessState::AddGPUAllocVisitor(int bus_id,
|
|||||||
gpu_visitors_.push_back(std::vector<SubAllocator::Visitor>());
|
gpu_visitors_.push_back(std::vector<SubAllocator::Visitor>());
|
||||||
}
|
}
|
||||||
gpu_visitors_[bus_id].push_back(visitor);
|
gpu_visitors_[bus_id].push_back(visitor);
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
}
|
}
|
||||||
|
|
||||||
void GPUProcessState::AddCUDAHostAllocVisitor(
|
void GPUProcessState::AddGpuHostAllocVisitor(
|
||||||
int numa_node, const SubAllocator::Visitor& visitor) {
|
int numa_node, const SubAllocator::Visitor& visitor) {
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
mutex_lock lock(mu_);
|
mutex_lock lock(mu_);
|
||||||
CHECK(cuda_host_allocators_.empty()) // Crash OK
|
CHECK(gpu_host_allocators_.empty()) // Crash OK
|
||||||
<< "AddCUDAHostAllocVisitor must be called before "
|
<< "AddGpuHostAllocVisitor must be called before "
|
||||||
"first call to GetCUDAHostAllocator.";
|
"first call to GetGpuHostAllocator.";
|
||||||
while (numa_node >= static_cast<int64>(cuda_host_alloc_visitors_.size())) {
|
while (numa_node >= static_cast<int64>(gpu_host_alloc_visitors_.size())) {
|
||||||
cuda_host_alloc_visitors_.push_back(std::vector<SubAllocator::Visitor>());
|
gpu_host_alloc_visitors_.push_back(std::vector<SubAllocator::Visitor>());
|
||||||
}
|
}
|
||||||
cuda_host_alloc_visitors_[numa_node].push_back(visitor);
|
gpu_host_alloc_visitors_[numa_node].push_back(visitor);
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
}
|
}
|
||||||
|
|
||||||
void GPUProcessState::AddCUDAHostFreeVisitor(
|
void GPUProcessState::AddGpuHostFreeVisitor(
|
||||||
int numa_node, const SubAllocator::Visitor& visitor) {
|
int numa_node, const SubAllocator::Visitor& visitor) {
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
mutex_lock lock(mu_);
|
mutex_lock lock(mu_);
|
||||||
CHECK(cuda_host_allocators_.empty()) // Crash OK
|
CHECK(gpu_host_allocators_.empty()) // Crash OK
|
||||||
<< "AddCUDAHostFreeVisitor must be called before "
|
<< "AddGpuHostFreeVisitor must be called before "
|
||||||
"first call to GetCUDAHostAllocator.";
|
"first call to GetGpuHostAllocator.";
|
||||||
while (numa_node >= static_cast<int64>(cuda_host_free_visitors_.size())) {
|
while (numa_node >= static_cast<int64>(gpu_host_free_visitors_.size())) {
|
||||||
cuda_host_free_visitors_.push_back(std::vector<SubAllocator::Visitor>());
|
gpu_host_free_visitors_.push_back(std::vector<SubAllocator::Visitor>());
|
||||||
}
|
}
|
||||||
cuda_host_free_visitors_[numa_node].push_back(visitor);
|
gpu_host_free_visitors_[numa_node].push_back(visitor);
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
}
|
}
|
||||||
|
|
||||||
void GPUProcessState::TestOnlyReset() {
|
void GPUProcessState::TestOnlyReset() {
|
||||||
@ -324,9 +325,9 @@ void GPUProcessState::TestOnlyReset() {
|
|||||||
gpu_device_enabled_ = false;
|
gpu_device_enabled_ = false;
|
||||||
gpu_allocators_.clear();
|
gpu_allocators_.clear();
|
||||||
gpu_visitors_.clear();
|
gpu_visitors_.clear();
|
||||||
cuda_host_allocators_.clear();
|
gpu_host_allocators_.clear();
|
||||||
cuda_host_alloc_visitors_.clear();
|
gpu_host_alloc_visitors_.clear();
|
||||||
cuda_host_free_visitors_.clear();
|
gpu_host_free_visitors_.clear();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -83,7 +83,7 @@ class GPUProcessState {
|
|||||||
virtual Allocator* GetGPUAllocator(const GPUOptions& options,
|
virtual Allocator* GetGPUAllocator(const GPUOptions& options,
|
||||||
TfGpuId tf_gpu_id, size_t total_bytes);
|
TfGpuId tf_gpu_id, size_t total_bytes);
|
||||||
|
|
||||||
virtual Allocator* GetCUDAHostAllocator(int numa_node);
|
virtual Allocator* GetGpuHostAllocator(int numa_node);
|
||||||
|
|
||||||
// Registers a Visitor to be invoked on new chunks of memory allocated by the
|
// Registers a Visitor to be invoked on new chunks of memory allocated by the
|
||||||
// SubAllocator of every GPU proximate to the specified bus. The AllocVisitor
|
// SubAllocator of every GPU proximate to the specified bus. The AllocVisitor
|
||||||
@ -98,13 +98,13 @@ class GPUProcessState {
|
|||||||
const SubAllocator::Visitor& visitor);
|
const SubAllocator::Visitor& visitor);
|
||||||
|
|
||||||
// Registers a Visitor to be invoked on new chunks of memory allocated by
|
// Registers a Visitor to be invoked on new chunks of memory allocated by
|
||||||
// the SubAllocator of the CUDAHostAllocator for the given numa_node.
|
// the SubAllocator of the GpuHostAllocator for the given numa_node.
|
||||||
virtual void AddCUDAHostAllocVisitor(int numa_node,
|
virtual void AddGpuHostAllocVisitor(int numa_node,
|
||||||
const SubAllocator::Visitor& visitor);
|
const SubAllocator::Visitor& visitor);
|
||||||
|
|
||||||
// Registers a Visitor to be invoked on each chunk handed back for freeing to
|
// Registers a Visitor to be invoked on each chunk handed back for freeing to
|
||||||
// the SubAllocator of the CUDAHostAllocator for the given numa_node.
|
// the SubAllocator of the GpuHostAllocator for the given numa_node.
|
||||||
virtual void AddCUDAHostFreeVisitor(int numa_node,
|
virtual void AddGpuHostFreeVisitor(int numa_node,
|
||||||
const SubAllocator::Visitor& visitor);
|
const SubAllocator::Visitor& visitor);
|
||||||
|
|
||||||
// Returns bus_id for the given GPU id.
|
// Returns bus_id for the given GPU id.
|
||||||
@ -143,10 +143,10 @@ class GPUProcessState {
|
|||||||
std::vector<AllocatorParts> gpu_allocators_ GUARDED_BY(mu_);
|
std::vector<AllocatorParts> gpu_allocators_ GUARDED_BY(mu_);
|
||||||
std::vector<std::vector<SubAllocator::Visitor>> gpu_visitors_ GUARDED_BY(mu_);
|
std::vector<std::vector<SubAllocator::Visitor>> gpu_visitors_ GUARDED_BY(mu_);
|
||||||
|
|
||||||
std::vector<AllocatorParts> cuda_host_allocators_ GUARDED_BY(mu_);
|
std::vector<AllocatorParts> gpu_host_allocators_ GUARDED_BY(mu_);
|
||||||
std::vector<std::vector<SubAllocator::Visitor>> cuda_host_alloc_visitors_
|
std::vector<std::vector<SubAllocator::Visitor>> gpu_host_alloc_visitors_
|
||||||
GUARDED_BY(mu_);
|
GUARDED_BY(mu_);
|
||||||
std::vector<std::vector<SubAllocator::Visitor>> cuda_host_free_visitors_
|
std::vector<std::vector<SubAllocator::Visitor>> gpu_host_free_visitors_
|
||||||
GUARDED_BY(mu_);
|
GUARDED_BY(mu_);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -150,7 +150,7 @@ void GPUUtil::SetProtoFromGPU(const Tensor& tensor, Device* dev,
|
|||||||
const int64 total_bytes = is_dead ? 0 : tensor.TotalBytes();
|
const int64 total_bytes = is_dead ? 0 : tensor.TotalBytes();
|
||||||
if (total_bytes > 0) {
|
if (total_bytes > 0) {
|
||||||
tracing::ScopedAnnotation annotation("SetProtoFromGPU");
|
tracing::ScopedAnnotation annotation("SetProtoFromGPU");
|
||||||
alloc = GPUProcessState::singleton()->GetCUDAHostAllocator(0);
|
alloc = GPUProcessState::singleton()->GetGpuHostAllocator(0);
|
||||||
buf = alloc->Allocate<char>(total_bytes);
|
buf = alloc->Allocate<char>(total_bytes);
|
||||||
if (LogMemory::IsEnabled()) {
|
if (LogMemory::IsEnabled()) {
|
||||||
LogMemory::RecordRawAllocation("SetProtoFromGPU",
|
LogMemory::RecordRawAllocation("SetProtoFromGPU",
|
||||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/common_runtime/pool_allocator.h"
|
#include "tensorflow/core/common_runtime/pool_allocator.h"
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/gpu/cuda_host_allocator.h"
|
#include "tensorflow/core/common_runtime/gpu/gpu_host_allocator.h"
|
||||||
#include "tensorflow/core/platform/stream_executor.h"
|
#include "tensorflow/core/platform/stream_executor.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
@ -29,7 +29,7 @@ TEST(PoolAllocatorTest, ZeroSizeBuffers) {
|
|||||||
se::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
|
se::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
|
||||||
PoolAllocator pool(
|
PoolAllocator pool(
|
||||||
2 /*pool_size_limit*/, false /*auto_resize*/,
|
2 /*pool_size_limit*/, false /*auto_resize*/,
|
||||||
new CUDAHostAllocator(
|
new GpuHostAllocator(
|
||||||
platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
|
platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
|
||||||
.ValueOrDie(),
|
.ValueOrDie(),
|
||||||
0 /*numa_node*/, {}, {}),
|
0 /*numa_node*/, {}, {}),
|
||||||
@ -48,7 +48,7 @@ TEST(PoolAllocatorTest, ZeroSizePool) {
|
|||||||
se::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
|
se::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
|
||||||
PoolAllocator pool(
|
PoolAllocator pool(
|
||||||
0 /*pool_size_limit*/, false /*auto_resize*/,
|
0 /*pool_size_limit*/, false /*auto_resize*/,
|
||||||
new CUDAHostAllocator(
|
new GpuHostAllocator(
|
||||||
platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
|
platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
|
||||||
.ValueOrDie(),
|
.ValueOrDie(),
|
||||||
0 /*numa_node*/, {}, {}),
|
0 /*numa_node*/, {}, {}),
|
||||||
@ -82,7 +82,7 @@ TEST(PoolAllocatorTest, Alignment) {
|
|||||||
se::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
|
se::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
|
||||||
PoolAllocator pool(
|
PoolAllocator pool(
|
||||||
0 /*pool_size_limit*/, false /*auto_resize*/,
|
0 /*pool_size_limit*/, false /*auto_resize*/,
|
||||||
new CUDAHostAllocator(
|
new GpuHostAllocator(
|
||||||
platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
|
platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
|
||||||
.ValueOrDie(),
|
.ValueOrDie(),
|
||||||
0 /*numa_node*/, {}, {}),
|
0 /*numa_node*/, {}, {}),
|
||||||
@ -142,7 +142,7 @@ TEST(PoolAllocatorTest, CudaHostAllocator) {
|
|||||||
};
|
};
|
||||||
se::Platform* platform =
|
se::Platform* platform =
|
||||||
se::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
|
se::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
|
||||||
CUDAHostAllocator* sub_allocator = new CUDAHostAllocator(
|
GpuHostAllocator* sub_allocator = new GpuHostAllocator(
|
||||||
platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
|
platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
|
||||||
.ValueOrDie(),
|
.ValueOrDie(),
|
||||||
0 /*numa_node*/, {alloc_visitor}, {free_visitor});
|
0 /*numa_node*/, {alloc_visitor}, {free_visitor});
|
||||||
@ -247,7 +247,7 @@ TEST(PoolAllocatorTest, Name) {
|
|||||||
se::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
|
se::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
|
||||||
PoolAllocator pool(
|
PoolAllocator pool(
|
||||||
2 /*pool_size_limit*/, false /*auto_resize*/,
|
2 /*pool_size_limit*/, false /*auto_resize*/,
|
||||||
new CUDAHostAllocator(
|
new GpuHostAllocator(
|
||||||
platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
|
platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
|
||||||
.ValueOrDie(),
|
.ValueOrDie(),
|
||||||
0 /*numa_node*/, {}, {}),
|
0 /*numa_node*/, {}, {}),
|
||||||
|
@ -83,7 +83,7 @@ class ProcessState : public ProcessStateInterface {
|
|||||||
|
|
||||||
// If these flags need to be runtime configurable consider adding
|
// If these flags need to be runtime configurable consider adding
|
||||||
// them to ConfigProto.
|
// them to ConfigProto.
|
||||||
static const bool FLAGS_brain_mem_reg_cuda_dma = true;
|
static const bool FLAGS_brain_mem_reg_gpu_dma = true;
|
||||||
static const bool FLAGS_brain_gpu_record_mem_types = false;
|
static const bool FLAGS_brain_gpu_record_mem_types = false;
|
||||||
|
|
||||||
// Helper method for unit tests to reset the ProcessState singleton by
|
// Helper method for unit tests to reset the ProcessState singleton by
|
||||||
|
@ -34,7 +34,10 @@ cc_library(
|
|||||||
|
|
||||||
tf_cuda_library(
|
tf_cuda_library(
|
||||||
name = "stream_executor",
|
name = "stream_executor",
|
||||||
cuda_deps = ["//tensorflow/stream_executor/cuda:cuda_activation"],
|
cuda_deps = [
|
||||||
|
"//tensorflow/stream_executor/cuda:cuda_activation",
|
||||||
|
"//tensorflow/stream_executor/rocm:rocm_activation",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/stream_executor",
|
"//tensorflow/stream_executor",
|
||||||
"//tensorflow/stream_executor:dnn",
|
"//tensorflow/stream_executor:dnn",
|
||||||
|
22
tensorflow/core/platform/rocm.h
Normal file
22
tensorflow/core/platform/rocm.h
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_PLATFORM_ROCM_H_
|
||||||
|
#define TENSORFLOW_CORE_PLATFORM_ROCM_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/platform/platform.h"
|
||||||
|
#include "tensorflow/stream_executor/rocm/rocm_activation.h"
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_PLATFORM_ROCM_H_
|
@ -62,6 +62,20 @@ cc_library(
|
|||||||
]),
|
]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "rocm_activation",
|
||||||
|
srcs = [],
|
||||||
|
hdrs = if_rocm_is_configured(["rocm_activation.h"]),
|
||||||
|
deps = if_rocm_is_configured([
|
||||||
|
":rocm_driver",
|
||||||
|
"@local_config_rocm//rocm:rocm_headers",
|
||||||
|
"//tensorflow/stream_executor",
|
||||||
|
"//tensorflow/stream_executor:stream_executor_internal",
|
||||||
|
"//tensorflow/stream_executor/gpu:gpu_activation",
|
||||||
|
"//tensorflow/stream_executor/platform",
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "rocm_event",
|
name = "rocm_event",
|
||||||
srcs = if_rocm_is_configured(["rocm_event.cc"]),
|
srcs = if_rocm_is_configured(["rocm_event.cc"]),
|
||||||
|
39
tensorflow/stream_executor/rocm/rocm_activation.h
Normal file
39
tensorflow/stream_executor/rocm/rocm_activation.h
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// This file contains APIs that assume a StreamExecutor is backed by ROCM.
|
||||||
|
// It reaches into the ROCM implementation to activate an underlying ROCM
|
||||||
|
// context.
|
||||||
|
//
|
||||||
|
// Having this file separate from rocm/rocm_gpu_executor.h means that dependent
|
||||||
|
// code does not also have to depend on rocm.h.
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_ACTIVATION_H_
|
||||||
|
#define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_ACTIVATION_H_
|
||||||
|
|
||||||
|
#include "tensorflow/stream_executor/gpu/gpu_activation.h"
|
||||||
|
|
||||||
|
namespace stream_executor {
|
||||||
|
|
||||||
|
class StreamExecutor;
|
||||||
|
|
||||||
|
namespace rocm {
|
||||||
|
|
||||||
|
using ScopedActivateExecutorContext = gpu::ScopedActivateExecutorContext;
|
||||||
|
|
||||||
|
} // namespace rocm
|
||||||
|
} // namespace stream_executor
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_ACTIVATION_H_
|
Loading…
Reference in New Issue
Block a user