add ROCm support for GPU types
This commit is contained in:
parent
143cad4d52
commit
b0b0a5b1be
@ -40,9 +40,9 @@ const char* const DEVICE_GPU = "GPU";
|
||||
const char* const DEVICE_SYCL = "SYCL";
|
||||
|
||||
const std::string DeviceName<Eigen::ThreadPoolDevice>::value = DEVICE_CPU;
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
const std::string DeviceName<Eigen::GpuDevice>::value = DEVICE_GPU;
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
const std::string DeviceName<Eigen::SyclDevice>::value = DEVICE_SYCL;
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
@ -83,12 +83,12 @@ struct DeviceName<Eigen::ThreadPoolDevice> {
|
||||
static const std::string value;
|
||||
};
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
template <>
|
||||
struct DeviceName<Eigen::GpuDevice> {
|
||||
static const std::string value;
|
||||
};
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
template <>
|
||||
|
Loading…
Reference in New Issue
Block a user