From f1c1a294f194d29686cc061cf146f8758912a4b5 Mon Sep 17 00:00:00 2001
From: Reed <reedwm@google.com>
Date: Thu, 13 Aug 2020 21:13:24 -0700
Subject: [PATCH 1/6] Fix tests when TF32 is enabled

---
 tensorflow/compiler/tests/cholesky_op_test.py |  2 +-
 .../tests/matrix_triangular_solve_op_test.py  |  3 +-
 tensorflow/compiler/tests/qr_op_test.py       |  7 ++-
 tensorflow/compiler/tests/xla_test.py         | 14 ++++-
 tensorflow/compiler/xla/client/lib/BUILD      |  1 +
 tensorflow/compiler/xla/client/lib/qr_test.cc |  4 ++
 tensorflow/compiler/xla/tests/BUILD           |  1 +
 .../compiler/xla/tests/cholesky_test.cc       |  7 +++
 tensorflow/core/kernels/conv_ops_test.cc      |  2 +
 tensorflow/python/framework/function_test.py  |  1 +
 tensorflow/python/framework/test_util.py      | 60 +++++++++++++++++++
 tensorflow/python/keras/BUILD                 |  1 +
 .../distribute/distribute_strategy_test.py    |  2 +
 .../distribute/keras_dnn_correctness_test.py  |  5 ++
 .../keras_image_model_correctness_test.py     |  4 ++
 .../keras_rnn_model_correctness_test.py       |  4 ++
 .../keras/distribute/keras_save_load_test.py  |  3 +
 .../distribute/saved_model_mixed_api_test.py  |  3 +
 .../distribute/saved_model_save_load_test.py  |  3 +
 tensorflow/python/keras/testing_utils.py      | 60 +++++++++++++++++++
 .../kernel_tests/batch_matmul_op_test.py      |  2 +
 .../python/kernel_tests/cholesky_op_test.py   |  4 +-
 .../python/kernel_tests/conv_ops_3d_test.py   |  4 ++
 .../dirichlet_multinomial_test.py             |  2 +
 .../distributions/dirichlet_test.py           |  2 +
 .../python/kernel_tests/einsum_op_test.py     |  4 ++
 .../python/kernel_tests/init_ops_test.py      |  4 ++
 .../sparse/csr_sparse_matrix_ops_test.py      | 16 ++---
 .../linalg/sparse/csr_sparse_matrix_test.py   |  6 +-
 .../python/kernel_tests/linalg_grad_test.py   | 13 ++++
 tensorflow/python/kernel_tests/lu_op_test.py  |  2 +-
 .../python/kernel_tests/matmul_op_test.py     |  1 +
 .../kernel_tests/matrix_inverse_op_test.py    |  2 +-
 .../matrix_square_root_op_test.py             |  3 +-
 tensorflow/python/kernel_tests/qr_op_test.py  |  2 +
 .../python/kernel_tests/rnn_cell_test.py      |  2 +
 .../kernel_tests/self_adjoint_eig_op_test.py  |  5 +-
 .../python/kernel_tests/tensordot_op_test.py  |  2 +
 tensorflow/python/ops/image_ops_test.py       |  4 ++
 tensorflow/python/ops/nn_test.py              |  2 +
 .../python/ops/parallel_for/math_test.py      |  6 ++
 .../python/ops/special_math_ops_test.py       |  2 +
 42 files changed, 254 insertions(+), 23 deletions(-)

diff --git a/tensorflow/compiler/tests/cholesky_op_test.py b/tensorflow/compiler/tests/cholesky_op_test.py
index 4bd2dfd9244..8a2966a0466 100644
--- a/tensorflow/compiler/tests/cholesky_op_test.py
+++ b/tensorflow/compiler/tests/cholesky_op_test.py
@@ -61,7 +61,7 @@ class CholeskyOpTest(xla_test.XLATestCase):
           dtypes.as_dtype(x.dtype), shape=x.shape)
       with self.test_scope():
         chol = linalg_ops.cholesky(placeholder)
-      verification = math_ops.matmul(chol, chol, adjoint_b=True)
+      verification = test_util.matmul_without_tf32(chol, chol, adjoint_b=True)
       self._verifyCholeskyBase(sess, placeholder, x, chol, verification, atol)
 
   def testBasic(self):
diff --git a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py
index 0202c582ef3..8f123dfb809 100644
--- a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py
+++ b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py
@@ -65,7 +65,8 @@ class MatrixTriangularSolveOpTest(xla_test.XLATestCase):
       with self.test_scope():
         x = linalg_ops.matrix_triangular_solve(
             placeholder_a, placeholder_b, lower=lower, adjoint=adjoint)
-      verification = math_ops.matmul(placeholder_ca, x, adjoint_a=adjoint)
+      verification = test_util.matmul_without_tf32(placeholder_ca, x,
+                                                   adjoint_a=adjoint)
       self._VerifyTriangularSolveBase(sess, placeholder_a, placeholder_ca,
                                       placeholder_b, a, clean_a, b,
                                       verification, atol)
diff --git a/tensorflow/compiler/tests/qr_op_test.py b/tensorflow/compiler/tests/qr_op_test.py
index 5fcf254db82..e16f0be7c64 100644
--- a/tensorflow/compiler/tests/qr_op_test.py
+++ b/tensorflow/compiler/tests/qr_op_test.py
@@ -24,12 +24,17 @@ from absl.testing import parameterized
 import numpy as np
 
 from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import linalg_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import test
 
 
+@test_util.run_all_without_tensor_float_32(
+    "It's unknown why this test requires TF32 to be disabled")
+# TODO(reedwm): Determine why this test requires TF32 disabled. Debugging is
+# difficult due to this test's flakiness
 class QrOpTest(xla_test.XLATestCase, parameterized.TestCase):
 
   def AdjustedNorm(self, x):
@@ -73,7 +78,7 @@ class QrOpTest(xla_test.XLATestCase, parameterized.TestCase):
 
     with self.session() as sess:
       x_tf = array_ops.placeholder(dtype)
-      with self.test_scope():
+      with self.device_scope():
         q_tf, r_tf = linalg_ops.qr(x_tf, full_matrices=full_matrices)
       q_tf_val, r_tf_val = sess.run([q_tf, r_tf], feed_dict={x_tf: x_np})
 
diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py
index 3b057ed8b17..792667b6074 100644
--- a/tensorflow/compiler/tests/xla_test.py
+++ b/tensorflow/compiler/tests/xla_test.py
@@ -235,8 +235,8 @@ class XLATestCase(test.TestCase):
         'test_session not supported on XLATestCase, please use session')
 
   @contextlib.contextmanager
-  def test_scope(self):
-    """Test scope that runs tests on `self.device`.
+  def device_scope(self):
+    """Scope that runs tests on `self.device`.
 
     Yields:
       A scope to apply to the operators under test.
@@ -244,6 +244,16 @@ class XLATestCase(test.TestCase):
     with ops.device('device:{}:0'.format(self.device)):
       yield
 
