[ROCm] Adding complex datatype support for StridedBatchGemm API calls

This commit is contained in:
Deven Desai 2020-04-15 21:32:43 +00:00
parent 2730e4b0bc
commit 2a267c30d7
1 changed files with 28 additions and 12 deletions

View File

@ -1519,7 +1519,7 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
float beta, DeviceMemory<Eigen::half> *c, int ldc) {
blas_log("DoBlasGemm");
VLOG(1) << absl::StreamFormat(
"doing rocBLAS SGEMM: at=%d bt=%d m=%u n=%u "
"doing rocBLAS SGEMM<half>: at=%d bt=%d m=%u n=%u "
"k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
"c=%p ldc=%d",
static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
@ -1565,7 +1565,7 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
DeviceMemory<float> *c, int ldc) {
blas_log("DoBlasGemm");
VLOG(1) << absl::StreamFormat(
"doing rocBLAS SGEMM: at=%d bt=%d m=%u n=%u "
"doing rocBLAS SGEMM<float>: at=%d bt=%d m=%u n=%u "
"k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
"c=%p ldc=%d",
static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
@ -2473,7 +2473,12 @@ bool ROCMBlas::DoBlasGemmStridedBatched(
int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
int batch_count) {
blas_log("DoBlasGemmStridedBatched");
VLOG(1) << absl::StreamFormat(
"doing rocBLAS SGEMM Strided Batched<float>: at=%d bt=%d m=%u n=%u "
"k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
"c=%p ldc=%d",
static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
return DoBlasInternal(wrap::rocblas_sgemm_strided_batched, stream,
false, /* pointer_mode_host */
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m,
@ -2487,7 +2492,12 @@ bool ROCMBlas::DoBlasGemmStridedBatched(
int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
int batch_count) {
blas_log("DoBlasGemmStridedBatched");
VLOG(1) << absl::StreamFormat(
"doing rocBLAS SGEMM Strided Batched<double>: at=%d bt=%d m=%u n=%u "
"k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
"c=%p ldc=%d",
static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
return DoBlasInternal(wrap::rocblas_dgemm_strided_batched, stream,
false, /* pointer_mode_host */
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m,
@ -2502,10 +2512,13 @@ bool ROCMBlas::DoBlasGemmStridedBatched(
const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
int64 stride_c, int batch_count) {
LOG(ERROR) << "rocBLAS does not currently support the "
"DoBlasGemmStridedBatched operation "
<< "for the \"complex<float>\" datatype";
return false;
return DoBlasInternal(wrap::rocblas_cgemm_strided_batched, stream,
false, /* pointer_mode_host */
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m,
n, k, complex_cast(alpha), complex_cast(a), lda,
stride_a, complex_cast(b), ldb, stride_b,
complex_cast(beta), complex_cast(c), ldc, stride_c,
batch_count);
}
bool ROCMBlas::DoBlasGemmStridedBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
@ -2514,10 +2527,13 @@ bool ROCMBlas::DoBlasGemmStridedBatched(
const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
int64 stride_c, int batch_count) {
LOG(ERROR) << "rocBLAS does not currently support the "
"DoBlasGemmStridedBatched operation "
<< "for the \"complex<double>\" datatype";
return false;
return DoBlasInternal(wrap::rocblas_zgemm_strided_batched, stream,
false, /* pointer_mode_host */
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m,
n, k, complex_cast(alpha), complex_cast(a), lda,
stride_a, complex_cast(b), ldb, stride_b,
complex_cast(beta), complex_cast(c), ldc, stride_c,
batch_count);
}
port::Status ROCMBlas::GetVersion(string *version) {