In `assign_moving_average`, call `update_fn` instead of `strategy.extended.update(var, update_fn)` when in update context.
This fixes an issue where variables may be updated #num_replica times when `assign_moving_average` is called in update context. See the unit test for an example. PiperOrigin-RevId: 315804133 Change-Id: If4b31ef15fcf5d1bb93ec53fa6ff6c234a13626a
This commit is contained in:
parent
a745f0a953
commit
be9e308f1f
|
@ -149,6 +149,24 @@ class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase):
|
||||||
(2.0 * 0.25 + 0.0) / (1.0 * 0.25 + 1.0)],
|
(2.0 * 0.25 + 0.0) / (1.0 * 0.25 + 1.0)],
|
||||||
var.eval())
|
var.eval())
|
||||||
|
|
||||||
|
@combinations.generate(all_combinations_eager)
|
||||||
|
def testUpdateContext(self, distribution, use_function):
|
||||||
|
with distribution.scope():
|
||||||
|
var1 = variables.Variable([0.0, 0.0])
|
||||||
|
var2 = variables.Variable([0.0, 0.0])
|
||||||
|
var3 = variables.Variable([0.0, 0.0])
|
||||||
|
|
||||||
|
def update_fn(v, value):
|
||||||
|
v.assign_add(value)
|
||||||
|
moving_averages.assign_moving_average(var2, [2.0, 4.0], decay=0.25)
|
||||||
|
moving_averages.assign_moving_average(
|
||||||
|
var3, [2.0, 4.0], decay=0.25, zero_debias=False)
|
||||||
|
|
||||||
|
distribution.extended.update(var1, update_fn, ([1.0, 1.0],))
|
||||||
|
|
||||||
|
self.assertAllClose([2.0, 4.0], var2.read_value())
|
||||||
|
self.assertAllClose([1.5, 3.0], var3.read_value())
|
||||||
|
|
||||||
@combinations.generate(all_combinations)
|
@combinations.generate(all_combinations)
|
||||||
def testAssignVariable(self, distribution):
|
def testAssignVariable(self, distribution):
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@ from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.distribute import distribute_lib
|
||||||
from tensorflow.python.distribute import distribution_strategy_context
|
from tensorflow.python.distribute import distribution_strategy_context
|
||||||
from tensorflow.python.distribute import reduce_util as ds_reduce_util
|
from tensorflow.python.distribute import reduce_util as ds_reduce_util
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
@ -83,7 +84,6 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
|
||||||
[Kingma et al., 2015](https://arxiv.org/abs/1412.6980)
|
[Kingma et al., 2015](https://arxiv.org/abs/1412.6980)
|
||||||
([pdf](https://arxiv.org/pdf/1412.6980.pdf))
|
([pdf](https://arxiv.org/pdf/1412.6980.pdf))
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with ops.name_scope(name, "AssignMovingAvg",
|
with ops.name_scope(name, "AssignMovingAvg",
|
||||||
[variable, value, decay]) as scope:
|
[variable, value, decay]) as scope:
|
||||||
decay = ops.convert_to_tensor(1.0 - decay, name="decay")
|
decay = ops.convert_to_tensor(1.0 - decay, name="decay")
|
||||||
|
@ -97,7 +97,7 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
|
||||||
if zero_debias:
|
if zero_debias:
|
||||||
return _zero_debias(strategy, v, value, decay)
|
return _zero_debias(strategy, v, value, decay)
|
||||||
else:
|
else:
|
||||||
return strategy.extended.update(v, update_fn, args=(value,))
|
return _update(strategy, v, update_fn, args=(value,))
|
||||||
|
|
||||||
replica_context = distribution_strategy_context.get_replica_context()
|
replica_context = distribution_strategy_context.get_replica_context()
|
||||||
if replica_context:
|
if replica_context:
|
||||||
|
@ -178,6 +178,20 @@ def weighted_moving_average(value,
|
||||||
return math_ops.divide(numerator, denominator, name=scope.name)
|
return math_ops.divide(numerator, denominator, name=scope.name)
|
||||||
|
|
||||||
|
|
||||||
|
def _update(strategy, var, update_fn, args):
|
||||||
|
"""Applies updates depending on the context."""
|
||||||
|
assert distribution_strategy_context.in_cross_replica_context(), (
|
||||||
|
"_update can only be called in cross-replica context")
|
||||||
|
if distribute_lib.get_update_replica_id() is not None:
|
||||||
|
# Call update_fn on var to delegate the implementation. We expect `var` will
|
||||||
|
# do the right thing in update context, e.g, if `var` is a MirroredVariable,
|
||||||
|
# it should pick its component variable based on `update_replica_id` and
|
||||||
|
# only update that.
|
||||||
|
return update_fn(var, *args)
|
||||||
|
else:
|
||||||
|
return strategy.extended.update(var, update_fn, args)
|
||||||
|
|
||||||
|
|
||||||
def _zero_debias(strategy, unbiased_var, value, decay):
|
def _zero_debias(strategy, unbiased_var, value, decay):
|
||||||
"""Compute the delta required for a debiased Variable.
|
"""Compute the delta required for a debiased Variable.
|
||||||
|
|
||||||
|
@ -263,8 +277,8 @@ def _zero_debias(strategy, unbiased_var, value, decay):
|
||||||
return state_ops.assign(
|
return state_ops.assign(
|
||||||
v, update_biased / bias_factor, name=ops.get_name_scope() + "/")
|
v, update_biased / bias_factor, name=ops.get_name_scope() + "/")
|
||||||
|
|
||||||
return strategy.extended.update(
|
return _update(
|
||||||
unbiased_var, update_fn, args=(value, biased_var, local_step))
|
strategy, unbiased_var, update_fn, args=(value, biased_var, local_step))
|
||||||
|
|
||||||
|
|
||||||
@tf_export("train.ExponentialMovingAverage")
|
@tf_export("train.ExponentialMovingAverage")
|
||||||
|
|
Loading…
Reference in New Issue