adding ROCm support in grappler code that does operator cost calculations for devices

This commit is contained in:
Deven Desai 2019-03-07 18:58:28 +00:00
parent 1ef3d8eb92
commit 2f589bfa2a
2 changed files with 35 additions and 0 deletions

View File

@ -23,6 +23,10 @@ limitations under the License.
#include "cuda/include/cudnn.h" #include "cuda/include/cudnn.h"
#endif #endif
#if TENSORFLOW_USE_ROCM
#include "rocm/include/hip/hip_runtime.h"
#endif
#ifdef EIGEN_USE_LIBXSMM #ifdef EIGEN_USE_LIBXSMM
#include "include/libxsmm.h" #include "include/libxsmm.h"
#endif #endif
@ -109,6 +113,36 @@ DeviceProperties GetLocalGPUInfo(PlatformGpuId platform_gpu_id) {
strings::StrCat(properties.major, ".", properties.minor); strings::StrCat(properties.major, ".", properties.minor);
(*device.mutable_environment())["cuda"] = strings::StrCat(CUDA_VERSION); (*device.mutable_environment())["cuda"] = strings::StrCat(CUDA_VERSION);
(*device.mutable_environment())["cudnn"] = strings::StrCat(CUDNN_VERSION); (*device.mutable_environment())["cudnn"] = strings::StrCat(CUDNN_VERSION);
#elif TENSORFLOW_USE_ROCM
hipDeviceProp_t properties;
hipError_t error =
hipGetDeviceProperties(&properties, platform_gpu_id.value());
if (error != hipSuccess) {
device.set_type("UNKNOWN");
LOG(ERROR) << "Failed to get device properties, error code: " << error;
return device;
}
// ROCM TODO review if numbers here are valid
device.set_vendor("Advanced Micro Devices, Inc");
device.set_model(properties.name);
device.set_frequency(properties.clockRate * 1e-3);
device.set_num_cores(properties.multiProcessorCount);
device.set_num_registers(properties.regsPerBlock);
device.set_l1_cache_size(16 * 1024);
device.set_l2_cache_size(properties.l2CacheSize);
device.set_l3_cache_size(0);
device.set_shared_memory_size_per_multiprocessor(
properties.maxSharedMemoryPerMultiProcessor);
device.set_memory_size(properties.totalGlobalMem);
// 8 is the number of bits per byte. 2 is accounted for
// double data rate (DDR).
device.set_bandwidth(properties.memoryBusWidth / 8 *
properties.memoryClockRate * 2);
(*device.mutable_environment())["architecture"] =
strings::StrCat("gfx", properties.gcnArch);
#endif #endif
return device; return device;

View File

@ -1209,6 +1209,7 @@ def tf_cuda_library(deps = None, cuda_deps = None, copts = tf_copts(), **kwargs)
"@local_config_cuda//cuda:cuda_headers", "@local_config_cuda//cuda:cuda_headers",
]) + if_rocm_is_configured(cuda_deps + [ ]) + if_rocm_is_configured(cuda_deps + [
# rocm_header placeholder # rocm_header placeholder
"@local_config_rocm//rocm:rocm_headers",
]), ]),
copts = (copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_rocm(["-DTENSORFLOW_USE_ROCM=1"]) + if_mkl(["-DINTEL_MKL=1"]) + if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"]) + if_enable_mkl(["-DENABLE_MKL"]) + if_tensorrt(["-DGOOGLE_TENSORRT=1"])), copts = (copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_rocm(["-DTENSORFLOW_USE_ROCM=1"]) + if_mkl(["-DINTEL_MKL=1"]) + if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"]) + if_enable_mkl(["-DENABLE_MKL"]) + if_tensorrt(["-DGOOGLE_TENSORRT=1"])),
**kwargs **kwargs