Implementing GpuManagedAllocator for ROCm

Enabling several common runtime unit tests for ROCm
This commit is contained in:
Eugene Kuznetsov 2020-01-15 17:12:25 -08:00
parent db8a74a737
commit af54994072
10 changed files with 53 additions and 36 deletions

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/common_runtime/gpu/gpu_device.h"

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
@ -859,4 +859,4 @@ BENCHMARK(BM_chain_1M_100_true)->Arg(8);
} // namespace
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -13,30 +13,42 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifdef GOOGLE_CUDA
#if GOOGLE_CUDA
#include "third_party/gpus/cuda/include/cuda.h"
#define EIGEN_USE_GPU
#endif
#if TENSORFLOW_USE_ROCM
#include "rocm/include/hip/hip_runtime.h"
#define EIGEN_USE_GPU
#endif
#include "tensorflow/core/common_runtime/gpu/gpu_managed_allocator.h"
namespace tensorflow {
void* GpuManagedAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
void* ptr = nullptr;
#ifdef GOOGLE_CUDA
#if GOOGLE_CUDA
CUdeviceptr result = 0;
CHECK_EQ(cuMemAllocManaged(&result, num_bytes, CU_MEM_ATTACH_GLOBAL),
CUDA_SUCCESS);
ptr = reinterpret_cast<void*>(result);
#elif TENSORFLOW_USE_ROCM
void** result = 0;
CHECK_EQ(hipHostMalloc(&result, num_bytes, 0),
0);
ptr = reinterpret_cast<void*>(result);
#endif
CHECK(!(reinterpret_cast<uintptr_t>(ptr) & (alignment - 1)));
return ptr;
}
void GpuManagedAllocator::DeallocateRaw(void* ptr) {
#ifdef GOOGLE_CUDA
#if GOOGLE_CUDA
CHECK_EQ(cudaFree(ptr), cudaSuccess);
#elif TENSORFLOW_USE_ROCM
CHECK_EQ(hipFree(ptr), hipSuccess);
#endif
}

View File

