Enable TPU Strategy combination for ctl_correctness_test.
PiperOrigin-RevId: 292556586 Change-Id: Ie421a5ae2869197d91557988c95afe587c1249de
This commit is contained in:
parent
9a565656e2
commit
cbf2288d95
@ -219,7 +219,7 @@ class TestDistributionStrategyDnnCorrectness(test.TestCase,
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.strategies_minus_tpu,
|
||||
distribution=strategy_combinations.all_strategies,
|
||||
optimizer_fn=strategy_combinations.optimizers_v1_and_v2,
|
||||
mode=['eager'],
|
||||
iteration_type=['iterator', 'dataset'],
|
||||
@ -230,21 +230,6 @@ class TestDistributionStrategyDnnCorrectness(test.TestCase,
|
||||
self.dnn_correctness(distribution, optimizer_fn, iteration_type,
|
||||
inside_func)
|
||||
|
||||
# TODO(b/133325470): Enable this test for all optimizers once we understand
|
||||
# the root cause of flakiness.
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=[strategy_combinations.tpu_strategy_one_step],
|
||||
optimizer_fn=[strategy_combinations.adagrad_optimizer_keras_v2_fn],
|
||||
mode=['eager'],
|
||||
iteration_type=['iterator', 'dataset'],
|
||||
inside_func=[False, True]
|
||||
))
|
||||
def test_dnn_correctness_tpus(self, distribution, optimizer_fn,
|
||||
iteration_type, inside_func):
|
||||
self.dnn_correctness(distribution, optimizer_fn, iteration_type,
|
||||
inside_func)
|
||||
|
||||
def dnn_correctness(self, distribution, optimizer_fn, iteration_type,
|
||||
inside_func):
|
||||
model = get_model()
|
||||
|
@ -396,17 +396,17 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
# predict = [4, 14]
|
||||
# predict - y = [-2, -7]
|
||||
# dloss/dw = 2 <[2, 7], [-2, -7]> = - 2(4 + 49) = -106
|
||||
# So unreplicated the update to w with lr=0.2 is -0.2 * -106 = 21.2
|
||||
# with sum loss reduction, or 10.6 with mean.
|
||||
# So unreplicated the update to w with lr=0.001 is -0.2 * -106 = 0.106
|
||||
# with sum loss reduction, or 0.053 with mean.
|
||||
if loss_reduction == losses_impl.Reduction.SUM:
|
||||
# Note that the "distribution.num_replicas_in_sync" factor will go away
|
||||
# once we split the input across replicas, instead of pulling a complete
|
||||
# batch of input per replica.
|
||||
self.assertNear(weight, 2 + 21.2 * distribution.num_replicas_in_sync,
|
||||
self.assertNear(weight, 2 + 0.106 * distribution.num_replicas_in_sync,
|
||||
0.0001)
|
||||
else:
|
||||
# One of the mean loss reductions.
|
||||
self.assertNear(weight, 2 + 10.6, 0.0001)
|
||||
self.assertNear(weight, 2 + 0.053, 0.0001)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
|
@ -150,7 +150,8 @@ central_storage_strategy_with_gpu_and_cpu = combinations.NamedDistribution(
|
||||
required_gpus=1)
|
||||
|
||||
gradient_descent_optimizer_v1_fn = combinations.NamedObject(
|
||||
"GradientDescentV1", lambda: gradient_descent.GradientDescentOptimizer(0.2))
|
||||
"GradientDescentV1",
|
||||
lambda: gradient_descent.GradientDescentOptimizer(0.001))
|
||||
adagrad_optimizer_v1_fn = combinations.NamedObject(
|
||||
"AdagradV1", lambda: adagrad.AdagradOptimizer(0.001))
|
||||
adam_optimizer_v1_fn = combinations.NamedObject(
|
||||
@ -179,7 +180,7 @@ nadam_optimizer_keras_v2_fn = combinations.NamedObject(
|
||||
ftrl_optimizer_keras_v2_fn = combinations.NamedObject(
|
||||
"FtrlKerasV2", lambda: ftrl_keras_v2.Ftrl(0.001))
|
||||
gradient_descent_optimizer_keras_v2_fn = combinations.NamedObject(
|
||||
"GradientDescentKerasV2", lambda: gradient_descent_keras_v2.SGD(0.2))
|
||||
"GradientDescentKerasV2", lambda: gradient_descent_keras_v2.SGD(0.001))
|
||||
rmsprop_optimizer_keras_v2_fn = combinations.NamedObject(
|
||||
"RmsPropKerasV2", lambda: rmsprop_keras_v2.RMSprop(0.001))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user