Make cuda_solvers_gpu.cu.cc compile with nvcc8.

PiperOrigin-RevId: 167754383
This commit is contained in:
A. Unique TensorFlower 2017-09-06 12:12:56 -07:00 committed by TensorFlower Gardener
parent d937d8695f
commit 0f6a17c51e

View File

@ -51,55 +51,57 @@ namespace {
// Hacks around missing support for complex arithmetic in nvcc. // Hacks around missing support for complex arithmetic in nvcc.
template <typename Scalar> template <typename Scalar>
__host__ __device__ inline Scalar Multiply(Scalar x, Scalar y) { __device__ inline Scalar Multiply(Scalar x, Scalar y) {
return x * y; return x * y;
} }
template <> template <>
__host__ __device__ inline cuComplex Multiply(cuComplex x, cuComplex y) { __device__ inline cuComplex Multiply(cuComplex x, cuComplex y) {
return cuCmulf(x, y); return cuCmulf(x, y);
} }
template <> template <>
__host__ __device__ inline cuDoubleComplex Multiply(cuDoubleComplex x, __device__ inline cuDoubleComplex Multiply(cuDoubleComplex x,
cuDoubleComplex y) { cuDoubleComplex y) {
return cuCmul(x, y); return cuCmul(x, y);
} }
template <typename Scalar> template <typename Scalar>
__host__ __device__ inline Scalar Negate(Scalar x) { __device__ inline Scalar Negate(Scalar x) {
return -x; return -x;
} }
template <> template <>
__host__ __device__ inline cuComplex Negate(cuComplex x) { __device__ inline cuComplex Negate(cuComplex x) {
return make_cuComplex(-cuCrealf(x), -cuCimagf(x)); return make_cuComplex(-cuCrealf(x), -cuCimagf(x));
} }
template <> template <>
__host__ __device__ inline cuDoubleComplex Negate(cuDoubleComplex x) { __device__ inline cuDoubleComplex Negate(cuDoubleComplex x) {
return make_cuDoubleComplex(-cuCreal(x), -cuCimag(x)); return make_cuDoubleComplex(-cuCreal(x), -cuCimag(x));
} }
template <typename Scalar> template <typename Scalar>
__host__ __device__ inline bool IsFinite(Scalar x) { __device__ inline bool IsFinite(Scalar x) {
return isfinite(x); return Eigen::numext::isfinite(x);
} }
template <> template <>
__host__ __device__ inline bool IsFinite(cuComplex x) { __device__ inline bool IsFinite(cuComplex x) {
return isfinite(cuCrealf(x)) && isfinite(cuCimagf(x)); return Eigen::numext::isfinite(cuCrealf(x)) &&
Eigen::numext::isfinite(cuCimagf(x));
} }
template <> template <>
__host__ __device__ inline bool IsFinite(cuDoubleComplex x) { __device__ inline bool IsFinite(cuDoubleComplex x) {
return isfinite(cuCreal(x)) && isfinite(cuCimag(x)); return Eigen::numext::isfinite(cuCreal(x)) &&
Eigen::numext::isfinite(cuCimag(x));
} }
template <typename Scalar> template <typename Scalar>
struct Const { struct Const {
template <typename RealScalar> template <typename RealScalar>
__host__ __device__ static inline Scalar make_const(const RealScalar x) { __device__ static inline Scalar make_const(const RealScalar x) {
return Scalar(x); return Scalar(x);
} }
}; };
@ -107,7 +109,7 @@ struct Const {
template <> template <>
struct Const<cuComplex> { struct Const<cuComplex> {
template <typename RealScalar> template <typename RealScalar>
__host__ __device__ static inline cuComplex make_const(const RealScalar x) { __device__ static inline cuComplex make_const(const RealScalar x) {
return make_cuComplex(x, 0.0f); return make_cuComplex(x, 0.0f);
} }
}; };
@ -115,8 +117,7 @@ struct Const<cuComplex> {
template <> template <>
struct Const<cuDoubleComplex> { struct Const<cuDoubleComplex> {
template <typename RealScalar> template <typename RealScalar>
__host__ __device__ static inline cuDoubleComplex make_const( __device__ static inline cuDoubleComplex make_const(const RealScalar x) {
const RealScalar x) {
return make_cuDoubleComplex(x, 0.0f); return make_cuDoubleComplex(x, 0.0f);
} }
}; };