diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc index b073d1ae568..4743fb637e9 100644 --- a/tensorflow/core/common_runtime/direct_session_test.cc +++ b/tensorflow/core/common_runtime/direct_session_test.cc @@ -51,9 +51,11 @@ limitations under the License. #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/util/device_name_utils.h" -#ifdef GOOGLE_CUDA +#if GOOGLE_CUDA #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cuda_runtime_api.h" +#elif TENSORFLOW_USE_ROCM +#include "rocm/include/hip/hip_runtime.h" #endif // GOOGLE_CUDA namespace tensorflow { @@ -2089,6 +2091,12 @@ bool IsCUDATensor(const Tensor& t) { if (err == cudaErrorInvalidValue) return false; CHECK_EQ(cudaSuccess, err) << cudaGetErrorString(err); return (attributes.memoryType == cudaMemoryTypeDevice); +#elif TENSORFLOW_USE_ROCM + hipPointerAttribute_t attributes; + hipError_t err = hipPointerGetAttributes(&attributes, t.tensor_data().data()); + if (err == hipErrorInvalidValue) return false; + CHECK_EQ(hipSuccess, err) << hipGetErrorString(err); + return (attributes.memoryType == hipMemoryTypeDevice); #else return false; #endif diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc index f848bdf7471..623cd479364 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc @@ -33,9 +33,11 @@ limitations under the License. #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" -#ifdef GOOGLE_CUDA +#if GOOGLE_CUDA #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cuda_runtime_api.h" +#elif TENSORFLOW_USE_ROCM +#include "rocm/include/hip/hip_runtime.h" #endif // GOOGLE_CUDA namespace tensorflow { @@ -122,7 +124,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { } Tensor GPUToCPU(const Tensor& device_tensor) { -#ifdef GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM CHECK(gpu_device_); CHECK(gpu_device_->tensorflow_gpu_device_info() != nullptr); DeviceContext* device_context = @@ -146,7 +148,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { } Tensor CPUToGPU(const Tensor& cpu_tensor) { -#ifdef GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM CHECK(gpu_device_); CHECK(gpu_device_->tensorflow_gpu_device_info() != nullptr); DeviceContext* device_context = @@ -461,6 +463,12 @@ bool IsCUDATensor(const Tensor& t) { if (err == cudaErrorInvalidValue) return false; CHECK_EQ(cudaSuccess, err) << cudaGetErrorString(err); return (attributes.memoryType == cudaMemoryTypeDevice); +#elif TENSORFLOW_USE_ROCM + hipPointerAttribute_t attributes; + hipError_t err = hipPointerGetAttributes(&attributes, t.tensor_data().data()); + if (err == hipErrorInvalidValue) return false; + CHECK_EQ(hipSuccess, err) << hipGetErrorString(err); + return (attributes.memoryType == hipMemoryTypeDevice); #else CHECK(false) << "IsCUDATensor should not be called when CUDA is not available"; diff --git a/tensorflow/core/grappler/clusters/utils_test.cc b/tensorflow/core/grappler/clusters/utils_test.cc index 3cf72fd8170..6b7013d3038 100644 --- a/tensorflow/core/grappler/clusters/utils_test.cc +++ b/tensorflow/core/grappler/clusters/utils_test.cc @@ -40,6 +40,18 @@ TEST(UtilsTest, GetLocalGPUInfo) { properties = GetLocalGPUInfo(PlatformGpuId(0)); EXPECT_EQ("GPU", properties.type()); EXPECT_EQ("NVIDIA", properties.vendor()); +#elif TENSORFLOW_USE_ROCM + LOG(INFO) << "ROCm is enabled."; + DeviceProperties properties; + + // Invalid platform GPU ID. + properties = GetLocalGPUInfo(PlatformGpuId(100)); + EXPECT_EQ("UNKNOWN", properties.type()); + + // Succeed when a valid platform GPU id was inserted. + properties = GetLocalGPUInfo(PlatformGpuId(0)); + EXPECT_EQ("GPU", properties.type()); + EXPECT_EQ("Advanced Micro Devices, Inc", properties.vendor()); #else LOG(INFO) << "CUDA is not enabled."; DeviceProperties properties; @@ -73,6 +85,8 @@ TEST(UtilsTest, GetDeviceInfo) { EXPECT_EQ("GPU", properties.type()); #if GOOGLE_CUDA EXPECT_EQ("NVIDIA", properties.vendor()); +#elif TENSORFLOW_USE_ROCM + EXPECT_EQ("Advanced Micro Devices, Inc", properties.vendor()); #endif // TF to platform GPU id mapping entry doesn't exist. @@ -81,7 +95,7 @@ TEST(UtilsTest, GetDeviceInfo) { properties = GetDeviceInfo(device); EXPECT_EQ("UNKNOWN", properties.type()); -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Invalid platform GPU id. TF_ASSERT_OK( GpuIdManager::InsertTfPlatformGpuIdPair(TfGpuId(0), PlatformGpuId(100))); @@ -94,7 +108,11 @@ TEST(UtilsTest, GetDeviceInfo) { device.id = 1; properties = GetDeviceInfo(device); EXPECT_EQ("GPU", properties.type()); +#if GOOGLE_CUDA EXPECT_EQ("NVIDIA", properties.vendor()); +#elif TENSORFLOW_USE_ROCM + EXPECT_EQ("Advanced Micro Devices, Inc", properties.vendor()); +#endif #endif } diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc index 7a9110e72ab..a346856745d 100644 --- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc @@ -203,7 +203,7 @@ TEST_F(PinToHostOptimizerTest, Identity) { // If CUDA, then there is a GPU kernel registration that is pinned to Host // memory. Consequently, `b` will be mapped to Host correct if there is // a GPU kernel registered. -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM EXPECT_EQ(node.device(), "/device:CPU:0"); #else EXPECT_TRUE(node.device().empty());