Use float2/double2 to pass complex float/double values to cublas.

We can't just pass pointers to std::complex<> as they have different alignment.

PiperOrigin-RevId: 306756926
Change-Id: I60aa0b79ab2aa33f50d015e6d8ab6908092c496d
This commit is contained in:
Artem Belevich 2020-04-15 17:56:50 -07:00 committed by TensorFlower Gardener
parent 6b2166de41
commit 1add21a1b4
3 changed files with 177 additions and 91 deletions

View File

@ -44,8 +44,6 @@ limitations under the License.
#define EIGEN_HAS_CUDA_FP16
#endif
#include <assert.h>
#include <complex>
#include "absl/strings/str_cat.h"
@ -490,8 +488,9 @@ bool CUDABlas::DoBlasAxpy(Stream *stream, uint64 elem_count,
std::complex<float> alpha,
const DeviceMemory<std::complex<float>> &x, int incx,
DeviceMemory<std::complex<float>> *y, int incy) {
auto cb_alpha = GpuComplexValue(alpha);
return DoBlasInternal(cublasCaxpy, stream, true /* = pointer_mode_host */,
elem_count, GpuComplex(&alpha),
elem_count, GpuComplex(&cb_alpha),
GpuComplex(GpuMemory(x)), incx,
GpuComplex(GpuMemoryMutable(y)), incy);
}
@ -500,8 +499,9 @@ bool CUDABlas::DoBlasAxpy(Stream *stream, uint64 elem_count,
std::complex<double> alpha,
const DeviceMemory<std::complex<double>> &x, int incx,
DeviceMemory<std::complex<double>> *y, int incy) {
auto cb_alpha = GpuComplexValue(alpha);
return DoBlasInternal(cublasZaxpy, stream, true /* = pointer_mode_host */,
elem_count, GpuComplex(&alpha),
elem_count, GpuComplex(&cb_alpha),
GpuComplex(GpuMemory(x)), incx,
GpuComplex(GpuMemoryMutable(y)), incy);
}
@ -752,30 +752,32 @@ bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
DeviceMemory<std::complex<float>> *x, int incx) {
return DoBlasInternal(cublasCsscal, stream, true /* = pointer_mode_host */,
elem_count, GpuComplex(&alpha),
GpuComplex(GpuMemoryMutable(x)), incx);
elem_count, &alpha, GpuComplex(GpuMemoryMutable(x)),
incx);
}
bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
DeviceMemory<std::complex<double>> *x, int incx) {
return DoBlasInternal(cublasZdscal, stream, true /* = pointer_mode_host */,
elem_count, GpuComplex(&alpha),
GpuComplex(GpuMemoryMutable(x)), incx);
elem_count, &alpha, GpuComplex(GpuMemoryMutable(x)),
incx);
}
bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count,
std::complex<float> alpha,
DeviceMemory<std::complex<float>> *x, int incx) {
auto cb_alpha = GpuComplexValue(alpha);
return DoBlasInternal(cublasCscal, stream, true /* = pointer_mode_host */,
elem_count, GpuComplex(&alpha),
elem_count, GpuComplex(&cb_alpha),
GpuComplex(GpuMemoryMutable(x)), incx);
}
bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count,
std::complex<double> alpha,
DeviceMemory<std::complex<double>> *x, int incx) {
auto cb_alpha = GpuComplexValue(alpha);
return DoBlasInternal(cublasZscal, stream, true /* = pointer_mode_host */,
elem_count, GpuComplex(&alpha),
elem_count, GpuComplex(&cb_alpha),
GpuComplex(GpuMemoryMutable(x)), incx);
}
@ -904,10 +906,12 @@ bool CUDABlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
const DeviceMemory<std::complex<float>> &x, int incx,
std::complex<float> beta,
DeviceMemory<std::complex<float>> *y, int incy) {
auto cb_alpha = GpuComplexValue(alpha);
auto cb_beta = GpuComplexValue(beta);
return DoBlasInternal(cublasCgbmv, stream, true /* = pointer_mode_host */,
CUDABlasTranspose(trans), m, n, kl, ku,
GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
GpuComplex(GpuMemory(x)), incx, GpuComplex(&beta),
GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
GpuComplex(GpuMemory(x)), incx, GpuComplex(&cb_beta),
GpuComplex(GpuMemoryMutable(y)), incy);
}
@ -918,10 +922,12 @@ bool CUDABlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
const DeviceMemory<std::complex<double>> &x, int incx,
std::complex<double> beta,
DeviceMemory<std::complex<double>> *y, int incy) {
auto cb_alpha = GpuComplexValue(alpha);
auto cb_beta = GpuComplexValue(beta);
return DoBlasInternal(cublasZgbmv, stream, true /* = pointer_mode_host */,
CUDABlasTranspose(trans), m, n, kl, ku,
GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
GpuComplex(GpuMemory(x)), incx, GpuComplex(&beta),
GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
GpuComplex(GpuMemory(x)), incx, GpuComplex(&cb_beta),
GpuComplex(GpuMemoryMutable(y)), incy);
}
@ -951,10 +957,12 @@ bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
const DeviceMemory<std::complex<float>> &x, int incx,
std::complex<float> beta,
DeviceMemory<std::complex<float>> *y, int incy) {
auto cb_alpha = GpuComplexValue(alpha);
auto cb_beta = GpuComplexValue(beta);
return DoBlasInternal(cublasCgemv, stream, true /* = pointer_mode_host */,
CUDABlasTranspose(trans), m, n, GpuComplex(&alpha),
CUDABlasTranspose(trans), m, n, GpuComplex(&cb_alpha),
GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
incx, GpuComplex(&beta),
incx, GpuComplex(&cb_beta),
GpuComplex(GpuMemoryMutable(y)), incy);
}
@ -964,10 +972,12 @@ bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
const DeviceMemory<std::complex<double>> &x, int incx,
std::complex<double> beta,
DeviceMemory<std::complex<double>> *y, int incy) {
auto cb_alpha = GpuComplexValue(alpha);
auto cb_beta = GpuComplexValue(beta);
return DoBlasInternal(cublasZgemv, stream, true /* = pointer_mode_host */,
CUDABlasTranspose(trans), m, n, GpuComplex(&alpha),
CUDABlasTranspose(trans), m, n, GpuComplex(&cb_alpha),
GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
incx, GpuComplex(&beta),
incx, GpuComplex(&cb_beta),
GpuComplex(GpuMemoryMutable(y)), incy);
}
@ -994,9 +1004,10 @@ bool CUDABlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n,
const DeviceMemory<std::complex<float>> &x, int incx,
const DeviceMemory<std::complex<float>> &y, int incy,
DeviceMemory<std::complex<float>> *a, int lda) {
auto cb_alpha = GpuComplexValue(alpha);
return DoBlasInternal(cublasCgerc, stream, true /* = pointer_mode_host */, m,
n, GpuComplex(&alpha), GpuComplex(GpuMemory(x)), incx,
GpuComplex(GpuMemory(y)), incy,
n, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(x)),
incx, GpuComplex(GpuMemory(y)), incy,
GpuComplex(GpuMemoryMutable(a)), lda);
}
@ -1005,9 +1016,10 @@ bool CUDABlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n,
const DeviceMemory<std::complex<double>> &x, int incx,
const DeviceMemory<std::complex<double>> &y, int incy,
DeviceMemory<std::complex<double>> *a, int lda) {
auto cb_alpha = GpuComplexValue(alpha);
return DoBlasInternal(cublasZgerc, stream, true /* = pointer_mode_host */, m,
n, GpuComplex(&alpha), GpuComplex(GpuMemory(x)), incx,
GpuComplex(GpuMemory(y)), incy,
n, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(x)),
incx, GpuComplex(GpuMemory(y)), incy,
GpuComplex(GpuMemoryMutable(a)), lda);
}
@ -1016,9 +1028,10 @@ bool CUDABlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n,
const DeviceMemory<std::complex<float>> &x, int incx,
const DeviceMemory<std::complex<float>> &y, int incy,
DeviceMemory<std::complex<float>> *a, int lda) {
auto cb_alpha = GpuComplexValue(alpha);
return DoBlasInternal(cublasCgeru, stream, true /* = pointer_mode_host */, m,
n, GpuComplex(&alpha), GpuComplex(GpuMemory(x)), incx,
GpuComplex(GpuMemory(y)), incy,
n, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(x)),
incx, GpuComplex(GpuMemory(y)), incy,
GpuComplex(GpuMemoryMutable(a)), lda);
}
@ -1027,9 +1040,10 @@ bool CUDABlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n,
const DeviceMemory<std::complex<double>> &x, int incx,
const DeviceMemory<std::complex<double>> &y, int incy,
DeviceMemory<std::complex<double>> *a, int lda) {
auto cb_alpha = GpuComplexValue(alpha);
return DoBlasInternal(cublasZgeru, stream, true /* = pointer_mode_host */, m,
n, GpuComplex(&alpha), GpuComplex(GpuMemory(x)), incx,
GpuComplex(GpuMemory(y)), incy,
n, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(x)),
incx, GpuComplex(GpuMemory(y)), incy,
GpuComplex(GpuMemoryMutable(a)), lda);
}
@ -1039,10 +1053,12 @@ bool CUDABlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
const DeviceMemory<std::complex<float>> &x, int incx,
std::complex<float> beta,
DeviceMemory<std::complex<float>> *y, int incy) {
auto cb_alpha = GpuComplexValue(alpha);
auto cb_beta = GpuComplexValue(beta);
return DoBlasInternal(cublasChbmv, stream, true /* = pointer_mode_host */,
CUDABlasUpperLower(uplo), n, k, GpuComplex(&alpha),
CUDABlasUpperLower(uplo), n, k, GpuComplex(&cb_alpha),
GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
incx, GpuComplex(&beta),
incx, GpuComplex(&cb_beta),
GpuComplex(GpuMemoryMutable(y)), incy);
}
@ -1052,10 +1068,12 @@ bool CUDABlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
const DeviceMemory<std::complex<double>> &x, int incx,
std::complex<double> beta,
DeviceMemory<std::complex<double>> *y, int incy) {
auto cb_alpha = GpuComplexValue(alpha);
auto cb_beta = GpuComplexValue(beta);
return DoBlasInternal(cublasZhbmv, stream, true /* = pointer_mode_host */,
CUDABlasUpperLower(uplo), n, k, GpuComplex(&alpha),
CUDABlasUpperLower(uplo), n, k, GpuComplex(&cb_alpha),
GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
incx, GpuComplex(&beta),
incx, GpuComplex(&cb_beta),
GpuComplex(GpuMemoryMutable(y)), incy);
}
@ -1065,10 +1083,12 @@ bool CUDABlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
const DeviceMemory<std::complex<float>> &x, int incx,
std::complex<float> beta,
DeviceMemory<std::complex<float>> *y, int incy) {
auto cb_alpha = GpuComplexValue(alpha);
auto cb_beta = GpuComplexValue(beta);
return DoBlasInternal(cublasChemv, stream, true /* = pointer_mode_host */,
CUDABlasUpperLower(uplo), n, GpuComplex(&alpha),
CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
incx, GpuComplex(&beta),
incx, GpuComplex(&cb_beta),
GpuComplex(GpuMemoryMutable(y)), incy);
}
@ -1078,10 +1098,12 @@ bool CUDABlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
const DeviceMemory<std::complex<double>> &x, int incx,
std::complex<double> beta,
DeviceMemory<std::complex<double>> *y, int incy) {
auto cb_alpha = GpuComplexValue(alpha);
auto cb_beta = GpuComplexValue(beta);
return DoBlasInternal(cublasZhemv, stream, true /* = pointer_mode_host */,
CUDABlasUpperLower(uplo), n, GpuComplex(&alpha),
CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
incx, GpuComplex(&beta),
incx, GpuComplex(&cb_beta),
GpuComplex(GpuMemoryMutable(y)), incy);
}
@ -1110,8 +1132,9 @@ bool CUDABlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
const DeviceMemory<std::complex<float>> &x, int incx,
const DeviceMemory<std::complex<float>> &y, int incy,
DeviceMemory<std::complex<float>> *a, int lda) {
auto cb_alpha = GpuComplexValue(alpha);
return DoBlasInternal(cublasCher2, stream, true /* = pointer_mode_host */,
CUDABlasUpperLower(uplo), n, GpuComplex(&alpha),
CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
GpuComplex(GpuMemory(x)), incx,
GpuComplex(GpuMemory(y)), incy,
GpuComplex(GpuMemoryMutable(a)), lda);
@ -1122,8 +1145,9 @@ bool CUDABlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
const DeviceMemory<std::complex<double>> &x, int incx,
const DeviceMemory<std::complex<double>> &y, int incy,
DeviceMemory<std::complex<double>> *a, int lda) {
auto cb_alpha = GpuComplexValue(alpha);
return DoBlasInternal(cublasZher2, stream, true /* = pointer_mode_host */,
CUDABlasUpperLower(uplo), n, GpuComplex(&alpha),
CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
GpuComplex(GpuMemory(x)), incx,
GpuComplex(GpuMemory(y)), incy,
GpuComplex(GpuMemoryMutable(a)), lda);
@ -1135,10 +1159,12 @@ bool CUDABlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
const DeviceMemory<std::complex<float>> &x, int incx,
std::complex<float> beta,
DeviceMemory<std::complex<float>> *y, int incy) {
auto cb_alpha = GpuComplexValue(alpha);
auto cb_beta = GpuComplexValue(beta);
return DoBlasInternal(cublasChpmv, stream, true /* = pointer_mode_host */,
CUDABlasUpperLower(uplo), n, GpuComplex(&alpha),
CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
GpuComplex(GpuMemory(ap)), GpuComplex(GpuMemory(x)),
incx, GpuComplex(&beta),
incx, GpuComplex(&cb_beta),
GpuComplex(GpuMemoryMutable(y)), incy);
}
@ -1148,10 +1174,12 @@ bool CUDABlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
const DeviceMemory<std::complex<double>> &x, int incx,
std::complex<double> beta,
DeviceMemory<std::complex<double>> *y, int incy) {
auto cb_alpha = GpuComplexValue(alpha);
auto cb_beta = GpuComplexValue(beta);
return DoBlasInternal(cublasZhpmv, stream, true /* = pointer_mode_host */,
CUDABlasUpperLower(uplo), n, GpuComplex(&alpha),
CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
GpuComplex(GpuMemory(ap)), GpuComplex(GpuMemory(x)),
incx, GpuComplex(&beta),
incx, GpuComplex(&cb_beta),
GpuComplex(GpuMemoryMutable(y)), incy);
}
@ -1160,7 +1188,7 @@ bool CUDABlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
const DeviceMemory<std::complex<float>> &x, int incx,
DeviceMemory<std::complex<float>> *ap) {
return DoBlasInternal(cublasChpr, stream, true /* = pointer_mode_host */,
CUDABlasUpperLower(uplo), n, GpuComplex(&alpha),
CUDABlasUpperLower(uplo), n, &alpha,
GpuComplex(GpuMemory(x)), incx,
GpuComplex(GpuMemoryMutable(ap)));
}
@ -1170,7 +1198,7 @@ bool CUDABlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
const DeviceMemory<std::complex<double>> &x, int incx,
DeviceMemory<std::complex<double>> *ap) {
return DoBlasInternal(cublasZhpr, stream, true /* = pointer_mode_host */,
CUDABlasUpperLower(uplo), n, GpuComplex(&alpha),
CUDABlasUpperLower(uplo), n, &alpha,
GpuComplex(GpuMemory(x)), incx,
GpuComplex(GpuMemoryMutable(ap)));
}
@ -1180,10 +1208,12 @@ bool CUDABlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
const DeviceMemory<std::complex<float>> &x, int incx,
const DeviceMemory<std::complex<float>> &y, int incy,
DeviceMemory<std::complex<float>> *ap) {
return DoBlasInternal(
cublasChpr2, stream, true /* = pointer_mode_host */,
CUDABlasUpperLower(uplo), n, GpuComplex(&alpha), GpuComplex(GpuMemory(x)),
incx, GpuComplex(GpuMemory(y)), incy, GpuComplex(GpuMemoryMutable(ap)));
auto cb_alpha = GpuComplexValue(alpha);
return DoBlasInternal(cublasChpr2, stream, true /* = pointer_mode_host */,
CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
GpuComplex(GpuMemory(x)), incx,
GpuComplex(GpuMemory(y)), incy,
GpuComplex(GpuMemoryMutable(ap)));
}
bool CUDABlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
@ -1191,10 +1221,12 @@ bool CUDABlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
const DeviceMemory<std::complex<double>> &x, int incx,
const DeviceMemory<std::complex<double>> &y, int incy,
DeviceMemory<std::complex<double>> *ap) {
return DoBlasInternal(
cublasZhpr2, stream, true /* = pointer_mode_host */,
CUDABlasUpperLower(uplo), n, GpuComplex(&alpha), GpuComplex(GpuMemory(x)),
incx, GpuComplex(GpuMemory(y)), incy, GpuComplex(GpuMemoryMutable(ap)));
auto cb_alpha = GpuComplexValue(alpha);
return DoBlasInternal(cublasZhpr2, stream, true /* = pointer_mode_host */,
CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
GpuComplex(GpuMemory(x)), incx,
GpuComplex(GpuMemory(y)), incy,
GpuComplex(GpuMemoryMutable(ap)));
}
bool CUDABlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
@ -1684,11 +1716,14 @@ bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
const DeviceMemory<std::complex<float>> &b, int ldb,
std::complex<float> beta,
DeviceMemory<std::complex<float>> *c, int ldc) {
auto cb_alpha = GpuComplexValue(alpha);
auto cb_beta = GpuComplexValue(beta);
return DoBlasInternal(cublasCgemm, stream, true /* = pointer_mode_host */,
CUDABlasTranspose(transa), CUDABlasTranspose(transb), m,
n, k, GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
GpuComplex(GpuMemory(b)), ldb, GpuComplex(&beta),
GpuComplex(GpuMemoryMutable(c)), ldc);
n, k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)),
lda, GpuComplex(GpuMemory(b)), ldb,
GpuComplex(&cb_beta), GpuComplex(GpuMemoryMutable(c)),
ldc);
}
bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
@ -1698,11 +1733,14 @@ bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
const DeviceMemory<std::complex<double>> &b, int ldb,
std::complex<double> beta,
DeviceMemory<std::complex<double>> *c, int ldc) {
auto cb_alpha = GpuComplexValue(alpha);
auto cb_beta = GpuComplexValue(beta);
return DoBlasInternal(cublasZgemm, stream, true /* = pointer_mode_host */,
CUDABlasTranspose(transa), CUDABlasTranspose(transb), m,
n, k, GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
GpuComplex(GpuMemory(b)), ldb, GpuComplex(&beta),
GpuComplex(GpuMemoryMutable(c)), ldc);
n, k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)),
lda, GpuComplex(GpuMemory(b)), ldb,
GpuComplex(&cb_beta), GpuComplex(GpuMemoryMutable(c)),
ldc);
}
bool CUDABlas::DoBlasGemvWithProfiling(
@ -2149,6 +2187,15 @@ struct HalfAsFloat<Eigen::half> {
typedef float type;
};
namespace {
// pass-through for non-complex types that don't need conversion to
// cublas-specific type.
template <typename T>
T inline GpuComplexValue(T v) {
return v;
}
} // namespace
template <typename T, typename Scalar, typename FuncT>
port::Status CUDABlas::DoBlasGemmBatchedInternal(
FuncT cublas_func, Stream *stream, blas::Transpose transa,
@ -2250,11 +2297,13 @@ port::Status CUDABlas::DoBlasGemmBatchedInternal(
#endif
// either CUDA_VERSION < 9.1 or SM < 5.0
if (data_type != CUDA_R_16F) {
auto cb_alpha = GpuComplexValue(alpha);
auto cb_beta = GpuComplexValue(beta);
bool ok = DoBlasInternal(
cublas_func, stream, true /* = pointer_mode_host */,
CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
GpuComplex(&alpha), const_cast<const CUDA_T **>(GpuMemory(a)), lda,
const_cast<const CUDA_T **>(GpuMemory(b)), ldb, GpuComplex(&beta),
GpuComplex(&cb_alpha), const_cast<const CUDA_T **>(GpuMemory(a)), lda,
const_cast<const CUDA_T **>(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
const_cast<CUDA_T **>(GpuMemory(c)), ldc, batch_count);
if (ok) {
return port::Status::OK();
@ -2454,11 +2503,13 @@ bool CUDABlas::DoBlasGemmStridedBatched(
const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
int64 stride_c, int batch_count) {
auto cb_alpha = GpuComplexValue(alpha);
auto cb_beta = GpuComplexValue(beta);
return DoBlasInternal(
cublasCgemmStridedBatched, stream, true /* = pointer_mode_host */,
CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda, stride_a,
GpuComplex(GpuMemory(b)), ldb, stride_b, GpuComplex(&beta),
GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda, stride_a,
GpuComplex(GpuMemory(b)), ldb, stride_b, GpuComplex(&cb_beta),
GpuComplex(GpuMemoryMutable(c)), ldc, stride_c, batch_count);
}
@ -2469,11 +2520,13 @@ bool CUDABlas::DoBlasGemmStridedBatched(
const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
int64 stride_c, int batch_count) {
auto cb_alpha = GpuComplexValue(alpha);
auto cb_beta = GpuComplexValue(beta);
return DoBlasInternal(
cublasZgemmStridedBatched, stream, true /* = pointer_mode_host */,
CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda, stride_a,
GpuComplex(GpuMemory(b)), ldb, stride_b, GpuComplex(&beta),
GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda, stride_a,
GpuComplex(GpuMemory(b)), ldb, stride_b, GpuComplex(&cb_beta),
GpuComplex(GpuMemoryMutable(c)), ldc, stride_c, batch_count);
}
@ -2484,10 +2537,12 @@ bool CUDABlas::DoBlasHemm(Stream *stream, blas::Side side,
const DeviceMemory<std::complex<float>> &b, int ldb,
std::complex<float> beta,
DeviceMemory<std::complex<float>> *c, int ldc) {
auto cb_alpha = GpuComplexValue(alpha);
auto cb_beta = GpuComplexValue(beta);
return DoBlasInternal(cublasChemm, stream, true /* = pointer_mode_host */,
CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n,
GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
GpuComplex(GpuMemory(b)), ldb, GpuComplex(&beta),
GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
GpuComplex(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
GpuComplex(GpuMemoryMutable(c)), ldc);
}
@ -2498,10 +2553,12 @@ bool CUDABlas::DoBlasHemm(Stream *stream, blas::Side side,
const DeviceMemory<std::complex<double>> &b, int ldb,
std::complex<double> beta,
DeviceMemory<std::complex<double>> *c, int ldc) {
auto cb_alpha = GpuComplexValue(alpha);
auto cb_beta = GpuComplexValue(beta);
return DoBlasInternal(cublasZhemm, stream, true /* = pointer_mode_host */,
CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n,
GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
GpuComplex(GpuMemory(b)), ldb, GpuComplex(&beta),
GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
GpuComplex(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
GpuComplex(GpuMemoryMutable(c)), ldc);
}
@ -2513,8 +2570,8 @@ bool CUDABlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo,
int ldc) {
return DoBlasInternal(cublasCherk, stream, true /* = pointer_mode_host */,
CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
k, GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
&beta, GpuComplex(GpuMemoryMutable(c)), ldc);
k, &alpha, GpuComplex(GpuMemory(a)), lda, &beta,
GpuComplex(GpuMemoryMutable(c)), ldc);
}
bool CUDABlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo,
@ -2525,8 +2582,8 @@ bool CUDABlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo,
int ldc) {
return DoBlasInternal(cublasZherk, stream, true /* = pointer_mode_host */,
CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
k, GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
&beta, GpuComplex(GpuMemoryMutable(c)), ldc);
k, &alpha, GpuComplex(GpuMemory(a)), lda, &beta,
GpuComplex(GpuMemoryMutable(c)), ldc);
}
bool CUDABlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
@ -2536,9 +2593,10 @@ bool CUDABlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
const DeviceMemory<std::complex<float>> &b, int ldb,
float beta, DeviceMemory<std::complex<float>> *c,
int ldc) {
auto cb_alpha = GpuComplexValue(alpha);
return DoBlasInternal(cublasCher2k, stream, true /* = pointer_mode_host */,
CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
k, GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
GpuComplex(GpuMemory(b)), ldb, &beta,
GpuComplex(GpuMemoryMutable(c)), ldc);
}
@ -2550,9 +2608,10 @@ bool CUDABlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
const DeviceMemory<std::complex<double>> &b, int ldb,
double beta, DeviceMemory<std::complex<double>> *c,
int ldc) {
auto cb_alpha = GpuComplexValue(alpha);
return DoBlasInternal(cublasZher2k, stream, true /* = pointer_mode_host */,
CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
k, GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
GpuComplex(GpuMemory(b)), ldb, &beta,
GpuComplex(GpuMemoryMutable(c)), ldc);
}
@ -2586,10 +2645,12 @@ bool CUDABlas::DoBlasSymm(Stream *stream, blas::Side side,
const DeviceMemory<std::complex<float>> &b, int ldb,
std::complex<float> beta,
DeviceMemory<std::complex<float>> *c, int ldc) {
auto cb_alpha = GpuComplexValue(alpha);
auto cb_beta = GpuComplexValue(beta);
return DoBlasInternal(cublasCsymm, stream, true /* = pointer_mode_host */,
CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n,
GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
GpuComplex(GpuMemory(b)), ldb, GpuComplex(&beta),
GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
GpuComplex(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
GpuComplex(GpuMemoryMutable(c)), ldc);
}
@ -2600,10 +2661,12 @@ bool CUDABlas::DoBlasSymm(Stream *stream, blas::Side side,
const DeviceMemory<std::complex<double>> &b, int ldb,
std::complex<double> beta,
DeviceMemory<std::complex<double>> *c, int ldc) {
auto cb_alpha = GpuComplexValue(alpha);
auto cb_beta = GpuComplexValue(beta);
return DoBlasInternal(cublasZsymm, stream, true /* = pointer_mode_host */,
CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n,
GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
GpuComplex(GpuMemory(b)), ldb, GpuComplex(&beta),
GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
GpuComplex(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
GpuComplex(GpuMemoryMutable(c)), ldc);
}
@ -2633,10 +2696,12 @@ bool CUDABlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
const DeviceMemory<std::complex<float>> &a, int lda,
std::complex<float> beta,
DeviceMemory<std::complex<float>> *c, int ldc) {
auto cb_alpha = GpuComplexValue(alpha);
auto cb_beta = GpuComplexValue(beta);
return DoBlasInternal(cublasCsyrk, stream, true /* = pointer_mode_host */,
CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
k, GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
GpuComplex(&beta), GpuComplex(GpuMemoryMutable(c)),
k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
GpuComplex(&cb_beta), GpuComplex(GpuMemoryMutable(c)),
ldc);
}
@ -2646,10 +2711,12 @@ bool CUDABlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
const DeviceMemory<std::complex<double>> &a, int lda,
std::complex<double> beta,
DeviceMemory<std::complex<double>> *c, int ldc) {
auto cb_alpha = GpuComplexValue(alpha);
auto cb_beta = GpuComplexValue(beta);
return DoBlasInternal(cublasZsyrk, stream, true /* = pointer_mode_host */,
CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
k, GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
GpuComplex(&beta), GpuComplex(GpuMemoryMutable(c)),
k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
GpuComplex(&cb_beta), GpuComplex(GpuMemoryMutable(c)),
ldc);
}
@ -2682,10 +2749,12 @@ bool CUDABlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
const DeviceMemory<std::complex<float>> &b, int ldb,
std::complex<float> beta,
DeviceMemory<std::complex<float>> *c, int ldc) {
auto cb_alpha = GpuComplexValue(alpha);
auto cb_beta = GpuComplexValue(beta);
return DoBlasInternal(cublasCsyr2k, stream, true /* = pointer_mode_host */,
CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
k, GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
GpuComplex(GpuMemory(b)), ldb, GpuComplex(&beta),
k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
GpuComplex(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
GpuComplex(GpuMemoryMutable(c)), ldc);
}
@ -2696,10 +2765,12 @@ bool CUDABlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
const DeviceMemory<std::complex<double>> &b, int ldb,
std::complex<double> beta,
DeviceMemory<std::complex<double>> *c, int ldc) {
auto cb_alpha = GpuComplexValue(alpha);
auto cb_beta = GpuComplexValue(beta);
return DoBlasInternal(cublasZsyr2k, stream, true /* = pointer_mode_host */,
CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
k, GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
GpuComplex(GpuMemory(b)), ldb, GpuComplex(&beta),
k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
GpuComplex(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
GpuComplex(GpuMemoryMutable(c)), ldc);
}
@ -2733,10 +2804,11 @@ bool CUDABlas::DoBlasTrmm(Stream *stream, blas::Side side,
std::complex<float> alpha,
const DeviceMemory<std::complex<float>> &a, int lda,
DeviceMemory<std::complex<float>> *b, int ldb) {
auto cb_alpha = GpuComplexValue(alpha);
return DoBlasInternal(cublasCtrmm, stream, true /* = pointer_mode_host */,
CUDABlasSide(side), CUDABlasUpperLower(uplo),
CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
GpuComplex(GpuMemoryMutable(b)), ldb,
GpuComplex(GpuMemoryMutable(b)), ldb);
}
@ -2747,10 +2819,11 @@ bool CUDABlas::DoBlasTrmm(Stream *stream, blas::Side side,
std::complex<double> alpha,
const DeviceMemory<std::complex<double>> &a, int lda,
DeviceMemory<std::complex<double>> *b, int ldb) {
auto cb_alpha = GpuComplexValue(alpha);
return DoBlasInternal(cublasZtrmm, stream, true /* = pointer_mode_host */,
CUDABlasSide(side), CUDABlasUpperLower(uplo),
CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
GpuComplex(GpuMemoryMutable(b)), ldb,
GpuComplex(GpuMemoryMutable(b)), ldb);
}
@ -2783,10 +2856,11 @@ bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
std::complex<float> alpha,
const DeviceMemory<std::complex<float>> &a, int lda,
DeviceMemory<std::complex<float>> *b, int ldb) {
auto cb_alpha = GpuComplexValue(alpha);
return DoBlasInternal(cublasCtrsm, stream, true /* = pointer_mode_host */,
CUDABlasSide(side), CUDABlasUpperLower(uplo),
CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
GpuComplex(GpuMemoryMutable(b)), ldb);
}
@ -2796,10 +2870,11 @@ bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
std::complex<double> alpha,
const DeviceMemory<std::complex<double>> &a, int lda,
DeviceMemory<std::complex<double>> *b, int ldb) {
auto cb_alpha = GpuComplexValue(alpha);
return DoBlasInternal(cublasZtrsm, stream, true /* = pointer_mode_host */,
CUDABlasSide(side), CUDABlasUpperLower(uplo),
CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
GpuComplex(&alpha), GpuComplex(GpuMemory(a)), lda,
GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
GpuComplex(GpuMemoryMutable(b)), ldb);
}

