From 2f589bfa2a4a917ab89175a070b46e993f19830e Mon Sep 17 00:00:00 2001 From: Deven Desai Date: Thu, 7 Mar 2019 18:58:28 +0000 Subject: [PATCH] adding ROCm support in grappler code that does operator cost calculations for devices --- tensorflow/core/grappler/clusters/utils.cc | 34 ++++++++++++++++++++++ tensorflow/tensorflow.bzl | 1 + 2 files changed, 35 insertions(+) diff --git a/tensorflow/core/grappler/clusters/utils.cc b/tensorflow/core/grappler/clusters/utils.cc index 567e7c075e0..f1d3a77e3f0 100644 --- a/tensorflow/core/grappler/clusters/utils.cc +++ b/tensorflow/core/grappler/clusters/utils.cc @@ -23,6 +23,10 @@ limitations under the License. #include "cuda/include/cudnn.h" #endif +#if TENSORFLOW_USE_ROCM +#include "rocm/include/hip/hip_runtime.h" +#endif + #ifdef EIGEN_USE_LIBXSMM #include "include/libxsmm.h" #endif @@ -109,6 +113,36 @@ DeviceProperties GetLocalGPUInfo(PlatformGpuId platform_gpu_id) { strings::StrCat(properties.major, ".", properties.minor); (*device.mutable_environment())["cuda"] = strings::StrCat(CUDA_VERSION); (*device.mutable_environment())["cudnn"] = strings::StrCat(CUDNN_VERSION); + +#elif TENSORFLOW_USE_ROCM + hipDeviceProp_t properties; + hipError_t error = + hipGetDeviceProperties(&properties, platform_gpu_id.value()); + if (error != hipSuccess) { + device.set_type("UNKNOWN"); + LOG(ERROR) << "Failed to get device properties, error code: " << error; + return device; + } + + // ROCM TODO review if numbers here are valid + device.set_vendor("Advanced Micro Devices, Inc"); + device.set_model(properties.name); + device.set_frequency(properties.clockRate * 1e-3); + device.set_num_cores(properties.multiProcessorCount); + device.set_num_registers(properties.regsPerBlock); + device.set_l1_cache_size(16 * 1024); + device.set_l2_cache_size(properties.l2CacheSize); + device.set_l3_cache_size(0); + device.set_shared_memory_size_per_multiprocessor( + properties.maxSharedMemoryPerMultiProcessor); + device.set_memory_size(properties.totalGlobalMem); + // 8 is the number of bits per byte. 2 is accounted for + // double data rate (DDR). + device.set_bandwidth(properties.memoryBusWidth / 8 * + properties.memoryClockRate * 2); + + (*device.mutable_environment())["architecture"] = + strings::StrCat("gfx", properties.gcnArch); #endif return device; diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index a7684a17ae7..f22bbc7bdd0 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -1209,6 +1209,7 @@ def tf_cuda_library(deps = None, cuda_deps = None, copts = tf_copts(), **kwargs) "@local_config_cuda//cuda:cuda_headers", ]) + if_rocm_is_configured(cuda_deps + [ # rocm_header placeholder + "@local_config_rocm//rocm:rocm_headers", ]), copts = (copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_rocm(["-DTENSORFLOW_USE_ROCM=1"]) + if_mkl(["-DINTEL_MKL=1"]) + if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"]) + if_enable_mkl(["-DENABLE_MKL"]) + if_tensorrt(["-DGOOGLE_TENSORRT=1"])), **kwargs