Explicitly disable XLA for AMP test
This test checks for certain graph nodes to verify AMP correctness, but XLA changes the graph in ways that make these checks fail.
This commit is contained in:
parent
2c171cdb26
commit
e4715eb757
@ -339,6 +339,7 @@ class AutoMixedPrecisionTest(test.TestCase):
|
||||
self.assertAllClose(output_val_ref, output_val, atol=2e-3, rtol=1e-3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla("This test does not pass with XLA")
|
||||
def test_conv_bn(self):
|
||||
"""Test graph with convolution followed by batch norm."""
|
||||
with compat.forward_compatibility_horizon(2019, 6, 7):
|
||||
@ -361,6 +362,7 @@ class AutoMixedPrecisionTest(test.TestCase):
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla("This test does not pass with XLA")
|
||||
def test_conv_bn_dropout(self):
|
||||
"""Test dropout precision of convolution batch norm graph."""
|
||||
with compat.forward_compatibility_horizon(2019, 6, 7):
|
||||
@ -387,6 +389,7 @@ class AutoMixedPrecisionTest(test.TestCase):
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla("This test does not pass with XLA")
|
||||
def test_conv_pool(self):
|
||||
"""Test graph with convolution followed by pooling."""
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
@ -407,6 +410,7 @@ class AutoMixedPrecisionTest(test.TestCase):
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla("This test does not pass with XLA")
|
||||
def test_simple_loop(self):
|
||||
"""Test graph with while loop."""
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
@ -425,6 +429,7 @@ class AutoMixedPrecisionTest(test.TestCase):
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla("This test does not pass with XLA")
|
||||
def test_loop_with_vars_intertwined(self):
|
||||
"""Test graph with intertwined while loops."""
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
@ -446,6 +451,7 @@ class AutoMixedPrecisionTest(test.TestCase):
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla("This test does not pass with XLA")
|
||||
def test_multi_paths(self):
|
||||
"""Test graph with multiple paths."""
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
@ -473,6 +479,7 @@ class AutoMixedPrecisionTest(test.TestCase):
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla("This test does not pass with XLA")
|
||||
def test_multi_paths_2(self):
|
||||
"""Test graph with multiple paths."""
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
@ -495,6 +502,7 @@ class AutoMixedPrecisionTest(test.TestCase):
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla("This test does not pass with XLA")
|
||||
def test_recurrent_lstm(self):
|
||||
"""Test graph with recurrent lstm."""
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
@ -520,34 +528,42 @@ class AutoMixedPrecisionTest(test.TestCase):
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla("This test does not pass with XLA")
|
||||
def test_propagation_through_simple_loop_1(self):
|
||||
self._run_simple_loop_test('W', 'C', 'C')
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla("This test does not pass with XLA")
|
||||
def test_propagation_through_simple_loop_2(self):
|
||||
self._run_simple_loop_test('C', 'C', 'W')
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla("This test does not pass with XLA")
|
||||
def test_propagation_through_simple_loop_3(self):
|
||||
self._run_simple_loop_test('W', 'G', 'W')
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla("This test does not pass with XLA")
|
||||
def test_propagation_through_simple_loop_4(self):
|
||||
self._run_simple_loop_test('W', 'gbg', 'W')
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla("This test does not pass with XLA")
|
||||
def test_propagation_through_simple_loop_5(self):
|
||||
self._run_simple_loop_test('b', 'gWC', 'c')
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla("This test does not pass with XLA")
|
||||
def test_propagation_through_simple_loop_6(self):
|
||||
self._run_simple_loop_test('b', 'CWCG', 'C')
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla("This test does not pass with XLA")
|
||||
def test_propagation_through_simple_loop_7(self):
|
||||
self._run_simple_loop_test('C', 'GWCG', 'C')
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla("This test does not pass with XLA")
|
||||
def test_propagation_through_simple_loop_8(self):
|
||||
self._run_simple_loop_test('C', 'CgbgWC', 'g')
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user