[ROCm] Adding complex datatype support for StridedBatchGemm API calls
This commit is contained in:
parent
2730e4b0bc
commit
2a267c30d7
|
@ -1519,7 +1519,7 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
|
|||
float beta, DeviceMemory<Eigen::half> *c, int ldc) {
|
||||
blas_log("DoBlasGemm");
|
||||
VLOG(1) << absl::StreamFormat(
|
||||
"doing rocBLAS SGEMM: at=%d bt=%d m=%u n=%u "
|
||||
"doing rocBLAS SGEMM<half>: at=%d bt=%d m=%u n=%u "
|
||||
"k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
|
||||
"c=%p ldc=%d",
|
||||
static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
|
||||
|
@ -1565,7 +1565,7 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
|
|||
DeviceMemory<float> *c, int ldc) {
|
||||
blas_log("DoBlasGemm");
|
||||
VLOG(1) << absl::StreamFormat(
|
||||
"doing rocBLAS SGEMM: at=%d bt=%d m=%u n=%u "
|
||||
"doing rocBLAS SGEMM<float>: at=%d bt=%d m=%u n=%u "
|
||||
"k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
|
||||
"c=%p ldc=%d",
|
||||
static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
|
||||
|
@ -2473,7 +2473,12 @@ bool ROCMBlas::DoBlasGemmStridedBatched(
|
|||
int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
|
||||
float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
|
||||
int batch_count) {
|
||||
blas_log("DoBlasGemmStridedBatched");
|
||||
VLOG(1) << absl::StreamFormat(
|
||||
"doing rocBLAS SGEMM Strided Batched<float>: at=%d bt=%d m=%u n=%u "
|
||||
"k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
|
||||
"c=%p ldc=%d",
|
||||
static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
|
||||
a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
|
||||
return DoBlasInternal(wrap::rocblas_sgemm_strided_batched, stream,
|
||||
false, /* pointer_mode_host */
|
||||
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m,
|
||||
|
@ -2487,7 +2492,12 @@ bool ROCMBlas::DoBlasGemmStridedBatched(
|
|||
int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
|
||||
double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
|
||||
int batch_count) {
|
||||
blas_log("DoBlasGemmStridedBatched");
|
||||
VLOG(1) << absl::StreamFormat(
|
||||
"doing rocBLAS SGEMM Strided Batched<double>: at=%d bt=%d m=%u n=%u "
|
||||
"k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
|
||||
"c=%p ldc=%d",
|
||||
static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
|
||||
a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
|
||||
return DoBlasInternal(wrap::rocblas_dgemm_strided_batched, stream,
|
||||
false, /* pointer_mode_host */
|
||||
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m,
|
||||
|
@ -2502,10 +2512,13 @@ bool ROCMBlas::DoBlasGemmStridedBatched(
|
|||
const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
|
||||
std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
|
||||
int64 stride_c, int batch_count) {
|
||||
LOG(ERROR) << "rocBLAS does not currently support the "
|
||||
"DoBlasGemmStridedBatched operation "
|
||||
<< "for the \"complex<float>\" datatype";
|
||||
return false;
|
||||
return DoBlasInternal(wrap::rocblas_cgemm_strided_batched, stream,
|
||||
false, /* pointer_mode_host */
|
||||
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m,
|
||||
n, k, complex_cast(alpha), complex_cast(a), lda,
|
||||
stride_a, complex_cast(b), ldb, stride_b,
|
||||
complex_cast(beta), complex_cast(c), ldc, stride_c,
|
||||
batch_count);
|
||||
}
|
||||
bool ROCMBlas::DoBlasGemmStridedBatched(
|
||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||
|
@ -2514,10 +2527,13 @@ bool ROCMBlas::DoBlasGemmStridedBatched(
|
|||
const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
|
||||
std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
|
||||
int64 stride_c, int batch_count) {
|
||||
LOG(ERROR) << "rocBLAS does not currently support the "
|
||||
"DoBlasGemmStridedBatched operation "
|
||||
<< "for the \"complex<double>\" datatype";
|
||||
return false;
|
||||
return DoBlasInternal(wrap::rocblas_zgemm_strided_batched, stream,
|
||||
false, /* pointer_mode_host */
|
||||
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m,
|
||||
n, k, complex_cast(alpha), complex_cast(a), lda,
|
||||
stride_a, complex_cast(b), ldb, stride_b,
|
||||
complex_cast(beta), complex_cast(c), ldc, stride_c,
|
||||
batch_count);
|
||||
}
|
||||
|
||||
port::Status ROCMBlas::GetVersion(string *version) {
|
||||
|
|
Loading…
Reference in New Issue