diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD index 50ed6086195..b7fe3b5bda6 100644 --- a/tensorflow/python/keras/distribute/BUILD +++ b/tensorflow/python/keras/distribute/BUILD @@ -128,7 +128,6 @@ distribute_py_test( "multi_and_single_gpu", "no_rocm", # times out on ROCm "no_windows_gpu", - "notpu", # TODO(b/155867206) flaky segfault "notsan", ], tpu_tags = [ diff --git a/tensorflow/python/keras/distribute/distribute_strategy_test.py b/tensorflow/python/keras/distribute/distribute_strategy_test.py index f6a83c499fe..eac1e2feb8b 100644 --- a/tensorflow/python/keras/distribute/distribute_strategy_test.py +++ b/tensorflow/python/keras/distribute/distribute_strategy_test.py @@ -575,8 +575,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, @combinations.generate( combinations.combine( - distribution=[strategy_combinations.one_device_strategy] + - tpu_strategies, + distribution=[strategy_combinations.one_device_strategy], mode=['graph', 'eager'])) def test_optimizer_in_cross_replica_context_raises_error(self, distribution): @@ -1070,6 +1069,11 @@ class TestDistributionStrategyWithDatasets(test.TestCase, @combinations.generate(all_strategy_combinations()) def test_on_dataset_with_unknown_cardinality_without_steps( self, distribution, mode): + # TODO(b/155867206): Investigate why this test occasionally segfaults on TPU + # in eager mode. + if mode == 'eager' and isinstance( + distribution, (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)): + self.skipTest('caused segfault with TPU in eager mode.') if mode == 'graph' and isinstance( distribution, (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)):