Small test fix: avoid creating optimizer objects inside the fn passed to strategy.run. This ends up creating separate optimizer objects for each replica in MirroredStrategy.

PiperOrigin-RevId: 344857919
Change-Id: Iabf8d42143da6683612c8d26f172493c9e24f72c
This commit is contained in:
Chenkai Kuang 2020-11-30 11:51:15 -08:00 committed by TensorFlower Gardener
parent 049dfcdfa4
commit ff705f007e

View File

@ -253,6 +253,10 @@ class KerasLayerTest(keras_parameterized.TestCase):
with strategy_fn().scope() as strategy:
with policy.policy_scope('mixed_float16'):
layer = mp_test_util.MultiplyLayer(assert_type=dtypes.float16)
# Learning rate is small enough that if applied to a float16 variable,
# the variable will not change. So this tests the learning rate is not
# applied to a float16 value, but instead the float32 variable.
opt = gradient_descent.SGD(2**-14)
def run_fn():
with backprop.GradientTape() as tape:
@ -261,10 +265,6 @@ class KerasLayerTest(keras_parameterized.TestCase):
# sum of each of the replica's losses.
y /= strategy.num_replicas_in_sync
# Learning rate is small enough that if applied to a float16 variable,
# the variable will not change. So this tests the learning rate is not
# applied to a float16 value, but instead the float32 variable.
opt = gradient_descent.SGD(2**-14)
grad = tape.gradient(y, layer.v)
return opt.apply_gradients([(grad, layer.v)])