Remove skip condition in ctl_correctness_test for MWMS + XLA. It is already disabled at the strategy combination level.
PiperOrigin-RevId: 343536440 Change-Id: If4ace40041481135b69c3242c18270c65044fcb3
This commit is contained in:
parent
9548966668
commit
6271251b7f
@ -33,7 +33,6 @@ from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.framework import test_combinations as combinations
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras.distribute import optimizer_combinations
|
||||
from tensorflow.python.keras.distribute import strategy_combinations
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -251,9 +250,6 @@ class TestDistributionStrategyDnnCorrectness(test.TestCase,
|
||||
# TODO(anjs): Identify why this particular V1 optimizer needs a higher tol.
|
||||
if 'FtrlV1' in optimizer_fn._name and 'TPU' in type(distribution).__name__:
|
||||
self.skipTest('Reduced tolerance of the order of 1e-1 required.')
|
||||
if ('CollectiveAllReduce' in type(distribution).__name__ and
|
||||
test_util.is_xla_enabled()):
|
||||
self.skipTest('XLA tests fail with MWMS.')
|
||||
self.dnn_correctness(distribution, optimizer_fn, iteration_type,
|
||||
inside_func, sync_batchnorm)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user