Remove workarounds for XLA's previous inf/nan behavior after it's been fixed.

PiperOrigin-RevId: 313559788
Change-Id: I3d5fe3d7b7267d073ef45fe042503932d99b03cb
This commit is contained in:
Tres Popp 2020-05-28 03:52:35 -07:00 committed by TensorFlower Gardener
parent 2217251dfa
commit 35312cceb1
5 changed files with 11 additions and 24 deletions

View File

@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import itertools
import os
import numpy as np
@ -1609,8 +1608,4 @@ class BinaryOpsTest(xla_test.XLATestCase):
if __name__ == "__main__":
# TODO(b/130689556): XLA CPU does not honor inf/nan which causes problems
os.environ[
"XLA_FLAGS"] = "--xla_cpu_enable_fast_math=false " + os.environ.get(
"XLA_FLAGS", "")
googletest.main()

View File

@ -347,8 +347,6 @@ class UnaryOpsTest(xla_test.XLATestCase):
expected=np.array(
[1.55740772, -2.18503986, -0.14254654, 1.15782128], dtype=dtype))
# TODO(b/130689556): Turn this on for CPU when we start honoring NaNs.
if self.device != "XLA_CPU":
self._assertOpOutputMatchesExpected(
math_ops.tanh,
np.array([[1, 2, 3, 4], [np.inf, -np.inf, np.nan, 20],

View File

@ -5122,8 +5122,6 @@ cuda_py_test(
srcs = ["ops/nn_test.py"],
python_version = "PY3",
tags = ["no_windows"],
# TODO(b/130689556): Numerical differences due to fast math on CPU.
xla_enable_strict_auto_jit = False,
deps = [
":array_ops",
":client_testlib",

View File

@ -18,8 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
from tensorflow.python.framework import constant_op
@ -133,8 +131,4 @@ class NumericsTest(test.TestCase):
if __name__ == "__main__":
# TODO(b/130689556): XLA CPU does not honor inf/nan which causes problems
os.environ[
"XLA_FLAGS"] = "--xla_cpu_enable_fast_math=false " + os.environ.get(
"XLA_FLAGS", "")
test.main()

View File

@ -1207,6 +1207,7 @@ class DataFormatVectorPermuteTest(test_lib.TestCase):
y_val = self.evaluate(y)
self.assertAllEqual(y_val, [4, 9])
@test_util.disable_xla("unsupported data format")
def testNHWCToWHCN(self):
x_val = [7, 4, 9, 3]
x = constant_op.constant(x_val)
@ -1215,6 +1216,7 @@ class DataFormatVectorPermuteTest(test_lib.TestCase):
y_val = self.evaluate(y)
self.assertAllEqual(y_val, [9, 4, 3, 7])
@test_util.disable_xla("unsupported data format")
def testNHWCToWHCN_Size2(self):
x_val = [4, 9]
x = constant_op.constant(x_val)