Use TensorFloat-32 instead of TF32 in comments and symbols.
The name TensorFloat-32 is clearer than TF32, as the "TF" in TF32 may be confused for TensorFlow. Also replace word "allow" with "enable" in internal functions like "allow_tf32_execution" to match the Python API PiperOrigin-RevId: 330621742 Change-Id: I3f5518aecf4f946d241b422e6a4d566fa199185f
This commit is contained in:
parent
097581d6a6
commit
443caf8a5b
tensorflow
compiler
core
common_runtime/gpu
kernels
platform
python
stream_executor/cuda
tools/def_file_filter
@ -33,7 +33,8 @@ from tensorflow.python.platform import test
|
||||
|
||||
@test_util.run_all_without_tensor_float_32(
|
||||
"XLA QR op calls matmul. Also, matmul used for verification. Also with "
|
||||
'TF32, mysterious "Unable to launch cuBLAS gemm" error occasionally occurs')
|
||||
'TensorFloat-32, mysterious "Unable to launch cuBLAS gemm" error '
|
||||
"occasionally occurs")
|
||||
# TODO(b/165435566): Fix "Unable to launch cuBLAS gemm" error
|
||||
class QrOpTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
|
||||
|
@ -305,7 +305,7 @@ xla_test(
|
||||
"//tensorflow/compiler/xla/tests:test_macros_header",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core/platform:tf32_utils",
|
||||
"//tensorflow/core/platform:tensor_float_32_utils",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -27,14 +27,15 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/tf32_utils.h"
|
||||
#include "tensorflow/core/platform/tensor_float_32_utils.h"
|
||||
|
||||
namespace {
|
||||
|
||||
using QrTest = xla::ClientLibraryTestBase;
|
||||
|
||||
XLA_TEST_F(QrTest, Simple) {
|
||||
tensorflow::allow_tf32_execution(false); // Test fails with tf32 allowed
|
||||
// Test fails with TensorFloat-32 enabled
|
||||
tensorflow::enable_tensor_float_32_execution(false);
|
||||
xla::XlaBuilder builder(TestName());
|
||||
|
||||
xla::Array2D<float> a_vals({
|
||||
@ -63,7 +64,8 @@ XLA_TEST_F(QrTest, Simple) {
|
||||
}
|
||||
|
||||
XLA_TEST_F(QrTest, ZeroDiagonal) {
|
||||
tensorflow::allow_tf32_execution(false); // Test fails with tf32 allowed
|
||||
// Test fails with TensorFloat-32 enabled
|
||||
tensorflow::enable_tensor_float_32_execution(false);
|
||||
xla::XlaBuilder builder(TestName());
|
||||
|
||||
xla::Array2D<float> a_vals({
|
||||
@ -91,7 +93,8 @@ XLA_TEST_F(QrTest, ZeroDiagonal) {
|
||||
}
|
||||
|
||||
XLA_TEST_F(QrTest, SimpleBatched) {
|
||||
tensorflow::allow_tf32_execution(false); // Test fails with tf32 allowed
|
||||
// Test fails with TensorFloat-32 enabled
|
||||
tensorflow::enable_tensor_float_32_execution(false);
|
||||
xla::XlaBuilder builder(TestName());
|
||||
|
||||
xla::Array3D<float> a_vals({
|
||||
|
@ -2694,6 +2694,6 @@ xla_test(
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core/platform:tf32_utils",
|
||||
"//tensorflow/core/platform:tensor_float_32_utils",
|
||||
],
|
||||
)
|
||||
|
@ -30,7 +30,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/tf32_utils.h"
|
||||
#include "tensorflow/core/platform/tensor_float_32_utils.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
@ -182,7 +182,8 @@ class RandomCholeskyTest
|
||||
public ::testing::WithParamInterface<CholeskyTestCase> {};
|
||||
|
||||
XLA_TEST_P(RandomCholeskyTest, Random) {
|
||||
tensorflow::allow_tf32_execution(false); // Test fails with tf32 allowed
|
||||
// Test fails with TensorFloat-32 enabled
|
||||
tensorflow::enable_tensor_float_32_execution(false);
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
auto test_params = GetParam();
|
||||
|
@ -159,7 +159,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:stream_executor",
|
||||
"//tensorflow/core/platform:tf32_utils",
|
||||
"//tensorflow/core/platform:tensor_float_32_utils",
|
||||
"//tensorflow/core/profiler/lib:annotated_traceme",
|
||||
"//tensorflow/core/profiler/lib:scoped_annotation",
|
||||
"//third_party/eigen3",
|
||||
|
@ -1672,7 +1672,7 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/kernels/image",
|
||||
"//tensorflow/core/platform:tf32_utils",
|
||||
"//tensorflow/core/platform:tensor_float_32_utils",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
],
|
||||
)
|
||||
|
@ -31,9 +31,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/tensor_float_32_utils.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
#include "tensorflow/core/platform/tf32_utils.h"
|
||||
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
@ -1045,7 +1045,8 @@ TYPED_TEST_P(FusedConv2DWithBiasOpTest, ExplicitPaddingConvolution) {
|
||||
#endif
|
||||
|
||||
TYPED_TEST_P(FusedConv2DWithBiasOpTest, OneByOneConvolutionAndActivation) {
|
||||
tensorflow::allow_tf32_execution(false); // Requires full precision Conv2D op
|
||||
// Requires full precision Conv2D op
|
||||
tensorflow::enable_tensor_float_32_execution(false);
|
||||
const int filter_size = 1;
|
||||
const int filter_count = 12;
|
||||
for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) {
|
||||
|
@ -978,9 +978,9 @@ cc_library(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf32_utils",
|
||||
srcs = ["tf32_utils.cc"],
|
||||
hdrs = ["tf32_utils.h"],
|
||||
name = "tensor_float_32_utils",
|
||||
srcs = ["tensor_float_32_utils.cc"],
|
||||
hdrs = ["tensor_float_32_utils.h"],
|
||||
copts = tf_copts(),
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -1006,8 +1006,8 @@ cc_library(
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "tf32_hdr",
|
||||
srcs = ["tf32_utils.h"],
|
||||
name = "tensor_float_32_hdr",
|
||||
srcs = ["tensor_float_32_utils.h"],
|
||||
)
|
||||
|
||||
tf_cc_tests(
|
||||
|
@ -13,18 +13,19 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/platform/tf32_utils.h"
|
||||
#include "tensorflow/core/platform/tensor_float_32_utils.h"
|
||||
|
||||
#include <atomic>
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Whether TensorFloat-32 should be used where supported.
|
||||
// TODO(reedwm): Change word "allow" to "enable" in all TensorFloat-32 functions
|
||||
static std::atomic<bool> tf32_allowed{true};
|
||||
static std::atomic<bool> tensor_float_32_enabled{true};
|
||||
|
||||
void allow_tf32_execution(bool allowed) { tf32_allowed = allowed; }
|
||||
void enable_tensor_float_32_execution(bool enabled) {
|
||||
tensor_float_32_enabled = enabled;
|
||||
}
|
||||
|
||||
bool tf32_execution_allowed() { return tf32_allowed; }
|
||||
bool tensor_float_32_execution_enabled() { return tensor_float_32_enabled; }
|
||||
|
||||
} // namespace tensorflow
|
@ -13,15 +13,15 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_PLATFORM_TF32_UTILS_H_
|
||||
#define TENSORFLOW_CORE_PLATFORM_TF32_UTILS_H_
|
||||
#ifndef TENSORFLOW_CORE_PLATFORM_TENSOR_FLOAT_32_UTILS_H_
|
||||
#define TENSORFLOW_CORE_PLATFORM_TENSOR_FLOAT_32_UTILS_H_
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
void allow_tf32_execution(bool allowed);
|
||||
void enable_tensor_float_32_execution(bool enabled);
|
||||
|
||||
bool tf32_execution_allowed();
|
||||
bool tensor_float_32_execution_enabled();
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_PLATFORM_TF32_UTILS_H_
|
||||
#endif // TENSORFLOW_CORE_PLATFORM_TENSOR_FLOAT_32_UTILS_H_
|
@ -795,10 +795,10 @@ tf_python_pybind_extension(
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "_pywrap_tf32_execution",
|
||||
srcs = ["util/tf32.cc"],
|
||||
hdrs = ["//tensorflow/core/platform:tf32_hdr"],
|
||||
module_name = "_pywrap_tf32_execution",
|
||||
name = "_pywrap_tensor_float_32_execution",
|
||||
srcs = ["util/tensor_float_32.cc"],
|
||||
hdrs = ["//tensorflow/core/platform:tensor_float_32_hdr"],
|
||||
module_name = "_pywrap_tensor_float_32_execution",
|
||||
deps = [
|
||||
"@pybind11",
|
||||
],
|
||||
@ -5556,7 +5556,7 @@ py_library(
|
||||
"//tensorflow:composite_tensor_whitelist",
|
||||
],
|
||||
deps = [
|
||||
":_pywrap_tf32_execution",
|
||||
":_pywrap_tensor_float_32_execution",
|
||||
# global_test_configuration is added here because all major tests depend on this
|
||||
# library. It isn't possible to add these test dependencies via tensorflow.bzl's
|
||||
# py_test because not all tensorflow tests use tensorflow.bzl's py_test.
|
||||
@ -5972,7 +5972,7 @@ pywrap_tensorflow_macro(
|
||||
"@ngraph_tf//:ngraph_tf",
|
||||
]) + if_xla_available([
|
||||
"//tensorflow/compiler/aot:tfcompile_lib",
|
||||
]) + if_static(extra_deps = ["//tensorflow/core/platform:tf32_utils"]),
|
||||
]) + if_static(extra_deps = ["//tensorflow/core/platform:tensor_float_32_utils"]),
|
||||
)
|
||||
|
||||
# ** Targets for Windows build (start) **
|
||||
@ -6025,7 +6025,7 @@ filegroup(
|
||||
"//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool", # graph_analyzer
|
||||
"//tensorflow/core/grappler/optimizers:meta_optimizer", # tf_optimizer
|
||||
"//tensorflow/core/grappler/utils:topological_sort", # tf_item
|
||||
"//tensorflow/core/platform:tf32_utils", # tf32
|
||||
"//tensorflow/core/platform:tensor_float_32_utils", # tensor_float_32
|
||||
"//tensorflow/core/profiler/internal:annotation_stack_impl", # profiler
|
||||
"//tensorflow/core/profiler/internal:print_model_analysis", # tfprof
|
||||
"//tensorflow/core/profiler/internal:traceme_recorder_impl", # profiler
|
||||
|
@ -18,7 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import _pywrap_tf32_execution
|
||||
from tensorflow.python import _pywrap_tensor_float_32_execution
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
@ -34,7 +34,7 @@ def tensor_float_32_execution_enabled():
|
||||
Returns:
|
||||
True if TensorFloat-32 is enabled (the default) and False otherwise
|
||||
"""
|
||||
return _pywrap_tf32_execution.is_allowed()
|
||||
return _pywrap_tensor_float_32_execution.is_enabled()
|
||||
|
||||
|
||||
@tf_export('config.experimental.enable_tensor_float_32_execution')
|
||||
@ -90,7 +90,7 @@ def enable_tensor_float_32_execution(enabled):
|
||||
Args:
|
||||
enabled: Bool indicating whether to enable TensorFloat-32 execution.
|
||||
"""
|
||||
_pywrap_tf32_execution.allow(enabled)
|
||||
_pywrap_tensor_float_32_execution.enable(enabled)
|
||||
|
||||
|
||||
@tf_export('config.threading.get_intra_op_parallelism_threads')
|
||||
|
@ -760,17 +760,18 @@ class TensorFloat32Test(test.TestCase):
|
||||
super(TensorFloat32Test, self).tearDown()
|
||||
config.enable_tensor_float_32_execution(True)
|
||||
|
||||
def test_tf32_enabled(self):
|
||||
def test_tensor_float_32_enabled(self):
|
||||
self.assertTrue(config.tensor_float_32_execution_enabled())
|
||||
|
||||
x = array_ops.fill((8, 8), 1 + 2**-20)
|
||||
y = array_ops.ones((8, 8))
|
||||
out = math_ops.matmul(x, y)
|
||||
# In tf32, each element of x is rounded to 1, so the output will be 8s.
|
||||
# In TensorFloat-32, each element of x is rounded to 1, so the output will
|
||||
# be 8s.
|
||||
expected = array_ops.fill((8, 8), 8)
|
||||
self.assertAllEqual(out, expected)
|
||||
|
||||
def test_tf32_disabled(self):
|
||||
def test_tensor_float_32_disabled(self):
|
||||
self.assertTrue(config.tensor_float_32_execution_enabled())
|
||||
config.enable_tensor_float_32_execution(False)
|
||||
self.assertFalse(config.tensor_float_32_execution_enabled())
|
||||
|
@ -29,7 +29,8 @@ from tensorflow.python.platform import test
|
||||
|
||||
@testing_utils.run_all_without_tensor_float_32(
|
||||
'Uses Dense layers, which call matmul. Even if Dense layers run in '
|
||||
'float64, the test sometimes fails with tf32 enabled for unknown reasons')
|
||||
'float64, the test sometimes fails with TensorFloat-32 enabled for unknown '
|
||||
'reasons')
|
||||
class DistributionStrategyCnnCorrectnessTest(
|
||||
keras_correctness_test_base.TestDistributionStrategyCorrectnessBase):
|
||||
|
||||
|
@ -50,8 +50,8 @@ def GetTestConfigs():
|
||||
|
||||
@test_util.run_all_without_tensor_float_32(
|
||||
"Tests Conv3d, which in some cases is implemented with a matmul. With "
|
||||
"tf32, tests fail in some of those cases (and as of August 13 2020, only "
|
||||
"those cases)")
|
||||
"TensorFloat-32, tests fail in some of those cases (and as of August 13 "
|
||||
"2020, only those cases)")
|
||||
class Conv3DTest(test.TestCase):
|
||||
|
||||
def _DtypesToTest(self, use_gpu):
|
||||
|
@ -108,13 +108,13 @@ def _GetMatrixBinaryFunctorGradientTest(functor_,
|
||||
@test_util.run_without_tensor_float_32(
|
||||
'Tests `tf.linalg.lstsq`, which call matmul. Additionally, calls ops '
|
||||
'which do matmul in their gradient, such as MatrixSolveLs.')
|
||||
# TODO(b/164254522): With tf32, some tests fails with extremely high absolute
|
||||
# and relative differences when calling assertAllClose. For example, the test
|
||||
# test_MatrixSolveLsGradient_float32_10_10_1e-06 of class
|
||||
# TODO(b/164254522): With TensorFloat-32, some tests fails with extremely high
|
||||
# absolute and relative differences when calling assertAllClose. For example,
|
||||
# the test test_MatrixSolveLsGradient_float32_10_10_1e-06 of class
|
||||
# MatrixBinaryFunctorGradientTest fails with a max absolute difference of
|
||||
# 0.883 and a max relative difference of 736892. We should consider disabling
|
||||
# tf32 within `tf.linalg.lstsq and perhaps other linear algebra functions,
|
||||
# even if tf32 is allowed globally.
|
||||
# TensorFloat-32 within `tf.linalg.lstsq and perhaps other linear algebra
|
||||
# functions, even if TensorFloat-32 is allowed globally.
|
||||
def Test(self):
|
||||
|
||||
def RandomInput():
|
||||
|
@ -14,9 +14,9 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "tensorflow/core/platform/tf32_utils.h"
|
||||
#include "tensorflow/core/platform/tensor_float_32_utils.h"
|
||||
|
||||
PYBIND11_MODULE(_pywrap_tf32_execution, m) {
|
||||
m.def("allow", &tensorflow::allow_tf32_execution);
|
||||
m.def("is_allowed", &tensorflow::tf32_execution_allowed);
|
||||
PYBIND11_MODULE(_pywrap_tensor_float_32_execution, m) {
|
||||
m.def("enable", &tensorflow::enable_tensor_float_32_execution);
|
||||
m.def("is_enabled", &tensorflow::tensor_float_32_execution_enabled);
|
||||
}
|
@ -263,7 +263,7 @@ cc_library(
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/platform:tf32_utils",
|
||||
"//tensorflow/core/platform:tensor_float_32_utils",
|
||||
"//tensorflow/stream_executor",
|
||||
"//tensorflow/stream_executor:event",
|
||||
"//tensorflow/stream_executor:host_or_device_scalar",
|
||||
@ -370,7 +370,7 @@ cc_library(
|
||||
"@local_config_cuda//cuda:cudnn_header",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/platform:tf32_utils",
|
||||
"//tensorflow/core/platform:tensor_float_32_utils",
|
||||
"//tensorflow/stream_executor:dnn",
|
||||
"//tensorflow/stream_executor:event",
|
||||
"//tensorflow/stream_executor:plugin_registry",
|
||||
|
@ -49,7 +49,7 @@ limitations under the License.
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "tensorflow/core/platform/tf32_utils.h"
|
||||
#include "tensorflow/core/platform/tensor_float_32_utils.h"
|
||||
#include "tensorflow/core/util/env_var.h"
|
||||
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
|
||||
#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
|
||||
@ -400,7 +400,7 @@ bool CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
|
||||
ScopedCublasMathMode math_mode{blas_};
|
||||
#if CUBLAS_VER_MAJOR >= 11
|
||||
if (math_type == CUBLAS_TF32_TENSOR_OP_MATH &&
|
||||
tensorflow::tf32_execution_allowed()) {
|
||||
tensorflow::tensor_float_32_execution_enabled()) {
|
||||
#else
|
||||
if (math_type == CUBLAS_TENSOR_OP_MATH) {
|
||||
#endif
|
||||
@ -1952,7 +1952,7 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
|
||||
<< " uses tensor ops, but tensor ops are not available in sm"
|
||||
<< cc_major << "X devices for float input types.";
|
||||
return false;
|
||||
} else if (!tensorflow::tf32_execution_allowed()) {
|
||||
} else if (!tensorflow::tensor_float_32_execution_enabled()) {
|
||||
VLOG(2) << "DoBlasGemmWithAlgorithm returning false because algorithm "
|
||||
<< algorithm
|
||||
<< " uses tensor ops, but tensor ops are disabled for fp32"
|
||||
@ -2294,10 +2294,11 @@ port::Status CUDABlas::DoBlasGemmBatchedInternal(
|
||||
#if CUBLAS_VER_MAJOR >= 11
|
||||
} else if (data_type == CUDA_R_32F) {
|
||||
// DoBlassInternalImpl will switch math_type back to CUBLAS_DEFAULT_MATH
|
||||
// if TF32 is disabled.
|
||||
// if TensorFloat-32 is disabled.
|
||||
math_type = CUBLAS_TF32_TENSOR_OP_MATH;
|
||||
algo = tensorflow::tf32_execution_allowed() ? CUBLAS_GEMM_DFALT_TENSOR_OP
|
||||
: CUBLAS_GEMM_DFALT;
|
||||
algo = tensorflow::tensor_float_32_execution_enabled()
|
||||
? CUBLAS_GEMM_DFALT_TENSOR_OP
|
||||
: CUBLAS_GEMM_DFALT;
|
||||
#endif
|
||||
} else {
|
||||
math_type = CUBLAS_DEFAULT_MATH;
|
||||
|
@ -22,7 +22,7 @@ limitations under the License.
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/tf32_utils.h"
|
||||
#include "tensorflow/core/platform/tensor_float_32_utils.h"
|
||||
#include "tensorflow/core/util/env_var.h"
|
||||
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
|
||||
#include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
|
||||
@ -740,7 +740,7 @@ static bool IsTensorMathOpSet(const CudnnConvolutionDescriptor& conv) {
|
||||
|
||||
static bool TensorOpMathAvailable(int cc_major) { return cc_major >= 7; }
|
||||
|
||||
static bool IsTensorMathAllowed(Stream* stream, dnn::DataType input_type) {
|
||||
static bool IsTensorMathEnabled(Stream* stream, dnn::DataType input_type) {
|
||||
int cc_major, cc_minor;
|
||||
std::tie(cc_major, cc_minor) = GetCcMajorMinor(stream);
|
||||
if (!TensorOpMathAvailable(cc_major)) {
|
||||
@ -750,7 +750,7 @@ static bool IsTensorMathAllowed(Stream* stream, dnn::DataType input_type) {
|
||||
#if CUDNN_VERSION < 8000
|
||||
return false;
|
||||
#else
|
||||
if (!tensorflow::tf32_execution_allowed()) {
|
||||
if (!tensorflow::tensor_float_32_execution_enabled()) {
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
@ -1099,7 +1099,7 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
|
||||
// TODO(csigg): Minimal support cuDNN version is 7.3, clean up.
|
||||
bool allow_tensor_ops = data_type == CUDNN_DATA_HALF;
|
||||
if (data_type == CUDNN_DATA_FLOAT)
|
||||
allow_tensor_ops = tensorflow::tf32_execution_allowed();
|
||||
allow_tensor_ops = tensorflow::tensor_float_32_execution_enabled();
|
||||
bool use_tensor_ops =
|
||||
algorithm_config.algorithm().has_value()
|
||||
? algorithm_config.algorithm()->tensor_ops_enabled()
|
||||
@ -2647,12 +2647,12 @@ port::StatusOr<bool> UseTensorOps(Stream* stream, dnn::DataType type,
|
||||
bool use_tensor_ops;
|
||||
if (desc.has_value()) {
|
||||
use_tensor_ops = desc->tensor_ops_enabled();
|
||||
if (use_tensor_ops && !IsTensorMathAllowed(stream, type)) {
|
||||
if (use_tensor_ops && !IsTensorMathEnabled(stream, type)) {
|
||||
return port::Status(port::error::INVALID_ARGUMENT,
|
||||
"Algo requests disallowed tensor op evaluation.");
|
||||
"Algo requests disabled tensor op evaluation.");
|
||||
}
|
||||
} else {
|
||||
use_tensor_ops = IsTensorMathAllowed(stream, type);
|
||||
use_tensor_ops = IsTensorMathEnabled(stream, type);
|
||||
}
|
||||
return use_tensor_ops;
|
||||
}
|
||||
|
@ -381,6 +381,6 @@ tensorflow::IsXlaEnabled
|
||||
tensorflow::GetMlirCommonFlags
|
||||
tensorflow::GetXlaDeviceFlags
|
||||
|
||||
[tf32_utils] # tf32
|
||||
tensorflow::allow_tf32_execution
|
||||
tensorflow::tf32_execution_allowed
|
||||
[tensor_float_32_utils] # tensor_float_32
|
||||
tensorflow::enable_tensor_float_32_execution
|
||||
tensorflow::tensor_float_32_execution_enabled
|
||||
|
Loading…
Reference in New Issue
Block a user