Enable TPU Strategy combination for ctl_correctness_test.
PiperOrigin-RevId: 292556586 Change-Id: Ie421a5ae2869197d91557988c95afe587c1249de
This commit is contained in:
parent
9a565656e2
commit
cbf2288d95
@ -219,7 +219,7 @@ class TestDistributionStrategyDnnCorrectness(test.TestCase,
|
|||||||
|
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
combinations.combine(
|
combinations.combine(
|
||||||
distribution=strategy_combinations.strategies_minus_tpu,
|
distribution=strategy_combinations.all_strategies,
|
||||||
optimizer_fn=strategy_combinations.optimizers_v1_and_v2,
|
optimizer_fn=strategy_combinations.optimizers_v1_and_v2,
|
||||||
mode=['eager'],
|
mode=['eager'],
|
||||||
iteration_type=['iterator', 'dataset'],
|
iteration_type=['iterator', 'dataset'],
|
||||||
@ -230,21 +230,6 @@ class TestDistributionStrategyDnnCorrectness(test.TestCase,
|
|||||||
self.dnn_correctness(distribution, optimizer_fn, iteration_type,
|
self.dnn_correctness(distribution, optimizer_fn, iteration_type,
|
||||||
inside_func)
|
inside_func)
|
||||||
|
|
||||||
# TODO(b/133325470): Enable this test for all optimizers once we understand
|
|
||||||
# the root cause of flakiness.
|
|
||||||
@combinations.generate(
|
|
||||||
combinations.combine(
|
|
||||||
distribution=[strategy_combinations.tpu_strategy_one_step],
|
|
||||||
optimizer_fn=[strategy_combinations.adagrad_optimizer_keras_v2_fn],
|
|
||||||
mode=['eager'],
|
|
||||||
iteration_type=['iterator', 'dataset'],
|
|
||||||
inside_func=[False, True]
|
|
||||||
))
|
|
||||||
def test_dnn_correctness_tpus(self, distribution, optimizer_fn,
|
|
||||||
iteration_type, inside_func):
|
|
||||||
self.dnn_correctness(distribution, optimizer_fn, iteration_type,
|
|
||||||
inside_func)
|
|
||||||
|
|
||||||
def dnn_correctness(self, distribution, optimizer_fn, iteration_type,
|
def dnn_correctness(self, distribution, optimizer_fn, iteration_type,
|
||||||
inside_func):
|
inside_func):
|
||||||
model = get_model()
|
model = get_model()
|
||||||
|
@ -396,17 +396,17 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
|||||||
# predict = [4, 14]
|
# predict = [4, 14]
|
||||||
# predict - y = [-2, -7]
|
# predict - y = [-2, -7]
|
||||||
# dloss/dw = 2 <[2, 7], [-2, -7]> = - 2(4 + 49) = -106
|
# dloss/dw = 2 <[2, 7], [-2, -7]> = - 2(4 + 49) = -106
|
||||||
# So unreplicated the update to w with lr=0.2 is -0.2 * -106 = 21.2
|
# So unreplicated the update to w with lr=0.001 is -0.2 * -106 = 0.106
|
||||||
# with sum loss reduction, or 10.6 with mean.
|
# with sum loss reduction, or 0.053 with mean.
|
||||||
if loss_reduction == losses_impl.Reduction.SUM:
|
if loss_reduction == losses_impl.Reduction.SUM:
|
||||||
# Note that the "distribution.num_replicas_in_sync" factor will go away
|
# Note that the "distribution.num_replicas_in_sync" factor will go away
|
||||||
# once we split the input across replicas, instead of pulling a complete
|
# once we split the input across replicas, instead of pulling a complete
|
||||||
# batch of input per replica.
|
# batch of input per replica.
|
||||||
self.assertNear(weight, 2 + 21.2 * distribution.num_replicas_in_sync,
|
self.assertNear(weight, 2 + 0.106 * distribution.num_replicas_in_sync,
|
||||||
0.0001)
|
0.0001)
|
||||||
else:
|
else:
|
||||||
# One of the mean loss reductions.
|
# One of the mean loss reductions.
|
||||||
self.assertNear(weight, 2 + 10.6, 0.0001)
|
self.assertNear(weight, 2 + 0.053, 0.0001)
|
||||||
|
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
combinations.times(
|
combinations.times(
|
||||||
|
@ -150,7 +150,8 @@ central_storage_strategy_with_gpu_and_cpu = combinations.NamedDistribution(
|
|||||||
required_gpus=1)
|
required_gpus=1)
|
||||||
|
|
||||||
gradient_descent_optimizer_v1_fn = combinations.NamedObject(
|
gradient_descent_optimizer_v1_fn = combinations.NamedObject(
|
||||||
"GradientDescentV1", lambda: gradient_descent.GradientDescentOptimizer(0.2))
|
"GradientDescentV1",
|
||||||
|
lambda: gradient_descent.GradientDescentOptimizer(0.001))
|
||||||
adagrad_optimizer_v1_fn = combinations.NamedObject(
|
adagrad_optimizer_v1_fn = combinations.NamedObject(
|
||||||
"AdagradV1", lambda: adagrad.AdagradOptimizer(0.001))
|
"AdagradV1", lambda: adagrad.AdagradOptimizer(0.001))
|
||||||
adam_optimizer_v1_fn = combinations.NamedObject(
|
adam_optimizer_v1_fn = combinations.NamedObject(
|
||||||
@ -179,7 +180,7 @@ nadam_optimizer_keras_v2_fn = combinations.NamedObject(
|
|||||||
ftrl_optimizer_keras_v2_fn = combinations.NamedObject(
|
ftrl_optimizer_keras_v2_fn = combinations.NamedObject(
|
||||||
"FtrlKerasV2", lambda: ftrl_keras_v2.Ftrl(0.001))
|
"FtrlKerasV2", lambda: ftrl_keras_v2.Ftrl(0.001))
|
||||||
gradient_descent_optimizer_keras_v2_fn = combinations.NamedObject(
|
gradient_descent_optimizer_keras_v2_fn = combinations.NamedObject(
|
||||||
"GradientDescentKerasV2", lambda: gradient_descent_keras_v2.SGD(0.2))
|
"GradientDescentKerasV2", lambda: gradient_descent_keras_v2.SGD(0.001))
|
||||||
rmsprop_optimizer_keras_v2_fn = combinations.NamedObject(
|
rmsprop_optimizer_keras_v2_fn = combinations.NamedObject(
|
||||||
"RmsPropKerasV2", lambda: rmsprop_keras_v2.RMSprop(0.001))
|
"RmsPropKerasV2", lambda: rmsprop_keras_v2.RMSprop(0.001))
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user