add ROCm support for GPU types
This commit is contained in:
parent
143cad4d52
commit
b0b0a5b1be
@ -40,9 +40,9 @@ const char* const DEVICE_GPU = "GPU";
|
|||||||
const char* const DEVICE_SYCL = "SYCL";
|
const char* const DEVICE_SYCL = "SYCL";
|
||||||
|
|
||||||
const std::string DeviceName<Eigen::ThreadPoolDevice>::value = DEVICE_CPU;
|
const std::string DeviceName<Eigen::ThreadPoolDevice>::value = DEVICE_CPU;
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
const std::string DeviceName<Eigen::GpuDevice>::value = DEVICE_GPU;
|
const std::string DeviceName<Eigen::GpuDevice>::value = DEVICE_GPU;
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#ifdef TENSORFLOW_USE_SYCL
|
#ifdef TENSORFLOW_USE_SYCL
|
||||||
const std::string DeviceName<Eigen::SyclDevice>::value = DEVICE_SYCL;
|
const std::string DeviceName<Eigen::SyclDevice>::value = DEVICE_SYCL;
|
||||||
#endif // TENSORFLOW_USE_SYCL
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
|
@ -83,12 +83,12 @@ struct DeviceName<Eigen::ThreadPoolDevice> {
|
|||||||
static const std::string value;
|
static const std::string value;
|
||||||
};
|
};
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
template <>
|
template <>
|
||||||
struct DeviceName<Eigen::GpuDevice> {
|
struct DeviceName<Eigen::GpuDevice> {
|
||||||
static const std::string value;
|
static const std::string value;
|
||||||
};
|
};
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#ifdef TENSORFLOW_USE_SYCL
|
#ifdef TENSORFLOW_USE_SYCL
|
||||||
template <>
|
template <>
|
||||||
|
Loading…
Reference in New Issue
Block a user