adding ROCm support in grappler code that does operator cost calculations for devices
This commit is contained in:
parent
1ef3d8eb92
commit
2f589bfa2a
@ -23,6 +23,10 @@ limitations under the License.
|
|||||||
#include "cuda/include/cudnn.h"
|
#include "cuda/include/cudnn.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if TENSORFLOW_USE_ROCM
|
||||||
|
#include "rocm/include/hip/hip_runtime.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef EIGEN_USE_LIBXSMM
|
#ifdef EIGEN_USE_LIBXSMM
|
||||||
#include "include/libxsmm.h"
|
#include "include/libxsmm.h"
|
||||||
#endif
|
#endif
|
||||||
@ -109,6 +113,36 @@ DeviceProperties GetLocalGPUInfo(PlatformGpuId platform_gpu_id) {
|
|||||||
strings::StrCat(properties.major, ".", properties.minor);
|
strings::StrCat(properties.major, ".", properties.minor);
|
||||||
(*device.mutable_environment())["cuda"] = strings::StrCat(CUDA_VERSION);
|
(*device.mutable_environment())["cuda"] = strings::StrCat(CUDA_VERSION);
|
||||||
(*device.mutable_environment())["cudnn"] = strings::StrCat(CUDNN_VERSION);
|
(*device.mutable_environment())["cudnn"] = strings::StrCat(CUDNN_VERSION);
|
||||||
|
|
||||||
|
#elif TENSORFLOW_USE_ROCM
|
||||||
|
hipDeviceProp_t properties;
|
||||||
|
hipError_t error =
|
||||||
|
hipGetDeviceProperties(&properties, platform_gpu_id.value());
|
||||||
|
if (error != hipSuccess) {
|
||||||
|
device.set_type("UNKNOWN");
|
||||||
|
LOG(ERROR) << "Failed to get device properties, error code: " << error;
|
||||||
|
return device;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ROCM TODO review if numbers here are valid
|
||||||
|
device.set_vendor("Advanced Micro Devices, Inc");
|
||||||
|
device.set_model(properties.name);
|
||||||
|
device.set_frequency(properties.clockRate * 1e-3);
|
||||||
|
device.set_num_cores(properties.multiProcessorCount);
|
||||||
|
device.set_num_registers(properties.regsPerBlock);
|
||||||
|
device.set_l1_cache_size(16 * 1024);
|
||||||
|
device.set_l2_cache_size(properties.l2CacheSize);
|
||||||
|
device.set_l3_cache_size(0);
|
||||||
|
device.set_shared_memory_size_per_multiprocessor(
|
||||||
|
properties.maxSharedMemoryPerMultiProcessor);
|
||||||
|
device.set_memory_size(properties.totalGlobalMem);
|
||||||
|
// 8 is the number of bits per byte. 2 is accounted for
|
||||||
|
// double data rate (DDR).
|
||||||
|
device.set_bandwidth(properties.memoryBusWidth / 8 *
|
||||||
|
properties.memoryClockRate * 2);
|
||||||
|
|
||||||
|
(*device.mutable_environment())["architecture"] =
|
||||||
|
strings::StrCat("gfx", properties.gcnArch);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
return device;
|
return device;
|
||||||
|
@ -1209,6 +1209,7 @@ def tf_cuda_library(deps = None, cuda_deps = None, copts = tf_copts(), **kwargs)
|
|||||||
"@local_config_cuda//cuda:cuda_headers",
|
"@local_config_cuda//cuda:cuda_headers",
|
||||||
]) + if_rocm_is_configured(cuda_deps + [
|
]) + if_rocm_is_configured(cuda_deps + [
|
||||||
# rocm_header placeholder
|
# rocm_header placeholder
|
||||||
|
"@local_config_rocm//rocm:rocm_headers",
|
||||||
]),
|
]),
|
||||||
copts = (copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_rocm(["-DTENSORFLOW_USE_ROCM=1"]) + if_mkl(["-DINTEL_MKL=1"]) + if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"]) + if_enable_mkl(["-DENABLE_MKL"]) + if_tensorrt(["-DGOOGLE_TENSORRT=1"])),
|
copts = (copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_rocm(["-DTENSORFLOW_USE_ROCM=1"]) + if_mkl(["-DINTEL_MKL=1"]) + if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"]) + if_enable_mkl(["-DENABLE_MKL"]) + if_tensorrt(["-DGOOGLE_TENSORRT=1"])),
|
||||||
**kwargs
|
**kwargs
|
||||||
|
Loading…
Reference in New Issue
Block a user