Don't disable all TPU tests just the ones that fail.

* Skips a test that segfaults sometimes when run on TPUs.
* Skips a test on TPU that fails with a different error message.

PiperOrigin-RevId: 312757787
Change-Id: I662c28c55a9f3f907c7f6a8f217506bb17c3a8c7
This commit is contained in:
Ken Franko 2020-05-21 15:56:56 -07:00 committed by TensorFlower Gardener
parent 7221ad6eda
commit ed39014cf6
2 changed files with 6 additions and 3 deletions

View File

@ -128,7 +128,6 @@ distribute_py_test(
"multi_and_single_gpu",
"no_rocm", # times out on ROCm
"no_windows_gpu",
"notpu", # TODO(b/155867206) flaky segfault
"notsan",
],
tpu_tags = [

View File

@ -575,8 +575,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
@combinations.generate(
combinations.combine(
distribution=[strategy_combinations.one_device_strategy] +
tpu_strategies,
distribution=[strategy_combinations.one_device_strategy],
mode=['graph', 'eager']))
def test_optimizer_in_cross_replica_context_raises_error(self, distribution):
@ -1070,6 +1069,11 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
@combinations.generate(all_strategy_combinations())
def test_on_dataset_with_unknown_cardinality_without_steps(
self, distribution, mode):
# TODO(b/155867206): Investigate why this test occasionally segfaults on TPU
# in eager mode.
if mode == 'eager' and isinstance(
distribution, (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)):
self.skipTest('caused segfault with TPU in eager mode.')
if mode == 'graph' and isinstance(
distribution, (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)):