In `assign_moving_average`, call `update_fn` instead of `strategy.extended.update(var, update_fn)` when in update context.

This fixes an issue where variables may be updated #num_replica times when `assign_moving_average` is called in update context. See the unit test for an example.

PiperOrigin-RevId: 315804133
Change-Id: If4b31ef15fcf5d1bb93ec53fa6ff6c234a13626a
This commit is contained in:
Chenkai Kuang 2020-06-10 17:46:39 -07:00 committed by TensorFlower Gardener
parent a745f0a953
commit be9e308f1f
2 changed files with 36 additions and 4 deletions

View File

@ -149,6 +149,24 @@ class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase):
(2.0 * 0.25 + 0.0) / (1.0 * 0.25 + 1.0)], (2.0 * 0.25 + 0.0) / (1.0 * 0.25 + 1.0)],
var.eval()) var.eval())
@combinations.generate(all_combinations_eager)
def testUpdateContext(self, distribution, use_function):
with distribution.scope():
var1 = variables.Variable([0.0, 0.0])
var2 = variables.Variable([0.0, 0.0])
var3 = variables.Variable([0.0, 0.0])
def update_fn(v, value):
v.assign_add(value)
moving_averages.assign_moving_average(var2, [2.0, 4.0], decay=0.25)
moving_averages.assign_moving_average(
var3, [2.0, 4.0], decay=0.25, zero_debias=False)
distribution.extended.update(var1, update_fn, ([1.0, 1.0],))
self.assertAllClose([2.0, 4.0], var2.read_value())
self.assertAllClose([1.5, 3.0], var3.read_value())
@combinations.generate(all_combinations) @combinations.generate(all_combinations)
def testAssignVariable(self, distribution): def testAssignVariable(self, distribution):

View File

@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import reduce_util as ds_reduce_util from tensorflow.python.distribute import reduce_util as ds_reduce_util
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -83,7 +84,6 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
[Kingma et al., 2015](https://arxiv.org/abs/1412.6980) [Kingma et al., 2015](https://arxiv.org/abs/1412.6980)
([pdf](https://arxiv.org/pdf/1412.6980.pdf)) ([pdf](https://arxiv.org/pdf/1412.6980.pdf))
""" """
with ops.name_scope(name, "AssignMovingAvg", with ops.name_scope(name, "AssignMovingAvg",
[variable, value, decay]) as scope: [variable, value, decay]) as scope:
decay = ops.convert_to_tensor(1.0 - decay, name="decay") decay = ops.convert_to_tensor(1.0 - decay, name="decay")
@ -97,7 +97,7 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
if zero_debias: if zero_debias:
return _zero_debias(strategy, v, value, decay) return _zero_debias(strategy, v, value, decay)
else: else:
return strategy.extended.update(v, update_fn, args=(value,)) return _update(strategy, v, update_fn, args=(value,))
replica_context = distribution_strategy_context.get_replica_context() replica_context = distribution_strategy_context.get_replica_context()
if replica_context: if replica_context:
@ -178,6 +178,20 @@ def weighted_moving_average(value,
return math_ops.divide(numerator, denominator, name=scope.name) return math_ops.divide(numerator, denominator, name=scope.name)
def _update(strategy, var, update_fn, args):
"""Applies updates depending on the context."""
assert distribution_strategy_context.in_cross_replica_context(), (
"_update can only be called in cross-replica context")
if distribute_lib.get_update_replica_id() is not None:
# Call update_fn on var to delegate the implementation. We expect `var` will
# do the right thing in update context, e.g, if `var` is a MirroredVariable,
# it should pick its component variable based on `update_replica_id` and
# only update that.
return update_fn(var, *args)
else:
return strategy.extended.update(var, update_fn, args)
def _zero_debias(strategy, unbiased_var, value, decay): def _zero_debias(strategy, unbiased_var, value, decay):
"""Compute the delta required for a debiased Variable. """Compute the delta required for a debiased Variable.
@ -263,8 +277,8 @@ def _zero_debias(strategy, unbiased_var, value, decay):
return state_ops.assign( return state_ops.assign(
v, update_biased / bias_factor, name=ops.get_name_scope() + "/") v, update_biased / bias_factor, name=ops.get_name_scope() + "/")
return strategy.extended.update( return _update(
unbiased_var, update_fn, args=(value, biased_var, local_step)) strategy, unbiased_var, update_fn, args=(value, biased_var, local_step))
@tf_export("train.ExponentialMovingAverage") @tf_export("train.ExponentialMovingAverage")