Merge pull request #29619 from bas-aarts:xla-amp-test-bug

PiperOrigin-RevId: 254148232
This commit is contained in:
TensorFlower Gardener 2019-06-20 00:54:17 -07:00
commit 420473f1fb

View File

@ -339,6 +339,7 @@ class AutoMixedPrecisionTest(test.TestCase):
self.assertAllClose(output_val_ref, output_val, atol=2e-3, rtol=1e-3)
@test_util.run_deprecated_v1
@test_util.disable_xla('This test does not pass with XLA')
def test_conv_bn(self):
"""Test graph with convolution followed by batch norm."""
with compat.forward_compatibility_horizon(2019, 6, 7):
@ -361,6 +362,7 @@ class AutoMixedPrecisionTest(test.TestCase):
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
@test_util.run_deprecated_v1
@test_util.disable_xla('This test does not pass with XLA')
def test_conv_bn_dropout(self):
"""Test dropout precision of convolution batch norm graph."""
with compat.forward_compatibility_horizon(2019, 6, 7):
@ -387,6 +389,7 @@ class AutoMixedPrecisionTest(test.TestCase):
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
@test_util.run_deprecated_v1
@test_util.disable_xla('This test does not pass with XLA')
def test_conv_pool(self):
"""Test graph with convolution followed by pooling."""
if test.is_gpu_available(cuda_only=True):
@ -407,6 +410,7 @@ class AutoMixedPrecisionTest(test.TestCase):
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
@test_util.run_deprecated_v1
@test_util.disable_xla('This test does not pass with XLA')
def test_simple_loop(self):
"""Test graph with while loop."""
if test.is_gpu_available(cuda_only=True):
@ -425,6 +429,7 @@ class AutoMixedPrecisionTest(test.TestCase):
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
@test_util.run_deprecated_v1
@test_util.disable_xla('This test does not pass with XLA')
def test_loop_with_vars_intertwined(self):
"""Test graph with intertwined while loops."""
if test.is_gpu_available(cuda_only=True):
@ -446,6 +451,7 @@ class AutoMixedPrecisionTest(test.TestCase):
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
@test_util.run_deprecated_v1
@test_util.disable_xla('This test does not pass with XLA')
def test_multi_paths(self):
"""Test graph with multiple paths."""
if test.is_gpu_available(cuda_only=True):
@ -473,6 +479,7 @@ class AutoMixedPrecisionTest(test.TestCase):
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
@test_util.run_deprecated_v1
@test_util.disable_xla('This test does not pass with XLA')
def test_multi_paths_2(self):
"""Test graph with multiple paths."""
if test.is_gpu_available(cuda_only=True):
@ -495,6 +502,7 @@ class AutoMixedPrecisionTest(test.TestCase):
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
@test_util.run_deprecated_v1
@test_util.disable_xla('This test does not pass with XLA')
def test_recurrent_lstm(self):
"""Test graph with recurrent lstm."""
if test.is_gpu_available(cuda_only=True):
@ -520,34 +528,42 @@ class AutoMixedPrecisionTest(test.TestCase):
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
@test_util.run_deprecated_v1
@test_util.disable_xla('This test does not pass with XLA')
def test_propagation_through_simple_loop_1(self):
self._run_simple_loop_test('W', 'C', 'C')
@test_util.run_deprecated_v1
@test_util.disable_xla('This test does not pass with XLA')
def test_propagation_through_simple_loop_2(self):
self._run_simple_loop_test('C', 'C', 'W')
@test_util.run_deprecated_v1
@test_util.disable_xla('This test does not pass with XLA')
def test_propagation_through_simple_loop_3(self):
self._run_simple_loop_test('W', 'G', 'W')
@test_util.run_deprecated_v1
@test_util.disable_xla('This test does not pass with XLA')
def test_propagation_through_simple_loop_4(self):
self._run_simple_loop_test('W', 'gbg', 'W')
@test_util.run_deprecated_v1
@test_util.disable_xla('This test does not pass with XLA')
def test_propagation_through_simple_loop_5(self):
self._run_simple_loop_test('b', 'gWC', 'c')
@test_util.run_deprecated_v1
@test_util.disable_xla('This test does not pass with XLA')
def test_propagation_through_simple_loop_6(self):
self._run_simple_loop_test('b', 'CWCG', 'C')
@test_util.run_deprecated_v1
@test_util.disable_xla('This test does not pass with XLA')
def test_propagation_through_simple_loop_7(self):
self._run_simple_loop_test('C', 'GWCG', 'C')
@test_util.run_deprecated_v1
@test_util.disable_xla('This test does not pass with XLA')
def test_propagation_through_simple_loop_8(self):
self._run_simple_loop_test('C', 'CgbgWC', 'g')