[ROCm] Adding complex datatype support for StridedBatchGemm API calls
This commit is contained in:
parent
2730e4b0bc
commit
2a267c30d7
|
@ -1519,7 +1519,7 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
|
||||||
float beta, DeviceMemory<Eigen::half> *c, int ldc) {
|
float beta, DeviceMemory<Eigen::half> *c, int ldc) {
|
||||||
blas_log("DoBlasGemm");
|
blas_log("DoBlasGemm");
|
||||||
VLOG(1) << absl::StreamFormat(
|
VLOG(1) << absl::StreamFormat(
|
||||||
"doing rocBLAS SGEMM: at=%d bt=%d m=%u n=%u "
|
"doing rocBLAS SGEMM<half>: at=%d bt=%d m=%u n=%u "
|
||||||
"k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
|
"k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
|
||||||
"c=%p ldc=%d",
|
"c=%p ldc=%d",
|
||||||
static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
|
static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
|
||||||
|
@ -1565,7 +1565,7 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
|
||||||
DeviceMemory<float> *c, int ldc) {
|
DeviceMemory<float> *c, int ldc) {
|
||||||
blas_log("DoBlasGemm");
|
blas_log("DoBlasGemm");
|
||||||
VLOG(1) << absl::StreamFormat(
|
VLOG(1) << absl::StreamFormat(
|
||||||
"doing rocBLAS SGEMM: at=%d bt=%d m=%u n=%u "
|
"doing rocBLAS SGEMM<float>: at=%d bt=%d m=%u n=%u "
|
||||||
"k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
|
"k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
|
||||||
"c=%p ldc=%d",
|
"c=%p ldc=%d",
|
||||||
static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
|
static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
|
||||||
|
@ -2473,7 +2473,12 @@ bool ROCMBlas::DoBlasGemmStridedBatched(
|
||||||
int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
|
int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
|
||||||
float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
|
float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
|
||||||
int batch_count) {
|
int batch_count) {
|
||||||
blas_log("DoBlasGemmStridedBatched");
|
VLOG(1) << absl::StreamFormat(
|
||||||
|
"doing rocBLAS SGEMM Strided Batched<float>: at=%d bt=%d m=%u n=%u "
|
||||||
|
"k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
|
||||||
|
"c=%p ldc=%d",
|
||||||
|
static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
|
||||||
|
a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
|
||||||
return DoBlasInternal(wrap::rocblas_sgemm_strided_batched, stream,
|
return DoBlasInternal(wrap::rocblas_sgemm_strided_batched, stream,
|
||||||
false, /* pointer_mode_host */
|
false, /* pointer_mode_host */
|
||||||
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m,
|
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m,
|
||||||
|
@ -2487,7 +2492,12 @@ bool ROCMBlas::DoBlasGemmStridedBatched(
|
||||||
int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
|
int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
|
||||||
double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
|
double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
|
||||||
int batch_count) {
|
int batch_count) {
|
||||||
blas_log("DoBlasGemmStridedBatched");
|
VLOG(1) << absl::StreamFormat(
|
||||||
|
"doing rocBLAS SGEMM Strided Batched<double>: at=%d bt=%d m=%u n=%u "
|
||||||
|
"k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
|
||||||
|
"c=%p ldc=%d",
|
||||||
|
static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
|
||||||
|
a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
|
||||||
return DoBlasInternal(wrap::rocblas_dgemm_strided_batched, stream,
|
return DoBlasInternal(wrap::rocblas_dgemm_strided_batched, stream,
|
||||||
false, /* pointer_mode_host */
|
false, /* pointer_mode_host */
|
||||||
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m,
|
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m,
|
||||||
|
@ -2502,10 +2512,13 @@ bool ROCMBlas::DoBlasGemmStridedBatched(
|
||||||
const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
|
const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
|
||||||
std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
|
std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
|
||||||
int64 stride_c, int batch_count) {
|
int64 stride_c, int batch_count) {
|
||||||
LOG(ERROR) << "rocBLAS does not currently support the "
|
return DoBlasInternal(wrap::rocblas_cgemm_strided_batched, stream,
|
||||||
"DoBlasGemmStridedBatched operation "
|
false, /* pointer_mode_host */
|
||||||
<< "for the \"complex<float>\" datatype";
|
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m,
|
||||||
return false;
|
n, k, complex_cast(alpha), complex_cast(a), lda,
|
||||||
|
stride_a, complex_cast(b), ldb, stride_b,
|
||||||
|
complex_cast(beta), complex_cast(c), ldc, stride_c,
|
||||||
|
batch_count);
|
||||||
}
|
}
|
||||||
bool ROCMBlas::DoBlasGemmStridedBatched(
|
bool ROCMBlas::DoBlasGemmStridedBatched(
|
||||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||||
|
@ -2514,10 +2527,13 @@ bool ROCMBlas::DoBlasGemmStridedBatched(
|
||||||
const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
|
const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
|
||||||
std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
|
std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
|
||||||
int64 stride_c, int batch_count) {
|
int64 stride_c, int batch_count) {
|
||||||
LOG(ERROR) << "rocBLAS does not currently support the "
|
return DoBlasInternal(wrap::rocblas_zgemm_strided_batched, stream,
|
||||||
"DoBlasGemmStridedBatched operation "
|
false, /* pointer_mode_host */
|
||||||
<< "for the \"complex<double>\" datatype";
|
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m,
|
||||||
return false;
|
n, k, complex_cast(alpha), complex_cast(a), lda,
|
||||||
|
stride_a, complex_cast(b), ldb, stride_b,
|
||||||
|
complex_cast(beta), complex_cast(c), ldc, stride_c,
|
||||||
|
batch_count);
|
||||||
}
|
}
|
||||||
|
|
||||||
port::Status ROCMBlas::GetVersion(string *version) {
|
port::Status ROCMBlas::GetVersion(string *version) {
|
||||||
|
|
Loading…
Reference in New Issue