Merge pull request #29619 from bas-aarts:xla-amp-test-bug
PiperOrigin-RevId: 254148232
This commit is contained in:
commit
420473f1fb
@ -339,6 +339,7 @@ class AutoMixedPrecisionTest(test.TestCase):
|
||||
self.assertAllClose(output_val_ref, output_val, atol=2e-3, rtol=1e-3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_conv_bn(self):
|
||||
"""Test graph with convolution followed by batch norm."""
|
||||
with compat.forward_compatibility_horizon(2019, 6, 7):
|
||||
@ -361,6 +362,7 @@ class AutoMixedPrecisionTest(test.TestCase):
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_conv_bn_dropout(self):
|
||||
"""Test dropout precision of convolution batch norm graph."""
|
||||
with compat.forward_compatibility_horizon(2019, 6, 7):
|
||||
@ -387,6 +389,7 @@ class AutoMixedPrecisionTest(test.TestCase):
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_conv_pool(self):
|
||||
"""Test graph with convolution followed by pooling."""
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
@ -407,6 +410,7 @@ class AutoMixedPrecisionTest(test.TestCase):
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_simple_loop(self):
|
||||
"""Test graph with while loop."""
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
@ -425,6 +429,7 @@ class AutoMixedPrecisionTest(test.TestCase):
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_loop_with_vars_intertwined(self):
|
||||
"""Test graph with intertwined while loops."""
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
@ -446,6 +451,7 @@ class AutoMixedPrecisionTest(test.TestCase):
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_multi_paths(self):
|
||||
"""Test graph with multiple paths."""
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
@ -473,6 +479,7 @@ class AutoMixedPrecisionTest(test.TestCase):
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_multi_paths_2(self):
|
||||
"""Test graph with multiple paths."""
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
@ -495,6 +502,7 @@ class AutoMixedPrecisionTest(test.TestCase):
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_recurrent_lstm(self):
|
||||
"""Test graph with recurrent lstm."""
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
@ -520,34 +528,42 @@ class AutoMixedPrecisionTest(test.TestCase):
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_propagation_through_simple_loop_1(self):
|
||||
self._run_simple_loop_test('W', 'C', 'C')
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_propagation_through_simple_loop_2(self):
|
||||
self._run_simple_loop_test('C', 'C', 'W')
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_propagation_through_simple_loop_3(self):
|
||||
self._run_simple_loop_test('W', 'G', 'W')
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_propagation_through_simple_loop_4(self):
|
||||
self._run_simple_loop_test('W', 'gbg', 'W')
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_propagation_through_simple_loop_5(self):
|
||||
self._run_simple_loop_test('b', 'gWC', 'c')
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_propagation_through_simple_loop_6(self):
|
||||
self._run_simple_loop_test('b', 'CWCG', 'C')
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_propagation_through_simple_loop_7(self):
|
||||
self._run_simple_loop_test('C', 'GWCG', 'C')
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_propagation_through_simple_loop_8(self):
|
||||
self._run_simple_loop_test('C', 'CgbgWC', 'g')
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user