Merge pull request #40624 from nluehr:TF32-v2

PiperOrigin-RevId: 317757557
Change-Id: I0a0f0cc9025db7d1fbc7975b07d7d934c6fa8c2f
This commit is contained in:
TensorFlower Gardener 2020-06-22 17:23:41 -07:00
commit 7c38468051
10 changed files with 422 additions and 217 deletions

View File

@ -938,6 +938,13 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "tf32_utils",
srcs = ["tf32_utils.cc"],
hdrs = ["tf32_utils.h"],
copts = tf_copts(),
)
tf_cc_tests( tf_cc_tests(
name = "low_level_library_tests", name = "low_level_library_tests",
size = "small", size = "small",

View File

@ -0,0 +1,30 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/platform/tf32_utils.h"
#include <atomic>
namespace tensorflow {
// Whether TensorFloat-32 should be used where supported.
// TODO(nluehr): Maybe enable by default after TF32 Ampere testing.
static std::atomic<bool> tf32_allowed{false};
void allow_tf32_execution(bool allowed) { tf32_allowed = allowed; }
bool tf32_execution_allowed() { return tf32_allowed; }
} // namespace tensorflow

View File

@ -0,0 +1,27 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_PLATFORM_TF32_UTILS_H_
#define TENSORFLOW_CORE_PLATFORM_TF32_UTILS_H_
namespace tensorflow {
void allow_tf32_execution(bool allowed);
bool tf32_execution_allowed();
} // namespace tensorflow
#endif // TENSORFLOW_CORE_PLATFORM_TF32_UTILS_H_

View File

@ -788,6 +788,16 @@ tf_python_pybind_extension(
], ],
) )
tf_python_pybind_extension(
name = "_pywrap_tf32_execution",
srcs = ["util/tf32.cc"],
module_name = "_pywrap_tf32_execution",
deps = [
"//tensorflow/core/platform:tf32_utils",
"@pybind11",
],
)
tf_python_pybind_extension( tf_python_pybind_extension(
name = "_pywrap_util_port", name = "_pywrap_util_port",
srcs = ["util/port_wrapper.cc"], srcs = ["util/port_wrapper.cc"],
@ -5678,6 +5688,7 @@ py_library(
"//tensorflow:composite_tensor_whitelist", "//tensorflow:composite_tensor_whitelist",
], ],
deps = [ deps = [
":_pywrap_tf32_execution",
":tf_decorator", ":tf_decorator",
":tf_export", ":tf_export",
":tf_stack", ":tf_stack",

View File

@ -18,11 +18,42 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python import _pywrap_tf32_execution
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
# No tf_export until TF is built against CUDA11 which is required for TF32.
def tensor_float_32_execution_allowed():
"""Get if TensorFloat-32 operations are enabled on supported hardware.
Returns:
True if TensorFloat-32 execution is enabled and False otherwise.
"""
return _pywrap_tf32_execution.is_allowed()
# No tf_export until TF is built against CUDA11 which is required for TF32.
def allow_tensor_float_32_execution(allowed):
"""Allow use of TensorFloat-32 with float32 ops on supported hardware.
TensorFloat-32 is a math mode introduced with the NVIDIA Ampere architecture.
TensorFloat-32 kernels take float32 inputs and produce float32 outputs.
Internally, the inputs are cast to a custom representation with 10-bit
mantissa (similar to float16) and 8-bit exponent (similar to float32) and are
executed using TensorCores with float32 accumulation. For more information,
see https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/.
TensorFloat-32 execution is disabled by default, but this may change in a
future version.
Args:
allowed: whether to allow TensorFloat-32 execution
"""
_pywrap_tf32_execution.allow(allowed)
@tf_export('config.threading.get_intra_op_parallelism_threads') @tf_export('config.threading.get_intra_op_parallelism_threads')
def get_intra_op_parallelism_threads(): def get_intra_op_parallelism_threads():
"""Get number of threads used within an individual op for parallelism. """Get number of threads used within an individual op for parallelism.

View File

@ -0,0 +1,22 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "pybind11/pybind11.h"
#include "tensorflow/core/platform/tf32_utils.h"
PYBIND11_MODULE(_pywrap_tf32_execution, m) {
m.def("allow", &tensorflow::allow_tf32_execution);
m.def("is_allowed", &tensorflow::tf32_execution_allowed);
}

View File

@ -251,6 +251,7 @@ cc_library(
"@local_config_cuda//cuda:cuda_headers", "@local_config_cuda//cuda:cuda_headers",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"//tensorflow/core/platform:tf32_utils",
"//tensorflow/stream_executor", "//tensorflow/stream_executor",
"//tensorflow/stream_executor:event", "//tensorflow/stream_executor:event",
"//tensorflow/stream_executor:host_or_device_scalar", "//tensorflow/stream_executor:host_or_device_scalar",
@ -356,6 +357,7 @@ cc_library(
"@local_config_cuda//cuda:cudnn_header", "@local_config_cuda//cuda:cudnn_header",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"//tensorflow/core/platform:tf32_utils",
"//tensorflow/stream_executor:dnn", "//tensorflow/stream_executor:dnn",
"//tensorflow/stream_executor:event", "//tensorflow/stream_executor:event",
"//tensorflow/stream_executor:plugin_registry", "//tensorflow/stream_executor:plugin_registry",

View File

@ -49,6 +49,7 @@ limitations under the License.
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "third_party/eigen3/Eigen/Core" #include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/platform/tf32_utils.h"
#include "tensorflow/core/util/env_var.h" #include "tensorflow/core/util/env_var.h"
#include "tensorflow/stream_executor/cuda/cuda_activation.h" #include "tensorflow/stream_executor/cuda/cuda_activation.h"
#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h" #include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
@ -101,18 +102,6 @@ static std::string ToString(cublasStatus_t status) {
} }
} }
// Decide whether to enable TENSOR_OP_MATH
static bool TensorOpMathEnabled() {
static bool is_enabled = [] {
bool is_disabled;
TF_CHECK_OK(
tensorflow::ReadBoolFromEnvVar("TF_DISABLE_CUBLAS_TENSOR_OP_MATH",
/*default_val=*/false, &is_disabled));
return !is_disabled;
}();
return is_enabled;
}
// cuBLAS has interfaces that permit pointers to be passed from either the host // cuBLAS has interfaces that permit pointers to be passed from either the host
// memory space or the device memory space; however, you must instruct it as to // memory space or the device memory space; however, you must instruct it as to
// which address space those pointers are in with cublasSetPointerMode. // which address space those pointers are in with cublasSetPointerMode.
@ -399,7 +388,7 @@ cudaDataType_t CUDAComputationType(blas::ComputationType ty) {
template <typename FuncT, typename... Args> template <typename FuncT, typename... Args>
bool CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream, bool CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
bool pointer_mode_host, bool err_on_failure, bool pointer_mode_host, bool err_on_failure,
bool use_tensor_op_math, Args... args) { cublasMath_t math_type, Args... args) {
absl::MutexLock lock(&mu_); absl::MutexLock lock(&mu_);
CHECK(blas_ != nullptr); CHECK(blas_ != nullptr);
@ -407,20 +396,26 @@ bool CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
return false; return false;
} }
#if CUDA_VERSION >= 9000
ScopedCublasMathMode math_mode{blas_};
#if CUBLAS_VER_MAJOR >= 11
if (math_type == CUBLAS_TF32_TENSOR_OP_MATH &&
tensorflow::tf32_execution_allowed()) {
#else
if (math_type == CUBLAS_TENSOR_OP_MATH) {
#endif
if (!math_mode.Init(math_type)) {
return false;
}
}
#endif
gpu::ScopedActivateExecutorContext sac{parent_}; gpu::ScopedActivateExecutorContext sac{parent_};
ScopedCublasPointerMode pointer_mode{blas_}; ScopedCublasPointerMode pointer_mode{blas_};
if (!pointer_mode.Init(pointer_mode_host ? CUBLAS_POINTER_MODE_HOST if (!pointer_mode.Init(pointer_mode_host ? CUBLAS_POINTER_MODE_HOST
: CUBLAS_POINTER_MODE_DEVICE)) { : CUBLAS_POINTER_MODE_DEVICE)) {
return false; return false;
} }
#if CUDA_VERSION >= 9000
ScopedCublasMathMode math_mode{blas_};
if (use_tensor_op_math) {
if (!math_mode.Init(CUBLAS_TENSOR_OP_MATH)) {
return false;
}
}
#endif
cublasStatus_t ret = cublas_func(blas_, args...); cublasStatus_t ret = cublas_func(blas_, args...);
if ((err_on_failure || VLOG_IS_ON(3)) && ret != CUBLAS_STATUS_SUCCESS) { if ((err_on_failure || VLOG_IS_ON(3)) && ret != CUBLAS_STATUS_SUCCESS) {
LOG(ERROR) << "failed to run cuBLAS routine: " << ToString(ret); LOG(ERROR) << "failed to run cuBLAS routine: " << ToString(ret);
@ -1633,21 +1628,15 @@ bool CUDABlas::DoBlasGemm(
} }
} }
bool use_tensor_ops = false; #if CUDA_VERSION < 11000
#if CUDA_VERSION >= 9000 cublasMath_t math_type = CUBLAS_TENSOR_OP_MATH;
int cc_major, cc_minor; #else
stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major, cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
&cc_minor);
// GPUs < sm_70 don't support tensor ops.
if (cc_major >= 7 && TensorOpMathEnabled()) {
use_tensor_ops = true;
}
#endif #endif
return DoBlasInternalImpl( return DoBlasInternalImpl(
cublasSgemmEx, stream, true /* = pointer_mode_host */, cublasSgemmEx, stream, true /* = pointer_mode_host */,
true /* = err_on_failure= */, use_tensor_ops, CUDABlasTranspose(transa), true /* = err_on_failure= */, math_type, CUDABlasTranspose(transa),
CUDABlasTranspose(transb), m, n, k, &alpha, GpuMemory(a), CUDABlasTranspose(transb), m, n, k, &alpha, GpuMemory(a),
SE_CUDA_DATA_HALF, lda, GpuMemory(b), SE_CUDA_DATA_HALF, ldb, &beta, SE_CUDA_DATA_HALF, lda, GpuMemory(b), SE_CUDA_DATA_HALF, ldb, &beta,
GpuMemoryMutable(c), SE_CUDA_DATA_HALF, ldc); GpuMemoryMutable(c), SE_CUDA_DATA_HALF, ldc);
@ -1692,10 +1681,18 @@ bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
"precondition violation"; "precondition violation";
} }
} }
return DoBlasInternal(cublasSgemm, stream, true /* = pointer_mode_host */,
CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, #if CUDA_VERSION < 11000
n, k, &alpha, GpuMemory(a), lda, GpuMemory(b), ldb, cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
&beta, GpuMemoryMutable(c), ldc); #else
cublasMath_t math_type = CUBLAS_TF32_TENSOR_OP_MATH;
#endif
return DoBlasInternalImpl(
cublasSgemm, stream, true /* = pointer_mode_host */,
true /* = err_on_failure */, math_type, CUDABlasTranspose(transa),
CUDABlasTranspose(transb), m, n, k, &alpha, GpuMemory(a), lda,
GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc);
} }
bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa, bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
@ -1914,21 +1911,6 @@ static bool UsesTensorOps(blas::AlgorithmType algo) {
#endif #endif
} }
template <typename InType>
static bool TensorOpsAvailable(int cc_major) {
#if CUDA_VERSION >= 9000
// cublas *does* allow tensor ops on inputs that are not fp16, so this is not
// strictly correct. We can't simply enable it, though, as that would change
// clients' behavior significantly: Using tensor ops on fp32 inputs cause them
// to be rounded to fp16.
if (cc_major >= 7 && TensorOpMathEnabled() &&
std::is_same<InType, Eigen::half>::value) {
return true;
}
#endif
return false;
}
template <typename InT, typename OutT, typename CompT> template <typename InT, typename OutT, typename CompT>
bool CUDABlas::DoBlasGemmWithAlgorithmImpl( bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
@ -1947,18 +1929,48 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
return false; return false;
} }
if (UsesTensorOps(algorithm) && !TensorOpsAvailable<InT>(cc_major)) { bool algo_uses_tensor_ops = UsesTensorOps(algorithm);
if (std::is_same<InT, Eigen::half>::value) { cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
if (algo_uses_tensor_ops) {
if (cc_major < 7) {
VLOG(2) << "DoBlasGemmWithAlgorithm returning false because algorithm " VLOG(2) << "DoBlasGemmWithAlgorithm returning false because algorithm "
<< algorithm << algorithm
<< " uses tensor ops, but tensor ops are not available in sm" << " uses tensor ops, but tensor ops are not available in sm"
<< cc_major << "X devices."; << cc_major << "X devices.";
return false;
} else if (std::is_same<InT, float>::value) {
#if CUDA_VERSION < 11000
VLOG(2) << "DoBlasGemmWithAlgorithm returning false because algorithm "
<< algorithm
<< " uses tensor ops, but tensor ops are not available for fp32"
<< " inputs.";
return false;
#else
if (cc_major < 8) {
VLOG(2) << "DoBlasGemmWithAlgorithm returning false because algorithm "
<< algorithm
<< " uses tensor ops, but tensor ops are not available in sm"
<< cc_major << "X devices for float input types.";
return false;
} else if (!tensorflow::tf32_execution_allowed()) {
VLOG(2) << "DoBlasGemmWithAlgorithm returning false because algorithm "
<< algorithm
<< " uses tensor ops, but tensor ops are disabled for fp32"
<< " inputs.";
return false;
}
math_type = CUBLAS_TF32_TENSOR_OP_MATH;
#endif
} else if (std::is_same<InT, Eigen::half>::value) {
#if CUDA_VERSION < 11000
math_type = CUBLAS_TENSOR_OP_MATH;
#endif
} else { } else {
VLOG(2) << "DoBlasGemmWithAlgorithm returning false because algorithm " VLOG(2) << "DoBlasGemmWithAlgorithm returning false because algorithm "
<< algorithm << algorithm
<< " uses tensor ops, but the input data type is not fp16."; << " uses tensor ops, which are not supported for InT.";
return false;
} }
return false;
} }
// Either both 'alpha' and 'beta' need to be pointers to device memory, or // Either both 'alpha' and 'beta' need to be pointers to device memory, or
@ -1998,10 +2010,10 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
// If 'alpha' and 'beta' are host scalars and CompT is Eigen::half, we // If 'alpha' and 'beta' are host scalars and CompT is Eigen::half, we
// essentially reinterpet_cast to __half, which is safe because Eigen::half // essentially reinterpet_cast to __half, which is safe because Eigen::half
// inherits from __half. // inherits from __half.
bool result = DoBlasInternalFailureOK( bool result = DoBlasInternalImpl(
AS_LAMBDA(cublasGemmEx), stream, AS_LAMBDA(cublasGemmEx), stream,
/* pointer_mode_host = */ !alpha.is_pointer(), CUDABlasTranspose(transa), /* pointer_mode_host = */ !alpha.is_pointer(), /*err_on_failure=*/false,
CUDABlasTranspose(transb), m, n, k, math_type, CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
alpha.is_pointer() ? GpuMemory(alpha.pointer()) : &alpha.value(), alpha.is_pointer() ? GpuMemory(alpha.pointer()) : &alpha.value(),
GpuMemory(a), cuda_in_type, lda, GpuMemory(b), cuda_in_type, ldb, GpuMemory(a), cuda_in_type, lda, GpuMemory(b), cuda_in_type, ldb,
beta.is_pointer() ? GpuMemory(beta.pointer()) : &beta.value(), beta.is_pointer() ? GpuMemory(beta.pointer()) : &beta.value(),
@ -2270,9 +2282,27 @@ port::Status CUDABlas::DoBlasGemmBatchedInternal(
if (stream->parent()->GetDeviceDescription().cuda_compute_capability( if (stream->parent()->GetDeviceDescription().cuda_compute_capability(
&cc_major, &cc_minor) && &cc_major, &cc_minor) &&
cc_major >= 5) { cc_major >= 5) {
bool use_tensor_ops = TensorOpMathEnabled() && data_type == CUDA_R_16F; cublasMath_t math_type;
cublasGemmAlgo_t algo = cublasGemmAlgo_t algo;
(use_tensor_ops ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT); if (data_type == CUDA_R_16F) {
#if CUDA_VERSION < 11000
math_type = CUBLAS_TENSOR_OP_MATH;
#else
math_type = CUBLAS_DEFAULT_MATH;
#endif
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
#if CUBLAS_VER_MAJOR >= 11
} else if (data_type == CUDA_R_32F) {
// DoBlassInternalImpl will switch math_type back to CUBLAS_DEFAULT_MATH
// if TF32 is disabled.
math_type = CUBLAS_TF32_TENSOR_OP_MATH;
algo = tensorflow::tf32_execution_allowed() ? CUBLAS_GEMM_DFALT_TENSOR_OP
: CUBLAS_GEMM_DFALT;
#endif
} else {
math_type = CUBLAS_DEFAULT_MATH;
algo = CUBLAS_GEMM_DFALT;
}
cudaDataType_t compute_type = cudaDataType_t compute_type =
(data_type == CUDA_R_16F ? CUDA_R_32F : data_type); (data_type == CUDA_R_16F ? CUDA_R_32F : data_type);
const void **a_void_ptrs = reinterpret_cast<const void **>( const void **a_void_ptrs = reinterpret_cast<const void **>(
@ -2284,7 +2314,7 @@ port::Status CUDABlas::DoBlasGemmBatchedInternal(
bool ok; bool ok;
ok = DoBlasInternalImpl( ok = DoBlasInternalImpl(
AS_LAMBDA(cublasGemmBatchedEx), stream, true /* = pointer_mode_host */, AS_LAMBDA(cublasGemmBatchedEx), stream, true /* = pointer_mode_host */,
true /* = err_on_failure */, use_tensor_ops, CUDABlasTranspose(transa), true /* = err_on_failure */, math_type, CUDABlasTranspose(transa),
CUDABlasTranspose(transb), m, n, k, &alpha, a_void_ptrs, data_type, lda, CUDABlasTranspose(transb), m, n, k, &alpha, a_void_ptrs, data_type, lda,
b_void_ptrs, data_type, ldb, &beta, c_void_ptrs, data_type, ldc, b_void_ptrs, data_type, ldb, &beta, c_void_ptrs, data_type, ldc,
batch_count, compute_type, algo); batch_count, compute_type, algo);
@ -2419,33 +2449,30 @@ bool CUDABlas::DoBlasGemmStridedBatched(
int lda, int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb, int lda, int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb,
int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc, int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc,
int64 stride_c, int batch_count) { int64 stride_c, int batch_count) {
bool use_tensor_ops = false; #if CUDA_VERSION >= 9010
#if CUDA_VERSION >= 9000
int cc_major, cc_minor; int cc_major, cc_minor;
if (stream->parent()->GetDeviceDescription().cuda_compute_capability( if (stream->parent()->GetDeviceDescription().cuda_compute_capability(
&cc_major, &cc_minor)) { &cc_major, &cc_minor) &&
// GPUs < sm_70 don't support tensor ops. cc_major >= 5) {
if (cc_major >= 7 && TensorOpMathEnabled()) { cublasGemmAlgo_t algo =
use_tensor_ops = true; (cc_major >= 7 ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT);
} #if CUDA_VERSION < 11000
#if CUDA_VERSION >= 9010 cublasMath_t math_type = CUBLAS_TENSOR_OP_MATH;
if (cc_major >= 5) { #else
cublasGemmAlgo_t algo = cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
(use_tensor_ops ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT);
bool ok = DoBlasInternalImpl(
AS_LAMBDA(cublasGemmStridedBatchedEx), stream,
true /* = pointer_mode_host */, true /* = err_on_failure */,
use_tensor_ops, CUDABlasTranspose(transa), CUDABlasTranspose(transb),
m, n, k, &alpha, GpuMemory(a), CUDA_R_16F, lda, stride_a,
GpuMemory(b), CUDA_R_16F, ldb, stride_b, &beta, GpuMemoryMutable(c),
CUDA_R_16F, ldc, stride_c, batch_count, CUDA_R_32F, algo);
if (ok) {
return true;
}
LOG(ERROR) << "failed BLAS call, see log for details";
return false;
}
#endif #endif
bool ok = DoBlasInternalImpl(
AS_LAMBDA(cublasGemmStridedBatchedEx), stream,
true /* = pointer_mode_host */, true /* = err_on_failure */, math_type,
CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
GpuMemory(a), CUDA_R_16F, lda, stride_a, GpuMemory(b), CUDA_R_16F, ldb,
stride_b, &beta, GpuMemoryMutable(c), CUDA_R_16F, ldc, stride_c,
batch_count, CUDA_R_32F, algo);
if (ok) {
return true;
}
LOG(ERROR) << "failed BLAS call, see log for details";
return false;
} }
#endif #endif
// Either CUDA_VERSION < 9.1 or SM < 5.0. Fall back to a loop. // Either CUDA_VERSION < 9.1 or SM < 5.0. Fall back to a loop.
@ -2458,10 +2485,10 @@ bool CUDABlas::DoBlasGemmStridedBatched(
reinterpret_cast<__half *>(GpuMemoryMutable(c) + batch * stride_c); reinterpret_cast<__half *>(GpuMemoryMutable(c) + batch * stride_c);
bool ok = DoBlasInternalImpl( bool ok = DoBlasInternalImpl(
cublasSgemmEx, stream, true /* = pointer_mode_host */, cublasSgemmEx, stream, true /* = pointer_mode_host */,
true /* = err_on_failure= */, use_tensor_ops, CUDABlasTranspose(transa), true /* = err_on_failure= */, CUBLAS_DEFAULT_MATH,
CUDABlasTranspose(transb), m, n, k, &alpha, a_matrix, SE_CUDA_DATA_HALF, CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
lda, b_matrix, SE_CUDA_DATA_HALF, ldb, &beta, c_matrix, a_matrix, SE_CUDA_DATA_HALF, lda, b_matrix, SE_CUDA_DATA_HALF, ldb,
SE_CUDA_DATA_HALF, ldc); &beta, c_matrix, SE_CUDA_DATA_HALF, ldc);
if (!ok) { if (!ok) {
LOG(ERROR) << "failed BLAS call, see log for details"; LOG(ERROR) << "failed BLAS call, see log for details";
return false; return false;
@ -2476,11 +2503,17 @@ bool CUDABlas::DoBlasGemmStridedBatched(
int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b, int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
float beta, DeviceMemory<float> *c, int ldc, int64 stride_c, float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
int batch_count) { int batch_count) {
return DoBlasInternal( #if CUDA_VERSION < 11000
cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
#else
cublasMath_t math_type = CUBLAS_TF32_TENSOR_OP_MATH;
#endif
return DoBlasInternalImpl(
cublasSgemmStridedBatched, stream, true /* = pointer_mode_host */, cublasSgemmStridedBatched, stream, true /* = pointer_mode_host */,
CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha, true /* = err_on_failure */, math_type, CUDABlasTranspose(transa),
GpuMemory(a), lda, stride_a, GpuMemory(b), ldb, stride_b, &beta, CUDABlasTranspose(transb), m, n, k, &alpha, GpuMemory(a), lda, stride_a,
GpuMemoryMutable(c), ldc, stride_c, batch_count); GpuMemory(b), ldb, stride_b, &beta, GpuMemoryMutable(c), ldc, stride_c,
batch_count);
} }
bool CUDABlas::DoBlasGemmStridedBatched( bool CUDABlas::DoBlasGemmStridedBatched(

View File

@ -21,6 +21,7 @@ limitations under the License.
#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_ #define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_
#include "absl/synchronization/mutex.h" #include "absl/synchronization/mutex.h"
#include "third_party/gpus/cuda/include/cublas_v2.h"
#include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/stream_executor/blas.h" #include "tensorflow/stream_executor/blas.h"
#include "tensorflow/stream_executor/host_or_device_scalar.h" #include "tensorflow/stream_executor/host_or_device_scalar.h"
@ -83,26 +84,17 @@ class CUDABlas : public blas::BlasSupport {
template <typename FuncT, typename... Args> template <typename FuncT, typename... Args>
bool DoBlasInternalImpl(FuncT cublas_func, Stream *stream, bool DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
bool pointer_mode_host, bool err_on_failure, bool pointer_mode_host, bool err_on_failure,
bool use_tensor_op_math, Args... args); cublasMath_t math_type, Args... args);
// Convenience functions that call DoBlasInternalImpl with different values // Convenience functions that call DoBlasInternalImpl with err_on_failure=true
// for err_on_failure. // and math_type=CUBLAS_DEFAULT_MATH.
template <typename FuncT, typename... Args> template <typename FuncT, typename... Args>
bool DoBlasInternal(FuncT cublas_func, Stream *stream, bool pointer_mode_host, bool DoBlasInternal(FuncT cublas_func, Stream *stream, bool pointer_mode_host,
Args... args) { Args... args) {
return DoBlasInternalImpl(cublas_func, stream, pointer_mode_host, return DoBlasInternalImpl(cublas_func, stream, pointer_mode_host,
/*err_on_failure=*/true, /*use_tensor_ops=*/false, /*err_on_failure=*/true, CUBLAS_DEFAULT_MATH,
args...); args...);
} }
template <typename FuncT, typename... Args>
bool DoBlasInternalFailureOK(FuncT cublas_func, Stream *stream,
bool pointer_mode_host, Args... args) {
// Tensor ops are hard-coded off in this path, but can still be enabled with
// a specific algorithm choice as in DoBlasGemmWithAlgorithmImpl().
return DoBlasInternalImpl(cublas_func, stream, pointer_mode_host,
/*err_on_failure=*/false,
/*use_tensor_ops=*/false, args...);
}
// A helper function to implement DoBlasGemmBatched interfaces for generic // A helper function to implement DoBlasGemmBatched interfaces for generic
// types. // types.

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "third_party/eigen3/Eigen/Core" #include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/tf32_utils.h"
#include "tensorflow/core/util/env_var.h" #include "tensorflow/core/util/env_var.h"
#include "tensorflow/stream_executor/cuda/cuda_activation.h" #include "tensorflow/stream_executor/cuda/cuda_activation.h"
#include "tensorflow/stream_executor/cuda/cuda_diagnostics.h" #include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
@ -601,31 +602,6 @@ class CudnnFilterDescriptor {
SE_DISALLOW_COPY_AND_ASSIGN(CudnnFilterDescriptor); SE_DISALLOW_COPY_AND_ASSIGN(CudnnFilterDescriptor);
}; };
// A helper function to decide whether to enable the TENSOR_OP_MATH math type
bool TensorOpMathEnabled() {
static bool is_enabled = [] {
bool is_disabled = false;
TF_CHECK_OK(
tensorflow::ReadBoolFromEnvVar("TF_DISABLE_CUDNN_TENSOR_OP_MATH",
/*default_val=*/false, &is_disabled));
return !is_disabled;
}();
return is_enabled;
}
// A helper function to decide whether to enable the TENSOR_OP_MATH math type
// for RNNs.
bool RnnTensorOpMathEnabled() {
static bool is_enabled = [] {
bool is_disabled = false;
TF_CHECK_OK(
tensorflow::ReadBoolFromEnvVar("TF_DISABLE_CUDNN_RNN_TENSOR_OP_MATH",
/*default_val=*/false, &is_disabled));
return !is_disabled;
}();
return is_enabled;
}
// A helper function to decide whether to use // A helper function to decide whether to use
// CUDNN_BATCHNORM_SPATIAL_PERSISTENT in batchnorm. This mode can be faster in // CUDNN_BATCHNORM_SPATIAL_PERSISTENT in batchnorm. This mode can be faster in
// some tasks because an optimized path may be selected for CUDNN_DATA_FLOAT // some tasks because an optimized path may be selected for CUDNN_DATA_FLOAT
@ -730,10 +706,6 @@ class CudnnConvolutionDescriptor {
: CUDNN_CROSS_CORRELATION, : CUDNN_CROSS_CORRELATION,
data_type)); data_type));
// NOTE(benbarsdell): This only applies if tensor op math is enabled
// and algo selection is set to Default.
this->set_use_tensor_op_math(true);
#if CUDNN_MAJOR >= 7 #if CUDNN_MAJOR >= 7
VLOG(2) << "Requesting grouped convolution: " VLOG(2) << "Requesting grouped convolution: "
<< convolution_descriptor.group_count(); << convolution_descriptor.group_count();
@ -745,13 +717,15 @@ class CudnnConvolutionDescriptor {
#endif #endif
} }
void set_use_tensor_op_math(bool use_tensor_op_math) const { void set_use_tensor_op_math(bool use_tensor_op_math) {
#if CUDNN_VERSION >= 7000 #if CUDNN_VERSION >= 7000
cudnnMathType_t math_type = cudnnMathType_t math_type =
#if CUDNN_VERSION >= 8000
(use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_FMA_MATH);
#else
(use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH); (use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH);
if (TensorOpMathEnabled()) { #endif
CHECK_CUDNN_OK(cudnnSetConvolutionMathType(handle_.get(), math_type)); CHECK_CUDNN_OK(cudnnSetConvolutionMathType(handle_.get(), math_type));
}
#endif #endif
} }
@ -763,6 +737,40 @@ class CudnnConvolutionDescriptor {
SE_DISALLOW_COPY_AND_ASSIGN(CudnnConvolutionDescriptor); SE_DISALLOW_COPY_AND_ASSIGN(CudnnConvolutionDescriptor);
}; };
// A helper function to query if a CudnnConvolutionDescriptor has tensor_op_math
// set
static bool IsTensorMathOpSet(const CudnnConvolutionDescriptor& conv) {
cudnnMathType_t math_type;
CHECK_CUDNN_OK(cudnnGetConvolutionMathType(conv.handle(), &math_type));
#if CUDNN_VERSION >= 8000
return math_type != CUDNN_FMA_MATH;
#else
return math_type == CUDNN_TENSOR_OP_MATH;
#endif
}
static bool TensorOpMathAvailable(int cc_major) {
return cc_major >= 7 && CUDNN_VERSION >= 7000;
}
static bool IsTensorMathAllowed(Stream* stream, dnn::DataType input_type) {
int cc_major, cc_minor;
std::tie(cc_major, cc_minor) = GetCcMajorMinor(stream);
if (!TensorOpMathAvailable(cc_major)) {
return false;
}
if (input_type == dnn::DataType::kFloat) {
#if CUDNN_VERSION < 8000
return false;
#else
if (!tensorflow::tf32_execution_allowed()) {
return false;
}
#endif
}
return true;
}
// Turns a PoolingDescriptor structure into a cudnn pooling descriptor handle // Turns a PoolingDescriptor structure into a cudnn pooling descriptor handle
// within a scope. // within a scope.
class CudnnPoolingDescriptor { class CudnnPoolingDescriptor {
@ -1155,21 +1163,27 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
// in profile mode, which is run with algorithms returned from // in profile mode, which is run with algorithms returned from
// GetRnnAlgorithms() (which are non-default and explicitly set whether to // GetRnnAlgorithms() (which are non-default and explicitly set whether to
// use tensor ops). CuDNN 7.2.1 fixed this issue // use tensor ops). CuDNN 7.2.1 fixed this issue
if (RnnTensorOpMathEnabled()) { bool allow_tensor_ops =
cudnnMathType_t math_type; data_type != CUDNN_DATA_FLOAT || tensorflow::tf32_execution_allowed();
if (algorithm_config.algorithm().has_value()) { bool use_tensor_ops;
math_type = algorithm_config.algorithm()->tensor_ops_enabled() if (algorithm_config.algorithm().has_value()) {
? CUDNN_TENSOR_OP_MATH use_tensor_ops = algorithm_config.algorithm()->tensor_ops_enabled();
: CUDNN_DEFAULT_MATH; } else {
} else { use_tensor_ops = CUDNN_VERSION >= 7201 && allow_tensor_ops;
#if CUDNN_VERSION >= 7201
math_type = CUDNN_TENSOR_OP_MATH;
#else
math_type = CUDNN_DEFAULT_MATH;
#endif // CUDNN_VERSION >= 7201
}
CHECK_CUDNN_OK(cudnnSetRNNMatrixMathType(rnn_desc.get(), math_type));
} }
if (use_tensor_ops && !allow_tensor_ops) {
return port::Status(port::error::INVALID_ARGUMENT,
"Algo requests disallowed tensor op evaluation.");
}
cudnnMathType_t math_type;
#if CUDNN_VERSION >= 8000
math_type = use_tensor_ops ? CUDNN_TENSOR_OP_MATH : CUDNN_FMA_MATH;
#else
math_type = use_tensor_ops ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH;
#endif
CHECK_CUDNN_OK(cudnnSetRNNMatrixMathType(rnn_desc.get(), math_type));
#endif // CUDNN_VERSION >= 7000 #endif // CUDNN_VERSION >= 7000
return CudnnRnnDescriptor(cudnn, std::move(rnn_desc), std::move(rnn_plan), return CudnnRnnDescriptor(cudnn, std::move(rnn_desc), std::move(rnn_plan),
@ -2560,10 +2574,11 @@ port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionForwardWorkspace(
const CudnnTensorDescriptor& output_nd, const CudnnTensorDescriptor& output_nd,
const dnn::AlgorithmDesc& algorithm_desc, const dnn::AlgorithmDesc& algorithm_desc,
ScratchAllocator* scratch_allocator) { ScratchAllocator* scratch_allocator) {
// TODO(csigg): This has side effects on the convolution descriptor. It is if (IsTensorMathOpSet(conv) != algorithm_desc.tensor_ops_enabled()) {
// functionally correct because the convolution is run with the algorithm of return port::Status(
// the last call to this function, but should be fixed anyway. port::error::INTERNAL,
conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled()); "Mismatch between cudnn conv and algorithm descriptors.");
}
// Query the size of the workspace and allocate it. // Query the size of the workspace and allocate it.
size_t size_in_bytes; size_t size_in_bytes;
@ -2603,10 +2618,11 @@ AllocateCudnnConvolutionBackwardDataWorkspace(
const CudnnTensorDescriptor& output_nd, const CudnnTensorDescriptor& output_nd,
const dnn::AlgorithmDesc& algorithm_desc, const dnn::AlgorithmDesc& algorithm_desc,
ScratchAllocator* scratch_allocator) { ScratchAllocator* scratch_allocator) {
// TODO(csigg): This has side effects on the convolution descriptor. It is if (IsTensorMathOpSet(conv) != algorithm_desc.tensor_ops_enabled()) {
// functionally correct because the convolution is run with the algorithm of return port::Status(
// the last call to this function, but should be fixed anyway. port::error::INTERNAL,
conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled()); "Mismatch between cudnn conv and algorithm descriptors.");
}
// Query the size of the workspace and allocate it. // Query the size of the workspace and allocate it.
size_t size_in_bytes; size_t size_in_bytes;
@ -2648,10 +2664,11 @@ AllocateCudnnConvolutionBackwardFilterWorkspace(
const CudnnTensorDescriptor& output_nd, const CudnnTensorDescriptor& output_nd,
const dnn::AlgorithmDesc& algorithm_desc, const dnn::AlgorithmDesc& algorithm_desc,
ScratchAllocator* scratch_allocator) { ScratchAllocator* scratch_allocator) {
// TODO(csigg): This has side effects on the convolution descriptor. It is if (IsTensorMathOpSet(conv) != algorithm_desc.tensor_ops_enabled()) {
// functionally correct because the convolution is run with the algorithm of return port::Status(
// the last call to this function, but should be fixed anyway. port::error::INTERNAL,
conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled()); "Mismatch between cudnn conv and algorithm descriptors.");
}
// Query the size of the workspace and allocate it. // Query the size of the workspace and allocate it.
size_t size_in_bytes; size_t size_in_bytes;
@ -2685,18 +2702,42 @@ AllocateCudnnConvolutionBackwardFilterWorkspace(
return scratch_allocator->AllocateBytes(size_in_bytes); return scratch_allocator->AllocateBytes(size_in_bytes);
} }
static bool TensorOpMathAvailable(int cc_major) { port::StatusOr<bool> UseTensorOps(Stream* stream, dnn::DataType type,
return cc_major >= 7 && CUDNN_VERSION >= 7000 && TensorOpMathEnabled(); absl::optional<dnn::AlgorithmDesc> desc) {
bool use_tensor_ops;
if (desc.has_value()) {
use_tensor_ops = desc->tensor_ops_enabled();
if (use_tensor_ops && !IsTensorMathAllowed(stream, type)) {
return port::Status(port::error::INVALID_ARGUMENT,
"Algo requests disallowed tensor op evaluation.");
}
} else {
use_tensor_ops = IsTensorMathAllowed(stream, type);
}
return use_tensor_ops;
} }
cudnnDataType_t GetRnnComputeType(dnn::DataType data_type);
dnn::DataType GetConvAccumulatorType(dnn::DataType data_type);
port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionForwardAlgorithm( port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionForwardAlgorithm(
Stream* stream, const CudnnHandle& cudnn, Stream* stream, const CudnnHandle& cudnn,
const dnn::AlgorithmConfig& algorithm_config, const dnn::AlgorithmConfig& algorithm_config,
const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter, const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
const CudnnConvolutionDescriptor& conv, dnn::DataType element_type,
const dnn::ConvolutionDescriptor& convolution_descriptor,
const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator, const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
DeviceMemory<uint8>* scratch) { DeviceMemory<uint8>* scratch) {
absl::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm(); absl::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm();
CudnnConvolutionDescriptor conv(
convolution_descriptor,
ToCudnnDataType(GetConvAccumulatorType(element_type)));
bool use_tensor_ops;
SE_ASSIGN_OR_RETURN(use_tensor_ops,
UseTensorOps(stream, element_type, algo_desc));
conv.set_use_tensor_op_math(use_tensor_ops);
if (!algo_desc.has_value()) { if (!algo_desc.has_value()) {
// Pick fastest algorithm within memory limit according to cuDNN's // Pick fastest algorithm within memory limit according to cuDNN's
// heuristics. // heuristics.
@ -2709,10 +2750,7 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionForwardAlgorithm(
GetCudnnConvolutionForwardAlgo( GetCudnnConvolutionForwardAlgo(
cudnn, input_nd, filter, conv, output_nd, cudnn, input_nd, filter, conv, output_nd,
specify_workspace_limit, memory_limit_bytes)); specify_workspace_limit, memory_limit_bytes));
int cc_major, cc_minor; algo_desc = dnn::AlgorithmDesc(algo, use_tensor_ops);
std::tie(cc_major, cc_minor) = GetCcMajorMinor(stream);
algo_desc = dnn::AlgorithmDesc(
algo, /*use_tensor_ops=*/TensorOpMathAvailable(cc_major));
} }
const auto scratch_or = AllocateCudnnConvolutionForwardWorkspace( const auto scratch_or = AllocateCudnnConvolutionForwardWorkspace(
@ -2736,6 +2774,9 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionForwardAlgorithm(
"Returned status: ", scratch_or.status().ToString())); "Returned status: ", scratch_or.status().ToString()));
} }
SE_ASSIGN_OR_RETURN(use_tensor_ops,
UseTensorOps(stream, element_type, algo_desc));
conv.set_use_tensor_op_math(use_tensor_ops);
SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionForwardWorkspace( SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionForwardWorkspace(
stream, cudnn, input_nd, filter, conv, stream, cudnn, input_nd, filter, conv,
output_nd, *algo_desc, scratch_allocator)); output_nd, *algo_desc, scratch_allocator));
@ -2746,10 +2787,19 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardDataAlgorithm(
Stream* stream, const CudnnHandle& cudnn, Stream* stream, const CudnnHandle& cudnn,
const dnn::AlgorithmConfig& algorithm_config, const dnn::AlgorithmConfig& algorithm_config,
const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter, const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
const CudnnConvolutionDescriptor& conv, dnn::DataType element_type,
const dnn::ConvolutionDescriptor& convolution_descriptor,
const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator, const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
DeviceMemory<uint8>* scratch) { DeviceMemory<uint8>* scratch) {
absl::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm(); absl::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm();
CudnnConvolutionDescriptor conv(
convolution_descriptor,
ToCudnnDataType(GetConvAccumulatorType(element_type)));
bool use_tensor_ops;
SE_ASSIGN_OR_RETURN(use_tensor_ops,
UseTensorOps(stream, element_type, algo_desc));
conv.set_use_tensor_op_math(use_tensor_ops);
if (!algo_desc.has_value()) { if (!algo_desc.has_value()) {
// Pick fastest algorithm within memory limit according to cuDNN's // Pick fastest algorithm within memory limit according to cuDNN's
// heuristics. // heuristics.
@ -2762,10 +2812,7 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardDataAlgorithm(
GetCudnnConvolutionBackwardDataAlgo( GetCudnnConvolutionBackwardDataAlgo(
cudnn, input_nd, filter, conv, output_nd, cudnn, input_nd, filter, conv, output_nd,
specify_workspace_limit, memory_limit_bytes)); specify_workspace_limit, memory_limit_bytes));
int cc_major, cc_minor; algo_desc = dnn::AlgorithmDesc(algo, use_tensor_ops);
std::tie(cc_major, cc_minor) = GetCcMajorMinor(stream);
algo_desc = dnn::AlgorithmDesc(
algo, /*use_tensor_ops=*/TensorOpMathAvailable(cc_major));
} }
const auto scratch_or = AllocateCudnnConvolutionBackwardDataWorkspace( const auto scratch_or = AllocateCudnnConvolutionBackwardDataWorkspace(
@ -2788,6 +2835,9 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardDataAlgorithm(
"while a secondary algorithm is not provided."); "while a secondary algorithm is not provided.");
} }
SE_ASSIGN_OR_RETURN(use_tensor_ops,
UseTensorOps(stream, element_type, algo_desc));
conv.set_use_tensor_op_math(use_tensor_ops);
SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardDataWorkspace( SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardDataWorkspace(
stream, cudnn, input_nd, filter, conv, stream, cudnn, input_nd, filter, conv,
output_nd, *algo_desc, scratch_allocator)); output_nd, *algo_desc, scratch_allocator));
@ -2798,10 +2848,19 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardFilterAlgorithm(
Stream* stream, const CudnnHandle& cudnn, Stream* stream, const CudnnHandle& cudnn,
const dnn::AlgorithmConfig& algorithm_config, const dnn::AlgorithmConfig& algorithm_config,
const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter, const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
const CudnnConvolutionDescriptor& conv, dnn::DataType element_type,
const dnn::ConvolutionDescriptor& convolution_descriptor,
const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator, const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
DeviceMemory<uint8>* scratch) { DeviceMemory<uint8>* scratch) {
absl::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm(); absl::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm();
CudnnConvolutionDescriptor conv(
convolution_descriptor,
ToCudnnDataType(GetConvAccumulatorType(element_type)));
bool use_tensor_ops;
SE_ASSIGN_OR_RETURN(use_tensor_ops,
UseTensorOps(stream, element_type, algo_desc));
conv.set_use_tensor_op_math(use_tensor_ops);
if (!algo_desc.has_value()) { if (!algo_desc.has_value()) {
// Pick fastest algorithm within memory limit according to cuDNN's // Pick fastest algorithm within memory limit according to cuDNN's
// heuristics. // heuristics.
@ -2814,10 +2873,7 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardFilterAlgorithm(
GetCudnnConvolutionBackwardFilterAlgo( GetCudnnConvolutionBackwardFilterAlgo(
cudnn, input_nd, filter, conv, output_nd, cudnn, input_nd, filter, conv, output_nd,
specify_workspace_limit, memory_limit_bytes)); specify_workspace_limit, memory_limit_bytes));
int cc_major, cc_minor; algo_desc = dnn::AlgorithmDesc(algo, use_tensor_ops);
std::tie(cc_major, cc_minor) = GetCcMajorMinor(stream);
algo_desc = dnn::AlgorithmDesc(
algo, /*use_tensor_ops=*/TensorOpMathAvailable(cc_major));
} }
auto scratch_or = AllocateCudnnConvolutionBackwardFilterWorkspace( auto scratch_or = AllocateCudnnConvolutionBackwardFilterWorkspace(
@ -2840,6 +2896,9 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardFilterAlgorithm(
"while a secondary algorithm is not provided."); "while a secondary algorithm is not provided.");
} }
SE_ASSIGN_OR_RETURN(use_tensor_ops,
UseTensorOps(stream, element_type, algo_desc));
conv.set_use_tensor_op_math(use_tensor_ops);
SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardFilterWorkspace( SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardFilterWorkspace(
stream, cudnn, input_nd, filter, conv, stream, cudnn, input_nd, filter, conv,
output_nd, *algo_desc, scratch_allocator)); output_nd, *algo_desc, scratch_allocator));
@ -3004,35 +3063,32 @@ port::Status CudnnSupport::DoPrepareForConvolution(
CudnnTensorDescriptor output_nd( CudnnTensorDescriptor output_nd(
output_descriptor, output_descriptor,
ToCudnnDataType(element_type, output_descriptor.layout())); ToCudnnDataType(element_type, output_descriptor.layout()));
CudnnConvolutionDescriptor conv(
convolution_descriptor,
ToCudnnDataType(GetConvAccumulatorType(element_type)));
auto cudnn = cudnn_->GetHandle(parent_, stream); auto cudnn = cudnn_->GetHandle(parent_, stream);
switch (kind) { switch (kind) {
case dnn::ConvolutionKind::FORWARD: { case dnn::ConvolutionKind::FORWARD: {
SE_ASSIGN_OR_RETURN( SE_ASSIGN_OR_RETURN(*algorithm_desc,
*algorithm_desc, GetCudnnConvolutionForwardAlgorithm(
GetCudnnConvolutionForwardAlgorithm( stream, cudnn, algorithm_config, input_nd,
stream, cudnn, algorithm_config, input_nd, filter_nd, conv, filter_nd, element_type, convolution_descriptor,
output_nd, scratch_allocator, scratch_memory)); output_nd, scratch_allocator, scratch_memory));
break; break;
} }
case dnn::ConvolutionKind::BACKWARD_DATA: { case dnn::ConvolutionKind::BACKWARD_DATA: {
SE_ASSIGN_OR_RETURN( SE_ASSIGN_OR_RETURN(*algorithm_desc,
*algorithm_desc, GetCudnnConvolutionBackwardDataAlgorithm(
GetCudnnConvolutionBackwardDataAlgorithm( stream, cudnn, algorithm_config, input_nd,
stream, cudnn, algorithm_config, input_nd, filter_nd, conv, filter_nd, element_type, convolution_descriptor,
output_nd, scratch_allocator, scratch_memory)); output_nd, scratch_allocator, scratch_memory));
break; break;
} }
case dnn::ConvolutionKind::BACKWARD_FILTER: { case dnn::ConvolutionKind::BACKWARD_FILTER: {
SE_ASSIGN_OR_RETURN( SE_ASSIGN_OR_RETURN(*algorithm_desc,
*algorithm_desc, GetCudnnConvolutionBackwardFilterAlgorithm(
GetCudnnConvolutionBackwardFilterAlgorithm( stream, cudnn, algorithm_config, input_nd,
stream, cudnn, algorithm_config, input_nd, filter_nd, conv, filter_nd, element_type, convolution_descriptor,
output_nd, scratch_allocator, scratch_memory)); output_nd, scratch_allocator, scratch_memory));
break; break;
} }
default: default:
@ -3061,8 +3117,9 @@ port::Status CudnnSupport::DoConvolve(
auto accumulator_type = GetConvAccumulatorType(element_type); auto accumulator_type = GetConvAccumulatorType(element_type);
CudnnConvolutionDescriptor conv(convolution_descriptor, CudnnConvolutionDescriptor conv(convolution_descriptor,
ToCudnnDataType(accumulator_type)); ToCudnnDataType(accumulator_type));
// Set use_tensor_math param to correct value SE_ASSIGN_OR_RETURN(bool use_tensor_ops,
conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled()); UseTensorOps(stream, element_type, algorithm_desc));
conv.set_use_tensor_op_math(use_tensor_ops);
auto cudnn = cudnn_->GetHandle(parent_, stream); auto cudnn = cudnn_->GetHandle(parent_, stream);
// Alpha is the scaling factor for input. // Alpha is the scaling factor for input.
@ -3295,14 +3352,6 @@ port::Status CudnnSupport::DoConvolve(
return port::Status::OK(); return port::Status::OK();
} }
// A helper function to query if a CudnnConvolutionDescriptor has tensor_op_math
// set
static bool IsTensorMathOpSet(const CudnnConvolutionDescriptor& conv) {
cudnnMathType_t math_type;
CHECK_CUDNN_OK(cudnnGetConvolutionMathType(conv.handle(), &math_type));
return math_type == CUDNN_TENSOR_OP_MATH;
}
template <typename ElementType, typename BiasType, typename ScaleType, template <typename ElementType, typename BiasType, typename ScaleType,
typename OutputType> typename OutputType>
port::Status CudnnSupport::DoFusedConvolveImpl( port::Status CudnnSupport::DoFusedConvolveImpl(
@ -3336,8 +3385,6 @@ port::Status CudnnSupport::DoFusedConvolveImpl(
filter_descriptor, filter_descriptor,
GetCudnnDataType<ElementType>(conv_input_descriptor.layout())); GetCudnnDataType<ElementType>(conv_input_descriptor.layout()));
CudnnTensorDescriptor bias_nd(bias_descriptor, GetCudnnDataType<BiasType>()); CudnnTensorDescriptor bias_nd(bias_descriptor, GetCudnnDataType<BiasType>());
CudnnConvolutionDescriptor conv(convolution_descriptor,
ToCudnnDataType(accumulator_type));
auto cudnn = cudnn_->GetHandle(parent_, stream); auto cudnn = cudnn_->GetHandle(parent_, stream);
@ -3347,9 +3394,14 @@ port::Status CudnnSupport::DoFusedConvolveImpl(
SE_ASSIGN_OR_RETURN( SE_ASSIGN_OR_RETURN(
dnn::AlgorithmDesc algo_desc, dnn::AlgorithmDesc algo_desc,
GetCudnnConvolutionForwardAlgorithm( GetCudnnConvolutionForwardAlgorithm(
stream, cudnn, algorithm_config, conv_input_nd, filter, conv, stream, cudnn, algorithm_config, conv_input_nd, filter,
dnn::ToDataType<ElementType>::value, convolution_descriptor,
output_nd, scratch_allocator, &scratch)); output_nd, scratch_allocator, &scratch));
CudnnConvolutionDescriptor conv(convolution_descriptor,
ToCudnnDataType(accumulator_type));
conv.set_use_tensor_op_math(algo_desc.tensor_ops_enabled());
std::unique_ptr<GpuTimer, GpuTimerDeleter> timer; std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
if (is_profiling) { if (is_profiling) {
timer.reset(new GpuTimer(parent_)); // NOLINT timer.reset(new GpuTimer(parent_)); // NOLINT
@ -3480,9 +3532,7 @@ bool CudnnSupport::GetRnnAlgorithms(
for (auto i : algo_types) { for (auto i : algo_types) {
out_algorithms->push_back({i, /*use_tensor_ops=*/false}); out_algorithms->push_back({i, /*use_tensor_ops=*/false});
#if CUDNN_VERSION >= 7100 #if CUDNN_VERSION >= 7100
if (RnnTensorOpMathEnabled()) { out_algorithms->push_back({i, /*use_tensor_ops=*/true});
out_algorithms->push_back({i, /*use_tensor_ops=*/true});
}
#endif #endif
} }
return true; return true;