add ROCm support for GPU types

This commit is contained in:
Wen-Heng (Jack) Chung 2019-03-15 23:52:58 +00:00
parent 143cad4d52
commit b0b0a5b1be
2 changed files with 4 additions and 4 deletions

View File

@ -40,9 +40,9 @@ const char* const DEVICE_GPU = "GPU";
const char* const DEVICE_SYCL = "SYCL";
const std::string DeviceName<Eigen::ThreadPoolDevice>::value = DEVICE_CPU;
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
const std::string DeviceName<Eigen::GpuDevice>::value = DEVICE_GPU;
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
const std::string DeviceName<Eigen::SyclDevice>::value = DEVICE_SYCL;
#endif // TENSORFLOW_USE_SYCL

View File

@ -83,12 +83,12 @@ struct DeviceName<Eigen::ThreadPoolDevice> {
static const std::string value;
};
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
template <>
struct DeviceName<Eigen::GpuDevice> {
static const std::string value;
};
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
template <>