Use float2/double2 to pass complex float/double values to cublas.
We can't just pass pointers to std::complex<> as they have different alignment. PiperOrigin-RevId: 306756926 Change-Id: I60aa0b79ab2aa33f50d015e6d8ab6908092c496d
This commit is contained in:
parent
6b2166de41
commit
1add21a1b4
|
@ -44,8 +44,6 @@ limitations under the License.
|
|||
#define EIGEN_HAS_CUDA_FP16
|
||||
#endif
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include <complex>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
|
@ -490,8 +488,9 @@ bool CUDABlas::DoBlasAxpy(Stream *stream, uint64 elem_count,
|
|||
std::complex<float> alpha,
|
||||
const DeviceMemory<std::complex<float>> &x, int incx,
|
||||
DeviceMemory<std::complex<float>> *y, int incy) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
return DoBlasInternal(cublasCaxpy, stream, true /* = pointer_mode_host */,
|
||||
elem_count, GpuComplex(&alpha),
|
||||
elem_count, GpuComplex(&cb_alpha),
|
||||
GpuComplex(GpuMemory(x)), incx,
|
||||
GpuComplex(GpuMemoryMutable(y)), incy);
|
||||
}
|
||||
|
@ -500,8 +499,9 @@ bool CUDABlas::DoBlasAxpy(Stream *stream, uint64 elem_count,
|
|||
std::complex<double> alpha,
|
||||
const DeviceMemory<std::complex<double>> &x, int incx,
|
||||
DeviceMemory<std::complex<double>> *y, int incy) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
return DoBlasInternal(cublasZaxpy, stream, true /* = pointer_mode_host */,
|
||||
elem_count, GpuComplex(&alpha),
|
||||
elem_count, GpuComplex(&cb_alpha),
|
||||
GpuComplex(GpuMemory(x)), incx,
|
||||
GpuComplex(GpuMemoryMutable(y)), incy);
|
||||
}
|
||||
|
@ -752,30 +752,32 @@ bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
|
|||
bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
|
||||
DeviceMemory<std::complex<float>> *x, int incx) {
|
||||
return DoBlasInternal(cublasCsscal, stream, true /* = pointer_mode_host */,
|
||||
elem_count, GpuComplex(&alpha),
|
||||
GpuComplex(GpuMemoryMutable(x)), incx);
|
||||
elem_count, &alpha, GpuComplex(GpuMemoryMutable(x)),
|
||||
incx);
|
||||
}
|
||||
|
||||
bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
|
||||
DeviceMemory<std::complex<double>> *x, int incx) {
|
||||
return DoBlasInternal(cublasZdscal, stream, true /* = pointer_mode_host */,
|
||||
elem_count, GpuComplex(&alpha),
|
||||
GpuComplex(GpuMemoryMutable(x)), incx);
|
||||
elem_count, &alpha, GpuComplex(GpuMemoryMutable(x)),
|
||||
incx);
|
||||
}
|
||||
|
||||
bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count,
|
||||
std::complex<float> alpha,
|
||||
DeviceMemory<std::complex<float>> *x, int incx) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
return DoBlasInternal(cublasCscal, stream, true /* = pointer_mode_host */,
|
||||
elem_count, GpuComplex(&alpha),
|
||||
elem_count, GpuComplex(&cb_alpha),
|
||||
GpuComplex(GpuMemoryMutable(x)), incx);
|
||||
}
|
||||
|
||||
bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count,
|
||||
std::complex<double> alpha,
|
||||
DeviceMemory<std::complex<double>> *x, int incx) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
return DoBlasInternal(cublasZscal, stream, true /* = pointer_mode_host */,
|
||||
elem_count, GpuComplex(&alpha),
|
||||
elem_count, GpuComplex(&cb_alpha),
|
||||
GpuComplex(GpuMemoryMutable(x)), incx);
|
||||
}
|
||||
|
||||
|
@ -904,10 +906,12 @@ bool CUDABlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
|
|||
const DeviceMemory<std::complex<float>> &x, int incx,
|
||||
std::complex<float> beta,
|
||||
DeviceMemory<std::complex<float>> *y, int incy) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
auto cb_beta = GpuComplexValue(beta);
|
||||
return DoBlasInternal(cublasCgbmv, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasTranspose(trans), m, n, kl, ku,
|
||||
GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
|
||||
GpuComplex(GpuMemory(x)), incx, GpuComplex(&beta),
|
||||
GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
|
||||
GpuComplex(GpuMemory(x)), incx, GpuComplex(&cb_beta),
|
||||
GpuComplex(GpuMemoryMutable(y)), incy);
|
||||
}
|
||||
|
||||
|
@ -918,10 +922,12 @@ bool CUDABlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
|
|||
const DeviceMemory<std::complex<double>> &x, int incx,
|
||||
std::complex<double> beta,
|
||||
DeviceMemory<std::complex<double>> *y, int incy) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
auto cb_beta = GpuComplexValue(beta);
|
||||
return DoBlasInternal(cublasZgbmv, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasTranspose(trans), m, n, kl, ku,
|
||||
GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
|
||||
GpuComplex(GpuMemory(x)), incx, GpuComplex(&beta),
|
||||
GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
|
||||
GpuComplex(GpuMemory(x)), incx, GpuComplex(&cb_beta),
|
||||
GpuComplex(GpuMemoryMutable(y)), incy);
|
||||
}
|
||||
|
||||
|
@ -951,10 +957,12 @@ bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
|
|||
const DeviceMemory<std::complex<float>> &x, int incx,
|
||||
std::complex<float> beta,
|
||||
DeviceMemory<std::complex<float>> *y, int incy) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
auto cb_beta = GpuComplexValue(beta);
|
||||
return DoBlasInternal(cublasCgemv, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasTranspose(trans), m, n, GpuComplex(&alpha),
|
||||
CUDABlasTranspose(trans), m, n, GpuComplex(&cb_alpha),
|
||||
GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
|
||||
incx, GpuComplex(&beta),
|
||||
incx, GpuComplex(&cb_beta),
|
||||
GpuComplex(GpuMemoryMutable(y)), incy);
|
||||
}
|
||||
|
||||
|
@ -964,10 +972,12 @@ bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
|
|||
const DeviceMemory<std::complex<double>> &x, int incx,
|
||||
std::complex<double> beta,
|
||||
DeviceMemory<std::complex<double>> *y, int incy) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
auto cb_beta = GpuComplexValue(beta);
|
||||
return DoBlasInternal(cublasZgemv, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasTranspose(trans), m, n, GpuComplex(&alpha),
|
||||
CUDABlasTranspose(trans), m, n, GpuComplex(&cb_alpha),
|
||||
GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
|
||||
incx, GpuComplex(&beta),
|
||||
incx, GpuComplex(&cb_beta),
|
||||
GpuComplex(GpuMemoryMutable(y)), incy);
|
||||
}
|
||||
|
||||
|
@ -994,9 +1004,10 @@ bool CUDABlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n,
|
|||
const DeviceMemory<std::complex<float>> &x, int incx,
|
||||
const DeviceMemory<std::complex<float>> &y, int incy,
|
||||
DeviceMemory<std::complex<float>> *a, int lda) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
return DoBlasInternal(cublasCgerc, stream, true /* = pointer_mode_host */, m,
|
||||
n, GpuComplex(&alpha), GpuComplex(GpuMemory(x)), incx,
|
||||
GpuComplex(GpuMemory(y)), incy,
|
||||
n, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(x)),
|
||||
incx, GpuComplex(GpuMemory(y)), incy,
|
||||
GpuComplex(GpuMemoryMutable(a)), lda);
|
||||
}
|
||||
|
||||
|
@ -1005,9 +1016,10 @@ bool CUDABlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n,
|
|||
const DeviceMemory<std::complex<double>> &x, int incx,
|
||||
const DeviceMemory<std::complex<double>> &y, int incy,
|
||||
DeviceMemory<std::complex<double>> *a, int lda) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
return DoBlasInternal(cublasZgerc, stream, true /* = pointer_mode_host */, m,
|
||||
n, GpuComplex(&alpha), GpuComplex(GpuMemory(x)), incx,
|
||||
GpuComplex(GpuMemory(y)), incy,
|
||||
n, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(x)),
|
||||
incx, GpuComplex(GpuMemory(y)), incy,
|
||||
GpuComplex(GpuMemoryMutable(a)), lda);
|
||||
}
|
||||
|
||||
|
@ -1016,9 +1028,10 @@ bool CUDABlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n,
|
|||
const DeviceMemory<std::complex<float>> &x, int incx,
|
||||
const DeviceMemory<std::complex<float>> &y, int incy,
|
||||
DeviceMemory<std::complex<float>> *a, int lda) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
return DoBlasInternal(cublasCgeru, stream, true /* = pointer_mode_host */, m,
|
||||
n, GpuComplex(&alpha), GpuComplex(GpuMemory(x)), incx,
|
||||
GpuComplex(GpuMemory(y)), incy,
|
||||
n, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(x)),
|
||||
incx, GpuComplex(GpuMemory(y)), incy,
|
||||
GpuComplex(GpuMemoryMutable(a)), lda);
|
||||
}
|
||||
|
||||
|
@ -1027,9 +1040,10 @@ bool CUDABlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n,
|
|||
const DeviceMemory<std::complex<double>> &x, int incx,
|
||||
const DeviceMemory<std::complex<double>> &y, int incy,
|
||||
DeviceMemory<std::complex<double>> *a, int lda) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
return DoBlasInternal(cublasZgeru, stream, true /* = pointer_mode_host */, m,
|
||||
n, GpuComplex(&alpha), GpuComplex(GpuMemory(x)), incx,
|
||||
GpuComplex(GpuMemory(y)), incy,
|
||||
n, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(x)),
|
||||
incx, GpuComplex(GpuMemory(y)), incy,
|
||||
GpuComplex(GpuMemoryMutable(a)), lda);
|
||||
}
|
||||
|
||||
|
@ -1039,10 +1053,12 @@ bool CUDABlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
|
|||
const DeviceMemory<std::complex<float>> &x, int incx,
|
||||
std::complex<float> beta,
|
||||
DeviceMemory<std::complex<float>> *y, int incy) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
auto cb_beta = GpuComplexValue(beta);
|
||||
return DoBlasInternal(cublasChbmv, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasUpperLower(uplo), n, k, GpuComplex(&alpha),
|
||||
CUDABlasUpperLower(uplo), n, k, GpuComplex(&cb_alpha),
|
||||
GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
|
||||
incx, GpuComplex(&beta),
|
||||
incx, GpuComplex(&cb_beta),
|
||||
GpuComplex(GpuMemoryMutable(y)), incy);
|
||||
}
|
||||
|
||||
|
@ -1052,10 +1068,12 @@ bool CUDABlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
|
|||
const DeviceMemory<std::complex<double>> &x, int incx,
|
||||
std::complex<double> beta,
|
||||
DeviceMemory<std::complex<double>> *y, int incy) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
auto cb_beta = GpuComplexValue(beta);
|
||||
return DoBlasInternal(cublasZhbmv, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasUpperLower(uplo), n, k, GpuComplex(&alpha),
|
||||
CUDABlasUpperLower(uplo), n, k, GpuComplex(&cb_alpha),
|
||||
GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
|
||||
incx, GpuComplex(&beta),
|
||||
incx, GpuComplex(&cb_beta),
|
||||
GpuComplex(GpuMemoryMutable(y)), incy);
|
||||
}
|
||||
|
||||
|
@ -1065,10 +1083,12 @@ bool CUDABlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
|
|||
const DeviceMemory<std::complex<float>> &x, int incx,
|
||||
std::complex<float> beta,
|
||||
DeviceMemory<std::complex<float>> *y, int incy) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
auto cb_beta = GpuComplexValue(beta);
|
||||
return DoBlasInternal(cublasChemv, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasUpperLower(uplo), n, GpuComplex(&alpha),
|
||||
CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
|
||||
GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
|
||||
incx, GpuComplex(&beta),
|
||||
incx, GpuComplex(&cb_beta),
|
||||
GpuComplex(GpuMemoryMutable(y)), incy);
|
||||
}
|
||||
|
||||
|
@ -1078,10 +1098,12 @@ bool CUDABlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
|
|||
const DeviceMemory<std::complex<double>> &x, int incx,
|
||||
std::complex<double> beta,
|
||||
DeviceMemory<std::complex<double>> *y, int incy) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
auto cb_beta = GpuComplexValue(beta);
|
||||
return DoBlasInternal(cublasZhemv, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasUpperLower(uplo), n, GpuComplex(&alpha),
|
||||
CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
|
||||
GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
|
||||
incx, GpuComplex(&beta),
|
||||
incx, GpuComplex(&cb_beta),
|
||||
GpuComplex(GpuMemoryMutable(y)), incy);
|
||||
}
|
||||
|
||||
|
@ -1110,8 +1132,9 @@ bool CUDABlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
|
|||
const DeviceMemory<std::complex<float>> &x, int incx,
|
||||
const DeviceMemory<std::complex<float>> &y, int incy,
|
||||
DeviceMemory<std::complex<float>> *a, int lda) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
return DoBlasInternal(cublasCher2, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasUpperLower(uplo), n, GpuComplex(&alpha),
|
||||
CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
|
||||
GpuComplex(GpuMemory(x)), incx,
|
||||
GpuComplex(GpuMemory(y)), incy,
|
||||
GpuComplex(GpuMemoryMutable(a)), lda);
|
||||
|
@ -1122,8 +1145,9 @@ bool CUDABlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
|
|||
const DeviceMemory<std::complex<double>> &x, int incx,
|
||||
const DeviceMemory<std::complex<double>> &y, int incy,
|
||||
DeviceMemory<std::complex<double>> *a, int lda) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
return DoBlasInternal(cublasZher2, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasUpperLower(uplo), n, GpuComplex(&alpha),
|
||||
CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
|
||||
GpuComplex(GpuMemory(x)), incx,
|
||||
GpuComplex(GpuMemory(y)), incy,
|
||||
GpuComplex(GpuMemoryMutable(a)), lda);
|
||||
|
@ -1135,10 +1159,12 @@ bool CUDABlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
|
|||
const DeviceMemory<std::complex<float>> &x, int incx,
|
||||
std::complex<float> beta,
|
||||
DeviceMemory<std::complex<float>> *y, int incy) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
auto cb_beta = GpuComplexValue(beta);
|
||||
return DoBlasInternal(cublasChpmv, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasUpperLower(uplo), n, GpuComplex(&alpha),
|
||||
CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
|
||||
GpuComplex(GpuMemory(ap)), GpuComplex(GpuMemory(x)),
|
||||
incx, GpuComplex(&beta),
|
||||
incx, GpuComplex(&cb_beta),
|
||||
GpuComplex(GpuMemoryMutable(y)), incy);
|
||||
}
|
||||
|
||||
|
@ -1148,10 +1174,12 @@ bool CUDABlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
|
|||
const DeviceMemory<std::complex<double>> &x, int incx,
|
||||
std::complex<double> beta,
|
||||
DeviceMemory<std::complex<double>> *y, int incy) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
auto cb_beta = GpuComplexValue(beta);
|
||||
return DoBlasInternal(cublasZhpmv, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasUpperLower(uplo), n, GpuComplex(&alpha),
|
||||
CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
|
||||
GpuComplex(GpuMemory(ap)), GpuComplex(GpuMemory(x)),
|
||||
incx, GpuComplex(&beta),
|
||||
incx, GpuComplex(&cb_beta),
|
||||
GpuComplex(GpuMemoryMutable(y)), incy);
|
||||
}
|
||||
|
||||
|
@ -1160,7 +1188,7 @@ bool CUDABlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
|
|||
const DeviceMemory<std::complex<float>> &x, int incx,
|
||||
DeviceMemory<std::complex<float>> *ap) {
|
||||
return DoBlasInternal(cublasChpr, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasUpperLower(uplo), n, GpuComplex(&alpha),
|
||||
CUDABlasUpperLower(uplo), n, &alpha,
|
||||
GpuComplex(GpuMemory(x)), incx,
|
||||
GpuComplex(GpuMemoryMutable(ap)));
|
||||
}
|
||||
|
@ -1170,7 +1198,7 @@ bool CUDABlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
|
|||
const DeviceMemory<std::complex<double>> &x, int incx,
|
||||
DeviceMemory<std::complex<double>> *ap) {
|
||||
return DoBlasInternal(cublasZhpr, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasUpperLower(uplo), n, GpuComplex(&alpha),
|
||||
CUDABlasUpperLower(uplo), n, &alpha,
|
||||
GpuComplex(GpuMemory(x)), incx,
|
||||
GpuComplex(GpuMemoryMutable(ap)));
|
||||
}
|
||||
|
@ -1180,10 +1208,12 @@ bool CUDABlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
|
|||
const DeviceMemory<std::complex<float>> &x, int incx,
|
||||
const DeviceMemory<std::complex<float>> &y, int incy,
|
||||
DeviceMemory<std::complex<float>> *ap) {
|
||||
return DoBlasInternal(
|
||||
cublasChpr2, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasUpperLower(uplo), n, GpuComplex(&alpha), GpuComplex(GpuMemory(x)),
|
||||
incx, GpuComplex(GpuMemory(y)), incy, GpuComplex(GpuMemoryMutable(ap)));
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
return DoBlasInternal(cublasChpr2, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
|
||||
GpuComplex(GpuMemory(x)), incx,
|
||||
GpuComplex(GpuMemory(y)), incy,
|
||||
GpuComplex(GpuMemoryMutable(ap)));
|
||||
}
|
||||
|
||||
bool CUDABlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
|
||||
|
@ -1191,10 +1221,12 @@ bool CUDABlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
|
|||
const DeviceMemory<std::complex<double>> &x, int incx,
|
||||
const DeviceMemory<std::complex<double>> &y, int incy,
|
||||
DeviceMemory<std::complex<double>> *ap) {
|
||||
return DoBlasInternal(
|
||||
cublasZhpr2, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasUpperLower(uplo), n, GpuComplex(&alpha), GpuComplex(GpuMemory(x)),
|
||||
incx, GpuComplex(GpuMemory(y)), incy, GpuComplex(GpuMemoryMutable(ap)));
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
return DoBlasInternal(cublasZhpr2, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
|
||||
GpuComplex(GpuMemory(x)), incx,
|
||||
GpuComplex(GpuMemory(y)), incy,
|
||||
GpuComplex(GpuMemoryMutable(ap)));
|
||||
}
|
||||
|
||||
bool CUDABlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
|
||||
|
@ -1684,11 +1716,14 @@ bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
|
|||
const DeviceMemory<std::complex<float>> &b, int ldb,
|
||||
std::complex<float> beta,
|
||||
DeviceMemory<std::complex<float>> *c, int ldc) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
auto cb_beta = GpuComplexValue(beta);
|
||||
return DoBlasInternal(cublasCgemm, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasTranspose(transa), CUDABlasTranspose(transb), m,
|
||||
n, k, GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
|
||||
GpuComplex(GpuMemory(b)), ldb, GpuComplex(&beta),
|
||||
GpuComplex(GpuMemoryMutable(c)), ldc);
|
||||
n, k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)),
|
||||
lda, GpuComplex(GpuMemory(b)), ldb,
|
||||
GpuComplex(&cb_beta), GpuComplex(GpuMemoryMutable(c)),
|
||||
ldc);
|
||||
}
|
||||
|
||||
bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
|
||||
|
@ -1698,11 +1733,14 @@ bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
|
|||
const DeviceMemory<std::complex<double>> &b, int ldb,
|
||||
std::complex<double> beta,
|
||||
DeviceMemory<std::complex<double>> *c, int ldc) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
auto cb_beta = GpuComplexValue(beta);
|
||||
return DoBlasInternal(cublasZgemm, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasTranspose(transa), CUDABlasTranspose(transb), m,
|
||||
n, k, GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
|
||||
GpuComplex(GpuMemory(b)), ldb, GpuComplex(&beta),
|
||||
GpuComplex(GpuMemoryMutable(c)), ldc);
|
||||
n, k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)),
|
||||
lda, GpuComplex(GpuMemory(b)), ldb,
|
||||
GpuComplex(&cb_beta), GpuComplex(GpuMemoryMutable(c)),
|
||||
ldc);
|
||||
}
|
||||
|
||||
bool CUDABlas::DoBlasGemvWithProfiling(
|
||||
|
@ -2149,6 +2187,15 @@ struct HalfAsFloat<Eigen::half> {
|
|||
typedef float type;
|
||||
};
|
||||
|
||||
namespace {
|
||||
// pass-through for non-complex types that don't need conversion to
|
||||
// cublas-specific type.
|
||||
template <typename T>
|
||||
T inline GpuComplexValue(T v) {
|
||||
return v;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <typename T, typename Scalar, typename FuncT>
|
||||
port::Status CUDABlas::DoBlasGemmBatchedInternal(
|
||||
FuncT cublas_func, Stream *stream, blas::Transpose transa,
|
||||
|
@ -2250,11 +2297,13 @@ port::Status CUDABlas::DoBlasGemmBatchedInternal(
|
|||
#endif
|
||||
// either CUDA_VERSION < 9.1 or SM < 5.0
|
||||
if (data_type != CUDA_R_16F) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
auto cb_beta = GpuComplexValue(beta);
|
||||
bool ok = DoBlasInternal(
|
||||
cublas_func, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
|
||||
GpuComplex(&alpha), const_cast<const CUDA_T **>(GpuMemory(a)), lda,
|
||||
const_cast<const CUDA_T **>(GpuMemory(b)), ldb, GpuComplex(&beta),
|
||||
GpuComplex(&cb_alpha), const_cast<const CUDA_T **>(GpuMemory(a)), lda,
|
||||
const_cast<const CUDA_T **>(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
|
||||
const_cast<CUDA_T **>(GpuMemory(c)), ldc, batch_count);
|
||||
if (ok) {
|
||||
return port::Status::OK();
|
||||
|
@ -2454,11 +2503,13 @@ bool CUDABlas::DoBlasGemmStridedBatched(
|
|||
const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
|
||||
std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
|
||||
int64 stride_c, int batch_count) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
auto cb_beta = GpuComplexValue(beta);
|
||||
return DoBlasInternal(
|
||||
cublasCgemmStridedBatched, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
|
||||
GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda, stride_a,
|
||||
GpuComplex(GpuMemory(b)), ldb, stride_b, GpuComplex(&beta),
|
||||
GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda, stride_a,
|
||||
GpuComplex(GpuMemory(b)), ldb, stride_b, GpuComplex(&cb_beta),
|
||||
GpuComplex(GpuMemoryMutable(c)), ldc, stride_c, batch_count);
|
||||
}
|
||||
|
||||
|
@ -2469,11 +2520,13 @@ bool CUDABlas::DoBlasGemmStridedBatched(
|
|||
const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
|
||||
std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
|
||||
int64 stride_c, int batch_count) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
auto cb_beta = GpuComplexValue(beta);
|
||||
return DoBlasInternal(
|
||||
cublasZgemmStridedBatched, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
|
||||
GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda, stride_a,
|
||||
GpuComplex(GpuMemory(b)), ldb, stride_b, GpuComplex(&beta),
|
||||
GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda, stride_a,
|
||||
GpuComplex(GpuMemory(b)), ldb, stride_b, GpuComplex(&cb_beta),
|
||||
GpuComplex(GpuMemoryMutable(c)), ldc, stride_c, batch_count);
|
||||
}
|
||||
|
||||
|
@ -2484,10 +2537,12 @@ bool CUDABlas::DoBlasHemm(Stream *stream, blas::Side side,
|
|||
const DeviceMemory<std::complex<float>> &b, int ldb,
|
||||
std::complex<float> beta,
|
||||
DeviceMemory<std::complex<float>> *c, int ldc) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
auto cb_beta = GpuComplexValue(beta);
|
||||
return DoBlasInternal(cublasChemm, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n,
|
||||
GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
|
||||
GpuComplex(GpuMemory(b)), ldb, GpuComplex(&beta),
|
||||
GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
|
||||
GpuComplex(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
|
||||
GpuComplex(GpuMemoryMutable(c)), ldc);
|
||||
}
|
||||
|
||||
|
@ -2498,10 +2553,12 @@ bool CUDABlas::DoBlasHemm(Stream *stream, blas::Side side,
|
|||
const DeviceMemory<std::complex<double>> &b, int ldb,
|
||||
std::complex<double> beta,
|
||||
DeviceMemory<std::complex<double>> *c, int ldc) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
auto cb_beta = GpuComplexValue(beta);
|
||||
return DoBlasInternal(cublasZhemm, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n,
|
||||
GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
|
||||
GpuComplex(GpuMemory(b)), ldb, GpuComplex(&beta),
|
||||
GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
|
||||
GpuComplex(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
|
||||
GpuComplex(GpuMemoryMutable(c)), ldc);
|
||||
}
|
||||
|
||||
|
@ -2513,8 +2570,8 @@ bool CUDABlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo,
|
|||
int ldc) {
|
||||
return DoBlasInternal(cublasCherk, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
|
||||
k, GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
|
||||
&beta, GpuComplex(GpuMemoryMutable(c)), ldc);
|
||||
k, &alpha, GpuComplex(GpuMemory(a)), lda, &beta,
|
||||
GpuComplex(GpuMemoryMutable(c)), ldc);
|
||||
}
|
||||
|
||||
bool CUDABlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo,
|
||||
|
@ -2525,8 +2582,8 @@ bool CUDABlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo,
|
|||
int ldc) {
|
||||
return DoBlasInternal(cublasZherk, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
|
||||
k, GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
|
||||
&beta, GpuComplex(GpuMemoryMutable(c)), ldc);
|
||||
k, &alpha, GpuComplex(GpuMemory(a)), lda, &beta,
|
||||
GpuComplex(GpuMemoryMutable(c)), ldc);
|
||||
}
|
||||
|
||||
bool CUDABlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
|
||||
|
@ -2536,9 +2593,10 @@ bool CUDABlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
|
|||
const DeviceMemory<std::complex<float>> &b, int ldb,
|
||||
float beta, DeviceMemory<std::complex<float>> *c,
|
||||
int ldc) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
return DoBlasInternal(cublasCher2k, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
|
||||
k, GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
|
||||
k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
|
||||
GpuComplex(GpuMemory(b)), ldb, &beta,
|
||||
GpuComplex(GpuMemoryMutable(c)), ldc);
|
||||
}
|
||||
|
@ -2550,9 +2608,10 @@ bool CUDABlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
|
|||
const DeviceMemory<std::complex<double>> &b, int ldb,
|
||||
double beta, DeviceMemory<std::complex<double>> *c,
|
||||
int ldc) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
return DoBlasInternal(cublasZher2k, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
|
||||
k, GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
|
||||
k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
|
||||
GpuComplex(GpuMemory(b)), ldb, &beta,
|
||||
GpuComplex(GpuMemoryMutable(c)), ldc);
|
||||
}
|
||||
|
@ -2586,10 +2645,12 @@ bool CUDABlas::DoBlasSymm(Stream *stream, blas::Side side,
|
|||
const DeviceMemory<std::complex<float>> &b, int ldb,
|
||||
std::complex<float> beta,
|
||||
DeviceMemory<std::complex<float>> *c, int ldc) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
auto cb_beta = GpuComplexValue(beta);
|
||||
return DoBlasInternal(cublasCsymm, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n,
|
||||
GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
|
||||
GpuComplex(GpuMemory(b)), ldb, GpuComplex(&beta),
|
||||
GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
|
||||
GpuComplex(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
|
||||
GpuComplex(GpuMemoryMutable(c)), ldc);
|
||||
}
|
||||
|
||||
|
@ -2600,10 +2661,12 @@ bool CUDABlas::DoBlasSymm(Stream *stream, blas::Side side,
|
|||
const DeviceMemory<std::complex<double>> &b, int ldb,
|
||||
std::complex<double> beta,
|
||||
DeviceMemory<std::complex<double>> *c, int ldc) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
auto cb_beta = GpuComplexValue(beta);
|
||||
return DoBlasInternal(cublasZsymm, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n,
|
||||
GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
|
||||
GpuComplex(GpuMemory(b)), ldb, GpuComplex(&beta),
|
||||
GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
|
||||
GpuComplex(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
|
||||
GpuComplex(GpuMemoryMutable(c)), ldc);
|
||||
}
|
||||
|
||||
|
@ -2633,10 +2696,12 @@ bool CUDABlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
|
|||
const DeviceMemory<std::complex<float>> &a, int lda,
|
||||
std::complex<float> beta,
|
||||
DeviceMemory<std::complex<float>> *c, int ldc) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
auto cb_beta = GpuComplexValue(beta);
|
||||
return DoBlasInternal(cublasCsyrk, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
|
||||
k, GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
|
||||
GpuComplex(&beta), GpuComplex(GpuMemoryMutable(c)),
|
||||
k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
|
||||
GpuComplex(&cb_beta), GpuComplex(GpuMemoryMutable(c)),
|
||||
ldc);
|
||||
}
|
||||
|
||||
|
@ -2646,10 +2711,12 @@ bool CUDABlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
|
|||
const DeviceMemory<std::complex<double>> &a, int lda,
|
||||
std::complex<double> beta,
|
||||
DeviceMemory<std::complex<double>> *c, int ldc) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
auto cb_beta = GpuComplexValue(beta);
|
||||
return DoBlasInternal(cublasZsyrk, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
|
||||
k, GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
|
||||
GpuComplex(&beta), GpuComplex(GpuMemoryMutable(c)),
|
||||
k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
|
||||
GpuComplex(&cb_beta), GpuComplex(GpuMemoryMutable(c)),
|
||||
ldc);
|
||||
}
|
||||
|
||||
|
@ -2682,10 +2749,12 @@ bool CUDABlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
|
|||
const DeviceMemory<std::complex<float>> &b, int ldb,
|
||||
std::complex<float> beta,
|
||||
DeviceMemory<std::complex<float>> *c, int ldc) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
auto cb_beta = GpuComplexValue(beta);
|
||||
return DoBlasInternal(cublasCsyr2k, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
|
||||
k, GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
|
||||
GpuComplex(GpuMemory(b)), ldb, GpuComplex(&beta),
|
||||
k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
|
||||
GpuComplex(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
|
||||
GpuComplex(GpuMemoryMutable(c)), ldc);
|
||||
}
|
||||
|
||||
|
@ -2696,10 +2765,12 @@ bool CUDABlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
|
|||
const DeviceMemory<std::complex<double>> &b, int ldb,
|
||||
std::complex<double> beta,
|
||||
DeviceMemory<std::complex<double>> *c, int ldc) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
auto cb_beta = GpuComplexValue(beta);
|
||||
return DoBlasInternal(cublasZsyr2k, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
|
||||
k, GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
|
||||
GpuComplex(GpuMemory(b)), ldb, GpuComplex(&beta),
|
||||
k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
|
||||
GpuComplex(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
|
||||
GpuComplex(GpuMemoryMutable(c)), ldc);
|
||||
}
|
||||
|
||||
|
@ -2733,10 +2804,11 @@ bool CUDABlas::DoBlasTrmm(Stream *stream, blas::Side side,
|
|||
std::complex<float> alpha,
|
||||
const DeviceMemory<std::complex<float>> &a, int lda,
|
||||
DeviceMemory<std::complex<float>> *b, int ldb) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
return DoBlasInternal(cublasCtrmm, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasSide(side), CUDABlasUpperLower(uplo),
|
||||
CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
|
||||
GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
|
||||
GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
|
||||
GpuComplex(GpuMemoryMutable(b)), ldb,
|
||||
GpuComplex(GpuMemoryMutable(b)), ldb);
|
||||
}
|
||||
|
@ -2747,10 +2819,11 @@ bool CUDABlas::DoBlasTrmm(Stream *stream, blas::Side side,
|
|||
std::complex<double> alpha,
|
||||
const DeviceMemory<std::complex<double>> &a, int lda,
|
||||
DeviceMemory<std::complex<double>> *b, int ldb) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
return DoBlasInternal(cublasZtrmm, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasSide(side), CUDABlasUpperLower(uplo),
|
||||
CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
|
||||
GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
|
||||
GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
|
||||
GpuComplex(GpuMemoryMutable(b)), ldb,
|
||||
GpuComplex(GpuMemoryMutable(b)), ldb);
|
||||
}
|
||||
|
@ -2783,10 +2856,11 @@ bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
|
|||
std::complex<float> alpha,
|
||||
const DeviceMemory<std::complex<float>> &a, int lda,
|
||||
DeviceMemory<std::complex<float>> *b, int ldb) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
return DoBlasInternal(cublasCtrsm, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasSide(side), CUDABlasUpperLower(uplo),
|
||||
CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
|
||||
GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
|
||||
GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
|
||||
GpuComplex(GpuMemoryMutable(b)), ldb);
|
||||
}
|
||||
|
||||
|
@ -2796,10 +2870,11 @@ bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
|
|||
std::complex<double> alpha,
|
||||
const DeviceMemory<std::complex<double>> &a, int lda,
|
||||
DeviceMemory<std::complex<double>> *b, int ldb) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
return DoBlasInternal(cublasZtrsm, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasSide(side), CUDABlasUpperLower(uplo),
|
||||
CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
|
||||
GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
|
||||
GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
|
||||
GpuComplex(GpuMemoryMutable(b)), ldb);
|
||||
}
|
||||
|
||||
|
|
|
@ -112,7 +112,10 @@ cc_library(
|
|||
cc_library(
|
||||
name = "gpu_helpers_header",
|
||||
hdrs = if_gpu_is_configured(["gpu_helpers.h"]),
|
||||
deps = [":gpu_types_header"],
|
||||
deps = [
|
||||
":gpu_types_header",
|
||||
"//tensorflow/core/platform:logging",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
|
|
@ -22,8 +22,10 @@ limitations under the License.
|
|||
#define TENSORFLOW_STREAM_EXECUTOR_GPU_GPU_HELPERS_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <complex>
|
||||
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_types.h"
|
||||
|
||||
namespace stream_executor {
|
||||
|
@ -83,12 +85,18 @@ struct GpuComplexT<std::complex<double>> {
|
|||
|
||||
template <typename T>
|
||||
inline const typename GpuComplexT<T>::type* GpuComplex(const T* p) {
|
||||
return reinterpret_cast<const typename GpuComplexT<T>::type*>(p);
|
||||
auto* result = reinterpret_cast<const typename GpuComplexT<T>::type*>(p);
|
||||
CHECK_EQ(reinterpret_cast<uintptr_t>(p) % alignof(decltype(*result)), 0)
|
||||
<< "Source pointer is not aligned by " << alignof(decltype(*result));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline typename GpuComplexT<T>::type* GpuComplex(T* p) {
|
||||
return reinterpret_cast<typename GpuComplexT<T>::type*>(p);
|
||||
auto* result = reinterpret_cast<typename GpuComplexT<T>::type*>(p);
|
||||
CHECK_EQ(reinterpret_cast<uintptr_t>(p) % alignof(decltype(*result)), 0)
|
||||
<< "Source pointer is not aligned by " << alignof(decltype(*result));
|
||||
return result;
|
||||
}
|
||||
|
||||
// Converts values of std::complex<float/double> to values of
|
||||
|
|
Loading…
Reference in New Issue