diff --git a/tensorflow/compiler/tests/cholesky_op_test.py b/tensorflow/compiler/tests/cholesky_op_test.py index 4bd2dfd9244..41877d39381 100644 --- a/tensorflow/compiler/tests/cholesky_op_test.py +++ b/tensorflow/compiler/tests/cholesky_op_test.py @@ -28,7 +28,6 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -61,7 +60,7 @@ class CholeskyOpTest(xla_test.XLATestCase): dtypes.as_dtype(x.dtype), shape=x.shape) with self.test_scope(): chol = linalg_ops.cholesky(placeholder) - verification = math_ops.matmul(chol, chol, adjoint_b=True) + verification = test_util.matmul_without_tf32(chol, chol, adjoint_b=True) self._verifyCholeskyBase(sess, placeholder, x, chol, verification, atol) def testBasic(self): diff --git a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py index 9d278cfbb28..08aad66abe1 100644 --- a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py +++ b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py @@ -29,7 +29,6 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -65,7 +64,8 @@ class MatrixTriangularSolveOpTest(xla_test.XLATestCase): with self.test_scope(): x = linalg_ops.matrix_triangular_solve( placeholder_a, placeholder_b, lower=lower, adjoint=adjoint) - verification = math_ops.matmul(placeholder_ca, x, adjoint_a=adjoint) + verification = test_util.matmul_without_tf32( + placeholder_ca, x, adjoint_a=adjoint) self._VerifyTriangularSolveBase(sess, placeholder_a, placeholder_ca, placeholder_b, a, clean_a, b, verification, atol) diff --git a/tensorflow/compiler/tests/qr_op_test.py b/tensorflow/compiler/tests/qr_op_test.py index 5fcf254db82..b2d5db8a3a8 100644 --- a/tensorflow/compiler/tests/qr_op_test.py +++ b/tensorflow/compiler/tests/qr_op_test.py @@ -24,12 +24,17 @@ from absl.testing import parameterized import numpy as np from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test +@test_util.run_all_without_tensor_float_32( + "XLA QR op calls matmul. Also, matmul used for verification. Also with " + 'TF32, mysterious "Unable to launch cuBLAS gemm" error occasionally occurs') +# TODO(b/165435566): Fix "Unable to launch cuBLAS gemm" error class QrOpTest(xla_test.XLATestCase, parameterized.TestCase): def AdjustedNorm(self, x): @@ -73,7 +78,7 @@ class QrOpTest(xla_test.XLATestCase, parameterized.TestCase): with self.session() as sess: x_tf = array_ops.placeholder(dtype) - with self.test_scope(): + with self.device_scope(): q_tf, r_tf = linalg_ops.qr(x_tf, full_matrices=full_matrices) q_tf_val, r_tf_val = sess.run([q_tf, r_tf], feed_dict={x_tf: x_np}) diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py index 8c31629c234..de97c6ff210 100644 --- a/tensorflow/compiler/tests/xla_test.py +++ b/tensorflow/compiler/tests/xla_test.py @@ -237,8 +237,8 @@ class XLATestCase(test.TestCase): 'test_session not supported on XLATestCase, please use session') @contextlib.contextmanager - def test_scope(self): - """Test scope that runs tests on `self.device`. + def device_scope(self): + """Scope that runs tests on `self.device`. Yields: A scope to apply to the operators under test. @@ -246,6 +246,15 @@ class XLATestCase(test.TestCase): with ops.device('device:{}:0'.format(self.device)): yield + def test_scope(self): + """Deprecated alias of `device_scope`. + + This should be avoided as the name starts with `test`, so test runners + treat it as a test. This interferes with class decorators that operate on + each test method. + """ + return self.device_scope() + def Benchmark(tf_bench, builder_fn, diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index d2f174eadb5..f0b4e5e6c79 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -305,6 +305,7 @@ xla_test( "//tensorflow/compiler/xla/tests:test_macros_header", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", + "//tensorflow/core/platform:tf32_utils", ], ) diff --git a/tensorflow/compiler/xla/client/lib/qr_test.cc b/tensorflow/compiler/xla/client/lib/qr_test.cc index a61f243e126..9752f844dfd 100644 --- a/tensorflow/compiler/xla/client/lib/qr_test.cc +++ b/tensorflow/compiler/xla/client/lib/qr_test.cc @@ -27,12 +27,14 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/tf32_utils.h" namespace { using QrTest = xla::ClientLibraryTestBase; XLA_TEST_F(QrTest, Simple) { + tensorflow::allow_tf32_execution(false); // Test fails with tf32 allowed xla::XlaBuilder builder(TestName()); xla::Array2D<float> a_vals({ @@ -61,6 +63,7 @@ XLA_TEST_F(QrTest, Simple) { } XLA_TEST_F(QrTest, ZeroDiagonal) { + tensorflow::allow_tf32_execution(false); // Test fails with tf32 allowed xla::XlaBuilder builder(TestName()); xla::Array2D<float> a_vals({ @@ -88,6 +91,7 @@ XLA_TEST_F(QrTest, ZeroDiagonal) { } XLA_TEST_F(QrTest, SimpleBatched) { + tensorflow::allow_tf32_execution(false); // Test fails with tf32 allowed xla::XlaBuilder builder(TestName()); xla::Array3D<float> a_vals({ diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 17444c042e7..734d2ed443c 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -2699,5 +2699,6 @@ xla_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", + "//tensorflow/core/platform:tf32_utils", ], ) diff --git a/tensorflow/compiler/xla/tests/cholesky_test.cc b/tensorflow/compiler/xla/tests/cholesky_test.cc index e7f5ca5ed8e..9a86852ce5c 100644 --- a/tensorflow/compiler/xla/tests/cholesky_test.cc +++ b/tensorflow/compiler/xla/tests/cholesky_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/tf32_utils.h" namespace xla { namespace { @@ -181,6 +182,7 @@ class RandomCholeskyTest public ::testing::WithParamInterface<CholeskyTestCase> {}; XLA_TEST_P(RandomCholeskyTest, Random) { + tensorflow::allow_tf32_execution(false); // Test fails with tf32 allowed XlaBuilder builder(TestName()); auto test_params = GetParam(); diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 0981bf8d65b..649f9a593b3 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1671,6 +1671,7 @@ tf_cuda_cc_test( "//tensorflow/core:test_main", "//tensorflow/core:testlib", "//tensorflow/core/kernels/image", + "//tensorflow/core/platform:tf32_utils", "@com_google_absl//absl/algorithm:container", ], ) diff --git a/tensorflow/core/kernels/conv_ops_test.cc b/tensorflow/core/kernels/conv_ops_test.cc index 3e192b83c57..fba71d5da33 100644 --- a/tensorflow/core/kernels/conv_ops_test.cc +++ b/tensorflow/core/kernels/conv_ops_test.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/platform/tf32_utils.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h" #include "tensorflow/core/public/session.h" @@ -1038,6 +1039,7 @@ TYPED_TEST_P(FusedConv2DWithBiasOpTest, ExplicitPaddingConvolution) { #endif TYPED_TEST_P(FusedConv2DWithBiasOpTest, OneByOneConvolutionAndActivation) { + tensorflow::allow_tf32_execution(false); // Requires full precision Conv2D op const int filter_size = 1; const int filter_count = 12; for (const string& activation : {"Relu", "Relu6", "Elu"}) { diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index 596b93227bf..298d41a995c 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -1592,6 +1592,7 @@ class UnrollLSTMTest(test.TestCase): self.assertAllClose(mv0, mv2, rtol=1e-4) self.assertAllClose(mv0, mv3, rtol=1e-4) + @test_util.run_without_tensor_float_32("Calls matmul in custom LSTM function") def testUnrollLSTMGrad(self): # Run one step of the unrolled lstm graph. def RunForwardBackward(mode, cfg=None): diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 4d7b7746b9c..874ecee2ba5 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -54,6 +54,7 @@ from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import tape +from tensorflow.python.framework import config from tensorflow.python.framework import device as pydev from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -70,6 +71,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import control_flow_util_v2 from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import math_ops from tensorflow.python.ops import script_ops from tensorflow.python.ops import summary_ops_v2 from tensorflow.python.ops import variables @@ -1908,6 +1910,68 @@ def xla_allow_fallback(description): # pylint: disable=unused-argument return xla_allow_fallback_impl +# The description is just for documentation purposes. +def run_without_tensor_float_32(description): # pylint: disable=unused-argument + """Execute test without TensorFloat-32 being allowed. + + While almost every real-world deep learning model runs fine with + TensorFloat-32 (TF32), many tests use assertAllClose or similar methods. TF32 + matmuls typically will cause such methods to fail with the default tolerances. + + Args: + description: A description used for documentation purposes, describing why + the test requires TensorFloat-32 to be disallowed. + """ + + def decorator(f): + + @functools.wraps(f) + def decorated(self, *args, **kwargs): + allowed = config.tensor_float_32_execution_allowed() + try: + config.allow_tensor_float_32_execution(False) + f(self, *args, **kwargs) + finally: + config.allow_tensor_float_32_execution(allowed) + + return decorated + + return decorator + + +# The description is just for documentation purposes. +def run_all_without_tensor_float_32(description): # pylint: disable=unused-argument + """Execute all tests in a class without TensorFloat-32 being allowed.""" + return for_all_test_methods(run_without_tensor_float_32, description) + + +def matmul_without_tf32(a, b, *args, **kwargs): + """Run matmul, but cast float32 inputs to float64 if TF32 is allowed. + + This effectively runs matmul without TensorFloat-32 (TF32). It should only be + used in tests when verifying some other op or functions works correctly, e.g. + to test `tf.linalg.sqrtm` by matrix multiplying the output of the op + by itself. In such cases, the matmul itself is not being tested so it's OK to + run it with higher precision. + + If a matmul itself is being tested, or some other op which uses matmul, use + `run_without_tensor_float_32` instead. + + Args: + a: First input to tf.linalg.matmul + b: Second input to tf.linalg.matmul + args: Other positional arguments to tf.linalg.matmul + **kwargs: Other keyword arguments to tf.linalg.matmul + """ + if config.tensor_float_32_execution_allowed() and a.dtype == "float32": + a = math_ops.cast(a, "float64") + b = math_ops.cast(b, "float64") + ret = math_ops.matmul(a, b, *args, **kwargs) + return math_ops.cast(ret, a.dtype) + else: + return math_ops.matmul(a, b, *args, **kwargs) + + class EagerSessionWarner(object): def __getattr__(self, attr): diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index d8eff0f2260..95a192af51d 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -307,6 +307,7 @@ py_library( deps = [ ":backend", ":models", + "//tensorflow/python:config", "//tensorflow/python:framework_test_lib", "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_spec", diff --git a/tensorflow/python/keras/distribute/distribute_strategy_test.py b/tensorflow/python/keras/distribute/distribute_strategy_test.py index 4ea53429195..b3e67b7d87c 100644 --- a/tensorflow/python/keras/distribute/distribute_strategy_test.py +++ b/tensorflow/python/keras/distribute/distribute_strategy_test.py @@ -1605,6 +1605,8 @@ class TestRegularizerLoss(test.TestCase, parameterized.TestCase): self.assertEqual(-1.0, v) +@testing_utils.run_all_without_tensor_float_32( + 'Uses Dense layers, which call matmul') class TestDistributionStrategyWithKerasModels(test.TestCase, parameterized.TestCase): diff --git a/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py b/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py index 6ec7cc2bac5..e04b40e33be 100644 --- a/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py +++ b/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py @@ -24,6 +24,7 @@ from tensorflow.python.distribute import combinations from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.eager import context from tensorflow.python.keras import backend as K +from tensorflow.python.keras import testing_utils from tensorflow.python.keras.distribute import keras_correctness_test_base from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras from tensorflow.python.platform import test @@ -47,6 +48,8 @@ def is_default_strategy(strategy): return not distribution_strategy_context.has_strategy() +@testing_utils.run_all_without_tensor_float_32( + 'Uses Dense layers, which call matmul') class TestDistributionStrategyDnnCorrectness( keras_correctness_test_base.TestDistributionStrategyCorrectnessBase): @@ -240,6 +243,8 @@ class SubclassedModel(keras.Model): return self.dense4(x) +@testing_utils.run_all_without_tensor_float_32( + 'Uses Dense layers, which call matmul') class TestDistributionStrategyDnnCorrectnessWithSubclassedModel( TestDistributionStrategyDnnCorrectness): diff --git a/tensorflow/python/keras/distribute/keras_image_model_correctness_test.py b/tensorflow/python/keras/distribute/keras_image_model_correctness_test.py index 7e6ae3cc719..57b9b718491 100644 --- a/tensorflow/python/keras/distribute/keras_image_model_correctness_test.py +++ b/tensorflow/python/keras/distribute/keras_image_model_correctness_test.py @@ -21,11 +21,15 @@ import numpy as np from tensorflow.python import keras from tensorflow.python.distribute import combinations from tensorflow.python.eager import context +from tensorflow.python.keras import testing_utils from tensorflow.python.keras.distribute import keras_correctness_test_base from tensorflow.python.keras.optimizer_v2 import gradient_descent from tensorflow.python.platform import test +@testing_utils.run_all_without_tensor_float_32( + 'Uses Dense layers, which call matmul. Even if Dense layers run in ' + 'float64, the test sometimes fails with tf32 enabled for unknown reasons') class DistributionStrategyCnnCorrectnessTest( keras_correctness_test_base.TestDistributionStrategyCorrectnessBase): diff --git a/tensorflow/python/keras/distribute/keras_rnn_model_correctness_test.py b/tensorflow/python/keras/distribute/keras_rnn_model_correctness_test.py index aa7f0c20045..4e82b7db433 100644 --- a/tensorflow/python/keras/distribute/keras_rnn_model_correctness_test.py +++ b/tensorflow/python/keras/distribute/keras_rnn_model_correctness_test.py @@ -69,6 +69,8 @@ class _DistributionStrategyRnnModelCorrectnessTest( return model +@testing_utils.run_all_without_tensor_float_32( + 'Uses Dense layers, which call matmul') class DistributionStrategyGruModelCorrectnessTest( _DistributionStrategyRnnModelCorrectnessTest): @@ -88,6 +90,8 @@ class DistributionStrategyGruModelCorrectnessTest( self.run_correctness_test(distribution, use_numpy, use_validation_data) +@testing_utils.run_all_without_tensor_float_32( + 'Uses Dense layers, which call matmul') class DistributionStrategyLstmModelCorrectnessTest( _DistributionStrategyRnnModelCorrectnessTest): diff --git a/tensorflow/python/keras/distribute/keras_save_load_test.py b/tensorflow/python/keras/distribute/keras_save_load_test.py index 65877a0f869..fc2e2bd46ec 100644 --- a/tensorflow/python/keras/distribute/keras_save_load_test.py +++ b/tensorflow/python/keras/distribute/keras_save_load_test.py @@ -20,10 +20,13 @@ from __future__ import print_function from tensorflow.python.distribute import combinations from tensorflow.python.eager import test +from tensorflow.python.keras import testing_utils from tensorflow.python.keras.distribute import saved_model_test_base as test_base from tensorflow.python.keras.saving import save +@testing_utils.run_all_without_tensor_float_32( + 'Uses Dense layers, which call matmul') class KerasSaveLoadTest(test_base.TestSavedModelBase): def setUp(self): diff --git a/tensorflow/python/keras/distribute/saved_model_mixed_api_test.py b/tensorflow/python/keras/distribute/saved_model_mixed_api_test.py index d303a4228b5..7815d7403fd 100644 --- a/tensorflow/python/keras/distribute/saved_model_mixed_api_test.py +++ b/tensorflow/python/keras/distribute/saved_model_mixed_api_test.py @@ -26,12 +26,15 @@ from __future__ import print_function from tensorflow.python.distribute import combinations from tensorflow.python.eager import test +from tensorflow.python.keras import testing_utils from tensorflow.python.keras.distribute import saved_model_test_base as test_base from tensorflow.python.keras.saving import save _DEFAULT_FUNCTION_KEY = 'serving_default' +@testing_utils.run_all_without_tensor_float_32( + 'Uses Dense layers, which call matmul') class SavedModelSaveAndLoadTest(test_base.TestSavedModelBase): def setUp(self): diff --git a/tensorflow/python/keras/distribute/saved_model_save_load_test.py b/tensorflow/python/keras/distribute/saved_model_save_load_test.py index 39856af2a20..2174d39bae4 100644 --- a/tensorflow/python/keras/distribute/saved_model_save_load_test.py +++ b/tensorflow/python/keras/distribute/saved_model_save_load_test.py @@ -24,6 +24,7 @@ from tensorflow.python.distribute import combinations from tensorflow.python.distribute import strategy_combinations from tensorflow.python.eager import test from tensorflow.python.framework import tensor_spec +from tensorflow.python.keras import testing_utils from tensorflow.python.keras.distribute import model_combinations from tensorflow.python.keras.distribute import saved_model_test_base as test_base from tensorflow.python.ops import array_ops @@ -32,6 +33,8 @@ from tensorflow.python.saved_model import save_options as save_options_lib from tensorflow.python.saved_model import saved_model +@testing_utils.run_all_without_tensor_float_32( + 'Uses Dense layers, which call matmul') class SavedModelKerasModelTest(test_base.TestSavedModelBase): def setUp(self): diff --git a/tensorflow/python/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py index 550ff664823..e2abd8506e0 100644 --- a/tensorflow/python/keras/testing_utils.py +++ b/tensorflow/python/keras/testing_utils.py @@ -26,6 +26,7 @@ import numpy as np from tensorflow.python import tf2 from tensorflow.python.eager import context +from tensorflow.python.framework import config from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -937,3 +938,65 @@ def use_gpu(): """Uses gpu when requested and available.""" with device(should_use_gpu=True): yield + + +def for_all_test_methods(decorator, *args, **kwargs): + """Generate class-level decorator from given method-level decorator. + + It is expected for the given decorator to take some arguments and return + a method that is then called on the test method to produce a decorated + method. + + Args: + decorator: The decorator to apply. + *args: Positional arguments + **kwargs: Keyword arguments + Returns: Function that will decorate a given classes test methods with the + decorator. + """ + + def all_test_methods_impl(cls): + """Apply decorator to all test methods in class.""" + for name in dir(cls): + value = getattr(cls, name) + if callable(value) and name.startswith('test') and (name != + 'test_session'): + setattr(cls, name, decorator(*args, **kwargs)(value)) + return cls + + return all_test_methods_impl + + +# The description is just for documentation purposes. +def run_without_tensor_float_32(description): # pylint: disable=unused-argument + """Execute test without TensorFloat-32 being allowed. + + While almost every real-world deep learning model runs fine with + TensorFloat-32 (TF32), many tests use assertAllClose or similar methods. TF32 + matmuls typically will cause such methods to fail with the default tolerances. + + Args: + description: A description used for documentation purposes, describing why + the test requires TensorFloat-32 to be disallowed. + """ + + def decorator(f): + + @functools.wraps(f) + def decorated(self, *args, **kwargs): + allowed = config.tensor_float_32_execution_allowed() + try: + config.allow_tensor_float_32_execution(False) + f(self, *args, **kwargs) + finally: + config.allow_tensor_float_32_execution(allowed) + + return decorated + + return decorator + + +# The description is just for documentation purposes. +def run_all_without_tensor_float_32(description): # pylint: disable=unused-argument + """Execute all tests in a class without TensorFloat-32 being allowed.""" + return for_all_test_methods(run_without_tensor_float_32, description) diff --git a/tensorflow/python/kernel_tests/batch_matmul_op_test.py b/tensorflow/python/kernel_tests/batch_matmul_op_test.py index 30b61027813..ac82a320bb6 100644 --- a/tensorflow/python/kernel_tests/batch_matmul_op_test.py +++ b/tensorflow/python/kernel_tests/batch_matmul_op_test.py @@ -130,6 +130,7 @@ class BatchMatmulOpTest(test.TestCase): def _GetBatchMatmulOpTest(dtype, adjoint_a, adjoint_b, use_static_shape): + @test_util.run_without_tensor_float_32("Tests batch matmul") def Test(self): np.random.seed(42) self._testNonEmpty(dtype, adjoint_a, adjoint_b, use_static_shape) @@ -141,6 +142,7 @@ def _GetBatchMatmulOpTest(dtype, adjoint_a, adjoint_b, use_static_shape): def _GetBatchMatmulOpBroadcastingTest(dtype, adjoint_a, adjoint_b, use_static_shape): + @test_util.run_without_tensor_float_32("Tests batch matmul") def Test(self): np.random.seed(42) self._testBroadcasting(dtype, adjoint_a, adjoint_b, use_static_shape) diff --git a/tensorflow/python/kernel_tests/cholesky_op_test.py b/tensorflow/python/kernel_tests/cholesky_op_test.py index a9afca8bfe7..0697f7def1b 100644 --- a/tensorflow/python/kernel_tests/cholesky_op_test.py +++ b/tensorflow/python/kernel_tests/cholesky_op_test.py @@ -106,7 +106,7 @@ class CholeskyOpTest(test.TestCase): def _verifyCholesky(self, x): # Verify that LL^T == x. chol = linalg_ops.cholesky(x) - verification = math_ops.matmul(chol, chol, adjoint_b=True) + verification = test_util.matmul_without_tf32(chol, chol, adjoint_b=True) self._verifyCholeskyBase(x, chol, verification) @test_util.run_in_graph_and_eager_modes(use_gpu=True) @@ -271,8 +271,8 @@ class CholeskyGradTest(test.TestCase): def Compute(x): # Turn the random matrix x into a Hermitian matrix by # computing the quadratic form x * x^H. - a = math_ops.matmul(x, math_ops.conj( - array_ops.matrix_transpose(x))) / shape[0] + a = test_util.matmul_without_tf32( + x, math_ops.conj(array_ops.matrix_transpose(x))) / shape[0] if batch: a = array_ops.tile(array_ops.expand_dims(a, 0), [2, 1, 1]) # Finally take the cholesky decomposition of the Hermitian matrix. diff --git a/tensorflow/python/kernel_tests/conv_ops_3d_test.py b/tensorflow/python/kernel_tests/conv_ops_3d_test.py index 9bd962e75f3..3acc1fe03be 100644 --- a/tensorflow/python/kernel_tests/conv_ops_3d_test.py +++ b/tensorflow/python/kernel_tests/conv_ops_3d_test.py @@ -48,6 +48,10 @@ def GetTestConfigs(): return test_configs +@test_util.run_all_without_tensor_float_32( + "Tests Conv3d, which in some cases is implemented with a matmul. With " + "tf32, tests fail in some of those cases (and as of August 13 2020, only " + "those cases)") class Conv3DTest(test.TestCase): def _DtypesToTest(self, use_gpu): diff --git a/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py b/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py index a4c07daa940..7c8f389f178 100644 --- a/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py +++ b/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py @@ -268,6 +268,8 @@ class DirichletMultinomialTest(test.TestCase): self.assertAllClose(sample_var_, analytic_var, atol=0.05, rtol=0.) self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.02, rtol=0.) + @test_util.run_without_tensor_float_32( + "Tests DirichletMultinomial.covariance, which calls matmul") def testCovariance(self): # Shape [2] alpha = [1., 2] diff --git a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py index 0f963824531..a0d8bef327d 100644 --- a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py +++ b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py @@ -200,6 +200,8 @@ class DirichletTest(test.TestCase): self.assertAllClose(sample_var_, analytic_var, atol=0.03, rtol=0.) self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.02, rtol=0.) + @test_util.run_without_tensor_float_32( + "Calls Dirichlet.covariance, which calls matmul") def testVariance(self): alpha = [1., 2, 3] denominator = np.sum(alpha)**2 * (np.sum(alpha) + 1) diff --git a/tensorflow/python/kernel_tests/einsum_op_test.py b/tensorflow/python/kernel_tests/einsum_op_test.py index aa9e356bea5..4236eb93278 100644 --- a/tensorflow/python/kernel_tests/einsum_op_test.py +++ b/tensorflow/python/kernel_tests/einsum_op_test.py @@ -35,6 +35,8 @@ from tensorflow.python.platform import benchmark from tensorflow.python.platform import test +@test_util.run_all_without_tensor_float_32( + 'Tests einsum, which sometimes does a matmul with cuBLAS') class EinsumOpTest(test.TestCase): def _check(self, s, *input_shapes, **kwargs): @@ -285,6 +287,8 @@ class EinsumOpTest(test.TestCase): @test_util.run_all_in_graph_and_eager_modes +@test_util.run_all_without_tensor_float_32( + "Tests einsum's gradient, which sometimes does a matmul with cuBLAS") class EinsumGradTest(test.TestCase): def _check_gradient(self, s, *input_shapes): diff --git a/tensorflow/python/kernel_tests/init_ops_test.py b/tensorflow/python/kernel_tests/init_ops_test.py index e3268fad2d8..f2348c6c7ac 100644 --- a/tensorflow/python/kernel_tests/init_ops_test.py +++ b/tensorflow/python/kernel_tests/init_ops_test.py @@ -945,6 +945,8 @@ class ConvolutionDeltaOrthogonalInitializerTest(test.TestCase): self.assertAllClose(abs_value, count, rtol=tol, atol=tol) +@test_util.run_all_without_tensor_float_32( + "Tests convolutional_orthogonal_1d, which calls matmul") class ConvolutionOrthogonal1dInitializerTest(test.TestCase): @test_util.run_deprecated_v1 @@ -1174,6 +1176,8 @@ class ConvolutionOrthogonal2dInitializerTest(test.TestCase): self.assertAllClose(self.evaluate(ratio), gain, rtol=tol, atol=tol) +@test_util.run_all_without_tensor_float_32( + "Tests convolutional_orthogonal_3d, which calls matmul") class ConvolutionOrthogonal3dInitializerTest(test.TestCase): @test_util.run_deprecated_v1 diff --git a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py index ac82f190db0..f42600bd334 100644 --- a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py +++ b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py @@ -534,7 +534,7 @@ class CSRSparseMatrixOpsTest(test.TestCase): c_value = self.evaluate(c) expected_c_value = self.evaluate( - math_ops.conj(math_ops.matmul(a_dense, b))) + math_ops.conj(test_util.matmul_without_tf32(a_dense, b))) self.assertAllClose(expected_c_value, c_value) @test_util.run_in_graph_and_eager_modes @@ -576,7 +576,7 @@ class CSRSparseMatrixOpsTest(test.TestCase): transpose_b=transpose_b, adjoint_a=adjoint_a, adjoint_b=adjoint_b) - c_dense_t = math_ops.matmul( + c_dense_t = test_util.matmul_without_tf32( a_mats, b_mats, transpose_a=transpose_a, @@ -640,7 +640,7 @@ class CSRSparseMatrixOpsTest(test.TestCase): adjoint_b=adjoint_b) # Example: t(adj(a) . b) = t(b) . conj(a) - c_dense_t = math_ops.matmul( + c_dense_t = test_util.matmul_without_tf32( math_ops.conj(b_mats) if adjoint_b else b_mats, math_ops.conj(a_mats) if adjoint_a else a_mats, transpose_a=not (transpose_b or adjoint_b), @@ -670,7 +670,7 @@ class CSRSparseMatrixOpsTest(test.TestCase): c_t = sparse_csr_matrix_ops.sparse_matrix_mat_mul( a_sm, b_mats, conjugate_output=True) - c_dense_t = math_ops.conj(math_ops.matmul(a_mats, b_mats)) + c_dense_t = math_ops.conj(test_util.matmul_without_tf32(a_mats, b_mats)) self.assertAllEqual(c_t.shape, c_dense_t.shape) c_t_value, c_dense_t_value = self.evaluate((c_t, c_dense_t)) @@ -772,7 +772,7 @@ class CSRSparseMatrixOpsTest(test.TestCase): adjoint_b=adjoint_b) c_sm_dense = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense( c_sm, dtypes.float32) - c_dense_t = math_ops.matmul( + c_dense_t = test_util.matmul_without_tf32( a_mats, b_mats, transpose_a=transpose_a, @@ -1143,7 +1143,7 @@ class CSRSparseMatrixOpsTest(test.TestCase): dense_cholesky = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense( cholesky_sparse_matrices, dtype) # Compute L * Lh where L is the Sparse Cholesky factor. - verification = math_ops.matmul( + verification = test_util.matmul_without_tf32( dense_cholesky, array_ops.transpose(dense_cholesky, conjugate=True)) verification = twist_matrix(verification, ordering_amd) # Assert that input matrix A satisfies A = L * Lh. @@ -1197,7 +1197,7 @@ class CSRSparseMatrixOpsTest(test.TestCase): cholesky_sparse_matrix, dtype) # Compute L * Lh. - verification = math_ops.matmul( + verification = test_util.matmul_without_tf32( dense_cholesky, array_ops.transpose(dense_cholesky, perm=[0, 2, 1], conjugate=True)) verification = twist_matrix(verification, ordering_amd) @@ -1238,7 +1238,7 @@ class CSRSparseMatrixOpsTest(test.TestCase): cholesky_sparse_matrix, dtypes.float32) # Compute L * Lh. - verification = math_ops.matmul( + verification = test_util.matmul_without_tf32( dense_cholesky, array_ops.transpose(dense_cholesky, perm=[0, 2, 1])) verification = twist_matrix(verification, ordering_amd) verification_values = self.evaluate(verification) diff --git a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_test.py b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_test.py index 35c706cb36a..4aa3474ffbb 100644 --- a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_test.py +++ b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_test.py @@ -162,7 +162,7 @@ class SparseMatrixMatmulTest(test.TestCase): 1.j * np.random.randn(*dense_shape_b))).astype(dtype) a_sm = sparse_csr_matrix_ops.CSRSparseMatrix(a_mats) b_sm = sparse_csr_matrix_ops.CSRSparseMatrix(b_mats) - c_dense = math_ops.matmul( + c_dense = test_util.matmul_without_tf32( a_mats, b_mats, transpose_a=transpose_a, @@ -202,7 +202,7 @@ class SparseMatrixMatmulTest(test.TestCase): b_mats = (np.random.randn(*dense_shape_b) + 1.j * np.random.randn(*dense_shape_b)).astype(dtype) a_sm = sparse_csr_matrix_ops.CSRSparseMatrix(a_mats) - c_dense = math_ops.matmul( + c_dense = test_util.matmul_without_tf32( a_mats, b_mats, transpose_a=transpose_a, @@ -240,7 +240,7 @@ class SparseMatrixMatmulTest(test.TestCase): b_mats = sparsify((np.random.randn(*dense_shape_b) + 1.j * np.random.randn(*dense_shape_b))).astype(dtype) b_sm = sparse_csr_matrix_ops.CSRSparseMatrix(b_mats) - c_dense = math_ops.matmul( + c_dense = test_util.matmul_without_tf32( a_mats, b_mats, transpose_a=transpose_a, diff --git a/tensorflow/python/kernel_tests/linalg_grad_test.py b/tensorflow/python/kernel_tests/linalg_grad_test.py index f1d885fd231..273aba4d94f 100644 --- a/tensorflow/python/kernel_tests/linalg_grad_test.py +++ b/tensorflow/python/kernel_tests/linalg_grad_test.py @@ -63,6 +63,9 @@ def _GetMatrixUnaryFunctorGradientTest(functor_, dtype_, shape_, **kwargs_): @test_util.enable_control_flow_v2 @test_util.run_in_graph_and_eager_modes(use_gpu=True) + @test_util.run_without_tensor_float_32( + 'Tests `tf.linalg.expm`, which call matmul. Additionally, calls ops ' + 'which do matmul in their gradient, such as MatrixSolve.') def Test(self): def RandomInput(): @@ -102,6 +105,16 @@ def _GetMatrixBinaryFunctorGradientTest(functor_, **kwargs_): @test_util.run_in_graph_and_eager_modes(use_gpu=True) + @test_util.run_without_tensor_float_32( + 'Tests `tf.linalg.lstsq`, which call matmul. Additionally, calls ops ' + 'which do matmul in their gradient, such as MatrixSolveLs.') + # TODO(b/164254522): With tf32, some tests fails with extremely high absolute + # and relative differences when calling assertAllClose. For example, the test + # test_MatrixSolveLsGradient_float32_10_10_1e-06 of class + # MatrixBinaryFunctorGradientTest fails with a max absolute difference of + # 0.883 and a max relative difference of 736892. We should consider disabling + # tf32 within `tf.linalg.lstsq and perhaps other linear algebra functions, + # even if tf32 is allowed globally. def Test(self): def RandomInput(): diff --git a/tensorflow/python/kernel_tests/lu_op_test.py b/tensorflow/python/kernel_tests/lu_op_test.py index fee6aecb3b0..8d522e80a08 100644 --- a/tensorflow/python/kernel_tests/lu_op_test.py +++ b/tensorflow/python/kernel_tests/lu_op_test.py @@ -91,7 +91,7 @@ class LuOpTest(test.TestCase): # Prepare the upper factor. upper = array_ops.matrix_band_part(lu, 0, -1) - verification = math_ops.matmul(lower, upper) + verification = test_util.matmul_without_tf32(lower, upper) # Permute the rows of product of the Cholesky factors. if num_rows > 0: diff --git a/tensorflow/python/kernel_tests/matmul_op_test.py b/tensorflow/python/kernel_tests/matmul_op_test.py index 712d7336b94..737ca777804 100644 --- a/tensorflow/python/kernel_tests/matmul_op_test.py +++ b/tensorflow/python/kernel_tests/matmul_op_test.py @@ -70,6 +70,7 @@ class MatMulTest(test_lib.TestCase): def _GetMatMulTest(a_np_, b_np_, use_static_shape_, **kwargs_): + @test_util.run_without_tensor_float_32("Tests matmul") def Test(self): np_val = np.matrix(a_np_) * np.matrix(b_np_) diff --git a/tensorflow/python/kernel_tests/matrix_inverse_op_test.py b/tensorflow/python/kernel_tests/matrix_inverse_op_test.py index ffe0f595618..9a5a467a5a1 100644 --- a/tensorflow/python/kernel_tests/matrix_inverse_op_test.py +++ b/tensorflow/python/kernel_tests/matrix_inverse_op_test.py @@ -26,7 +26,6 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variables from tensorflow.python.platform import benchmark @@ -41,7 +40,7 @@ class InverseOpTest(test.TestCase): with self.cached_session(use_gpu=True): # Verify that x^{-1} * x == Identity matrix. inv = linalg_ops.matrix_inverse(y, adjoint=adjoint) - tf_ans = math_ops.matmul(inv, y, adjoint_b=adjoint) + tf_ans = test_util.matmul_without_tf32(inv, y, adjoint_b=adjoint) np_ans = np.identity(y.shape[-1]) if x.ndim > 2: tiling = list(y.shape) diff --git a/tensorflow/python/kernel_tests/matrix_square_root_op_test.py b/tensorflow/python/kernel_tests/matrix_square_root_op_test.py index 6cf330ed981..98796f256ab 100644 --- a/tensorflow/python/kernel_tests/matrix_square_root_op_test.py +++ b/tensorflow/python/kernel_tests/matrix_square_root_op_test.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import stateless_random_ops from tensorflow.python.platform import test +@test_util.run_all_without_tensor_float_32 class SquareRootOpTest(test.TestCase): def _verifySquareRoot(self, matrix, np_type): @@ -36,7 +37,7 @@ class SquareRootOpTest(test.TestCase): # Verify that matmul(sqrtm(A), sqrtm(A)) = A sqrt = gen_linalg_ops.matrix_square_root(matrix) - square = math_ops.matmul(sqrt, sqrt) + square = test_util.matmul_without_tf32(sqrt, sqrt) self.assertShapeEqual(matrix, square) self.assertAllClose(matrix, square, rtol=1e-4, atol=1e-3) diff --git a/tensorflow/python/kernel_tests/qr_op_test.py b/tensorflow/python/kernel_tests/qr_op_test.py index b895fe4ea99..0a618b7f555 100644 --- a/tensorflow/python/kernel_tests/qr_op_test.py +++ b/tensorflow/python/kernel_tests/qr_op_test.py @@ -200,6 +200,8 @@ def _GetQrGradOpTest(dtype_, shape_, full_matrices_): return a @test_util.run_in_graph_and_eager_modes(use_gpu=True) + @test_util.run_without_tensor_float_32("Tests Qr gradient, which calls matmul" + ) def Test(self): np.random.seed(42) # Optimal stepsize for central difference is O(epsilon^{1/3}). diff --git a/tensorflow/python/kernel_tests/rnn_cell_test.py b/tensorflow/python/kernel_tests/rnn_cell_test.py index 01b324f29fb..7fa31d14777 100644 --- a/tensorflow/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/python/kernel_tests/rnn_cell_test.py @@ -3062,6 +3062,8 @@ class RNNCellTest(test.TestCase, parameterized.TestCase): @test_util.run_all_in_graph_and_eager_modes +@test_util.run_all_without_tensor_float_32( + "Uses an LSTMCell, which calls matmul") class DropoutWrapperTest(test.TestCase, parameterized.TestCase): def _testDropoutWrapper(self, diff --git a/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py b/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py index 5be7cb4dd3a..40f8b31b7c2 100644 --- a/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py +++ b/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py @@ -38,6 +38,7 @@ def _AddTest(test_class, op_name, testcase_name, fn): setattr(test_class, test_name, fn) +@test_util.run_all_without_tensor_float_32 class SelfAdjointEigTest(test.TestCase): @test_util.run_deprecated_v1 @@ -160,8 +161,8 @@ def _GetSelfAdjointEigTest(dtype_, shape_, compute_v_): tf_e, tf_v = linalg_ops.self_adjoint_eig(constant_op.constant(a)) # Check that V*diag(E)*V^T is close to A. - a_ev = math_ops.matmul( - math_ops.matmul(tf_v, array_ops.matrix_diag(tf_e)), + a_ev = test_util.matmul_without_tf32( + test_util.matmul_without_tf32(tf_v, array_ops.matrix_diag(tf_e)), tf_v, adjoint_b=True) self.assertAllClose(self.evaluate(a_ev), a, atol=atol) diff --git a/tensorflow/python/kernel_tests/tensordot_op_test.py b/tensorflow/python/kernel_tests/tensordot_op_test.py index a031f9bca07..368a7f18f8b 100644 --- a/tensorflow/python/kernel_tests/tensordot_op_test.py +++ b/tensorflow/python/kernel_tests/tensordot_op_test.py @@ -165,6 +165,7 @@ def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_): return a, b, a_dims, b_dims @test_util.run_in_graph_and_eager_modes(use_gpu=True) + @test_util.run_without_tensor_float_32("Tests tensordot, which calls matmul") def test_tensordot(self): if dynamic_shape_ and context.executing_eagerly(): self.skipTest("Placeholders not support in eager mode") @@ -196,6 +197,7 @@ def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_): self.assertAllEqual(tf_ans.shape, np_ans.shape) @test_util.run_in_graph_and_eager_modes(use_gpu=True) + @test_util.run_without_tensor_float_32("Tests tensordot, which calls matmul") def test_tensordot_scalar_axes(self): if dynamic_shape_ and context.executing_eagerly(): self.skipTest("Placeholders not support in eager mode") diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py index 751a8a00758..1c8d8d69b38 100644 --- a/tensorflow/python/ops/image_ops_test.py +++ b/tensorflow/python/ops/image_ops_test.py @@ -97,6 +97,8 @@ class RGBToHSVTest(test_util.TensorFlowTestCase): class RGBToYIQTest(test_util.TensorFlowTestCase): + @test_util.run_without_tensor_float_32( + "Calls rgb_to_yiq and yiq_to_rgb, which use matmul") def testBatch(self): # Build an arbitrary RGB image np.random.seed(7) @@ -127,6 +129,8 @@ class RGBToYIQTest(test_util.TensorFlowTestCase): class RGBToYUVTest(test_util.TensorFlowTestCase): + @test_util.run_without_tensor_float_32( + "Calls rgb_to_yuv and yuv_to_rgb, which use matmul") def testBatch(self): # Build an arbitrary RGB image np.random.seed(7) diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index 9b864be39a2..7f3d9f6e286 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -541,6 +541,8 @@ class DropoutTest(test_lib.TestCase): _ = nn_ops.dropout(x, 0.5) +@test_util.run_all_without_tensor_float_32( + "Tests _compute_sampled_logits and related functions, which call matmul") class ComputeSampledLogitsTest(test_lib.TestCase): def setUp(self): diff --git a/tensorflow/python/ops/parallel_for/math_test.py b/tensorflow/python/ops/parallel_for/math_test.py index 85b58055d8f..30e724413f4 100644 --- a/tensorflow/python/ops/parallel_for/math_test.py +++ b/tensorflow/python/ops/parallel_for/math_test.py @@ -261,6 +261,9 @@ class MathTest(PForTestCase, parameterized.TestCase): self._test_loop_fn(loop_fn, 4) + @test_util.run_without_tensor_float_32( + "Calls matmul in parallel for-loop and compares result to calling matmul " + "in sequential for-loop") def test_matmul(self): for tr_a in (True, False): for tr_b in (True, False): @@ -745,6 +748,9 @@ class LinalgTest(PForTestCase): self._test_loop_fn(loop_fn, 2) + @test_util.run_without_tensor_float_32( + "Calls einsum in parallel for-loop and compares result to calling einsum " + "in sequential for-loop") def test_einsum(self): b = 10 x_series = random_ops.random_uniform([b, 9, 9]) diff --git a/tensorflow/python/ops/special_math_ops_test.py b/tensorflow/python/ops/special_math_ops_test.py index 623f5063c7d..ba184b222ca 100644 --- a/tensorflow/python/ops/special_math_ops_test.py +++ b/tensorflow/python/ops/special_math_ops_test.py @@ -635,6 +635,8 @@ class BesselTest(test.TestCase, parameterized.TestCase): @test_util.run_all_in_graph_and_eager_modes +@test_util.run_all_without_tensor_float_32( + 'Tests einsum, which sometimes does a matmul with cuBLAS') class EinsumTest(test.TestCase): def _check(self, s, *input_shapes, **kwargs):