From 5445a05db78d3cce2fb707f0f9bb0e61b71a11dd Mon Sep 17 00:00:00 2001
From: Reed <reedwm@google.com>
Date: Wed, 21 Oct 2020 13:22:05 -0700
Subject: [PATCH 1/2] Fix tests that fail with tf32

---
 tensorflow/c/eager/gradient_checker_test.cc  | 4 ++++
 tensorflow/c/eager/mnist_gradients_test.cc   | 5 +++++
 tensorflow/python/framework/function_test.py | 3 ++-
 tensorflow/python/framework/test_util.py     | 8 ++++++++
 tensorflow/python/kernel_tests/qr_op_test.py | 4 ++--
 5 files changed, 21 insertions(+), 3 deletions(-)

diff --git a/tensorflow/c/eager/gradient_checker_test.cc b/tensorflow/c/eager/gradient_checker_test.cc
index 7a438085fb5..393ad2ceb98 100644
--- a/tensorflow/c/eager/gradient_checker_test.cc
+++ b/tensorflow/c/eager/gradient_checker_test.cc
@@ -29,6 +29,7 @@ limitations under the License.
 #include "tensorflow/c/tf_tensor.h"
 #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
 #include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/tensor_float_32_utils.h"
 #include "tensorflow/core/platform/test.h"
 
 namespace tensorflow {
@@ -56,6 +57,9 @@ Status RegisterGradients(GradientRegistry* registry) {
 }
 
 TEST_P(GradientCheckerTest, TestGradCheckMatMul) {
+  // Computing numerical gradients with TensorFloat-32 is numerically unstable
+  enable_tensor_float_32_execution(false);
+
   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
       TF_NewStatus(), TF_DeleteStatus);
   AbstractContextPtr ctx;
diff --git a/tensorflow/c/eager/mnist_gradients_test.cc b/tensorflow/c/eager/mnist_gradients_test.cc
index 4114f50a798..86b1cbeceaa 100644
--- a/tensorflow/c/eager/mnist_gradients_test.cc
+++ b/tensorflow/c/eager/mnist_gradients_test.cc
@@ -27,6 +27,7 @@ limitations under the License.
 #include "tensorflow/c/tf_tensor.h"
 #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
 #include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/tensor_float_32_utils.h"
 #include "tensorflow/core/platform/test.h"
 
 namespace tensorflow {
@@ -43,6 +44,10 @@ class CppGradients
     TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
     Status s = StatusFromTF_Status(status.get());
     CHECK_EQ(errors::OK, s.code()) << s.error_message();
+
+  // Computing numerical gradients with TensorFloat-32 is numerically unstable.
+  // Some forward pass tests also fail with TensorFloat-32 due to low tolerances
+  enable_tensor_float_32_execution(false);
   }
 };
 
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index 90cd0f62986..ea376077ab7 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -1487,6 +1487,8 @@ class FunctionCaptureByValueTest(test.TestCase):
       self.assertAllEqual(y, [[12.0]])
 
 
+@test_util.run_all_without_tensor_float_32(
+    "Calls matmul in custom LSTM function")
 class UnrollLSTMTest(test.TestCase):
   BATCH_SIZE = 16
   LSTM_DIMS = 32
@@ -1593,7 +1595,6 @@ class UnrollLSTMTest(test.TestCase):
       self.assertAllClose(mv0, mv2, rtol=1e-4)
       self.assertAllClose(mv0, mv3, rtol=1e-4)
 
-  @test_util.run_without_tensor_float_32("Calls matmul in custom LSTM function")
   def testUnrollLSTMGrad(self):
     # Run one step of the unrolled lstm graph.
     def RunForwardBackward(mode, cfg=None):
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 329de12ac41..f55cf51062d 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -1977,6 +1977,9 @@ def matmul_without_tf32(a, b, *args, **kwargs):
   If a matmul itself is being tested, or some other op which uses matmul, use
   `run_without_tensor_float_32` instead.
 
+  This also casts complex64 inputs to complex128, since TensorFloat-32 can also
+  be used with complex64
+
   Args:
     a: First input to tf.linalg.matmul
     b: Second input to tf.linalg.matmul
@@ -1991,6 +1994,11 @@ def matmul_without_tf32(a, b, *args, **kwargs):
     b = math_ops.cast(b, "float64")
     ret = math_ops.matmul(a, b, *args, **kwargs)
     return math_ops.cast(ret, a.dtype)
+  elif config.tensor_float_32_execution_enabled() and a.dtype == "complex64":
+    a = math_ops.cast(a, "complex128")
+    b = math_ops.cast(b, "complex128")
+    ret = math_ops.matmul(a, b, *args, **kwargs)
+    return math_ops.cast(ret, a.dtype)
   else:
     return math_ops.matmul(a, b, *args, **kwargs)
 
diff --git a/tensorflow/python/kernel_tests/qr_op_test.py b/tensorflow/python/kernel_tests/qr_op_test.py
index 0a618b7f555..6d131e8e35f 100644
--- a/tensorflow/python/kernel_tests/qr_op_test.py
+++ b/tensorflow/python/kernel_tests/qr_op_test.py
@@ -112,12 +112,12 @@ def _GetQrOpTest(dtype_, shape_, full_matrices_, use_static_shape_):
     else:
       tol = 1e-14
     # Tests that a ~= q*r.
-    a_recon = math_ops.matmul(q, r)
+    a_recon = test_util.matmul_without_tf32(q, r)
     self.assertAllClose(a_recon, a, rtol=tol, atol=tol)
 
   def CheckUnitary(self, x):
     # Tests that x[...,:,:]^H * x[...,:,:] is close to the identity.
-    xx = math_ops.matmul(x, x, adjoint_a=True)
+    xx = test_util.matmul_without_tf32(x, x, adjoint_a=True)
     identity = array_ops.matrix_band_part(array_ops.ones_like(xx), 0, 0)
     if is_single:
       tol = 1e-5

From 1a5329c084d3150a02177b0b3fdea09dc05e85d5 Mon Sep 17 00:00:00 2001
From: Reed <reedwm@google.com>
Date: Wed, 21 Oct 2020 13:28:31 -0700
Subject: [PATCH 2/2] Fix indentation

---
 tensorflow/c/eager/mnist_gradients_test.cc | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/tensorflow/c/eager/mnist_gradients_test.cc b/tensorflow/c/eager/mnist_gradients_test.cc
index 86b1cbeceaa..befe7772ce0 100644
--- a/tensorflow/c/eager/mnist_gradients_test.cc
+++ b/tensorflow/c/eager/mnist_gradients_test.cc
@@ -45,9 +45,9 @@ class CppGradients
     Status s = StatusFromTF_Status(status.get());
     CHECK_EQ(errors::OK, s.code()) << s.error_message();
 
-  // Computing numerical gradients with TensorFloat-32 is numerically unstable.
-  // Some forward pass tests also fail with TensorFloat-32 due to low tolerances
-  enable_tensor_float_32_execution(false);
+    // Computing numerical gradients with TensorFloat-32 is numerically unstable.
+    // Some forward pass tests also fail with TensorFloat-32 due to low tolerances
+    enable_tensor_float_32_execution(false);
   }
 };