From e4715eb7575481034e715fa946af2a193816e48f Mon Sep 17 00:00:00 2001 From: Bas Aarts Date: Mon, 10 Jun 2019 15:56:53 -0700 Subject: [PATCH] Explicitly disable XLA for AMP test This test checks for certain graph nodes to verify AMP correctness, but XLA changes the graph in ways that make these checks fail. --- .../python/grappler/auto_mixed_precision_test.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tensorflow/python/grappler/auto_mixed_precision_test.py b/tensorflow/python/grappler/auto_mixed_precision_test.py index 4df38c6f9c3..9603f54cc7a 100644 --- a/tensorflow/python/grappler/auto_mixed_precision_test.py +++ b/tensorflow/python/grappler/auto_mixed_precision_test.py @@ -339,6 +339,7 @@ class AutoMixedPrecisionTest(test.TestCase): self.assertAllClose(output_val_ref, output_val, atol=2e-3, rtol=1e-3) @test_util.run_deprecated_v1 + @test_util.disable_xla("This test does not pass with XLA") def test_conv_bn(self): """Test graph with convolution followed by batch norm.""" with compat.forward_compatibility_horizon(2019, 6, 7): @@ -361,6 +362,7 @@ class AutoMixedPrecisionTest(test.TestCase): self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3) @test_util.run_deprecated_v1 + @test_util.disable_xla("This test does not pass with XLA") def test_conv_bn_dropout(self): """Test dropout precision of convolution batch norm graph.""" with compat.forward_compatibility_horizon(2019, 6, 7): @@ -387,6 +389,7 @@ class AutoMixedPrecisionTest(test.TestCase): self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3) @test_util.run_deprecated_v1 + @test_util.disable_xla("This test does not pass with XLA") def test_conv_pool(self): """Test graph with convolution followed by pooling.""" if test.is_gpu_available(cuda_only=True): @@ -407,6 +410,7 @@ class AutoMixedPrecisionTest(test.TestCase): self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3) @test_util.run_deprecated_v1 + @test_util.disable_xla("This test does not pass with XLA") def test_simple_loop(self): """Test graph with while loop.""" if test.is_gpu_available(cuda_only=True): @@ -425,6 +429,7 @@ class AutoMixedPrecisionTest(test.TestCase): self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3) @test_util.run_deprecated_v1 + @test_util.disable_xla("This test does not pass with XLA") def test_loop_with_vars_intertwined(self): """Test graph with intertwined while loops.""" if test.is_gpu_available(cuda_only=True): @@ -446,6 +451,7 @@ class AutoMixedPrecisionTest(test.TestCase): self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3) @test_util.run_deprecated_v1 + @test_util.disable_xla("This test does not pass with XLA") def test_multi_paths(self): """Test graph with multiple paths.""" if test.is_gpu_available(cuda_only=True): @@ -473,6 +479,7 @@ class AutoMixedPrecisionTest(test.TestCase): self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3) @test_util.run_deprecated_v1 + @test_util.disable_xla("This test does not pass with XLA") def test_multi_paths_2(self): """Test graph with multiple paths.""" if test.is_gpu_available(cuda_only=True): @@ -495,6 +502,7 @@ class AutoMixedPrecisionTest(test.TestCase): self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3) @test_util.run_deprecated_v1 + @test_util.disable_xla("This test does not pass with XLA") def test_recurrent_lstm(self): """Test graph with recurrent lstm.""" if test.is_gpu_available(cuda_only=True): @@ -520,34 +528,42 @@ class AutoMixedPrecisionTest(test.TestCase): self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3) @test_util.run_deprecated_v1 + @test_util.disable_xla("This test does not pass with XLA") def test_propagation_through_simple_loop_1(self): self._run_simple_loop_test('W', 'C', 'C') @test_util.run_deprecated_v1 + @test_util.disable_xla("This test does not pass with XLA") def test_propagation_through_simple_loop_2(self): self._run_simple_loop_test('C', 'C', 'W') @test_util.run_deprecated_v1 + @test_util.disable_xla("This test does not pass with XLA") def test_propagation_through_simple_loop_3(self): self._run_simple_loop_test('W', 'G', 'W') @test_util.run_deprecated_v1 + @test_util.disable_xla("This test does not pass with XLA") def test_propagation_through_simple_loop_4(self): self._run_simple_loop_test('W', 'gbg', 'W') @test_util.run_deprecated_v1 + @test_util.disable_xla("This test does not pass with XLA") def test_propagation_through_simple_loop_5(self): self._run_simple_loop_test('b', 'gWC', 'c') @test_util.run_deprecated_v1 + @test_util.disable_xla("This test does not pass with XLA") def test_propagation_through_simple_loop_6(self): self._run_simple_loop_test('b', 'CWCG', 'C') @test_util.run_deprecated_v1 + @test_util.disable_xla("This test does not pass with XLA") def test_propagation_through_simple_loop_7(self): self._run_simple_loop_test('C', 'GWCG', 'C') @test_util.run_deprecated_v1 + @test_util.disable_xla("This test does not pass with XLA") def test_propagation_through_simple_loop_8(self): self._run_simple_loop_test('C', 'CgbgWC', 'g')