View File

@ -112,7 +112,10 @@ cc_library(
cc_library(
name = "gpu_helpers_header",
hdrs = if_gpu_is_configured(["gpu_helpers.h"]),
deps = [":gpu_types_header"],
deps = [
":gpu_types_header",
"//tensorflow/core/platform:logging",
],
)
cc_library(

View File

@ -22,8 +22,10 @@ limitations under the License.
#define TENSORFLOW_STREAM_EXECUTOR_GPU_GPU_HELPERS_H_
#include <stddef.h>
#include <complex>
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/stream_executor/gpu/gpu_types.h"
namespace stream_executor {
@ -83,12 +85,18 @@ struct GpuComplexT<std::complex<double>> {
template <typename T>
inline const typename GpuComplexT<T>::type* GpuComplex(const T* p) {
return reinterpret_cast<const typename GpuComplexT<T>::type*>(p);
auto* result = reinterpret_cast<const typename GpuComplexT<T>::type*>(p);
CHECK_EQ(reinterpret_cast<uintptr_t>(p) % alignof(decltype(*result)), 0)
<< "Source pointer is not aligned by " << alignof(decltype(*result));
return result;
}
template <typename T>
inline typename GpuComplexT<T>::type* GpuComplex(T* p) {
return reinterpret_cast<typename GpuComplexT<T>::type*>(p);
auto* result = reinterpret_cast<typename GpuComplexT<T>::type*>(p);
CHECK_EQ(reinterpret_cast<uintptr_t>(p) % alignof(decltype(*result)), 0)
<< "Source pointer is not aligned by " << alignof(decltype(*result));
return result;
}
// Converts values of std::complex<float/double> to values of