diff --git a/tensorflow/c/eager/gradient_checker_test.cc b/tensorflow/c/eager/gradient_checker_test.cc index 7a438085fb5..393ad2ceb98 100644 --- a/tensorflow/c/eager/gradient_checker_test.cc +++ b/tensorflow/c/eager/gradient_checker_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/c/tf_tensor.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/tensor_float_32_utils.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -56,6 +57,9 @@ Status RegisterGradients(GradientRegistry* registry) { } TEST_P(GradientCheckerTest, TestGradCheckMatMul) { + // Computing numerical gradients with TensorFloat-32 is numerically unstable + enable_tensor_float_32_execution(false); + std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); AbstractContextPtr ctx; diff --git a/tensorflow/c/eager/mnist_gradients_test.cc b/tensorflow/c/eager/mnist_gradients_test.cc index 4114f50a798..86b1cbeceaa 100644 --- a/tensorflow/c/eager/mnist_gradients_test.cc +++ b/tensorflow/c/eager/mnist_gradients_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/c/tf_tensor.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/tensor_float_32_utils.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -43,6 +44,10 @@ class CppGradients TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); Status s = StatusFromTF_Status(status.get()); CHECK_EQ(errors::OK, s.code()) << s.error_message(); + + // Computing numerical gradients with TensorFloat-32 is numerically unstable. + // Some forward pass tests also fail with TensorFloat-32 due to low tolerances + enable_tensor_float_32_execution(false); } }; diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index 90cd0f62986..ea376077ab7 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -1487,6 +1487,8 @@ class FunctionCaptureByValueTest(test.TestCase): self.assertAllEqual(y, [[12.0]]) +@test_util.run_all_without_tensor_float_32( + "Calls matmul in custom LSTM function") class UnrollLSTMTest(test.TestCase): BATCH_SIZE = 16 LSTM_DIMS = 32 @@ -1593,7 +1595,6 @@ class UnrollLSTMTest(test.TestCase): self.assertAllClose(mv0, mv2, rtol=1e-4) self.assertAllClose(mv0, mv3, rtol=1e-4) - @test_util.run_without_tensor_float_32("Calls matmul in custom LSTM function") def testUnrollLSTMGrad(self): # Run one step of the unrolled lstm graph. def RunForwardBackward(mode, cfg=None): diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 329de12ac41..f55cf51062d 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -1977,6 +1977,9 @@ def matmul_without_tf32(a, b, *args, **kwargs): If a matmul itself is being tested, or some other op which uses matmul, use `run_without_tensor_float_32` instead. + This also casts complex64 inputs to complex128, since TensorFloat-32 can also + be used with complex64 + Args: a: First input to tf.linalg.matmul b: Second input to tf.linalg.matmul @@ -1991,6 +1994,11 @@ def matmul_without_tf32(a, b, *args, **kwargs): b = math_ops.cast(b, "float64") ret = math_ops.matmul(a, b, *args, **kwargs) return math_ops.cast(ret, a.dtype) + elif config.tensor_float_32_execution_enabled() and a.dtype == "complex64": + a = math_ops.cast(a, "complex128") + b = math_ops.cast(b, "complex128") + ret = math_ops.matmul(a, b, *args, **kwargs) + return math_ops.cast(ret, a.dtype) else: return math_ops.matmul(a, b, *args, **kwargs) diff --git a/tensorflow/python/kernel_tests/qr_op_test.py b/tensorflow/python/kernel_tests/qr_op_test.py index 0a618b7f555..6d131e8e35f 100644 --- a/tensorflow/python/kernel_tests/qr_op_test.py +++ b/tensorflow/python/kernel_tests/qr_op_test.py @@ -112,12 +112,12 @@ def _GetQrOpTest(dtype_, shape_, full_matrices_, use_static_shape_): else: tol = 1e-14 # Tests that a ~= q*r. - a_recon = math_ops.matmul(q, r) + a_recon = test_util.matmul_without_tf32(q, r) self.assertAllClose(a_recon, a, rtol=tol, atol=tol) def CheckUnitary(self, x): # Tests that x[...,:,:]^H * x[...,:,:] is close to the identity. - xx = math_ops.matmul(x, x, adjoint_a=True) + xx = test_util.matmul_without_tf32(x, x, adjoint_a=True) identity = array_ops.matrix_band_part(array_ops.ones_like(xx), 0, 0) if is_single: tol = 1e-5