Remove workarounds for XLA's previous inf/nan behavior after it's been fixed.
PiperOrigin-RevId: 313559788 Change-Id: I3d5fe3d7b7267d073ef45fe042503932d99b03cb
This commit is contained in:
parent
2217251dfa
commit
35312cceb1
|
@ -19,7 +19,6 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import itertools
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -1609,8 +1608,4 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# TODO(b/130689556): XLA CPU does not honor inf/nan which causes problems
|
||||
os.environ[
|
||||
"XLA_FLAGS"] = "--xla_cpu_enable_fast_math=false " + os.environ.get(
|
||||
"XLA_FLAGS", "")
|
||||
googletest.main()
|
||||
|
|
|
@ -347,8 +347,6 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
|||
expected=np.array(
|
||||
[1.55740772, -2.18503986, -0.14254654, 1.15782128], dtype=dtype))
|
||||
|
||||
# TODO(b/130689556): Turn this on for CPU when we start honoring NaNs.
|
||||
if self.device != "XLA_CPU":
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.tanh,
|
||||
np.array([[1, 2, 3, 4], [np.inf, -np.inf, np.nan, 20],
|
||||
|
|
|
@ -5122,8 +5122,6 @@ cuda_py_test(
|
|||
srcs = ["ops/nn_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = ["no_windows"],
|
||||
# TODO(b/130689556): Numerical differences due to fast math on CPU.
|
||||
xla_enable_strict_auto_jit = False,
|
||||
deps = [
|
||||
":array_ops",
|
||||
":client_testlib",
|
||||
|
|
|
@ -18,8 +18,6 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
|
@ -133,8 +131,4 @@ class NumericsTest(test.TestCase):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# TODO(b/130689556): XLA CPU does not honor inf/nan which causes problems
|
||||
os.environ[
|
||||
"XLA_FLAGS"] = "--xla_cpu_enable_fast_math=false " + os.environ.get(
|
||||
"XLA_FLAGS", "")
|
||||
test.main()
|
||||
|
|
|
@ -1207,6 +1207,7 @@ class DataFormatVectorPermuteTest(test_lib.TestCase):
|
|||
y_val = self.evaluate(y)
|
||||
self.assertAllEqual(y_val, [4, 9])
|
||||
|
||||
@test_util.disable_xla("unsupported data format")
|
||||
def testNHWCToWHCN(self):
|
||||
x_val = [7, 4, 9, 3]
|
||||
x = constant_op.constant(x_val)
|
||||
|
@ -1215,6 +1216,7 @@ class DataFormatVectorPermuteTest(test_lib.TestCase):
|
|||
y_val = self.evaluate(y)
|
||||
self.assertAllEqual(y_val, [9, 4, 3, 7])
|
||||
|
||||
@test_util.disable_xla("unsupported data format")
|
||||
def testNHWCToWHCN_Size2(self):
|
||||
x_val = [4, 9]
|
||||
x = constant_op.constant(x_val)
|
||||
|
|
Loading…
Reference in New Issue