adding ROCm support in grappler code that does operator cost calculations for devices
This commit is contained in:
parent
1ef3d8eb92
commit
2f589bfa2a
@ -23,6 +23,10 @@ limitations under the License.
|
||||
#include "cuda/include/cudnn.h"
|
||||
#endif
|
||||
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
#include "rocm/include/hip/hip_runtime.h"
|
||||
#endif
|
||||
|
||||
#ifdef EIGEN_USE_LIBXSMM
|
||||
#include "include/libxsmm.h"
|
||||
#endif
|
||||
@ -109,6 +113,36 @@ DeviceProperties GetLocalGPUInfo(PlatformGpuId platform_gpu_id) {
|
||||
strings::StrCat(properties.major, ".", properties.minor);
|
||||
(*device.mutable_environment())["cuda"] = strings::StrCat(CUDA_VERSION);
|
||||
(*device.mutable_environment())["cudnn"] = strings::StrCat(CUDNN_VERSION);
|
||||
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
hipDeviceProp_t properties;
|
||||
hipError_t error =
|
||||
hipGetDeviceProperties(&properties, platform_gpu_id.value());
|
||||
if (error != hipSuccess) {
|
||||
device.set_type("UNKNOWN");
|
||||
LOG(ERROR) << "Failed to get device properties, error code: " << error;
|
||||
return device;
|
||||
}
|
||||
|
||||
// ROCM TODO review if numbers here are valid
|
||||
device.set_vendor("Advanced Micro Devices, Inc");
|
||||
device.set_model(properties.name);
|
||||
device.set_frequency(properties.clockRate * 1e-3);
|
||||
device.set_num_cores(properties.multiProcessorCount);
|
||||
device.set_num_registers(properties.regsPerBlock);
|
||||
device.set_l1_cache_size(16 * 1024);
|
||||
device.set_l2_cache_size(properties.l2CacheSize);
|
||||
device.set_l3_cache_size(0);
|
||||
device.set_shared_memory_size_per_multiprocessor(
|
||||
properties.maxSharedMemoryPerMultiProcessor);
|
||||
device.set_memory_size(properties.totalGlobalMem);
|
||||
// 8 is the number of bits per byte. 2 is accounted for
|
||||
// double data rate (DDR).
|
||||
device.set_bandwidth(properties.memoryBusWidth / 8 *
|
||||
properties.memoryClockRate * 2);
|
||||
|
||||
(*device.mutable_environment())["architecture"] =
|
||||
strings::StrCat("gfx", properties.gcnArch);
|
||||
#endif
|
||||
|
||||
return device;
|
||||
|
@ -1209,6 +1209,7 @@ def tf_cuda_library(deps = None, cuda_deps = None, copts = tf_copts(), **kwargs)
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
]) + if_rocm_is_configured(cuda_deps + [
|
||||
# rocm_header placeholder
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
]),
|
||||
copts = (copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_rocm(["-DTENSORFLOW_USE_ROCM=1"]) + if_mkl(["-DINTEL_MKL=1"]) + if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"]) + if_enable_mkl(["-DENABLE_MKL"]) + if_tensorrt(["-DGOOGLE_TENSORRT=1"])),
|
||||
**kwargs
|
||||
|
Loading…
Reference in New Issue
Block a user