Keras + Distribution Strategy: Use a slightly lower tolerance for correctness test for default distribution strategy as the weights have more variance whenever the training is run on GPUs.
PiperOrigin-RevId: 226117010
This commit is contained in:
parent
9a63f5b843
commit
627fd023a0
@ -1317,17 +1317,14 @@ class TestDistributionStrategyCorrectness(test.TestCase,
|
||||
|
||||
@combinations.generate(strategy_and_input_combinations())
|
||||
def test_correctness(self, distribution, use_numpy, use_validation_data):
|
||||
# TODO(b/121224478): This test is flaky with default strategy. Remove this
|
||||
# once the issue is fixed.
|
||||
if isinstance(distribution, distribute_lib._DefaultDistributionStrategy): # pylint: disable=protected-access
|
||||
self.skipTest('Disable the test for default strategy.')
|
||||
|
||||
with self.cached_session():
|
||||
default_tolerance = 1e-5
|
||||
tol_table = {}
|
||||
|
||||
if isinstance(distribution, (mirrored_strategy.MirroredStrategy,
|
||||
mirrored_strategy.CoreMirroredStrategy)):
|
||||
if isinstance(distribution, (
|
||||
mirrored_strategy.MirroredStrategy,
|
||||
mirrored_strategy.CoreMirroredStrategy,
|
||||
distribute_lib._DefaultDistributionStrategy)): # pylint: disable=protected-access
|
||||
# TODO(b/119257215): Weights are not exactly the same, so use larger
|
||||
# tolerance for now. Predict should be related to weights.
|
||||
tol_table = {
|
||||
|
Loading…
x
Reference in New Issue
Block a user