@ -13,20 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/common_runtime/pool_allocator.h"
#include "tensorflow/core/common_runtime/gpu/gpu_host_allocator.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/test.h"
#include "gpu_init.h"
namespace tensorflow {
namespace {
TEST(PoolAllocatorTest, ZeroSizeBuffers) {
se::Platform* platform =
se::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
se::MultiPlatformManager::PlatformWithName(GpuPlatformName())
.ValueOrDie();
PoolAllocator pool(
2 /*pool_size_limit*/, false /*auto_resize*/,
new GpuHostAllocator(
@ -45,7 +46,8 @@ TEST(PoolAllocatorTest, ZeroSizeBuffers) {
TEST(PoolAllocatorTest, ZeroSizePool) {
se::Platform* platform =
se::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
se::MultiPlatformManager::PlatformWithName(GpuPlatformName())
.ValueOrDie();
PoolAllocator pool(
0 /*pool_size_limit*/, false /*auto_resize*/,
new GpuHostAllocator(
@ -79,7 +81,8 @@ TEST(PoolAllocatorTest, ZeroSizePool) {
TEST(PoolAllocatorTest, Alignment) {
se::Platform* platform =
se::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
se::MultiPlatformManager::PlatformWithName(GpuPlatformName())
.ValueOrDie();
PoolAllocator pool(
0 /*pool_size_limit*/, false /*auto_resize*/,
new GpuHostAllocator(
@ -141,7 +144,8 @@ TEST(PoolAllocatorTest, CudaHostAllocator) {
free_size += size;
};
se::Platform* platform =
se::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
se::MultiPlatformManager::PlatformWithName(GpuPlatformName())
.ValueOrDie();
GpuHostAllocator* sub_allocator = new GpuHostAllocator(
platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
.ValueOrDie(),
@ -244,7 +248,8 @@ TEST(PoolAllocatorTest, Pow2Rounder) {
TEST(PoolAllocatorTest, Name) {
se::Platform* platform =
se::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
se::MultiPlatformManager::PlatformWithName(GpuPlatformName())
.ValueOrDie();
PoolAllocator pool(
2 /*pool_size_limit*/, false /*auto_resize*/,
new GpuHostAllocator(
@ -258,4 +263,4 @@ TEST(PoolAllocatorTest, Name) {
} // namespace
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -201,7 +201,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
if (col_exec_) col_exec_->Unref();
}
#ifdef GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
void InitGPUDevices() {
auto device_factory = DeviceFactory::GetFactory("GPU");
CHECK(device_factory);
@ -214,7 +214,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
void Init(int num_workers, int num_devices_per_worker, DataType dtype,
const DeviceType& device_type, int fail_after) {
#ifdef GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
InitGPUDevices();
#endif
VLOG(2) << "num_workers=" << num_workers
@ -871,7 +871,7 @@ TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4TasksVariableGPU) {
} \
}
#ifndef GOOGLE_CUDA
#if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
// B T W D L A F
DEF_TEST(FLOAT, CPU, 1, 2, 1, 0, false)
DEF_TEST(FLOAT, CPU, 1, 2, 1001, 0, true)
@ -889,7 +889,7 @@ DEF_TEST(FLOAT, CPU, 2, 4, 128, 1, true)
DEF_TEST(FLOAT, CPU, 2, 4, 128, 5, false)
#endif
#ifdef GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// Can only set W=1 for GPU tests.
// B T W D L A F
DEF_TEST(FLOAT, GPU, 1, 2, 1, 0, true)

View File

@ -30,10 +30,10 @@ TEST(MemoryTypeChecker, Int32OK) {
auto in1 = test::graph::Constant(g, v);
test::graph::Add(g, in0, in1);
TF_EXPECT_OK(ValidateMemoryTypes(DEVICE_CPU, g));
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// There is a kernel for adding two int32s on host memory.
TF_EXPECT_OK(ValidateMemoryTypes(DEVICE_GPU, g));
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
TF_EXPECT_OK(ValidateMemoryTypes(DEVICE_SYCL, g));
#endif // TENSORFLOW_USE_SYCL
@ -47,7 +47,7 @@ TEST(MemoryTypeChecker, Int32NotOk) {
auto x = test::graph::Constant(g, v);
test::graph::Cast(g, x, DT_FLOAT);
TF_EXPECT_OK(ValidateMemoryTypes(DEVICE_CPU, g));
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// There is no kernel for casting int32/host memory to float/device
// memory.
EXPECT_TRUE(errors::IsInternal(ValidateMemoryTypes(DEVICE_GPU, g)));
@ -55,7 +55,7 @@ TEST(MemoryTypeChecker, Int32NotOk) {
// But we can insert _HostSend/_HostRecv to ensure the invariant.
TF_EXPECT_OK(EnsureMemoryTypes(DEVICE_GPU, "/device:GPU:0", g));
TF_EXPECT_OK(ValidateMemoryTypes(DEVICE_GPU, g));
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
// There is no kernel for casting int32/host memory to float/device
// memory.
@ -80,12 +80,12 @@ TEST(MemoryTypeChecker, MemoryTypeForOutput) {
TF_EXPECT_OK(MemoryTypeForOutput(DEVICE_CPU, g, sf, 0, &memory_type));
// float Switch's output on CPU doesn't have HOST_MEMORY constraint.
EXPECT_EQ(memory_type, DEVICE_MEMORY);
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
auto si = test::graph::Switch(g, test::graph::Constant(g, vi), pred);
TF_EXPECT_OK(MemoryTypeForOutput(DEVICE_GPU, g, si, 0, &memory_type));
// int Switch's output on GPU has HOST_MEMORY constraint.
EXPECT_EQ(memory_type, HOST_MEMORY);
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
auto si = test::graph::Switch(g, test::graph::Constant(g, vi), pred);
TF_EXPECT_OK(MemoryTypeForOutput(DEVICE_SYCL, g, si, 0, &memory_type));

View File

@ -1724,7 +1724,7 @@ TEST_F(PlacerTest, TestNonExistentDevice) {
EXPECT_TRUE(absl::StrContains(s.error_message(), "but available devices"));
}
#if !GOOGLE_CUDA
#if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
// Test that we inform the user if they appear to be explicitly placing nodes
// on a GPU when CUDA is not available
TEST_F(PlacerTest, TestUseGpuWithNoCuda) {

View File

@ -153,7 +153,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
return cpu_tensor;
#else
CHECK(false);
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
}
Tensor CPUToGPU(const Tensor& cpu_tensor) {
@ -178,7 +178,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
return device_tensor;
#else
CHECK(false);
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
}
Status RunWithRuntime(
@ -479,7 +479,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, ClusterFLRParallelTest) {
}
bool IsCUDATensor(const Tensor& t) {
#ifdef GOOGLE_CUDA
#if GOOGLE_CUDA
cudaPointerAttributes attributes;
cudaError_t err =
cudaPointerGetAttributes(&attributes, t.tensor_data().data());

View File

@ -116,7 +116,7 @@ class RingGathererTest : public ::testing::Test {
protected:
RingGathererTest() : device_type_(DEVICE_CPU) {}
#ifdef GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
void InitGPUDevices() {
auto device_factory = DeviceFactory::GetFactory("GPU");
CHECK(device_factory);
@ -135,7 +135,7 @@ class RingGathererTest : public ::testing::Test {
void Init(int num_workers, int num_devices, DataType dtype,
const DeviceType& device_type, int num_subdivs, int fail_after) {
#ifdef GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
InitGPUDevices();
#endif
device_type_ = device_type;
@ -603,7 +603,7 @@ TEST_F(RingGathererTest, InitializeParams) {
} \
}
#ifndef GOOGLE_CUDA
#if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
// Success tests
DEF_TEST(FLOAT, CPU, 1, 2, 1, 1, 0)
DEF_TEST(FLOAT, CPU, 1, 2, 1, 2, 0)
@ -628,7 +628,7 @@ DEF_TEST(FLOAT, CPU, 2, 8, 1, 9408, 7)
DEF_TEST(FLOAT, CPU, 2, 8, 1, 9408, 11)
#endif
#ifdef GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// GPU tests. So long as the device names are all in a single tasks we
// bypass inter-worker routing code and can fake multiple GPUs with a single
// GPU, from the perspective of the RingGatherer logic. So these tests

View File

@ -138,7 +138,7 @@ class RingReducerTest : public ::testing::Test {
protected:
RingReducerTest() : device_type_(DEVICE_CPU) {}
#ifdef GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
void InitGPUDevices() {
auto device_factory = DeviceFactory::GetFactory("GPU");
CHECK(device_factory);
@ -157,7 +157,7 @@ class RingReducerTest : public ::testing::Test {
void Init(int num_workers, int num_devices, DataType dtype,
const DeviceType& device_type, int num_subdivs, int fail_after) {
#ifdef GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
InitGPUDevices();
#endif
device_type_ = device_type;
@ -683,7 +683,7 @@ TEST_F(RingReducerTest, AutomaticSubdivUpperBound) {
} \
}
#ifndef GOOGLE_CUDA
#if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
// Success tests
DEF_TEST(FLOAT, CPU, 1, 2, 1, 1, 0)
DEF_TEST(FLOAT, CPU, 1, 2, 1, 2, 0)
@ -710,7 +710,7 @@ DEF_TEST(FLOAT, CPU, 2, 8, 1, 9408, 7)
DEF_TEST(FLOAT, CPU, 2, 8, 2, 9408, 11)
#endif
#ifdef GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// GPU tests. So long as the device names are all in a single tasks we
// bypass inter-worker routing code and can fake multiple GPUs with a single
// GPU, from the perspective of the RingReducer logic. So these tests