Enable TPU Strategy combination for ctl_correctness_test.

PiperOrigin-RevId: 292556586
Change-Id: Ie421a5ae2869197d91557988c95afe587c1249de
This commit is contained in:
Ken Franko 2020-01-31 09:12:02 -08:00 committed by TensorFlower Gardener
parent 9a565656e2
commit cbf2288d95
3 changed files with 8 additions and 22 deletions

View File

@ -219,7 +219,7 @@ class TestDistributionStrategyDnnCorrectness(test.TestCase,
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.strategies_minus_tpu,
distribution=strategy_combinations.all_strategies,
optimizer_fn=strategy_combinations.optimizers_v1_and_v2,
mode=['eager'],
iteration_type=['iterator', 'dataset'],
@ -230,21 +230,6 @@ class TestDistributionStrategyDnnCorrectness(test.TestCase,
self.dnn_correctness(distribution, optimizer_fn, iteration_type,
inside_func)
# TODO(b/133325470): Enable this test for all optimizers once we understand
# the root cause of flakiness.
@combinations.generate(
combinations.combine(
distribution=[strategy_combinations.tpu_strategy_one_step],
optimizer_fn=[strategy_combinations.adagrad_optimizer_keras_v2_fn],
mode=['eager'],
iteration_type=['iterator', 'dataset'],
inside_func=[False, True]
))
def test_dnn_correctness_tpus(self, distribution, optimizer_fn,
iteration_type, inside_func):
self.dnn_correctness(distribution, optimizer_fn, iteration_type,
inside_func)
def dnn_correctness(self, distribution, optimizer_fn, iteration_type,
inside_func):
model = get_model()

View File

@ -396,17 +396,17 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
# predict = [4, 14]
# predict - y = [-2, -7]
# dloss/dw = 2 <[2, 7], [-2, -7]> = - 2(4 + 49) = -106
# So unreplicated the update to w with lr=0.2 is -0.2 * -106 = 21.2
# with sum loss reduction, or 10.6 with mean.
# So unreplicated the update to w with lr=0.001 is -0.2 * -106 = 0.106
# with sum loss reduction, or 0.053 with mean.
if loss_reduction == losses_impl.Reduction.SUM:
# Note that the "distribution.num_replicas_in_sync" factor will go away
# once we split the input across replicas, instead of pulling a complete
# batch of input per replica.
self.assertNear(weight, 2 + 21.2 * distribution.num_replicas_in_sync,
self.assertNear(weight, 2 + 0.106 * distribution.num_replicas_in_sync,
0.0001)
else:
# One of the mean loss reductions.
self.assertNear(weight, 2 + 10.6, 0.0001)
self.assertNear(weight, 2 + 0.053, 0.0001)
@combinations.generate(
combinations.times(

View File

@ -150,7 +150,8 @@ central_storage_strategy_with_gpu_and_cpu = combinations.NamedDistribution(
required_gpus=1)
gradient_descent_optimizer_v1_fn = combinations.NamedObject(
"GradientDescentV1", lambda: gradient_descent.GradientDescentOptimizer(0.2))
"GradientDescentV1",
lambda: gradient_descent.GradientDescentOptimizer(0.001))
adagrad_optimizer_v1_fn = combinations.NamedObject(
"AdagradV1", lambda: adagrad.AdagradOptimizer(0.001))
adam_optimizer_v1_fn = combinations.NamedObject(
@ -179,7 +180,7 @@ nadam_optimizer_keras_v2_fn = combinations.NamedObject(
ftrl_optimizer_keras_v2_fn = combinations.NamedObject(
"FtrlKerasV2", lambda: ftrl_keras_v2.Ftrl(0.001))
gradient_descent_optimizer_keras_v2_fn = combinations.NamedObject(
"GradientDescentKerasV2", lambda: gradient_descent_keras_v2.SGD(0.2))
"GradientDescentKerasV2", lambda: gradient_descent_keras_v2.SGD(0.001))
rmsprop_optimizer_keras_v2_fn = combinations.NamedObject(
"RmsPropKerasV2", lambda: rmsprop_keras_v2.RMSprop(0.001))