Move ScopedCublasMathMode inside DoBlasInternalImpl

This commit is contained in:
Nathan Luehr 2020-06-22 12:29:33 -05:00
parent 9afaf559d9
commit 3f2c98610e
2 changed files with 76 additions and 61 deletions

View File

@ -386,9 +386,9 @@ cudaDataType_t CUDAComputationType(blas::ComputationType ty) {
} // namespace
template <typename FuncT, typename... Args>
bool CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
bool CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream* stream,
bool pointer_mode_host, bool err_on_failure,
Args... args) {
cublasMath_t math_type, Args... args) {
absl::MutexLock lock(&mu_);
CHECK(blas_ != nullptr);
@ -396,6 +396,20 @@ bool CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
return false;
}
#if CUDA_VERSION >= 9000
ScopedCublasMathMode math_mode{blas_};
#if CUBLAS_VER_MAJOR >= 11
if (math_type == CUBLAS_TF32_TENSOR_OP_MATH &&
tensorflow::tf32_execution_allowed()) {
#else
if (math_type == CUBLAS_TENSOR_OP_MATH) {
#endif
if (!math_mode.Init(math_type)) {
return false;
}
}
#endif
gpu::ScopedActivateExecutorContext sac{parent_};
ScopedCublasPointerMode pointer_mode{blas_};
if (!pointer_mode.Init(pointer_mode_host ? CUBLAS_POINTER_MODE_HOST
@ -1615,15 +1629,14 @@ bool CUDABlas::DoBlasGemm(
}
#if CUDA_VERSION < 11000
ScopedCublasMathMode math_mode{blas_};
if (!math_mode.Init(CUBLAS_TENSOR_OP_MATH)) {
return false;
}
cublasMath_t math_type = CUBLAS_TENSOR_OP_MATH;
#else
cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
#endif
return DoBlasInternalImpl(
cublasSgemmEx, stream, true /* = pointer_mode_host */,
true /* = err_on_failure= */, CUDABlasTranspose(transa),
true /* = err_on_failure= */, math_type, CUDABlasTranspose(transa),
CUDABlasTranspose(transb), m, n, k, &alpha, GpuMemory(a),
SE_CUDA_DATA_HALF, lda, GpuMemory(b), SE_CUDA_DATA_HALF, ldb, &beta,
GpuMemoryMutable(c), SE_CUDA_DATA_HALF, ldc);
@ -1669,19 +1682,17 @@ bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
}
}
#if CUBLAS_VER_MAJOR >= 11
ScopedCublasMathMode math_mode{blas_};
if (tensorflow::tf32_execution_allowed()) {
if (!math_mode.Init(CUBLAS_TF32_TENSOR_OP_MATH)) {
return false;
}
}
#if CUDA_VERSION < 11000
cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
#else
cublasMath_t math_type = CUBLAS_TF32_TENSOR_OP_MATH;
#endif
return DoBlasInternal(cublasSgemm, stream, true /* = pointer_mode_host */,
CUDABlasTranspose(transa), CUDABlasTranspose(transb), m,
n, k, &alpha, GpuMemory(a), lda, GpuMemory(b), ldb,
&beta, GpuMemoryMutable(c), ldc);
return DoBlasInternalImpl(
cublasSgemm, stream, true /* = pointer_mode_host */,
true /* = err_on_failure */, math_type, CUDABlasTranspose(transa),
CUDABlasTranspose(transb), m, n, k, &alpha, GpuMemory(a), lda,
GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc);
}
bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
@ -1704,16 +1715,6 @@ bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
DeviceMemory<std::complex<float>> *c, int ldc) {
auto cb_alpha = GpuComplexValue(alpha);
auto cb_beta = GpuComplexValue(beta);
#if CUBLAS_VER_MAJOR >= 11
ScopedCublasMathMode math_mode{blas_};
if (tensorflow::tf32_execution_allowed()) {
if (!math_mode.Init(CUBLAS_TF32_TENSOR_OP_MATH)) {
return false;
}
}
#endif
return DoBlasInternal(cublasCgemm, stream, true /* = pointer_mode_host */,
CUDABlasTranspose(transa), CUDABlasTranspose(transb), m,
n, k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)),
@ -2286,10 +2287,27 @@ port::Status CUDABlas::DoBlasGemmBatchedInternal(
if (stream->parent()->GetDeviceDescription().cuda_compute_capability(
&cc_major, &cc_minor) &&
cc_major >= 5) {
bool use_tensor_ops =
data_type == CUDA_R_16F || tensorflow::tf32_execution_allowed();
cublasGemmAlgo_t algo =
(use_tensor_ops ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT);
cublasMath_t math_type;
cublasGemmAlgo_t algo;
if (data_type == CUDA_R_16F) {
#if CUDA_VERSION < 11000
math_type = CUBLAS_TENSOR_OP_MATH;
#else
math_type = CUBLAS_DEFAULT_MATH;
#endif
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
#if CUBLAS_VER_MAJOR >= 11
} else if (data_type == CUDA_R_32F) {
// DoBlassInternalImpl will switch math_type back to CUBLAS_DEFAULT_MATH
// if TF32 is disabled.
math_type = CUBLAS_TF32_TENSOR_OP_MATH;
algo = tensorflow::tf32_execution_allowed() ? CUBLAS_GEMM_DFALT_TENSOR_OP
: CUBLAS_GEMM_DFALT;
#endif
} else {
math_type = CUBLAS_DEFAULT_MATH;
algo = CUBLAS_GEMM_DFALT;
}
cudaDataType_t compute_type =
(data_type == CUDA_R_16F ? CUDA_R_32F : data_type);
const void **a_void_ptrs = reinterpret_cast<const void **>(
@ -2301,7 +2319,7 @@ port::Status CUDABlas::DoBlasGemmBatchedInternal(
bool ok;
ok = DoBlasInternalImpl(
AS_LAMBDA(cublasGemmBatchedEx), stream, true /* = pointer_mode_host */,
true /* = err_on_failure */, CUDABlasTranspose(transa),
true /* = err_on_failure */, math_type, CUDABlasTranspose(transa),
CUDABlasTranspose(transb), m, n, k, &alpha, a_void_ptrs, data_type, lda,
b_void_ptrs, data_type, ldb, &beta, c_void_ptrs, data_type, ldc,
batch_count, compute_type, algo);
@ -2475,9 +2493,14 @@ bool CUDABlas::DoBlasGemmStridedBatched(
cc_major >= 5) {
cublasGemmAlgo_t algo =
(cc_major >= 7 ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT);
#if CUDA_VERSION < 11000
cublasMath_t math_type = CUBLAS_TENSOR_OP_MATH;
#else
cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
#endif
bool ok = DoBlasInternalImpl(
AS_LAMBDA(cublasGemmStridedBatchedEx), stream,
true /* = pointer_mode_host */, true /* = err_on_failure */,
true /* = pointer_mode_host */, true /* = err_on_failure */, math_type,
CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
GpuMemory(a), CUDA_R_16F, lda, stride_a, GpuMemory(b), CUDA_R_16F, ldb,
stride_b, &beta, GpuMemoryMutable(c), CUDA_R_16F, ldc, stride_c,
@ -2499,10 +2522,10 @@ bool CUDABlas::DoBlasGemmStridedBatched(
reinterpret_cast<__half *>(GpuMemoryMutable(c) + batch * stride_c);
bool ok = DoBlasInternalImpl(
cublasSgemmEx, stream, true /* = pointer_mode_host */,
true /* = err_on_failure= */, CUDABlasTranspose(transa),
CUDABlasTranspose(transb), m, n, k, &alpha, a_matrix, SE_CUDA_DATA_HALF,
lda, b_matrix, SE_CUDA_DATA_HALF, ldb, &beta, c_matrix,
SE_CUDA_DATA_HALF, ldc);
true /* = err_on_failure= */, CUBLAS_DEFAULT_MATH,
CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
a_matrix, SE_CUDA_DATA_HALF, lda, b_matrix, SE_CUDA_DATA_HALF, ldb,
&beta, c_matrix, SE_CUDA_DATA_HALF, ldc);
if (!ok) {
LOG(ERROR) << "failed BLAS call, see log for details";
return false;
@ -2517,19 +2540,17 @@ bool CUDABlas::DoBlasGemmStridedBatched(
int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
int batch_count) {
#if CUBLAS_VER_MAJOR >= 11
ScopedCublasMathMode math_mode{blas_};
if (tensorflow::tf32_execution_allowed()) {
if (!math_mode.Init(CUBLAS_TF32_TENSOR_OP_MATH)) {
return false;
}
}
#if CUDA_VERSION < 11000
cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
#else
cublasMath_t math_type = CUBLAS_TF32_TENSOR_OP_MATH;
#endif
return DoBlasInternal(
return DoBlasInternalImpl(
cublasSgemmStridedBatched, stream, true /* = pointer_mode_host */,
CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
GpuMemory(a), lda, stride_a, GpuMemory(b), ldb, stride_b, &beta,
GpuMemoryMutable(c), ldc, stride_c, batch_count);
true /* = err_on_failure */, math_type, CUDABlasTranspose(transa),
CUDABlasTranspose(transb), m, n, k, &alpha, GpuMemory(a), lda, stride_a,
GpuMemory(b), ldb, stride_b, &beta, GpuMemoryMutable(c), ldc, stride_c,
batch_count);
}
bool CUDABlas::DoBlasGemmStridedBatched(
@ -2552,14 +2573,6 @@ bool CUDABlas::DoBlasGemmStridedBatched(
const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
int64 stride_c, int batch_count) {
#if CUBLAS_VER_MAJOR >= 11
ScopedCublasMathMode math_mode{blas_};
if (tensorflow::tf32_execution_allowed()) {
if (!math_mode.Init(CUBLAS_TF32_TENSOR_OP_MATH)) {
return false;
}
}
#endif
auto cb_alpha = GpuComplexValue(alpha);
auto cb_beta = GpuComplexValue(beta);
return DoBlasInternal(

View File

@ -81,9 +81,9 @@ class CUDABlas : public blas::BlasSupport {
// err_on_failure: Whether to print an error if the cublas function fails.
// args: Arguments of cuBLAS function.
template <typename FuncT, typename... Args>
bool DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
bool DoBlasInternalImpl(FuncT cublas_func, Stream* stream,
bool pointer_mode_host, bool err_on_failure,
Args... args);
cublasMath_t math_type, Args... args);
// Convenience functions that call DoBlasInternalImpl with different values
// for err_on_failure.
@ -91,7 +91,8 @@ class CUDABlas : public blas::BlasSupport {
bool DoBlasInternal(FuncT cublas_func, Stream *stream, bool pointer_mode_host,
Args... args) {
return DoBlasInternalImpl(cublas_func, stream, pointer_mode_host,
/*err_on_failure=*/true, args...);
/*err_on_failure=*/true, CUBLAS_DEFAULT_MATH,
args...);
}
template <typename FuncT, typename... Args>
bool DoBlasInternalFailureOK(FuncT cublas_func, Stream *stream,
@ -99,7 +100,8 @@ class CUDABlas : public blas::BlasSupport {
// Tensor ops are hard-coded off in this path, but can still be enabled with
// a specific algorithm choice as in DoBlasGemmWithAlgorithmImpl().
return DoBlasInternalImpl(cublas_func, stream, pointer_mode_host,
/*err_on_failure=*/false, args...);
/*err_on_failure=*/false, CUBLAS_DEFAULT_MATH,
args...);
}
// A helper function to implement DoBlasGemmBatched interfaces for generic