In `assign_moving_average`, call `update_fn` instead of `strategy.extended.update(var, update_fn)` when in update context.
This fixes an issue where variables may be updated #num_replica times when `assign_moving_average` is called in update context. See the unit test for an example. PiperOrigin-RevId: 315804133 Change-Id: If4b31ef15fcf5d1bb93ec53fa6ff6c234a13626a
This commit is contained in:
parent
a745f0a953
commit
be9e308f1f
|
@ -149,6 +149,24 @@ class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase):
|
|||
(2.0 * 0.25 + 0.0) / (1.0 * 0.25 + 1.0)],
|
||||
var.eval())
|
||||
|
||||
@combinations.generate(all_combinations_eager)
|
||||
def testUpdateContext(self, distribution, use_function):
|
||||
with distribution.scope():
|
||||
var1 = variables.Variable([0.0, 0.0])
|
||||
var2 = variables.Variable([0.0, 0.0])
|
||||
var3 = variables.Variable([0.0, 0.0])
|
||||
|
||||
def update_fn(v, value):
|
||||
v.assign_add(value)
|
||||
moving_averages.assign_moving_average(var2, [2.0, 4.0], decay=0.25)
|
||||
moving_averages.assign_moving_average(
|
||||
var3, [2.0, 4.0], decay=0.25, zero_debias=False)
|
||||
|
||||
distribution.extended.update(var1, update_fn, ([1.0, 1.0],))
|
||||
|
||||
self.assertAllClose([2.0, 4.0], var2.read_value())
|
||||
self.assertAllClose([1.5, 3.0], var3.read_value())
|
||||
|
||||
@combinations.generate(all_combinations)
|
||||
def testAssignVariable(self, distribution):
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.distribute import reduce_util as ds_reduce_util
|
||||
from tensorflow.python.framework import dtypes
|
||||
|
@ -83,7 +84,6 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
|
|||
[Kingma et al., 2015](https://arxiv.org/abs/1412.6980)
|
||||
([pdf](https://arxiv.org/pdf/1412.6980.pdf))
|
||||
"""
|
||||
|
||||
with ops.name_scope(name, "AssignMovingAvg",
|
||||
[variable, value, decay]) as scope:
|
||||
decay = ops.convert_to_tensor(1.0 - decay, name="decay")
|
||||
|
@ -97,7 +97,7 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
|
|||
if zero_debias:
|
||||
return _zero_debias(strategy, v, value, decay)
|
||||
else:
|
||||
return strategy.extended.update(v, update_fn, args=(value,))
|
||||
return _update(strategy, v, update_fn, args=(value,))
|
||||
|
||||
replica_context = distribution_strategy_context.get_replica_context()
|
||||
if replica_context:
|
||||
|
@ -178,6 +178,20 @@ def weighted_moving_average(value,
|
|||
return math_ops.divide(numerator, denominator, name=scope.name)
|
||||
|
||||
|
||||
def _update(strategy, var, update_fn, args):
|
||||
"""Applies updates depending on the context."""
|
||||
assert distribution_strategy_context.in_cross_replica_context(), (
|
||||
"_update can only be called in cross-replica context")
|
||||
if distribute_lib.get_update_replica_id() is not None:
|
||||
# Call update_fn on var to delegate the implementation. We expect `var` will
|
||||
# do the right thing in update context, e.g, if `var` is a MirroredVariable,
|
||||
# it should pick its component variable based on `update_replica_id` and
|
||||
# only update that.
|
||||
return update_fn(var, *args)
|
||||
else:
|
||||
return strategy.extended.update(var, update_fn, args)
|
||||
|
||||
|
||||
def _zero_debias(strategy, unbiased_var, value, decay):
|
||||
"""Compute the delta required for a debiased Variable.
|
||||
|
||||
|
@ -263,8 +277,8 @@ def _zero_debias(strategy, unbiased_var, value, decay):
|
|||
return state_ops.assign(
|
||||
v, update_biased / bias_factor, name=ops.get_name_scope() + "/")
|
||||
|
||||
return strategy.extended.update(
|
||||
unbiased_var, update_fn, args=(value, biased_var, local_step))
|
||||
return _update(
|
||||
strategy, unbiased_var, update_fn, args=(value, biased_var, local_step))
|
||||
|
||||
|
||||
@tf_export("train.ExponentialMovingAverage")
|
||||
|
|
Loading…
Reference in New Issue