From ed39014cf6c7e0fcd7a08ce445a52ec27949c251 Mon Sep 17 00:00:00 2001 From: Ken Franko Date: Thu, 21 May 2020 15:56:56 -0700 Subject: [PATCH] Don't disable all TPU tests just the ones that fail. * Skips a test that segfaults sometimes when run on TPUs. * Skips a test on TPU that fails with a different error message. PiperOrigin-RevId: 312757787 Change-Id: I662c28c55a9f3f907c7f6a8f217506bb17c3a8c7 --- tensorflow/python/keras/distribute/BUILD | 1 - .../python/keras/distribute/distribute_strategy_test.py | 8 ++++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD index 50ed6086195..b7fe3b5bda6 100644 --- a/tensorflow/python/keras/distribute/BUILD +++ b/tensorflow/python/keras/distribute/BUILD @@ -128,7 +128,6 @@ distribute_py_test( "multi_and_single_gpu", "no_rocm", # times out on ROCm "no_windows_gpu", - "notpu", # TODO(b/155867206) flaky segfault "notsan", ], tpu_tags = [ diff --git a/tensorflow/python/keras/distribute/distribute_strategy_test.py b/tensorflow/python/keras/distribute/distribute_strategy_test.py index f6a83c499fe..eac1e2feb8b 100644 --- a/tensorflow/python/keras/distribute/distribute_strategy_test.py +++ b/tensorflow/python/keras/distribute/distribute_strategy_test.py @@ -575,8 +575,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, @combinations.generate( combinations.combine( - distribution=[strategy_combinations.one_device_strategy] + - tpu_strategies, + distribution=[strategy_combinations.one_device_strategy], mode=['graph', 'eager'])) def test_optimizer_in_cross_replica_context_raises_error(self, distribution): @@ -1070,6 +1069,11 @@ class TestDistributionStrategyWithDatasets(test.TestCase, @combinations.generate(all_strategy_combinations()) def test_on_dataset_with_unknown_cardinality_without_steps( self, distribution, mode): + # TODO(b/155867206): Investigate why this test occasionally segfaults on TPU + # in eager mode. + if mode == 'eager' and isinstance( + distribution, (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)): + self.skipTest('caused segfault with TPU in eager mode.') if mode == 'graph' and isinstance( distribution, (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)):