+  def test_scope(self):
+    """Deprecated alias of `device_scope`.
+
+    This should be avoided as the name starts with `test`, so test runners
+    treat it as a test. This interferes with class decorators that operate on
+    each test method.
+    """
+    return self.device_scope()
+
+
 
 def Benchmark(tf_bench,
               builder_fn,
diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD
index a3c7c39e3ff..367743f2b6e 100644
--- a/tensorflow/compiler/xla/client/lib/BUILD
+++ b/tensorflow/compiler/xla/client/lib/BUILD
@@ -292,6 +292,7 @@ xla_test(
     deps = [
         ":matrix",
         ":qr",
+        "//tensorflow/core/platform:tf32_utils",
         "//tensorflow/compiler/xla:array2d",
         "//tensorflow/compiler/xla:array3d",
         "//tensorflow/compiler/xla:literal",
diff --git a/tensorflow/compiler/xla/client/lib/qr_test.cc b/tensorflow/compiler/xla/client/lib/qr_test.cc
index a61f243e126..1e004a59961 100644
--- a/tensorflow/compiler/xla/client/lib/qr_test.cc
+++ b/tensorflow/compiler/xla/client/lib/qr_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
 
 #include "tensorflow/compiler/xla/client/lib/qr.h"
 
+#include "tensorflow/core/platform/tf32_utils.h"
 #include "tensorflow/compiler/xla/array2d.h"
 #include "tensorflow/compiler/xla/array3d.h"
 #include "tensorflow/compiler/xla/client/lib/matrix.h"
@@ -33,6 +34,7 @@ namespace {
 using QrTest = xla::ClientLibraryTestBase;
 
 XLA_TEST_F(QrTest, Simple) {
+  tensorflow::allow_tf32_execution(false);
   xla::XlaBuilder builder(TestName());
 
   xla::Array2D<float> a_vals({
@@ -61,6 +63,7 @@ XLA_TEST_F(QrTest, Simple) {
 }
 
 XLA_TEST_F(QrTest, ZeroDiagonal) {
+  tensorflow::allow_tf32_execution(false);
   xla::XlaBuilder builder(TestName());
 
   xla::Array2D<float> a_vals({
@@ -88,6 +91,7 @@ XLA_TEST_F(QrTest, ZeroDiagonal) {
 }
 
 XLA_TEST_F(QrTest, SimpleBatched) {
+  tensorflow::allow_tf32_execution(false);
   xla::XlaBuilder builder(TestName());
 
   xla::Array3D<float> a_vals({
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 17444c042e7..7253ee413c5 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -2687,6 +2687,7 @@ xla_test(
     ],
     deps = [
         ":test_macros_header",
+        "//tensorflow/core/platform:tf32_utils",
         "//tensorflow/compiler/xla:array2d",
         "//tensorflow/compiler/xla:literal",
         "//tensorflow/compiler/xla:statusor",
diff --git a/tensorflow/compiler/xla/tests/cholesky_test.cc b/tensorflow/compiler/xla/tests/cholesky_test.cc
index e7f5ca5ed8e..e64c69bdbc8 100644
--- a/tensorflow/compiler/xla/tests/cholesky_test.cc
+++ b/tensorflow/compiler/xla/tests/cholesky_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
 #include <numeric>
 #include <vector>
 
+#include "tensorflow/core/platform/tf32_utils.h"
 #include "tensorflow/compiler/xla/array2d.h"
 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
 #include "tensorflow/compiler/xla/client/lib/matrix.h"
@@ -37,6 +38,7 @@ namespace {
 using CholeskyTest = ClientLibraryTestBase;
 
 XLA_TEST_F(CholeskyTest, NonPSDInput) {
+  tensorflow::allow_tf32_execution(false);
   XlaBuilder builder(TestName());
 
   Array2D<float> a_vals({
@@ -61,6 +63,7 @@ XLA_TEST_F(CholeskyTest, NonPSDInput) {
 }
 
 XLA_TEST_F(CholeskyTest, Lower) {
+  tensorflow::allow_tf32_execution(false);
   XlaBuilder builder(TestName());
 
   float nan = std::numeric_limits<float>::quiet_NaN();
@@ -87,6 +90,7 @@ XLA_TEST_F(CholeskyTest, Lower) {
 }
 
 XLA_TEST_F(CholeskyTest, Upper) {
+  tensorflow::allow_tf32_execution(false);
   XlaBuilder builder(TestName());
 
   float nan = std::numeric_limits<float>::quiet_NaN();
@@ -113,6 +117,7 @@ XLA_TEST_F(CholeskyTest, Upper) {
 }
 
 XLA_TEST_F(CholeskyTest, Simple2) {
+  tensorflow::allow_tf32_execution(false);
   XlaBuilder builder(TestName());
 
   Array2D<float> a_vals({
@@ -136,6 +141,7 @@ XLA_TEST_F(CholeskyTest, Simple2) {
 }
 
 XLA_TEST_F(CholeskyTest, SimpleBatched) {
+  tensorflow::allow_tf32_execution(false);
   XlaBuilder builder(TestName());
 
   Array3D<float> a_vals({
@@ -181,6 +187,7 @@ class RandomCholeskyTest
       public ::testing::WithParamInterface<CholeskyTestCase> {};
 
 XLA_TEST_P(RandomCholeskyTest, Random) {
+  tensorflow::allow_tf32_execution(false);
   XlaBuilder builder(TestName());
 
   auto test_params = GetParam();
diff --git a/tensorflow/core/kernels/conv_ops_test.cc b/tensorflow/core/kernels/conv_ops_test.cc
index 3e192b83c57..fba71d5da33 100644
--- a/tensorflow/core/kernels/conv_ops_test.cc
+++ b/tensorflow/core/kernels/conv_ops_test.cc
@@ -32,6 +32,7 @@ limitations under the License.
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/platform/tf32_utils.h"
 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
 #include "tensorflow/core/public/session.h"
 
@@ -1038,6 +1039,7 @@ TYPED_TEST_P(FusedConv2DWithBiasOpTest, ExplicitPaddingConvolution) {
 #endif
 
 TYPED_TEST_P(FusedConv2DWithBiasOpTest, OneByOneConvolutionAndActivation) {
+  tensorflow::allow_tf32_execution(false);  // Requires full precision Conv2D op
   const int filter_size = 1;
   const int filter_count = 12;
   for (const string& activation : {"Relu", "Relu6", "Elu"}) {
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index 596b93227bf..298d41a995c 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -1592,6 +1592,7 @@ class UnrollLSTMTest(test.TestCase):
       self.assertAllClose(mv0, mv2, rtol=1e-4)
       self.assertAllClose(mv0, mv3, rtol=1e-4)
 
+  @test_util.run_without_tensor_float_32("Calls matmul in custom LSTM function")
   def testUnrollLSTMGrad(self):
     # Run one step of the unrolled lstm graph.
     def RunForwardBackward(mode, cfg=None):
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 4d7b7746b9c..d4017ebc03f 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -54,6 +54,7 @@ from tensorflow.python.eager import backprop
 from tensorflow.python.eager import context
 from tensorflow.python.eager import def_function
 from tensorflow.python.eager import tape
+from tensorflow.python.framework import config
 from tensorflow.python.framework import device as pydev
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
@@ -70,6 +71,7 @@ from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_util
 from tensorflow.python.ops import control_flow_util_v2
 from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import script_ops
 from tensorflow.python.ops import summary_ops_v2
 from tensorflow.python.ops import variables
@@ -1907,6 +1909,64 @@ def xla_allow_fallback(description):  # pylint: disable=unused-argument
 
   return xla_allow_fallback_impl
 
+# The description is just for documentation purposes.
+def run_without_tensor_float_32(description):  # pylint: disable=unused-argument
+  """Execute test without TensorFloat-32 being allowed.
+
+  While almost every real-world deep learning model runs fine with
+  TensorFloat-32 (TF32), many tests use assertAllClose or similar methods. TF32
+  matmuls typically will cause such methods to fail with the default tolerances.
+
+  Args:
+    description: A description used for documentation purposes, describing why
+    the test requires TensorFloat-32 to be disallowed.
+  """
+  def decorator(f):
+    @functools.wraps(f)
+    def decorated(self, *args, **kwargs):
+      allowed = config.tensor_float_32_execution_allowed()
+      try:
+        config.allow_tensor_float_32_execution(False)
+        f(self, *args, **kwargs)
+      finally:
+        config.allow_tensor_float_32_execution(allowed)
+    return decorated
+
+  return decorator
+
+
+# The description is just for documentation purposes.
+def run_all_without_tensor_float_32(description):  # pylint: disable=unused-argument
+  """Execute all tests in a class without TensorFloat-32 being allowed."""
+  return for_all_test_methods(run_without_tensor_float_32, description)
+
+
+def matmul_without_tf32(a, b, *args, **kwargs):
+  """Run matmul, but cast float32 inputs to float64 if TF32 is allowed.
+
+  This effectively runs matmul without TensorFloat-32 (TF32). It should only be
+  used in tests when verifying some other op or functions works correctly, e.g.
+  to test `tf.linalg.sqrtm` by matrix multiplying the output of the op
+  by itself. In such cases, the matmul itself is not being tested so it's OK to
+  run it with higher precision.
+
+  If a matmul itself is being tested, or some other op which uses matmul, use
+  `run_without_tensor_float_32` instead.
+
+  Args:
+    a: First input to tf.linalg.matmul
+    b: Second input to tf.linalg.matmul
+    args: Other positional arguments to tf.linalg.matmul
+    **kwargs: Other keyword arguments to tf.linalg.matmul
+  """
+  if config.tensor_float_32_execution_allowed() and a.dtype == "float32":
+    a = math_ops.cast(a, "float64")
+    b = math_ops.cast(b, "float64")
+    ret = math_ops.matmul(a, b, *args, **kwargs)
+    return math_ops.cast(ret, a.dtype)
+  else:
+    return math_ops.matmul(a, b, *args, **kwargs)
+
 
 class EagerSessionWarner(object):
 
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index 24c5b9de8ca..7479515d60d 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -307,6 +307,7 @@ py_library(
     deps = [
         ":backend",
         ":models",
+        "//tensorflow/python:config",
         "//tensorflow/python:framework_test_lib",
         "//tensorflow/python:tensor_shape",
         "//tensorflow/python:tensor_spec",
diff --git a/tensorflow/python/keras/distribute/distribute_strategy_test.py b/tensorflow/python/keras/distribute/distribute_strategy_test.py
index 4b6d3a80730..d94489387ad 100644
--- a/tensorflow/python/keras/distribute/distribute_strategy_test.py
+++ b/tensorflow/python/keras/distribute/distribute_strategy_test.py
@@ -1604,6 +1604,8 @@ class TestRegularizerLoss(test.TestCase, parameterized.TestCase):
       self.assertEqual(-1.0, v)
 
 
+@testing_utils.run_all_without_tensor_float_32(
+    'Uses Dense layers, which call matmul')
 class TestDistributionStrategyWithKerasModels(test.TestCase,
                                               parameterized.TestCase):
 
diff --git a/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py b/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py
index 6ec7cc2bac5..e04b40e33be 100644
--- a/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py
+++ b/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py
@@ -24,6 +24,7 @@ from tensorflow.python.distribute import combinations
 from tensorflow.python.distribute import distribution_strategy_context
 from tensorflow.python.eager import context
 from tensorflow.python.keras import backend as K
+from tensorflow.python.keras import testing_utils
 from tensorflow.python.keras.distribute import keras_correctness_test_base
 from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras
 from tensorflow.python.platform import test
@@ -47,6 +48,8 @@ def is_default_strategy(strategy):
     return not distribution_strategy_context.has_strategy()
 
 
+@testing_utils.run_all_without_tensor_float_32(
+    'Uses Dense layers, which call matmul')
 class TestDistributionStrategyDnnCorrectness(
     keras_correctness_test_base.TestDistributionStrategyCorrectnessBase):
 
@@ -240,6 +243,8 @@ class SubclassedModel(keras.Model):
     return self.dense4(x)
 
 
+@testing_utils.run_all_without_tensor_float_32(
+    'Uses Dense layers, which call matmul')
 class TestDistributionStrategyDnnCorrectnessWithSubclassedModel(
     TestDistributionStrategyDnnCorrectness):
 
diff --git a/tensorflow/python/keras/distribute/keras_image_model_correctness_test.py b/tensorflow/python/keras/distribute/keras_image_model_correctness_test.py
index 7e6ae3cc719..57b9b718491 100644
--- a/tensorflow/python/keras/distribute/keras_image_model_correctness_test.py
+++ b/tensorflow/python/keras/distribute/keras_image_model_correctness_test.py
@@ -21,11 +21,15 @@ import numpy as np
 from tensorflow.python import keras
 from tensorflow.python.distribute import combinations
 from tensorflow.python.eager import context
+from tensorflow.python.keras import testing_utils
 from tensorflow.python.keras.distribute import keras_correctness_test_base
 from tensorflow.python.keras.optimizer_v2 import gradient_descent
 from tensorflow.python.platform import test
 
 
+@testing_utils.run_all_without_tensor_float_32(
+    'Uses Dense layers, which call matmul. Even if Dense layers run in '
+    'float64, the test sometimes fails with tf32 enabled for unknown reasons')
 class DistributionStrategyCnnCorrectnessTest(
     keras_correctness_test_base.TestDistributionStrategyCorrectnessBase):
 
diff --git a/tensorflow/python/keras/distribute/keras_rnn_model_correctness_test.py b/tensorflow/python/keras/distribute/keras_rnn_model_correctness_test.py
index aa7f0c20045..b1113c3ffac 100644
--- a/tensorflow/python/keras/distribute/keras_rnn_model_correctness_test.py
+++ b/tensorflow/python/keras/distribute/keras_rnn_model_correctness_test.py
@@ -69,6 +69,8 @@ class _DistributionStrategyRnnModelCorrectnessTest(
     return model
 
 
+@testing_utils.run_all_without_tensor_float_32(
+    "Uses Dense layers, which call matmul")
 class DistributionStrategyGruModelCorrectnessTest(
     _DistributionStrategyRnnModelCorrectnessTest):
 
@@ -88,6 +90,8 @@ class DistributionStrategyGruModelCorrectnessTest(
     self.run_correctness_test(distribution, use_numpy, use_validation_data)
 
 
+@testing_utils.run_all_without_tensor_float_32(
+    "Uses Dense layers, which call matmul")
 class DistributionStrategyLstmModelCorrectnessTest(
     _DistributionStrategyRnnModelCorrectnessTest):
 
diff --git a/tensorflow/python/keras/distribute/keras_save_load_test.py b/tensorflow/python/keras/distribute/keras_save_load_test.py
index 65877a0f869..fc2e2bd46ec 100644
--- a/tensorflow/python/keras/distribute/keras_save_load_test.py
+++ b/tensorflow/python/keras/distribute/keras_save_load_test.py
@@ -20,10 +20,13 @@ from __future__ import print_function
 
 from tensorflow.python.distribute import combinations
 from tensorflow.python.eager import test
+from tensorflow.python.keras import testing_utils
 from tensorflow.python.keras.distribute import saved_model_test_base as test_base
 from tensorflow.python.keras.saving import save
 
 
+@testing_utils.run_all_without_tensor_float_32(
+    'Uses Dense layers, which call matmul')
 class KerasSaveLoadTest(test_base.TestSavedModelBase):
 
   def setUp(self):
diff --git a/tensorflow/python/keras/distribute/saved_model_mixed_api_test.py b/tensorflow/python/keras/distribute/saved_model_mixed_api_test.py
index d303a4228b5..7815d7403fd 100644
--- a/tensorflow/python/keras/distribute/saved_model_mixed_api_test.py
+++ b/tensorflow/python/keras/distribute/saved_model_mixed_api_test.py
@@ -26,12 +26,15 @@ from __future__ import print_function
 
 from tensorflow.python.distribute import combinations
 from tensorflow.python.eager import test
+from tensorflow.python.keras import testing_utils
 from tensorflow.python.keras.distribute import saved_model_test_base as test_base
 from tensorflow.python.keras.saving import save
 
 _DEFAULT_FUNCTION_KEY = 'serving_default'
 
 
+@testing_utils.run_all_without_tensor_float_32(
+    'Uses Dense layers, which call matmul')
 class SavedModelSaveAndLoadTest(test_base.TestSavedModelBase):
 
   def setUp(self):
diff --git a/tensorflow/python/keras/distribute/saved_model_save_load_test.py b/tensorflow/python/keras/distribute/saved_model_save_load_test.py
index 39856af2a20..2174d39bae4 100644
--- a/tensorflow/python/keras/distribute/saved_model_save_load_test.py
+++ b/tensorflow/python/keras/distribute/saved_model_save_load_test.py
@@ -24,6 +24,7 @@ from tensorflow.python.distribute import combinations
 from tensorflow.python.distribute import strategy_combinations
 from tensorflow.python.eager import test
 from tensorflow.python.framework import tensor_spec
+from tensorflow.python.keras import testing_utils
 from tensorflow.python.keras.distribute import model_combinations
 from tensorflow.python.keras.distribute import saved_model_test_base as test_base
 from tensorflow.python.ops import array_ops
@@ -32,6 +33,8 @@ from tensorflow.python.saved_model import save_options as save_options_lib
 from tensorflow.python.saved_model import saved_model
 
 
+@testing_utils.run_all_without_tensor_float_32(
+    'Uses Dense layers, which call matmul')
 class SavedModelKerasModelTest(test_base.TestSavedModelBase):
 
   def setUp(self):
diff --git a/tensorflow/python/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py
index 550ff664823..da7f880f975 100644
--- a/tensorflow/python/keras/testing_utils.py
+++ b/tensorflow/python/keras/testing_utils.py
@@ -26,6 +26,7 @@ import numpy as np
 
 from tensorflow.python import tf2
 from tensorflow.python.eager import context
+from tensorflow.python.framework import config
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
@@ -937,3 +938,62 @@ def use_gpu():
   """Uses gpu when requested and available."""
   with device(should_use_gpu=True):
     yield
+
+
+def for_all_test_methods(decorator, *args, **kwargs):
+  """Generate class-level decorator from given method-level decorator.
+
+  It is expected for the given decorator to take some arguments and return
+  a method that is then called on the test method to produce a decorated
+  method.
+
+  Args:
+    decorator: The decorator to apply.
+    *args: Positional arguments
+    **kwargs: Keyword arguments
+  Returns: Function that will decorate a given classes test methods with the
+    decorator.
+  """
+
+  def all_test_methods_impl(cls):
+    """Apply decorator to all test methods in class."""
+    for name in dir(cls):
+      value = getattr(cls, name)
+      if callable(value) and name.startswith(
+          'test') and (name != 'test_session'):
+        setattr(cls, name, decorator(*args, **kwargs)(value))
+    return cls
+
+  return all_test_methods_impl
+
+
+# The description is just for documentation purposes.
+def run_without_tensor_float_32(description):  # pylint: disable=unused-argument
+  """Execute test without TensorFloat-32 being allowed.
+
+  While almost every real-world deep learning model runs fine with
+  TensorFloat-32 (TF32), many tests use assertAllClose or similar methods. TF32
+  matmuls typically will cause such methods to fail with the default tolerances.
+
+  Args:
+    description: A description used for documentation purposes, describing why
+    the test requires TensorFloat-32 to be disallowed.
+  """
+  def decorator(f):
+    @functools.wraps(f)
+    def decorated(self, *args, **kwargs):
+      allowed = config.tensor_float_32_execution_allowed()
+      try:
+        config.allow_tensor_float_32_execution(False)
+        f(self, *args, **kwargs)
+      finally:
+        config.allow_tensor_float_32_execution(allowed)
+    return decorated
+
+  return decorator
+
+
+# The description is just for documentation purposes.
+def run_all_without_tensor_float_32(description):  # pylint: disable=unused-argument
+  """Execute all tests in a class without TensorFloat-32 being allowed."""
+  return for_all_test_methods(run_without_tensor_float_32, description)
diff --git a/tensorflow/python/kernel_tests/batch_matmul_op_test.py b/tensorflow/python/kernel_tests/batch_matmul_op_test.py
index 30b61027813..ac82a320bb6 100644
--- a/tensorflow/python/kernel_tests/batch_matmul_op_test.py
+++ b/tensorflow/python/kernel_tests/batch_matmul_op_test.py
@@ -130,6 +130,7 @@ class BatchMatmulOpTest(test.TestCase):
 
 def _GetBatchMatmulOpTest(dtype, adjoint_a, adjoint_b, use_static_shape):
 
+  @test_util.run_without_tensor_float_32("Tests batch matmul")
   def Test(self):
     np.random.seed(42)
     self._testNonEmpty(dtype, adjoint_a, adjoint_b, use_static_shape)
@@ -141,6 +142,7 @@ def _GetBatchMatmulOpTest(dtype, adjoint_a, adjoint_b, use_static_shape):
 def _GetBatchMatmulOpBroadcastingTest(dtype, adjoint_a, adjoint_b,
                                       use_static_shape):
 
+  @test_util.run_without_tensor_float_32("Tests batch matmul")
   def Test(self):
     np.random.seed(42)
     self._testBroadcasting(dtype, adjoint_a, adjoint_b, use_static_shape)
diff --git a/tensorflow/python/kernel_tests/cholesky_op_test.py b/tensorflow/python/kernel_tests/cholesky_op_test.py
index a9afca8bfe7..608ff4449f4 100644
--- a/tensorflow/python/kernel_tests/cholesky_op_test.py
+++ b/tensorflow/python/kernel_tests/cholesky_op_test.py
@@ -106,7 +106,7 @@ class CholeskyOpTest(test.TestCase):
   def _verifyCholesky(self, x):
     # Verify that LL^T == x.
     chol = linalg_ops.cholesky(x)
-    verification = math_ops.matmul(chol, chol, adjoint_b=True)
+    verification = test_util.matmul_without_tf32(chol, chol, adjoint_b=True)
     self._verifyCholeskyBase(x, chol, verification)
 
   @test_util.run_in_graph_and_eager_modes(use_gpu=True)
@@ -271,7 +271,7 @@ class CholeskyGradTest(test.TestCase):
     def Compute(x):
       # Turn the random matrix x into a Hermitian matrix by
       # computing the quadratic form x * x^H.
-      a = math_ops.matmul(x, math_ops.conj(
+      a = test_util.matmul_without_tf32(x, math_ops.conj(
           array_ops.matrix_transpose(x))) / shape[0]
       if batch:
         a = array_ops.tile(array_ops.expand_dims(a, 0), [2, 1, 1])
diff --git a/tensorflow/python/kernel_tests/conv_ops_3d_test.py b/tensorflow/python/kernel_tests/conv_ops_3d_test.py
index ff4da3afc9f..04cb9d4de4f 100644
--- a/tensorflow/python/kernel_tests/conv_ops_3d_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_3d_test.py
@@ -48,6 +48,10 @@ def GetTestConfigs():
   return test_configs
 
 
+@test_util.run_all_without_tensor_float_32(
+    "Tests Conv3d, which in some cases is implemented with a matmul. With "
+    "tf32, tests fail in some of those cases (and as of August 13 2020, only "
+    "those cases)")
 class Conv3DTest(test.TestCase):
 
   def _DtypesToTest(self, use_gpu):
diff --git a/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py b/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py
index a4c07daa940..7c8f389f178 100644
--- a/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py
+++ b/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py
@@ -268,6 +268,8 @@ class DirichletMultinomialTest(test.TestCase):
       self.assertAllClose(sample_var_, analytic_var, atol=0.05, rtol=0.)
       self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.02, rtol=0.)
 
+  @test_util.run_without_tensor_float_32(
+      "Tests DirichletMultinomial.covariance, which calls matmul")
   def testCovariance(self):
     # Shape [2]
     alpha = [1., 2]
diff --git a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py
index 0f963824531..a0d8bef327d 100644
--- a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py
+++ b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py
@@ -200,6 +200,8 @@ class DirichletTest(test.TestCase):
     self.assertAllClose(sample_var_, analytic_var, atol=0.03, rtol=0.)
     self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.02, rtol=0.)
 
+  @test_util.run_without_tensor_float_32(
+      "Calls Dirichlet.covariance, which calls matmul")
   def testVariance(self):
     alpha = [1., 2, 3]
     denominator = np.sum(alpha)**2 * (np.sum(alpha) + 1)
diff --git a/tensorflow/python/kernel_tests/einsum_op_test.py b/tensorflow/python/kernel_tests/einsum_op_test.py
index 10b96716580..928789d695e 100644
--- a/tensorflow/python/kernel_tests/einsum_op_test.py
+++ b/tensorflow/python/kernel_tests/einsum_op_test.py
@@ -35,6 +35,8 @@ from tensorflow.python.platform import benchmark
 from tensorflow.python.platform import test
 
 
+@test_util.run_all_without_tensor_float_32(
+    'Tests einsum, which sometimes does a matmul with cuBLAS')
 class EinsumOpTest(test.TestCase):
 
   def _check(self, s, *input_shapes, **kwargs):
@@ -287,6 +289,8 @@ class EinsumOpTest(test.TestCase):
 
 
 @test_util.run_all_in_graph_and_eager_modes
+@test_util.run_all_without_tensor_float_32(
+    "Tests einsum's gradient, which sometimes does a matmul with cuBLAS")
 class EinsumGradTest(test.TestCase):
 
   def _check_gradient(self, s, *input_shapes):
diff --git a/tensorflow/python/kernel_tests/init_ops_test.py b/tensorflow/python/kernel_tests/init_ops_test.py
index e3268fad2d8..f2348c6c7ac 100644
--- a/tensorflow/python/kernel_tests/init_ops_test.py
+++ b/tensorflow/python/kernel_tests/init_ops_test.py
@@ -945,6 +945,8 @@ class ConvolutionDeltaOrthogonalInitializerTest(test.TestCase):
       self.assertAllClose(abs_value, count, rtol=tol, atol=tol)
 
 
+@test_util.run_all_without_tensor_float_32(
+    "Tests convolutional_orthogonal_1d, which calls matmul")
 class ConvolutionOrthogonal1dInitializerTest(test.TestCase):
 
   @test_util.run_deprecated_v1
@@ -1174,6 +1176,8 @@ class ConvolutionOrthogonal2dInitializerTest(test.TestCase):
         self.assertAllClose(self.evaluate(ratio), gain, rtol=tol, atol=tol)
 
 
+@test_util.run_all_without_tensor_float_32(
+    "Tests convolutional_orthogonal_3d, which calls matmul")
 class ConvolutionOrthogonal3dInitializerTest(test.TestCase):
 
   @test_util.run_deprecated_v1
diff --git a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py
index ac82f190db0..f42600bd334 100644
--- a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py
+++ b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py
@@ -534,7 +534,7 @@ class CSRSparseMatrixOpsTest(test.TestCase):
       c_value = self.evaluate(c)
 
       expected_c_value = self.evaluate(
-          math_ops.conj(math_ops.matmul(a_dense, b)))
+          math_ops.conj(test_util.matmul_without_tf32(a_dense, b)))
       self.assertAllClose(expected_c_value, c_value)
 
   @test_util.run_in_graph_and_eager_modes
@@ -576,7 +576,7 @@ class CSRSparseMatrixOpsTest(test.TestCase):
                 transpose_b=transpose_b,
                 adjoint_a=adjoint_a,
                 adjoint_b=adjoint_b)
-            c_dense_t = math_ops.matmul(
+            c_dense_t = test_util.matmul_without_tf32(
                 a_mats,
                 b_mats,
                 transpose_a=transpose_a,
@@ -640,7 +640,7 @@ class CSRSparseMatrixOpsTest(test.TestCase):
                 adjoint_b=adjoint_b)
 
             # Example: t(adj(a) . b) = t(b) . conj(a)
-            c_dense_t = math_ops.matmul(
+            c_dense_t = test_util.matmul_without_tf32(
                 math_ops.conj(b_mats) if adjoint_b else b_mats,
                 math_ops.conj(a_mats) if adjoint_a else a_mats,
                 transpose_a=not (transpose_b or adjoint_b),
@@ -670,7 +670,7 @@ class CSRSparseMatrixOpsTest(test.TestCase):
     c_t = sparse_csr_matrix_ops.sparse_matrix_mat_mul(
         a_sm, b_mats, conjugate_output=True)
 
-    c_dense_t = math_ops.conj(math_ops.matmul(a_mats, b_mats))
+    c_dense_t = math_ops.conj(test_util.matmul_without_tf32(a_mats, b_mats))
     self.assertAllEqual(c_t.shape, c_dense_t.shape)
     c_t_value, c_dense_t_value = self.evaluate((c_t, c_dense_t))
 
@@ -772,7 +772,7 @@ class CSRSparseMatrixOpsTest(test.TestCase):
             adjoint_b=adjoint_b)
         c_sm_dense = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
             c_sm, dtypes.float32)
-        c_dense_t = math_ops.matmul(
+        c_dense_t = test_util.matmul_without_tf32(
             a_mats,
             b_mats,
             transpose_a=transpose_a,
@@ -1143,7 +1143,7 @@ class CSRSparseMatrixOpsTest(test.TestCase):
         dense_cholesky = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
             cholesky_sparse_matrices, dtype)
         # Compute L * Lh where L is the Sparse Cholesky factor.
-        verification = math_ops.matmul(
+        verification = test_util.matmul_without_tf32(
             dense_cholesky, array_ops.transpose(dense_cholesky, conjugate=True))
         verification = twist_matrix(verification, ordering_amd)
         # Assert that input matrix A satisfies A = L * Lh.
@@ -1197,7 +1197,7 @@ class CSRSparseMatrixOpsTest(test.TestCase):
           cholesky_sparse_matrix, dtype)
 
       # Compute L * Lh.
-      verification = math_ops.matmul(
+      verification = test_util.matmul_without_tf32(
           dense_cholesky,
           array_ops.transpose(dense_cholesky, perm=[0, 2, 1], conjugate=True))
       verification = twist_matrix(verification, ordering_amd)
@@ -1238,7 +1238,7 @@ class CSRSparseMatrixOpsTest(test.TestCase):
         cholesky_sparse_matrix, dtypes.float32)
 
     # Compute L * Lh.
-    verification = math_ops.matmul(
+    verification = test_util.matmul_without_tf32(
         dense_cholesky, array_ops.transpose(dense_cholesky, perm=[0, 2, 1]))
     verification = twist_matrix(verification, ordering_amd)
     verification_values = self.evaluate(verification)
diff --git a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_test.py b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_test.py
index 35c706cb36a..4aa3474ffbb 100644
--- a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_test.py
+++ b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_test.py
@@ -162,7 +162,7 @@ class SparseMatrixMatmulTest(test.TestCase):
                          1.j * np.random.randn(*dense_shape_b))).astype(dtype)
       a_sm = sparse_csr_matrix_ops.CSRSparseMatrix(a_mats)
       b_sm = sparse_csr_matrix_ops.CSRSparseMatrix(b_mats)
-      c_dense = math_ops.matmul(
+      c_dense = test_util.matmul_without_tf32(
           a_mats,
           b_mats,
           transpose_a=transpose_a,
@@ -202,7 +202,7 @@ class SparseMatrixMatmulTest(test.TestCase):
       b_mats = (np.random.randn(*dense_shape_b) +
                 1.j * np.random.randn(*dense_shape_b)).astype(dtype)
       a_sm = sparse_csr_matrix_ops.CSRSparseMatrix(a_mats)
-      c_dense = math_ops.matmul(
+      c_dense = test_util.matmul_without_tf32(
           a_mats,
           b_mats,
           transpose_a=transpose_a,
@@ -240,7 +240,7 @@ class SparseMatrixMatmulTest(test.TestCase):
       b_mats = sparsify((np.random.randn(*dense_shape_b) +
                          1.j * np.random.randn(*dense_shape_b))).astype(dtype)
       b_sm = sparse_csr_matrix_ops.CSRSparseMatrix(b_mats)
-      c_dense = math_ops.matmul(
+      c_dense = test_util.matmul_without_tf32(
           a_mats,
           b_mats,
           transpose_a=transpose_a,
diff --git a/tensorflow/python/kernel_tests/linalg_grad_test.py b/tensorflow/python/kernel_tests/linalg_grad_test.py
index f1d885fd231..273aba4d94f 100644
--- a/tensorflow/python/kernel_tests/linalg_grad_test.py
+++ b/tensorflow/python/kernel_tests/linalg_grad_test.py
@@ -63,6 +63,9 @@ def _GetMatrixUnaryFunctorGradientTest(functor_, dtype_, shape_, **kwargs_):
 
   @test_util.enable_control_flow_v2
   @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+  @test_util.run_without_tensor_float_32(
+      'Tests `tf.linalg.expm`, which call matmul. Additionally, calls ops '
+      'which do matmul in their gradient, such as MatrixSolve.')
   def Test(self):
 
     def RandomInput():
@@ -102,6 +105,16 @@ def _GetMatrixBinaryFunctorGradientTest(functor_,
                                         **kwargs_):
 
   @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+  @test_util.run_without_tensor_float_32(
+      'Tests `tf.linalg.lstsq`, which call matmul. Additionally, calls ops '
+      'which do matmul in their gradient, such as MatrixSolveLs.')
+  # TODO(b/164254522): With tf32, some tests fails with extremely high absolute
+  # and relative differences when calling assertAllClose. For example, the test
+  # test_MatrixSolveLsGradient_float32_10_10_1e-06 of class
+  # MatrixBinaryFunctorGradientTest fails with a max absolute difference of
+  # 0.883 and a max relative difference of 736892. We should consider disabling
+  # tf32 within `tf.linalg.lstsq and perhaps other linear algebra functions,
+  # even if tf32 is allowed globally.
   def Test(self):
 
     def RandomInput():
diff --git a/tensorflow/python/kernel_tests/lu_op_test.py b/tensorflow/python/kernel_tests/lu_op_test.py
index fee6aecb3b0..8d522e80a08 100644
--- a/tensorflow/python/kernel_tests/lu_op_test.py
+++ b/tensorflow/python/kernel_tests/lu_op_test.py
@@ -91,7 +91,7 @@ class LuOpTest(test.TestCase):
     # Prepare the upper factor.
     upper = array_ops.matrix_band_part(lu, 0, -1)
 
-    verification = math_ops.matmul(lower, upper)
+    verification = test_util.matmul_without_tf32(lower, upper)
 
     # Permute the rows of product of the Cholesky factors.
     if num_rows > 0:
diff --git a/tensorflow/python/kernel_tests/matmul_op_test.py b/tensorflow/python/kernel_tests/matmul_op_test.py
index 712d7336b94..df194f0322a 100644
--- a/tensorflow/python/kernel_tests/matmul_op_test.py
+++ b/tensorflow/python/kernel_tests/matmul_op_test.py
@@ -70,6 +70,7 @@ class MatMulTest(test_lib.TestCase):
 
 def _GetMatMulTest(a_np_, b_np_, use_static_shape_, **kwargs_):
 
+  @test_util.run_without_tensor_float_32('Tests matmul')
   def Test(self):
     np_val = np.matrix(a_np_) * np.matrix(b_np_)
 
diff --git a/tensorflow/python/kernel_tests/matrix_inverse_op_test.py b/tensorflow/python/kernel_tests/matrix_inverse_op_test.py
index ffe0f595618..5ce6b52afba 100644
--- a/tensorflow/python/kernel_tests/matrix_inverse_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_inverse_op_test.py
@@ -41,7 +41,7 @@ class InverseOpTest(test.TestCase):
       with self.cached_session(use_gpu=True):
         # Verify that x^{-1} * x == Identity matrix.
         inv = linalg_ops.matrix_inverse(y, adjoint=adjoint)
-        tf_ans = math_ops.matmul(inv, y, adjoint_b=adjoint)
+        tf_ans = test_util.matmul_without_tf32(inv, y, adjoint_b=adjoint)
         np_ans = np.identity(y.shape[-1])
         if x.ndim > 2:
           tiling = list(y.shape)
diff --git a/tensorflow/python/kernel_tests/matrix_square_root_op_test.py b/tensorflow/python/kernel_tests/matrix_square_root_op_test.py
index 6cf330ed981..98796f256ab 100644
--- a/tensorflow/python/kernel_tests/matrix_square_root_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_square_root_op_test.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import stateless_random_ops
 from tensorflow.python.platform import test
 
 
+@test_util.run_all_without_tensor_float_32
 class SquareRootOpTest(test.TestCase):
 
   def _verifySquareRoot(self, matrix, np_type):
@@ -36,7 +37,7 @@ class SquareRootOpTest(test.TestCase):
 
     # Verify that matmul(sqrtm(A), sqrtm(A)) = A
     sqrt = gen_linalg_ops.matrix_square_root(matrix)
-    square = math_ops.matmul(sqrt, sqrt)
+    square = test_util.matmul_without_tf32(sqrt, sqrt)
     self.assertShapeEqual(matrix, square)
     self.assertAllClose(matrix, square, rtol=1e-4, atol=1e-3)
 
diff --git a/tensorflow/python/kernel_tests/qr_op_test.py b/tensorflow/python/kernel_tests/qr_op_test.py
index b895fe4ea99..b4cb2f3ed27 100644
--- a/tensorflow/python/kernel_tests/qr_op_test.py
+++ b/tensorflow/python/kernel_tests/qr_op_test.py
@@ -200,6 +200,8 @@ def _GetQrGradOpTest(dtype_, shape_, full_matrices_):
     return a
 
   @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+  @test_util.run_without_tensor_float_32(
+      "Tests Qr gradient, which calls matmul")
   def Test(self):
     np.random.seed(42)
     # Optimal stepsize for central difference is O(epsilon^{1/3}).
diff --git a/tensorflow/python/kernel_tests/rnn_cell_test.py b/tensorflow/python/kernel_tests/rnn_cell_test.py
index 01b324f29fb..7fa31d14777 100644
--- a/tensorflow/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/python/kernel_tests/rnn_cell_test.py
@@ -3062,6 +3062,8 @@ class RNNCellTest(test.TestCase, parameterized.TestCase):
 
 
 @test_util.run_all_in_graph_and_eager_modes
+@test_util.run_all_without_tensor_float_32(
+    "Uses an LSTMCell, which calls matmul")
 class DropoutWrapperTest(test.TestCase, parameterized.TestCase):
 
   def _testDropoutWrapper(self,
diff --git a/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py b/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py
index 5be7cb4dd3a..40f8b31b7c2 100644
--- a/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py
+++ b/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py
@@ -38,6 +38,7 @@ def _AddTest(test_class, op_name, testcase_name, fn):
   setattr(test_class, test_name, fn)
 
 
+@test_util.run_all_without_tensor_float_32
 class SelfAdjointEigTest(test.TestCase):
 
   @test_util.run_deprecated_v1
@@ -160,8 +161,8 @@ def _GetSelfAdjointEigTest(dtype_, shape_, compute_v_):
         tf_e, tf_v = linalg_ops.self_adjoint_eig(constant_op.constant(a))
 
         # Check that V*diag(E)*V^T is close to A.
-        a_ev = math_ops.matmul(
-            math_ops.matmul(tf_v, array_ops.matrix_diag(tf_e)),
+        a_ev = test_util.matmul_without_tf32(
+            test_util.matmul_without_tf32(tf_v, array_ops.matrix_diag(tf_e)),
             tf_v,
             adjoint_b=True)
         self.assertAllClose(self.evaluate(a_ev), a, atol=atol)
diff --git a/tensorflow/python/kernel_tests/tensordot_op_test.py b/tensorflow/python/kernel_tests/tensordot_op_test.py
index a031f9bca07..368a7f18f8b 100644
--- a/tensorflow/python/kernel_tests/tensordot_op_test.py
+++ b/tensorflow/python/kernel_tests/tensordot_op_test.py
@@ -165,6 +165,7 @@ def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_):
     return a, b, a_dims, b_dims
 
   @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+  @test_util.run_without_tensor_float_32("Tests tensordot, which calls matmul")
   def test_tensordot(self):
     if dynamic_shape_ and context.executing_eagerly():
       self.skipTest("Placeholders not support in eager mode")
@@ -196,6 +197,7 @@ def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_):
       self.assertAllEqual(tf_ans.shape, np_ans.shape)
 
   @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+  @test_util.run_without_tensor_float_32("Tests tensordot, which calls matmul")
   def test_tensordot_scalar_axes(self):
     if dynamic_shape_ and context.executing_eagerly():
       self.skipTest("Placeholders not support in eager mode")
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index 1adece3474b..c9fe9a52d20 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -97,6 +97,8 @@ class RGBToHSVTest(test_util.TensorFlowTestCase):
 
 class RGBToYIQTest(test_util.TensorFlowTestCase):
 
+  @test_util.run_without_tensor_float_32(
+      "Calls rgb_to_yiq and yiq_to_rgb, which use matmul")
   def testBatch(self):
     # Build an arbitrary RGB image
     np.random.seed(7)
@@ -127,6 +129,8 @@ class RGBToYIQTest(test_util.TensorFlowTestCase):
 
 class RGBToYUVTest(test_util.TensorFlowTestCase):
 
+  @test_util.run_without_tensor_float_32(
+      "Calls rgb_to_yuv and yuv_to_rgb, which use matmul")
   def testBatch(self):
     # Build an arbitrary RGB image
     np.random.seed(7)
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
index 3802f92b384..61641abf7e4 100644
--- a/tensorflow/python/ops/nn_test.py
+++ b/tensorflow/python/ops/nn_test.py
@@ -540,6 +540,8 @@ class DropoutTest(test_lib.TestCase):
       _ = nn_ops.dropout(x, 0.5)
 
 
+@test_util.run_all_without_tensor_float_32(
+    'Tests _compute_sampled_logits and related functions, which call matmul')
 class ComputeSampledLogitsTest(test_lib.TestCase):
 
   def setUp(self):
diff --git a/tensorflow/python/ops/parallel_for/math_test.py b/tensorflow/python/ops/parallel_for/math_test.py
index 85b58055d8f..30e724413f4 100644
--- a/tensorflow/python/ops/parallel_for/math_test.py
+++ b/tensorflow/python/ops/parallel_for/math_test.py
@@ -261,6 +261,9 @@ class MathTest(PForTestCase, parameterized.TestCase):
 
     self._test_loop_fn(loop_fn, 4)
 
+  @test_util.run_without_tensor_float_32(
+      "Calls matmul in parallel for-loop and compares result to calling matmul "
+      "in sequential for-loop")
   def test_matmul(self):
     for tr_a in (True, False):
       for tr_b in (True, False):
@@ -745,6 +748,9 @@ class LinalgTest(PForTestCase):
 
     self._test_loop_fn(loop_fn, 2)
 
+  @test_util.run_without_tensor_float_32(
+      "Calls einsum in parallel for-loop and compares result to calling einsum "
+      "in sequential for-loop")
   def test_einsum(self):
     b = 10
     x_series = random_ops.random_uniform([b, 9, 9])
diff --git a/tensorflow/python/ops/special_math_ops_test.py b/tensorflow/python/ops/special_math_ops_test.py
index 623f5063c7d..ba184b222ca 100644
--- a/tensorflow/python/ops/special_math_ops_test.py
+++ b/tensorflow/python/ops/special_math_ops_test.py
@@ -635,6 +635,8 @@ class BesselTest(test.TestCase, parameterized.TestCase):
 
 
 @test_util.run_all_in_graph_and_eager_modes
+@test_util.run_all_without_tensor_float_32(
+    'Tests einsum, which sometimes does a matmul with cuBLAS')
 class EinsumTest(test.TestCase):
 
   def _check(self, s, *input_shapes, **kwargs):

From 37ae20031b5a7bf073076bf60b8ace5c0140290b Mon Sep 17 00:00:00 2001
From: Reed <reedwm@google.com>
Date: Mon, 17 Aug 2020 14:11:16 -0700
Subject: [PATCH 2/6] Address comments

---
 tensorflow/compiler/tests/xla_test.py          | 1 -
 tensorflow/compiler/xla/client/lib/qr_test.cc  | 6 +++---
 tensorflow/compiler/xla/tests/cholesky_test.cc | 7 +------
 3 files changed, 4 insertions(+), 10 deletions(-)

diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py
index 792667b6074..2a531a369b9 100644
--- a/tensorflow/compiler/tests/xla_test.py
+++ b/tensorflow/compiler/tests/xla_test.py
@@ -254,7 +254,6 @@ class XLATestCase(test.TestCase):
     return self.device_scope()
 
 
-
 def Benchmark(tf_bench,
               builder_fn,
               use_xla_jit,
diff --git a/tensorflow/compiler/xla/client/lib/qr_test.cc b/tensorflow/compiler/xla/client/lib/qr_test.cc
index 1e004a59961..a612665240e 100644
--- a/tensorflow/compiler/xla/client/lib/qr_test.cc
+++ b/tensorflow/compiler/xla/client/lib/qr_test.cc
@@ -34,7 +34,7 @@ namespace {
 using QrTest = xla::ClientLibraryTestBase;
 
 XLA_TEST_F(QrTest, Simple) {
-  tensorflow::allow_tf32_execution(false);
+  tensorflow::allow_tf32_execution(false);  // Test fails with tf32 allowed
   xla::XlaBuilder builder(TestName());
 
   xla::Array2D<float> a_vals({
@@ -63,7 +63,7 @@ XLA_TEST_F(QrTest, Simple) {
 }
 
 XLA_TEST_F(QrTest, ZeroDiagonal) {
-  tensorflow::allow_tf32_execution(false);
+  tensorflow::allow_tf32_execution(false);  // Test fails with tf32 allowed
   xla::XlaBuilder builder(TestName());
 
   xla::Array2D<float> a_vals({
@@ -91,7 +91,7 @@ XLA_TEST_F(QrTest, ZeroDiagonal) {
 }
 
 XLA_TEST_F(QrTest, SimpleBatched) {
-  tensorflow::allow_tf32_execution(false);
+  tensorflow::allow_tf32_execution(false);  // Test fails with tf32 allowed
   xla::XlaBuilder builder(TestName());
 
   xla::Array3D<float> a_vals({
diff --git a/tensorflow/compiler/xla/tests/cholesky_test.cc b/tensorflow/compiler/xla/tests/cholesky_test.cc
index e64c69bdbc8..39c4cc07860 100644
--- a/tensorflow/compiler/xla/tests/cholesky_test.cc
+++ b/tensorflow/compiler/xla/tests/cholesky_test.cc
@@ -38,7 +38,6 @@ namespace {
 using CholeskyTest = ClientLibraryTestBase;
 
 XLA_TEST_F(CholeskyTest, NonPSDInput) {
-  tensorflow::allow_tf32_execution(false);
   XlaBuilder builder(TestName());
 
   Array2D<float> a_vals({
@@ -63,7 +62,6 @@ XLA_TEST_F(CholeskyTest, NonPSDInput) {
 }
 
 XLA_TEST_F(CholeskyTest, Lower) {
-  tensorflow::allow_tf32_execution(false);
   XlaBuilder builder(TestName());
 
   float nan = std::numeric_limits<float>::quiet_NaN();
@@ -90,7 +88,6 @@ XLA_TEST_F(CholeskyTest, Lower) {
 }
 
 XLA_TEST_F(CholeskyTest, Upper) {
-  tensorflow::allow_tf32_execution(false);
   XlaBuilder builder(TestName());
 
   float nan = std::numeric_limits<float>::quiet_NaN();
@@ -117,7 +114,6 @@ XLA_TEST_F(CholeskyTest, Upper) {
 }
 
 XLA_TEST_F(CholeskyTest, Simple2) {
-  tensorflow::allow_tf32_execution(false);
   XlaBuilder builder(TestName());
 
   Array2D<float> a_vals({
@@ -141,7 +137,6 @@ XLA_TEST_F(CholeskyTest, Simple2) {
 }
 
 XLA_TEST_F(CholeskyTest, SimpleBatched) {
-  tensorflow::allow_tf32_execution(false);
   XlaBuilder builder(TestName());
 
   Array3D<float> a_vals({
@@ -187,7 +182,7 @@ class RandomCholeskyTest
       public ::testing::WithParamInterface<CholeskyTestCase> {};
 
 XLA_TEST_P(RandomCholeskyTest, Random) {
-  tensorflow::allow_tf32_execution(false);
+  tensorflow::allow_tf32_execution(false);  // Test fails with tf32 allowed
   XlaBuilder builder(TestName());
 
   auto test_params = GetParam();

From fa83fab56b3bc7f82fad6ca145874c22babecaa3 Mon Sep 17 00:00:00 2001
From: Reed <reedwm@google.com>
Date: Mon, 17 Aug 2020 14:59:36 -0700
Subject: [PATCH 3/6] fix lint issues

---
 tensorflow/compiler/xla/client/lib/BUILD | 2 +-
 tensorflow/compiler/xla/tests/BUILD      | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD
index 367743f2b6e..d662d3102c0 100644
--- a/tensorflow/compiler/xla/client/lib/BUILD
+++ b/tensorflow/compiler/xla/client/lib/BUILD
@@ -292,7 +292,6 @@ xla_test(
     deps = [
         ":matrix",
         ":qr",
-        "//tensorflow/core/platform:tf32_utils",
         "//tensorflow/compiler/xla:array2d",
         "//tensorflow/compiler/xla:array3d",
         "//tensorflow/compiler/xla:literal",
@@ -305,6 +304,7 @@ xla_test(
         "//tensorflow/compiler/xla/tests:test_macros_header",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "//tensorflow/core:test",
+        "//tensorflow/core/platform:tf32_utils",
     ],
 )
 
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 7253ee413c5..734d2ed443c 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -2687,7 +2687,6 @@ xla_test(
     ],
     deps = [
         ":test_macros_header",
-        "//tensorflow/core/platform:tf32_utils",
         "//tensorflow/compiler/xla:array2d",
         "//tensorflow/compiler/xla:literal",
         "//tensorflow/compiler/xla:statusor",
@@ -2700,5 +2699,6 @@ xla_test(
         "//tensorflow/compiler/xla/tests:literal_test_util",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "//tensorflow/core:test",
+        "//tensorflow/core/platform:tf32_utils",
     ],
 )

From 92dc3cd54c0e6ecd934f2e09955cd9b3f315bc33 Mon Sep 17 00:00:00 2001
From: Reed <reedwm@google.com>
Date: Tue, 18 Aug 2020 20:32:08 -0700
Subject: [PATCH 4/6] Improve comment in QrOpTest

---
 tensorflow/compiler/tests/qr_op_test.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/tensorflow/compiler/tests/qr_op_test.py b/tensorflow/compiler/tests/qr_op_test.py
index e16f0be7c64..df50bc6ad65 100644
--- a/tensorflow/compiler/tests/qr_op_test.py
+++ b/tensorflow/compiler/tests/qr_op_test.py
@@ -32,9 +32,9 @@ from tensorflow.python.platform import test
 
 
 @test_util.run_all_without_tensor_float_32(
-    "It's unknown why this test requires TF32 to be disabled")
-# TODO(reedwm): Determine why this test requires TF32 disabled. Debugging is
-# difficult due to this test's flakiness
+    'XLA QR op calls matmul. Also, matmul used for verification. Also with '
+    'TF32, mysterious "Unable to launch cuBLAS gemm" error occasionally occurs')
+# TODO(b/165435566): Fix "Unable to launch cuBLAS gemm" error
 class QrOpTest(xla_test.XLATestCase, parameterized.TestCase):
 
   def AdjustedNorm(self, x):

From 6638b67423737ff696dfbdcd5f40ec235afda73a Mon Sep 17 00:00:00 2001
From: Reed <reedwm@google.com>
Date: Tue, 18 Aug 2020 20:34:13 -0700
Subject: [PATCH 5/6] Fix lint errors

---
 tensorflow/compiler/tests/cholesky_op_test.py                | 1 -
 tensorflow/compiler/tests/matrix_triangular_solve_op_test.py | 1 -
 tensorflow/python/kernel_tests/matrix_inverse_op_test.py     | 1 -
 3 files changed, 3 deletions(-)

diff --git a/tensorflow/compiler/tests/cholesky_op_test.py b/tensorflow/compiler/tests/cholesky_op_test.py
index 8a2966a0466..41877d39381 100644
--- a/tensorflow/compiler/tests/cholesky_op_test.py
+++ b/tensorflow/compiler/tests/cholesky_op_test.py
@@ -28,7 +28,6 @@ from tensorflow.python.framework import errors
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py
index 8f123dfb809..192526c2216 100644
--- a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py
+++ b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py
@@ -29,7 +29,6 @@ from tensorflow.python.framework import errors
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/python/kernel_tests/matrix_inverse_op_test.py b/tensorflow/python/kernel_tests/matrix_inverse_op_test.py
index 5ce6b52afba..9a5a467a5a1 100644
--- a/tensorflow/python/kernel_tests/matrix_inverse_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_inverse_op_test.py
@@ -26,7 +26,6 @@ from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import random_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import benchmark

From b067698e08ba91e4ef8d32b4cc21155ba0c99593 Mon Sep 17 00:00:00 2001
From: Reed <reedwm@google.com>
Date: Tue, 25 Aug 2020 10:28:02 -0700
Subject: [PATCH 6/6] Fix internal build error in conv_ops_test.cc

---
 tensorflow/core/kernels/BUILD | 1 +
 1 file changed, 1 insertion(+)

diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index ccb12d9b09d..3653e59ff08 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -1743,6 +1743,7 @@ tf_cuda_cc_test(
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
         "//tensorflow/core:testlib",
+        "//tensorflow/core/platform:tf32_utils",
         "@com_google_absl//absl/algorithm:container",
     ],
 )