diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD index 70bb8a89417..33a1e7cfe0a 100644 --- a/tensorflow/core/platform/BUILD +++ b/tensorflow/core/platform/BUILD @@ -938,6 +938,13 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "tf32_utils", + srcs = ["tf32_utils.cc"], + hdrs = ["tf32_utils.h"], + copts = tf_copts(), +) + tf_cc_tests( name = "low_level_library_tests", size = "small", diff --git a/tensorflow/core/platform/tf32_utils.cc b/tensorflow/core/platform/tf32_utils.cc new file mode 100644 index 00000000000..d2f40ea161a --- /dev/null +++ b/tensorflow/core/platform/tf32_utils.cc @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/platform/tf32_utils.h" + +#include + +namespace tensorflow { + +// Whether TensorFloat-32 should be used where supported. +// TODO(nluehr): Maybe enable by default after TF32 Ampere testing. +static std::atomic tf32_allowed{false}; + +void allow_tf32_execution(bool allowed) { tf32_allowed = allowed; } + +bool tf32_execution_allowed() { return tf32_allowed; } + +} // namespace tensorflow diff --git a/tensorflow/core/platform/tf32_utils.h b/tensorflow/core/platform/tf32_utils.h new file mode 100644 index 00000000000..7a158d00ad3 --- /dev/null +++ b/tensorflow/core/platform/tf32_utils.h @@ -0,0 +1,27 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_TF32_UTILS_H_ +#define TENSORFLOW_CORE_PLATFORM_TF32_UTILS_H_ + +namespace tensorflow { + +void allow_tf32_execution(bool allowed); + +bool tf32_execution_allowed(); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_TF32_UTILS_H_ diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index de9cf9a24c7..5f9e2dfb1ff 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -788,6 +788,16 @@ tf_python_pybind_extension( ], ) +tf_python_pybind_extension( + name = "_pywrap_tf32_execution", + srcs = ["util/tf32.cc"], + module_name = "_pywrap_tf32_execution", + deps = [ + "//tensorflow/core/platform:tf32_utils", + "@pybind11", + ], +) + tf_python_pybind_extension( name = "_pywrap_util_port", srcs = ["util/port_wrapper.cc"], @@ -5678,6 +5688,7 @@ py_library( "//tensorflow:composite_tensor_whitelist", ], deps = [ + ":_pywrap_tf32_execution", ":tf_decorator", ":tf_export", ":tf_stack", diff --git a/tensorflow/python/framework/config.py b/tensorflow/python/framework/config.py index 9ff16f2a327..0962b9a8a70 100644 --- a/tensorflow/python/framework/config.py +++ b/tensorflow/python/framework/config.py @@ -18,11 +18,42 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python import _pywrap_tf32_execution from tensorflow.python.eager import context from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export +# No tf_export until TF is built against CUDA11 which is required for TF32. +def tensor_float_32_execution_allowed(): + """Get if TensorFloat-32 operations are enabled on supported hardware. + + Returns: + True if TensorFloat-32 execution is enabled and False otherwise. + """ + return _pywrap_tf32_execution.is_allowed() + + +# No tf_export until TF is built against CUDA11 which is required for TF32. +def allow_tensor_float_32_execution(allowed): + """Allow use of TensorFloat-32 with float32 ops on supported hardware. + + TensorFloat-32 is a math mode introduced with the NVIDIA Ampere architecture. + TensorFloat-32 kernels take float32 inputs and produce float32 outputs. + Internally, the inputs are cast to a custom representation with 10-bit + mantissa (similar to float16) and 8-bit exponent (similar to float32) and are + executed using TensorCores with float32 accumulation. For more information, + see https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/. + + TensorFloat-32 execution is disabled by default, but this may change in a + future version. + + Args: + allowed: whether to allow TensorFloat-32 execution + """ + _pywrap_tf32_execution.allow(allowed) + + @tf_export('config.threading.get_intra_op_parallelism_threads') def get_intra_op_parallelism_threads(): """Get number of threads used within an individual op for parallelism. diff --git a/tensorflow/python/util/tf32.cc b/tensorflow/python/util/tf32.cc new file mode 100644 index 00000000000..7dece6ccdae --- /dev/null +++ b/tensorflow/python/util/tf32.cc @@ -0,0 +1,22 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "pybind11/pybind11.h" +#include "tensorflow/core/platform/tf32_utils.h" + +PYBIND11_MODULE(_pywrap_tf32_execution, m) { + m.def("allow", &tensorflow::allow_tf32_execution); + m.def("is_allowed", &tensorflow::tf32_execution_allowed); +} diff --git a/tensorflow/stream_executor/cuda/BUILD b/tensorflow/stream_executor/cuda/BUILD index c3cf9f5db15..3a14be9ad50 100644 --- a/tensorflow/stream_executor/cuda/BUILD +++ b/tensorflow/stream_executor/cuda/BUILD @@ -251,6 +251,7 @@ cc_library( "@local_config_cuda//cuda:cuda_headers", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/platform:tf32_utils", "//tensorflow/stream_executor", "//tensorflow/stream_executor:event", "//tensorflow/stream_executor:host_or_device_scalar", @@ -356,6 +357,7 @@ cc_library( "@local_config_cuda//cuda:cudnn_header", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/platform:tf32_utils", "//tensorflow/stream_executor:dnn", "//tensorflow/stream_executor:event", "//tensorflow/stream_executor:plugin_registry", diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index c9f0fc462c9..f32c8b3e81e 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -49,6 +49,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "third_party/eigen3/Eigen/Core" +#include "tensorflow/core/platform/tf32_utils.h" #include "tensorflow/core/util/env_var.h" #include "tensorflow/stream_executor/cuda/cuda_activation.h" #include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h" @@ -101,18 +102,6 @@ static std::string ToString(cublasStatus_t status) { } } -// Decide whether to enable TENSOR_OP_MATH -static bool TensorOpMathEnabled() { - static bool is_enabled = [] { - bool is_disabled; - TF_CHECK_OK( - tensorflow::ReadBoolFromEnvVar("TF_DISABLE_CUBLAS_TENSOR_OP_MATH", - /*default_val=*/false, &is_disabled)); - return !is_disabled; - }(); - return is_enabled; -} - // cuBLAS has interfaces that permit pointers to be passed from either the host // memory space or the device memory space; however, you must instruct it as to // which address space those pointers are in with cublasSetPointerMode. @@ -399,7 +388,7 @@ cudaDataType_t CUDAComputationType(blas::ComputationType ty) { template bool CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream, bool pointer_mode_host, bool err_on_failure, - bool use_tensor_op_math, Args... args) { + cublasMath_t math_type, Args... args) { absl::MutexLock lock(&mu_); CHECK(blas_ != nullptr); @@ -407,20 +396,26 @@ bool CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream, return false; } +#if CUDA_VERSION >= 9000 + ScopedCublasMathMode math_mode{blas_}; +#if CUBLAS_VER_MAJOR >= 11 + if (math_type == CUBLAS_TF32_TENSOR_OP_MATH && + tensorflow::tf32_execution_allowed()) { +#else + if (math_type == CUBLAS_TENSOR_OP_MATH) { +#endif + if (!math_mode.Init(math_type)) { + return false; + } + } +#endif + gpu::ScopedActivateExecutorContext sac{parent_}; ScopedCublasPointerMode pointer_mode{blas_}; if (!pointer_mode.Init(pointer_mode_host ? CUBLAS_POINTER_MODE_HOST : CUBLAS_POINTER_MODE_DEVICE)) { return false; } -#if CUDA_VERSION >= 9000 - ScopedCublasMathMode math_mode{blas_}; - if (use_tensor_op_math) { - if (!math_mode.Init(CUBLAS_TENSOR_OP_MATH)) { - return false; - } - } -#endif cublasStatus_t ret = cublas_func(blas_, args...); if ((err_on_failure || VLOG_IS_ON(3)) && ret != CUBLAS_STATUS_SUCCESS) { LOG(ERROR) << "failed to run cuBLAS routine: " << ToString(ret); @@ -1633,21 +1628,15 @@ bool CUDABlas::DoBlasGemm( } } - bool use_tensor_ops = false; -#if CUDA_VERSION >= 9000 - int cc_major, cc_minor; - stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major, - &cc_minor); - - // GPUs < sm_70 don't support tensor ops. - if (cc_major >= 7 && TensorOpMathEnabled()) { - use_tensor_ops = true; - } +#if CUDA_VERSION < 11000 + cublasMath_t math_type = CUBLAS_TENSOR_OP_MATH; +#else + cublasMath_t math_type = CUBLAS_DEFAULT_MATH; #endif return DoBlasInternalImpl( cublasSgemmEx, stream, true /* = pointer_mode_host */, - true /* = err_on_failure= */, use_tensor_ops, CUDABlasTranspose(transa), + true /* = err_on_failure= */, math_type, CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha, GpuMemory(a), SE_CUDA_DATA_HALF, lda, GpuMemory(b), SE_CUDA_DATA_HALF, ldb, &beta, GpuMemoryMutable(c), SE_CUDA_DATA_HALF, ldc); @@ -1692,10 +1681,18 @@ bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa, "precondition violation"; } } - return DoBlasInternal(cublasSgemm, stream, true /* = pointer_mode_host */, - CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, - n, k, &alpha, GpuMemory(a), lda, GpuMemory(b), ldb, - &beta, GpuMemoryMutable(c), ldc); + +#if CUDA_VERSION < 11000 + cublasMath_t math_type = CUBLAS_DEFAULT_MATH; +#else + cublasMath_t math_type = CUBLAS_TF32_TENSOR_OP_MATH; +#endif + + return DoBlasInternalImpl( + cublasSgemm, stream, true /* = pointer_mode_host */, + true /* = err_on_failure */, math_type, CUDABlasTranspose(transa), + CUDABlasTranspose(transb), m, n, k, &alpha, GpuMemory(a), lda, + GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc); } bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa, @@ -1914,21 +1911,6 @@ static bool UsesTensorOps(blas::AlgorithmType algo) { #endif } -template -static bool TensorOpsAvailable(int cc_major) { -#if CUDA_VERSION >= 9000 - // cublas *does* allow tensor ops on inputs that are not fp16, so this is not - // strictly correct. We can't simply enable it, though, as that would change - // clients' behavior significantly: Using tensor ops on fp32 inputs cause them - // to be rounded to fp16. - if (cc_major >= 7 && TensorOpMathEnabled() && - std::is_same::value) { - return true; - } -#endif - return false; -} - template bool CUDABlas::DoBlasGemmWithAlgorithmImpl( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, @@ -1947,18 +1929,48 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl( return false; } - if (UsesTensorOps(algorithm) && !TensorOpsAvailable(cc_major)) { - if (std::is_same::value) { + bool algo_uses_tensor_ops = UsesTensorOps(algorithm); + cublasMath_t math_type = CUBLAS_DEFAULT_MATH; + if (algo_uses_tensor_ops) { + if (cc_major < 7) { VLOG(2) << "DoBlasGemmWithAlgorithm returning false because algorithm " << algorithm << " uses tensor ops, but tensor ops are not available in sm" << cc_major << "X devices."; + return false; + } else if (std::is_same::value) { +#if CUDA_VERSION < 11000 + VLOG(2) << "DoBlasGemmWithAlgorithm returning false because algorithm " + << algorithm + << " uses tensor ops, but tensor ops are not available for fp32" + << " inputs."; + return false; +#else + if (cc_major < 8) { + VLOG(2) << "DoBlasGemmWithAlgorithm returning false because algorithm " + << algorithm + << " uses tensor ops, but tensor ops are not available in sm" + << cc_major << "X devices for float input types."; + return false; + } else if (!tensorflow::tf32_execution_allowed()) { + VLOG(2) << "DoBlasGemmWithAlgorithm returning false because algorithm " + << algorithm + << " uses tensor ops, but tensor ops are disabled for fp32" + << " inputs."; + return false; + } + math_type = CUBLAS_TF32_TENSOR_OP_MATH; +#endif + } else if (std::is_same::value) { +#if CUDA_VERSION < 11000 + math_type = CUBLAS_TENSOR_OP_MATH; +#endif } else { VLOG(2) << "DoBlasGemmWithAlgorithm returning false because algorithm " << algorithm - << " uses tensor ops, but the input data type is not fp16."; + << " uses tensor ops, which are not supported for InT."; + return false; } - return false; } // Either both 'alpha' and 'beta' need to be pointers to device memory, or @@ -1998,10 +2010,10 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl( // If 'alpha' and 'beta' are host scalars and CompT is Eigen::half, we // essentially reinterpet_cast to __half, which is safe because Eigen::half // inherits from __half. - bool result = DoBlasInternalFailureOK( + bool result = DoBlasInternalImpl( AS_LAMBDA(cublasGemmEx), stream, - /* pointer_mode_host = */ !alpha.is_pointer(), CUDABlasTranspose(transa), - CUDABlasTranspose(transb), m, n, k, + /* pointer_mode_host = */ !alpha.is_pointer(), /*err_on_failure=*/false, + math_type, CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, alpha.is_pointer() ? GpuMemory(alpha.pointer()) : &alpha.value(), GpuMemory(a), cuda_in_type, lda, GpuMemory(b), cuda_in_type, ldb, beta.is_pointer() ? GpuMemory(beta.pointer()) : &beta.value(), @@ -2270,9 +2282,27 @@ port::Status CUDABlas::DoBlasGemmBatchedInternal( if (stream->parent()->GetDeviceDescription().cuda_compute_capability( &cc_major, &cc_minor) && cc_major >= 5) { - bool use_tensor_ops = TensorOpMathEnabled() && data_type == CUDA_R_16F; - cublasGemmAlgo_t algo = - (use_tensor_ops ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT); + cublasMath_t math_type; + cublasGemmAlgo_t algo; + if (data_type == CUDA_R_16F) { +#if CUDA_VERSION < 11000 + math_type = CUBLAS_TENSOR_OP_MATH; +#else + math_type = CUBLAS_DEFAULT_MATH; +#endif + algo = CUBLAS_GEMM_DFALT_TENSOR_OP; +#if CUBLAS_VER_MAJOR >= 11 + } else if (data_type == CUDA_R_32F) { + // DoBlassInternalImpl will switch math_type back to CUBLAS_DEFAULT_MATH + // if TF32 is disabled. + math_type = CUBLAS_TF32_TENSOR_OP_MATH; + algo = tensorflow::tf32_execution_allowed() ? CUBLAS_GEMM_DFALT_TENSOR_OP + : CUBLAS_GEMM_DFALT; +#endif + } else { + math_type = CUBLAS_DEFAULT_MATH; + algo = CUBLAS_GEMM_DFALT; + } cudaDataType_t compute_type = (data_type == CUDA_R_16F ? CUDA_R_32F : data_type); const void **a_void_ptrs = reinterpret_cast( @@ -2284,7 +2314,7 @@ port::Status CUDABlas::DoBlasGemmBatchedInternal( bool ok; ok = DoBlasInternalImpl( AS_LAMBDA(cublasGemmBatchedEx), stream, true /* = pointer_mode_host */, - true /* = err_on_failure */, use_tensor_ops, CUDABlasTranspose(transa), + true /* = err_on_failure */, math_type, CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha, a_void_ptrs, data_type, lda, b_void_ptrs, data_type, ldb, &beta, c_void_ptrs, data_type, ldc, batch_count, compute_type, algo); @@ -2419,33 +2449,30 @@ bool CUDABlas::DoBlasGemmStridedBatched( int lda, int64 stride_a, const DeviceMemory &b, int ldb, int64 stride_b, float beta, DeviceMemory *c, int ldc, int64 stride_c, int batch_count) { - bool use_tensor_ops = false; -#if CUDA_VERSION >= 9000 +#if CUDA_VERSION >= 9010 int cc_major, cc_minor; if (stream->parent()->GetDeviceDescription().cuda_compute_capability( - &cc_major, &cc_minor)) { - // GPUs < sm_70 don't support tensor ops. - if (cc_major >= 7 && TensorOpMathEnabled()) { - use_tensor_ops = true; - } -#if CUDA_VERSION >= 9010 - if (cc_major >= 5) { - cublasGemmAlgo_t algo = - (use_tensor_ops ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT); - bool ok = DoBlasInternalImpl( - AS_LAMBDA(cublasGemmStridedBatchedEx), stream, - true /* = pointer_mode_host */, true /* = err_on_failure */, - use_tensor_ops, CUDABlasTranspose(transa), CUDABlasTranspose(transb), - m, n, k, &alpha, GpuMemory(a), CUDA_R_16F, lda, stride_a, - GpuMemory(b), CUDA_R_16F, ldb, stride_b, &beta, GpuMemoryMutable(c), - CUDA_R_16F, ldc, stride_c, batch_count, CUDA_R_32F, algo); - if (ok) { - return true; - } - LOG(ERROR) << "failed BLAS call, see log for details"; - return false; - } + &cc_major, &cc_minor) && + cc_major >= 5) { + cublasGemmAlgo_t algo = + (cc_major >= 7 ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT); +#if CUDA_VERSION < 11000 + cublasMath_t math_type = CUBLAS_TENSOR_OP_MATH; +#else + cublasMath_t math_type = CUBLAS_DEFAULT_MATH; #endif + bool ok = DoBlasInternalImpl( + AS_LAMBDA(cublasGemmStridedBatchedEx), stream, + true /* = pointer_mode_host */, true /* = err_on_failure */, math_type, + CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha, + GpuMemory(a), CUDA_R_16F, lda, stride_a, GpuMemory(b), CUDA_R_16F, ldb, + stride_b, &beta, GpuMemoryMutable(c), CUDA_R_16F, ldc, stride_c, + batch_count, CUDA_R_32F, algo); + if (ok) { + return true; + } + LOG(ERROR) << "failed BLAS call, see log for details"; + return false; } #endif // Either CUDA_VERSION < 9.1 or SM < 5.0. Fall back to a loop. @@ -2458,10 +2485,10 @@ bool CUDABlas::DoBlasGemmStridedBatched( reinterpret_cast<__half *>(GpuMemoryMutable(c) + batch * stride_c); bool ok = DoBlasInternalImpl( cublasSgemmEx, stream, true /* = pointer_mode_host */, - true /* = err_on_failure= */, use_tensor_ops, CUDABlasTranspose(transa), - CUDABlasTranspose(transb), m, n, k, &alpha, a_matrix, SE_CUDA_DATA_HALF, - lda, b_matrix, SE_CUDA_DATA_HALF, ldb, &beta, c_matrix, - SE_CUDA_DATA_HALF, ldc); + true /* = err_on_failure= */, CUBLAS_DEFAULT_MATH, + CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha, + a_matrix, SE_CUDA_DATA_HALF, lda, b_matrix, SE_CUDA_DATA_HALF, ldb, + &beta, c_matrix, SE_CUDA_DATA_HALF, ldc); if (!ok) { LOG(ERROR) << "failed BLAS call, see log for details"; return false; @@ -2476,11 +2503,17 @@ bool CUDABlas::DoBlasGemmStridedBatched( int64 stride_a, const DeviceMemory &b, int ldb, int64 stride_b, float beta, DeviceMemory *c, int ldc, int64 stride_c, int batch_count) { - return DoBlasInternal( +#if CUDA_VERSION < 11000 + cublasMath_t math_type = CUBLAS_DEFAULT_MATH; +#else + cublasMath_t math_type = CUBLAS_TF32_TENSOR_OP_MATH; +#endif + return DoBlasInternalImpl( cublasSgemmStridedBatched, stream, true /* = pointer_mode_host */, - CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha, - GpuMemory(a), lda, stride_a, GpuMemory(b), ldb, stride_b, &beta, - GpuMemoryMutable(c), ldc, stride_c, batch_count); + true /* = err_on_failure */, math_type, CUDABlasTranspose(transa), + CUDABlasTranspose(transb), m, n, k, &alpha, GpuMemory(a), lda, stride_a, + GpuMemory(b), ldb, stride_b, &beta, GpuMemoryMutable(c), ldc, stride_c, + batch_count); } bool CUDABlas::DoBlasGemmStridedBatched( diff --git a/tensorflow/stream_executor/cuda/cuda_blas.h b/tensorflow/stream_executor/cuda/cuda_blas.h index 817bdb72777..9ff63102aaa 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.h +++ b/tensorflow/stream_executor/cuda/cuda_blas.h @@ -21,6 +21,7 @@ limitations under the License. #define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_ #include "absl/synchronization/mutex.h" +#include "third_party/gpus/cuda/include/cublas_v2.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/stream_executor/blas.h" #include "tensorflow/stream_executor/host_or_device_scalar.h" @@ -83,26 +84,17 @@ class CUDABlas : public blas::BlasSupport { template bool DoBlasInternalImpl(FuncT cublas_func, Stream *stream, bool pointer_mode_host, bool err_on_failure, - bool use_tensor_op_math, Args... args); + cublasMath_t math_type, Args... args); - // Convenience functions that call DoBlasInternalImpl with different values - // for err_on_failure. + // Convenience functions that call DoBlasInternalImpl with err_on_failure=true + // and math_type=CUBLAS_DEFAULT_MATH. template bool DoBlasInternal(FuncT cublas_func, Stream *stream, bool pointer_mode_host, Args... args) { return DoBlasInternalImpl(cublas_func, stream, pointer_mode_host, - /*err_on_failure=*/true, /*use_tensor_ops=*/false, + /*err_on_failure=*/true, CUBLAS_DEFAULT_MATH, args...); } - template - bool DoBlasInternalFailureOK(FuncT cublas_func, Stream *stream, - bool pointer_mode_host, Args... args) { - // Tensor ops are hard-coded off in this path, but can still be enabled with - // a specific algorithm choice as in DoBlasGemmWithAlgorithmImpl(). - return DoBlasInternalImpl(cublas_func, stream, pointer_mode_host, - /*err_on_failure=*/false, - /*use_tensor_ops=*/false, args...); - } // A helper function to implement DoBlasGemmBatched interfaces for generic // types. diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index be18c989861..a97850bd8d5 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "third_party/eigen3/Eigen/Core" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/tf32_utils.h" #include "tensorflow/core/util/env_var.h" #include "tensorflow/stream_executor/cuda/cuda_activation.h" #include "tensorflow/stream_executor/cuda/cuda_diagnostics.h" @@ -601,31 +602,6 @@ class CudnnFilterDescriptor { SE_DISALLOW_COPY_AND_ASSIGN(CudnnFilterDescriptor); }; -// A helper function to decide whether to enable the TENSOR_OP_MATH math type -bool TensorOpMathEnabled() { - static bool is_enabled = [] { - bool is_disabled = false; - TF_CHECK_OK( - tensorflow::ReadBoolFromEnvVar("TF_DISABLE_CUDNN_TENSOR_OP_MATH", - /*default_val=*/false, &is_disabled)); - return !is_disabled; - }(); - return is_enabled; -} - -// A helper function to decide whether to enable the TENSOR_OP_MATH math type -// for RNNs. -bool RnnTensorOpMathEnabled() { - static bool is_enabled = [] { - bool is_disabled = false; - TF_CHECK_OK( - tensorflow::ReadBoolFromEnvVar("TF_DISABLE_CUDNN_RNN_TENSOR_OP_MATH", - /*default_val=*/false, &is_disabled)); - return !is_disabled; - }(); - return is_enabled; -} - // A helper function to decide whether to use // CUDNN_BATCHNORM_SPATIAL_PERSISTENT in batchnorm. This mode can be faster in // some tasks because an optimized path may be selected for CUDNN_DATA_FLOAT @@ -730,10 +706,6 @@ class CudnnConvolutionDescriptor { : CUDNN_CROSS_CORRELATION, data_type)); - // NOTE(benbarsdell): This only applies if tensor op math is enabled - // and algo selection is set to Default. - this->set_use_tensor_op_math(true); - #if CUDNN_MAJOR >= 7 VLOG(2) << "Requesting grouped convolution: " << convolution_descriptor.group_count(); @@ -745,13 +717,15 @@ class CudnnConvolutionDescriptor { #endif } - void set_use_tensor_op_math(bool use_tensor_op_math) const { + void set_use_tensor_op_math(bool use_tensor_op_math) { #if CUDNN_VERSION >= 7000 cudnnMathType_t math_type = +#if CUDNN_VERSION >= 8000 + (use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_FMA_MATH); +#else (use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH); - if (TensorOpMathEnabled()) { - CHECK_CUDNN_OK(cudnnSetConvolutionMathType(handle_.get(), math_type)); - } +#endif + CHECK_CUDNN_OK(cudnnSetConvolutionMathType(handle_.get(), math_type)); #endif } @@ -763,6 +737,40 @@ class CudnnConvolutionDescriptor { SE_DISALLOW_COPY_AND_ASSIGN(CudnnConvolutionDescriptor); }; +// A helper function to query if a CudnnConvolutionDescriptor has tensor_op_math +// set +static bool IsTensorMathOpSet(const CudnnConvolutionDescriptor& conv) { + cudnnMathType_t math_type; + CHECK_CUDNN_OK(cudnnGetConvolutionMathType(conv.handle(), &math_type)); +#if CUDNN_VERSION >= 8000 + return math_type != CUDNN_FMA_MATH; +#else + return math_type == CUDNN_TENSOR_OP_MATH; +#endif +} + +static bool TensorOpMathAvailable(int cc_major) { + return cc_major >= 7 && CUDNN_VERSION >= 7000; +} + +static bool IsTensorMathAllowed(Stream* stream, dnn::DataType input_type) { + int cc_major, cc_minor; + std::tie(cc_major, cc_minor) = GetCcMajorMinor(stream); + if (!TensorOpMathAvailable(cc_major)) { + return false; + } + if (input_type == dnn::DataType::kFloat) { +#if CUDNN_VERSION < 8000 + return false; +#else + if (!tensorflow::tf32_execution_allowed()) { + return false; + } +#endif + } + return true; +} + // Turns a PoolingDescriptor structure into a cudnn pooling descriptor handle // within a scope. class CudnnPoolingDescriptor { @@ -1155,21 +1163,27 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { // in profile mode, which is run with algorithms returned from // GetRnnAlgorithms() (which are non-default and explicitly set whether to // use tensor ops). CuDNN 7.2.1 fixed this issue - if (RnnTensorOpMathEnabled()) { - cudnnMathType_t math_type; - if (algorithm_config.algorithm().has_value()) { - math_type = algorithm_config.algorithm()->tensor_ops_enabled() - ? CUDNN_TENSOR_OP_MATH - : CUDNN_DEFAULT_MATH; - } else { -#if CUDNN_VERSION >= 7201 - math_type = CUDNN_TENSOR_OP_MATH; -#else - math_type = CUDNN_DEFAULT_MATH; -#endif // CUDNN_VERSION >= 7201 - } - CHECK_CUDNN_OK(cudnnSetRNNMatrixMathType(rnn_desc.get(), math_type)); + bool allow_tensor_ops = + data_type != CUDNN_DATA_FLOAT || tensorflow::tf32_execution_allowed(); + bool use_tensor_ops; + if (algorithm_config.algorithm().has_value()) { + use_tensor_ops = algorithm_config.algorithm()->tensor_ops_enabled(); + } else { + use_tensor_ops = CUDNN_VERSION >= 7201 && allow_tensor_ops; } + + if (use_tensor_ops && !allow_tensor_ops) { + return port::Status(port::error::INVALID_ARGUMENT, + "Algo requests disallowed tensor op evaluation."); + } + + cudnnMathType_t math_type; +#if CUDNN_VERSION >= 8000 + math_type = use_tensor_ops ? CUDNN_TENSOR_OP_MATH : CUDNN_FMA_MATH; +#else + math_type = use_tensor_ops ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH; +#endif + CHECK_CUDNN_OK(cudnnSetRNNMatrixMathType(rnn_desc.get(), math_type)); #endif // CUDNN_VERSION >= 7000 return CudnnRnnDescriptor(cudnn, std::move(rnn_desc), std::move(rnn_plan), @@ -2560,10 +2574,11 @@ port::StatusOr> AllocateCudnnConvolutionForwardWorkspace( const CudnnTensorDescriptor& output_nd, const dnn::AlgorithmDesc& algorithm_desc, ScratchAllocator* scratch_allocator) { - // TODO(csigg): This has side effects on the convolution descriptor. It is - // functionally correct because the convolution is run with the algorithm of - // the last call to this function, but should be fixed anyway. - conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled()); + if (IsTensorMathOpSet(conv) != algorithm_desc.tensor_ops_enabled()) { + return port::Status( + port::error::INTERNAL, + "Mismatch between cudnn conv and algorithm descriptors."); + } // Query the size of the workspace and allocate it. size_t size_in_bytes; @@ -2603,10 +2618,11 @@ AllocateCudnnConvolutionBackwardDataWorkspace( const CudnnTensorDescriptor& output_nd, const dnn::AlgorithmDesc& algorithm_desc, ScratchAllocator* scratch_allocator) { - // TODO(csigg): This has side effects on the convolution descriptor. It is - // functionally correct because the convolution is run with the algorithm of - // the last call to this function, but should be fixed anyway. - conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled()); + if (IsTensorMathOpSet(conv) != algorithm_desc.tensor_ops_enabled()) { + return port::Status( + port::error::INTERNAL, + "Mismatch between cudnn conv and algorithm descriptors."); + } // Query the size of the workspace and allocate it. size_t size_in_bytes; @@ -2648,10 +2664,11 @@ AllocateCudnnConvolutionBackwardFilterWorkspace( const CudnnTensorDescriptor& output_nd, const dnn::AlgorithmDesc& algorithm_desc, ScratchAllocator* scratch_allocator) { - // TODO(csigg): This has side effects on the convolution descriptor. It is - // functionally correct because the convolution is run with the algorithm of - // the last call to this function, but should be fixed anyway. - conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled()); + if (IsTensorMathOpSet(conv) != algorithm_desc.tensor_ops_enabled()) { + return port::Status( + port::error::INTERNAL, + "Mismatch between cudnn conv and algorithm descriptors."); + } // Query the size of the workspace and allocate it. size_t size_in_bytes; @@ -2685,18 +2702,42 @@ AllocateCudnnConvolutionBackwardFilterWorkspace( return scratch_allocator->AllocateBytes(size_in_bytes); } -static bool TensorOpMathAvailable(int cc_major) { - return cc_major >= 7 && CUDNN_VERSION >= 7000 && TensorOpMathEnabled(); +port::StatusOr UseTensorOps(Stream* stream, dnn::DataType type, + absl::optional desc) { + bool use_tensor_ops; + if (desc.has_value()) { + use_tensor_ops = desc->tensor_ops_enabled(); + if (use_tensor_ops && !IsTensorMathAllowed(stream, type)) { + return port::Status(port::error::INVALID_ARGUMENT, + "Algo requests disallowed tensor op evaluation."); + } + } else { + use_tensor_ops = IsTensorMathAllowed(stream, type); + } + return use_tensor_ops; } +cudnnDataType_t GetRnnComputeType(dnn::DataType data_type); +dnn::DataType GetConvAccumulatorType(dnn::DataType data_type); + port::StatusOr GetCudnnConvolutionForwardAlgorithm( Stream* stream, const CudnnHandle& cudnn, const dnn::AlgorithmConfig& algorithm_config, const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter, - const CudnnConvolutionDescriptor& conv, + dnn::DataType element_type, + const dnn::ConvolutionDescriptor& convolution_descriptor, const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator, DeviceMemory* scratch) { absl::optional algo_desc = algorithm_config.algorithm(); + + CudnnConvolutionDescriptor conv( + convolution_descriptor, + ToCudnnDataType(GetConvAccumulatorType(element_type))); + bool use_tensor_ops; + SE_ASSIGN_OR_RETURN(use_tensor_ops, + UseTensorOps(stream, element_type, algo_desc)); + conv.set_use_tensor_op_math(use_tensor_ops); + if (!algo_desc.has_value()) { // Pick fastest algorithm within memory limit according to cuDNN's // heuristics. @@ -2709,10 +2750,7 @@ port::StatusOr GetCudnnConvolutionForwardAlgorithm( GetCudnnConvolutionForwardAlgo( cudnn, input_nd, filter, conv, output_nd, specify_workspace_limit, memory_limit_bytes)); - int cc_major, cc_minor; - std::tie(cc_major, cc_minor) = GetCcMajorMinor(stream); - algo_desc = dnn::AlgorithmDesc( - algo, /*use_tensor_ops=*/TensorOpMathAvailable(cc_major)); + algo_desc = dnn::AlgorithmDesc(algo, use_tensor_ops); } const auto scratch_or = AllocateCudnnConvolutionForwardWorkspace( @@ -2736,6 +2774,9 @@ port::StatusOr GetCudnnConvolutionForwardAlgorithm( "Returned status: ", scratch_or.status().ToString())); } + SE_ASSIGN_OR_RETURN(use_tensor_ops, + UseTensorOps(stream, element_type, algo_desc)); + conv.set_use_tensor_op_math(use_tensor_ops); SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionForwardWorkspace( stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc, scratch_allocator)); @@ -2746,10 +2787,19 @@ port::StatusOr GetCudnnConvolutionBackwardDataAlgorithm( Stream* stream, const CudnnHandle& cudnn, const dnn::AlgorithmConfig& algorithm_config, const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter, - const CudnnConvolutionDescriptor& conv, + dnn::DataType element_type, + const dnn::ConvolutionDescriptor& convolution_descriptor, const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator, DeviceMemory* scratch) { absl::optional algo_desc = algorithm_config.algorithm(); + CudnnConvolutionDescriptor conv( + convolution_descriptor, + ToCudnnDataType(GetConvAccumulatorType(element_type))); + bool use_tensor_ops; + SE_ASSIGN_OR_RETURN(use_tensor_ops, + UseTensorOps(stream, element_type, algo_desc)); + conv.set_use_tensor_op_math(use_tensor_ops); + if (!algo_desc.has_value()) { // Pick fastest algorithm within memory limit according to cuDNN's // heuristics. @@ -2762,10 +2812,7 @@ port::StatusOr GetCudnnConvolutionBackwardDataAlgorithm( GetCudnnConvolutionBackwardDataAlgo( cudnn, input_nd, filter, conv, output_nd, specify_workspace_limit, memory_limit_bytes)); - int cc_major, cc_minor; - std::tie(cc_major, cc_minor) = GetCcMajorMinor(stream); - algo_desc = dnn::AlgorithmDesc( - algo, /*use_tensor_ops=*/TensorOpMathAvailable(cc_major)); + algo_desc = dnn::AlgorithmDesc(algo, use_tensor_ops); } const auto scratch_or = AllocateCudnnConvolutionBackwardDataWorkspace( @@ -2788,6 +2835,9 @@ port::StatusOr GetCudnnConvolutionBackwardDataAlgorithm( "while a secondary algorithm is not provided."); } + SE_ASSIGN_OR_RETURN(use_tensor_ops, + UseTensorOps(stream, element_type, algo_desc)); + conv.set_use_tensor_op_math(use_tensor_ops); SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardDataWorkspace( stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc, scratch_allocator)); @@ -2798,10 +2848,19 @@ port::StatusOr GetCudnnConvolutionBackwardFilterAlgorithm( Stream* stream, const CudnnHandle& cudnn, const dnn::AlgorithmConfig& algorithm_config, const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter, - const CudnnConvolutionDescriptor& conv, + dnn::DataType element_type, + const dnn::ConvolutionDescriptor& convolution_descriptor, const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator, DeviceMemory* scratch) { absl::optional algo_desc = algorithm_config.algorithm(); + CudnnConvolutionDescriptor conv( + convolution_descriptor, + ToCudnnDataType(GetConvAccumulatorType(element_type))); + bool use_tensor_ops; + SE_ASSIGN_OR_RETURN(use_tensor_ops, + UseTensorOps(stream, element_type, algo_desc)); + conv.set_use_tensor_op_math(use_tensor_ops); + if (!algo_desc.has_value()) { // Pick fastest algorithm within memory limit according to cuDNN's // heuristics. @@ -2814,10 +2873,7 @@ port::StatusOr GetCudnnConvolutionBackwardFilterAlgorithm( GetCudnnConvolutionBackwardFilterAlgo( cudnn, input_nd, filter, conv, output_nd, specify_workspace_limit, memory_limit_bytes)); - int cc_major, cc_minor; - std::tie(cc_major, cc_minor) = GetCcMajorMinor(stream); - algo_desc = dnn::AlgorithmDesc( - algo, /*use_tensor_ops=*/TensorOpMathAvailable(cc_major)); + algo_desc = dnn::AlgorithmDesc(algo, use_tensor_ops); } auto scratch_or = AllocateCudnnConvolutionBackwardFilterWorkspace( @@ -2840,6 +2896,9 @@ port::StatusOr GetCudnnConvolutionBackwardFilterAlgorithm( "while a secondary algorithm is not provided."); } + SE_ASSIGN_OR_RETURN(use_tensor_ops, + UseTensorOps(stream, element_type, algo_desc)); + conv.set_use_tensor_op_math(use_tensor_ops); SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardFilterWorkspace( stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc, scratch_allocator)); @@ -3004,35 +3063,32 @@ port::Status CudnnSupport::DoPrepareForConvolution( CudnnTensorDescriptor output_nd( output_descriptor, ToCudnnDataType(element_type, output_descriptor.layout())); - CudnnConvolutionDescriptor conv( - convolution_descriptor, - ToCudnnDataType(GetConvAccumulatorType(element_type))); auto cudnn = cudnn_->GetHandle(parent_, stream); switch (kind) { case dnn::ConvolutionKind::FORWARD: { - SE_ASSIGN_OR_RETURN( - *algorithm_desc, - GetCudnnConvolutionForwardAlgorithm( - stream, cudnn, algorithm_config, input_nd, filter_nd, conv, - output_nd, scratch_allocator, scratch_memory)); + SE_ASSIGN_OR_RETURN(*algorithm_desc, + GetCudnnConvolutionForwardAlgorithm( + stream, cudnn, algorithm_config, input_nd, + filter_nd, element_type, convolution_descriptor, + output_nd, scratch_allocator, scratch_memory)); break; } case dnn::ConvolutionKind::BACKWARD_DATA: { - SE_ASSIGN_OR_RETURN( - *algorithm_desc, - GetCudnnConvolutionBackwardDataAlgorithm( - stream, cudnn, algorithm_config, input_nd, filter_nd, conv, - output_nd, scratch_allocator, scratch_memory)); + SE_ASSIGN_OR_RETURN(*algorithm_desc, + GetCudnnConvolutionBackwardDataAlgorithm( + stream, cudnn, algorithm_config, input_nd, + filter_nd, element_type, convolution_descriptor, + output_nd, scratch_allocator, scratch_memory)); break; } case dnn::ConvolutionKind::BACKWARD_FILTER: { - SE_ASSIGN_OR_RETURN( - *algorithm_desc, - GetCudnnConvolutionBackwardFilterAlgorithm( - stream, cudnn, algorithm_config, input_nd, filter_nd, conv, - output_nd, scratch_allocator, scratch_memory)); + SE_ASSIGN_OR_RETURN(*algorithm_desc, + GetCudnnConvolutionBackwardFilterAlgorithm( + stream, cudnn, algorithm_config, input_nd, + filter_nd, element_type, convolution_descriptor, + output_nd, scratch_allocator, scratch_memory)); break; } default: @@ -3061,8 +3117,9 @@ port::Status CudnnSupport::DoConvolve( auto accumulator_type = GetConvAccumulatorType(element_type); CudnnConvolutionDescriptor conv(convolution_descriptor, ToCudnnDataType(accumulator_type)); - // Set use_tensor_math param to correct value - conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled()); + SE_ASSIGN_OR_RETURN(bool use_tensor_ops, + UseTensorOps(stream, element_type, algorithm_desc)); + conv.set_use_tensor_op_math(use_tensor_ops); auto cudnn = cudnn_->GetHandle(parent_, stream); // Alpha is the scaling factor for input. @@ -3295,14 +3352,6 @@ port::Status CudnnSupport::DoConvolve( return port::Status::OK(); } -// A helper function to query if a CudnnConvolutionDescriptor has tensor_op_math -// set -static bool IsTensorMathOpSet(const CudnnConvolutionDescriptor& conv) { - cudnnMathType_t math_type; - CHECK_CUDNN_OK(cudnnGetConvolutionMathType(conv.handle(), &math_type)); - return math_type == CUDNN_TENSOR_OP_MATH; -} - template port::Status CudnnSupport::DoFusedConvolveImpl( @@ -3336,8 +3385,6 @@ port::Status CudnnSupport::DoFusedConvolveImpl( filter_descriptor, GetCudnnDataType(conv_input_descriptor.layout())); CudnnTensorDescriptor bias_nd(bias_descriptor, GetCudnnDataType()); - CudnnConvolutionDescriptor conv(convolution_descriptor, - ToCudnnDataType(accumulator_type)); auto cudnn = cudnn_->GetHandle(parent_, stream); @@ -3347,9 +3394,14 @@ port::Status CudnnSupport::DoFusedConvolveImpl( SE_ASSIGN_OR_RETURN( dnn::AlgorithmDesc algo_desc, GetCudnnConvolutionForwardAlgorithm( - stream, cudnn, algorithm_config, conv_input_nd, filter, conv, + stream, cudnn, algorithm_config, conv_input_nd, filter, + dnn::ToDataType::value, convolution_descriptor, output_nd, scratch_allocator, &scratch)); + CudnnConvolutionDescriptor conv(convolution_descriptor, + ToCudnnDataType(accumulator_type)); + conv.set_use_tensor_op_math(algo_desc.tensor_ops_enabled()); + std::unique_ptr timer; if (is_profiling) { timer.reset(new GpuTimer(parent_)); // NOLINT @@ -3480,9 +3532,7 @@ bool CudnnSupport::GetRnnAlgorithms( for (auto i : algo_types) { out_algorithms->push_back({i, /*use_tensor_ops=*/false}); #if CUDNN_VERSION >= 7100 - if (RnnTensorOpMathEnabled()) { - out_algorithms->push_back({i, /*use_tensor_ops=*/true}); - } + out_algorithms->push_back({i, /*use_tensor_ops=*/true}); #endif } return true;