From 35312cceb134c28c9fe1a53cb8d5c27f281c4054 Mon Sep 17 00:00:00 2001 From: Tres Popp Date: Thu, 28 May 2020 03:52:35 -0700 Subject: [PATCH] Remove workarounds for XLA's previous inf/nan behavior after it's been fixed. PiperOrigin-RevId: 313559788 Change-Id: I3d5fe3d7b7267d073ef45fe042503932d99b03cb --- tensorflow/compiler/tests/binary_ops_test.py | 5 ----- tensorflow/compiler/tests/unary_ops_test.py | 20 +++++++++---------- tensorflow/python/BUILD | 2 -- .../python/kernel_tests/numerics_test.py | 6 ------ tensorflow/python/ops/nn_test.py | 2 ++ 5 files changed, 11 insertions(+), 24 deletions(-) diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index c7be2c55de7..422695c374b 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import itertools -import os import numpy as np @@ -1609,8 +1608,4 @@ class BinaryOpsTest(xla_test.XLATestCase): if __name__ == "__main__": - # TODO(b/130689556): XLA CPU does not honor inf/nan which causes problems - os.environ[ - "XLA_FLAGS"] = "--xla_cpu_enable_fast_math=false " + os.environ.get( - "XLA_FLAGS", "") googletest.main() diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index d0e928a5ce6..85bf89c4f9e 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -347,17 +347,15 @@ class UnaryOpsTest(xla_test.XLATestCase): expected=np.array( [1.55740772, -2.18503986, -0.14254654, 1.15782128], dtype=dtype)) - # TODO(b/130689556): Turn this on for CPU when we start honoring NaNs. - if self.device != "XLA_CPU": - self._assertOpOutputMatchesExpected( - math_ops.tanh, - np.array([[1, 2, 3, 4], [np.inf, -np.inf, np.nan, 20], - [19, -19, 22, -22]], - dtype=dtype), - expected=np.array( - [[0.76159418, 0.96402758, 0.99505478, 0.99932933], - [1.0, -1.0, np.nan, 1.0], [1.0, -1.0, 1.0, -1.0]], - dtype=dtype)) + self._assertOpOutputMatchesExpected( + math_ops.tanh, + np.array([[1, 2, 3, 4], [np.inf, -np.inf, np.nan, 20], + [19, -19, 22, -22]], + dtype=dtype), + expected=np.array( + [[0.76159418, 0.96402758, 0.99505478, 0.99932933], + [1.0, -1.0, np.nan, 1.0], [1.0, -1.0, 1.0, -1.0]], + dtype=dtype)) self._assertOpOutputMatchesExpected( nn_ops.log_softmax, diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 13c58c74583..2fb22a89706 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -5122,8 +5122,6 @@ cuda_py_test( srcs = ["ops/nn_test.py"], python_version = "PY3", tags = ["no_windows"], - # TODO(b/130689556): Numerical differences due to fast math on CPU. - xla_enable_strict_auto_jit = False, deps = [ ":array_ops", ":client_testlib", diff --git a/tensorflow/python/kernel_tests/numerics_test.py b/tensorflow/python/kernel_tests/numerics_test.py index 475badb6efe..eadb8ceff07 100644 --- a/tensorflow/python/kernel_tests/numerics_test.py +++ b/tensorflow/python/kernel_tests/numerics_test.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os - import numpy as np from tensorflow.python.framework import constant_op @@ -133,8 +131,4 @@ class NumericsTest(test.TestCase): if __name__ == "__main__": - # TODO(b/130689556): XLA CPU does not honor inf/nan which causes problems - os.environ[ - "XLA_FLAGS"] = "--xla_cpu_enable_fast_math=false " + os.environ.get( - "XLA_FLAGS", "") test.main() diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index 0088c04f909..477e0528c0d 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -1207,6 +1207,7 @@ class DataFormatVectorPermuteTest(test_lib.TestCase): y_val = self.evaluate(y) self.assertAllEqual(y_val, [4, 9]) + @test_util.disable_xla("unsupported data format") def testNHWCToWHCN(self): x_val = [7, 4, 9, 3] x = constant_op.constant(x_val) @@ -1215,6 +1216,7 @@ class DataFormatVectorPermuteTest(test_lib.TestCase): y_val = self.evaluate(y) self.assertAllEqual(y_val, [9, 4, 3, 7]) + @test_util.disable_xla("unsupported data format") def testNHWCToWHCN_Size2(self): x_val = [4, 9] x = constant_op.constant(x_val)