Move ScopedCublasMathMode inside DoBlasInternalImpl
This commit is contained in:
parent
9afaf559d9
commit
3f2c98610e
@ -386,9 +386,9 @@ cudaDataType_t CUDAComputationType(blas::ComputationType ty) {
|
||||
} // namespace
|
||||
|
||||
template <typename FuncT, typename... Args>
|
||||
bool CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
|
||||
bool CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream* stream,
|
||||
bool pointer_mode_host, bool err_on_failure,
|
||||
Args... args) {
|
||||
cublasMath_t math_type, Args... args) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
|
||||
CHECK(blas_ != nullptr);
|
||||
@ -396,6 +396,20 @@ bool CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
|
||||
return false;
|
||||
}
|
||||
|
||||
#if CUDA_VERSION >= 9000
|
||||
ScopedCublasMathMode math_mode{blas_};
|
||||
#if CUBLAS_VER_MAJOR >= 11
|
||||
if (math_type == CUBLAS_TF32_TENSOR_OP_MATH &&
|
||||
tensorflow::tf32_execution_allowed()) {
|
||||
#else
|
||||
if (math_type == CUBLAS_TENSOR_OP_MATH) {
|
||||
#endif
|
||||
if (!math_mode.Init(math_type)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
gpu::ScopedActivateExecutorContext sac{parent_};
|
||||
ScopedCublasPointerMode pointer_mode{blas_};
|
||||
if (!pointer_mode.Init(pointer_mode_host ? CUBLAS_POINTER_MODE_HOST
|
||||
@ -1615,15 +1629,14 @@ bool CUDABlas::DoBlasGemm(
|
||||
}
|
||||
|
||||
#if CUDA_VERSION < 11000
|
||||
ScopedCublasMathMode math_mode{blas_};
|
||||
if (!math_mode.Init(CUBLAS_TENSOR_OP_MATH)) {
|
||||
return false;
|
||||
}
|
||||
cublasMath_t math_type = CUBLAS_TENSOR_OP_MATH;
|
||||
#else
|
||||
cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
|
||||
#endif
|
||||
|
||||
return DoBlasInternalImpl(
|
||||
cublasSgemmEx, stream, true /* = pointer_mode_host */,
|
||||
true /* = err_on_failure= */, CUDABlasTranspose(transa),
|
||||
true /* = err_on_failure= */, math_type, CUDABlasTranspose(transa),
|
||||
CUDABlasTranspose(transb), m, n, k, &alpha, GpuMemory(a),
|
||||
SE_CUDA_DATA_HALF, lda, GpuMemory(b), SE_CUDA_DATA_HALF, ldb, &beta,
|
||||
GpuMemoryMutable(c), SE_CUDA_DATA_HALF, ldc);
|
||||
@ -1669,19 +1682,17 @@ bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
|
||||
}
|
||||
}
|
||||
|
||||
#if CUBLAS_VER_MAJOR >= 11
|
||||
ScopedCublasMathMode math_mode{blas_};
|
||||
if (tensorflow::tf32_execution_allowed()) {
|
||||
if (!math_mode.Init(CUBLAS_TF32_TENSOR_OP_MATH)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
#if CUDA_VERSION < 11000
|
||||
cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
|
||||
#else
|
||||
cublasMath_t math_type = CUBLAS_TF32_TENSOR_OP_MATH;
|
||||
#endif
|
||||
|
||||
return DoBlasInternal(cublasSgemm, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasTranspose(transa), CUDABlasTranspose(transb), m,
|
||||
n, k, &alpha, GpuMemory(a), lda, GpuMemory(b), ldb,
|
||||
&beta, GpuMemoryMutable(c), ldc);
|
||||
return DoBlasInternalImpl(
|
||||
cublasSgemm, stream, true /* = pointer_mode_host */,
|
||||
true /* = err_on_failure */, math_type, CUDABlasTranspose(transa),
|
||||
CUDABlasTranspose(transb), m, n, k, &alpha, GpuMemory(a), lda,
|
||||
GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc);
|
||||
}
|
||||
|
||||
bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
|
||||
@ -1704,16 +1715,6 @@ bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
|
||||
DeviceMemory<std::complex<float>> *c, int ldc) {
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
auto cb_beta = GpuComplexValue(beta);
|
||||
|
||||
#if CUBLAS_VER_MAJOR >= 11
|
||||
ScopedCublasMathMode math_mode{blas_};
|
||||
if (tensorflow::tf32_execution_allowed()) {
|
||||
if (!math_mode.Init(CUBLAS_TF32_TENSOR_OP_MATH)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return DoBlasInternal(cublasCgemm, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasTranspose(transa), CUDABlasTranspose(transb), m,
|
||||
n, k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)),
|
||||
@ -2286,10 +2287,27 @@ port::Status CUDABlas::DoBlasGemmBatchedInternal(
|
||||
if (stream->parent()->GetDeviceDescription().cuda_compute_capability(
|
||||
&cc_major, &cc_minor) &&
|
||||
cc_major >= 5) {
|
||||
bool use_tensor_ops =
|
||||
data_type == CUDA_R_16F || tensorflow::tf32_execution_allowed();
|
||||
cublasGemmAlgo_t algo =
|
||||
(use_tensor_ops ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT);
|
||||
cublasMath_t math_type;
|
||||
cublasGemmAlgo_t algo;
|
||||
if (data_type == CUDA_R_16F) {
|
||||
#if CUDA_VERSION < 11000
|
||||
math_type = CUBLAS_TENSOR_OP_MATH;
|
||||
#else
|
||||
math_type = CUBLAS_DEFAULT_MATH;
|
||||
#endif
|
||||
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
|
||||
#if CUBLAS_VER_MAJOR >= 11
|
||||
} else if (data_type == CUDA_R_32F) {
|
||||
// DoBlassInternalImpl will switch math_type back to CUBLAS_DEFAULT_MATH
|
||||
// if TF32 is disabled.
|
||||
math_type = CUBLAS_TF32_TENSOR_OP_MATH;
|
||||
algo = tensorflow::tf32_execution_allowed() ? CUBLAS_GEMM_DFALT_TENSOR_OP
|
||||
: CUBLAS_GEMM_DFALT;
|
||||
#endif
|
||||
} else {
|
||||
math_type = CUBLAS_DEFAULT_MATH;
|
||||
algo = CUBLAS_GEMM_DFALT;
|
||||
}
|
||||
cudaDataType_t compute_type =
|
||||
(data_type == CUDA_R_16F ? CUDA_R_32F : data_type);
|
||||
const void **a_void_ptrs = reinterpret_cast<const void **>(
|
||||
@ -2301,7 +2319,7 @@ port::Status CUDABlas::DoBlasGemmBatchedInternal(
|
||||
bool ok;
|
||||
ok = DoBlasInternalImpl(
|
||||
AS_LAMBDA(cublasGemmBatchedEx), stream, true /* = pointer_mode_host */,
|
||||
true /* = err_on_failure */, CUDABlasTranspose(transa),
|
||||
true /* = err_on_failure */, math_type, CUDABlasTranspose(transa),
|
||||
CUDABlasTranspose(transb), m, n, k, &alpha, a_void_ptrs, data_type, lda,
|
||||
b_void_ptrs, data_type, ldb, &beta, c_void_ptrs, data_type, ldc,
|
||||
batch_count, compute_type, algo);
|
||||
@ -2475,9 +2493,14 @@ bool CUDABlas::DoBlasGemmStridedBatched(
|
||||
cc_major >= 5) {
|
||||
cublasGemmAlgo_t algo =
|
||||
(cc_major >= 7 ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT);
|
||||
#if CUDA_VERSION < 11000
|
||||
cublasMath_t math_type = CUBLAS_TENSOR_OP_MATH;
|
||||
#else
|
||||
cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
|
||||
#endif
|
||||
bool ok = DoBlasInternalImpl(
|
||||
AS_LAMBDA(cublasGemmStridedBatchedEx), stream,
|
||||
true /* = pointer_mode_host */, true /* = err_on_failure */,
|
||||
true /* = pointer_mode_host */, true /* = err_on_failure */, math_type,
|
||||
CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
|
||||
GpuMemory(a), CUDA_R_16F, lda, stride_a, GpuMemory(b), CUDA_R_16F, ldb,
|
||||
stride_b, &beta, GpuMemoryMutable(c), CUDA_R_16F, ldc, stride_c,
|
||||
@ -2499,10 +2522,10 @@ bool CUDABlas::DoBlasGemmStridedBatched(
|
||||
reinterpret_cast<__half *>(GpuMemoryMutable(c) + batch * stride_c);
|
||||
bool ok = DoBlasInternalImpl(
|
||||
cublasSgemmEx, stream, true /* = pointer_mode_host */,
|
||||
true /* = err_on_failure= */, CUDABlasTranspose(transa),
|
||||
CUDABlasTranspose(transb), m, n, k, &alpha, a_matrix, SE_CUDA_DATA_HALF,
|
||||
lda, b_matrix, SE_CUDA_DATA_HALF, ldb, &beta, c_matrix,
|
||||
SE_CUDA_DATA_HALF, ldc);
|
||||
true /* = err_on_failure= */, CUBLAS_DEFAULT_MATH,
|
||||
CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
|
||||
a_matrix, SE_CUDA_DATA_HALF, lda, b_matrix, SE_CUDA_DATA_HALF, ldb,
|
||||
&beta, c_matrix, SE_CUDA_DATA_HALF, ldc);
|
||||
if (!ok) {
|
||||
LOG(ERROR) << "failed BLAS call, see log for details";
|
||||
return false;
|
||||
@ -2517,19 +2540,17 @@ bool CUDABlas::DoBlasGemmStridedBatched(
|
||||
int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
|
||||
float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
|
||||
int batch_count) {
|
||||
#if CUBLAS_VER_MAJOR >= 11
|
||||
ScopedCublasMathMode math_mode{blas_};
|
||||
if (tensorflow::tf32_execution_allowed()) {
|
||||
if (!math_mode.Init(CUBLAS_TF32_TENSOR_OP_MATH)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
#if CUDA_VERSION < 11000
|
||||
cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
|
||||
#else
|
||||
cublasMath_t math_type = CUBLAS_TF32_TENSOR_OP_MATH;
|
||||
#endif
|
||||
return DoBlasInternal(
|
||||
return DoBlasInternalImpl(
|
||||
cublasSgemmStridedBatched, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
|
||||
GpuMemory(a), lda, stride_a, GpuMemory(b), ldb, stride_b, &beta,
|
||||
GpuMemoryMutable(c), ldc, stride_c, batch_count);
|
||||
true /* = err_on_failure */, math_type, CUDABlasTranspose(transa),
|
||||
CUDABlasTranspose(transb), m, n, k, &alpha, GpuMemory(a), lda, stride_a,
|
||||
GpuMemory(b), ldb, stride_b, &beta, GpuMemoryMutable(c), ldc, stride_c,
|
||||
batch_count);
|
||||
}
|
||||
|
||||
bool CUDABlas::DoBlasGemmStridedBatched(
|
||||
@ -2552,14 +2573,6 @@ bool CUDABlas::DoBlasGemmStridedBatched(
|
||||
const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
|
||||
std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
|
||||
int64 stride_c, int batch_count) {
|
||||
#if CUBLAS_VER_MAJOR >= 11
|
||||
ScopedCublasMathMode math_mode{blas_};
|
||||
if (tensorflow::tf32_execution_allowed()) {
|
||||
if (!math_mode.Init(CUBLAS_TF32_TENSOR_OP_MATH)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
auto cb_alpha = GpuComplexValue(alpha);
|
||||
auto cb_beta = GpuComplexValue(beta);
|
||||
return DoBlasInternal(
|
||||
|
@ -81,9 +81,9 @@ class CUDABlas : public blas::BlasSupport {
|
||||
// err_on_failure: Whether to print an error if the cublas function fails.
|
||||
// args: Arguments of cuBLAS function.
|
||||
template <typename FuncT, typename... Args>
|
||||
bool DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
|
||||
bool DoBlasInternalImpl(FuncT cublas_func, Stream* stream,
|
||||
bool pointer_mode_host, bool err_on_failure,
|
||||
Args... args);
|
||||
cublasMath_t math_type, Args... args);
|
||||
|
||||
// Convenience functions that call DoBlasInternalImpl with different values
|
||||
// for err_on_failure.
|
||||
@ -91,7 +91,8 @@ class CUDABlas : public blas::BlasSupport {
|
||||
bool DoBlasInternal(FuncT cublas_func, Stream *stream, bool pointer_mode_host,
|
||||
Args... args) {
|
||||
return DoBlasInternalImpl(cublas_func, stream, pointer_mode_host,
|
||||
/*err_on_failure=*/true, args...);
|
||||
/*err_on_failure=*/true, CUBLAS_DEFAULT_MATH,
|
||||
args...);
|
||||
}
|
||||
template <typename FuncT, typename... Args>
|
||||
bool DoBlasInternalFailureOK(FuncT cublas_func, Stream *stream,
|
||||
@ -99,7 +100,8 @@ class CUDABlas : public blas::BlasSupport {
|
||||
// Tensor ops are hard-coded off in this path, but can still be enabled with
|
||||
// a specific algorithm choice as in DoBlasGemmWithAlgorithmImpl().
|
||||
return DoBlasInternalImpl(cublas_func, stream, pointer_mode_host,
|
||||
/*err_on_failure=*/false, args...);
|
||||
/*err_on_failure=*/false, CUBLAS_DEFAULT_MATH,
|
||||
args...);
|
||||
}
|
||||
|
||||
// A helper function to implement DoBlasGemmBatched interfaces for generic
|
||||
|
Loading…
Reference in New Issue
Block a user