diff --git a/tensorflow/python/distribute/moving_averages_test.py b/tensorflow/python/distribute/moving_averages_test.py index 6066e3e234f..83c1be3e3f5 100644 --- a/tensorflow/python/distribute/moving_averages_test.py +++ b/tensorflow/python/distribute/moving_averages_test.py @@ -149,6 +149,24 @@ class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase): (2.0 * 0.25 + 0.0) / (1.0 * 0.25 + 1.0)], var.eval()) + @combinations.generate(all_combinations_eager) + def testUpdateContext(self, distribution, use_function): + with distribution.scope(): + var1 = variables.Variable([0.0, 0.0]) + var2 = variables.Variable([0.0, 0.0]) + var3 = variables.Variable([0.0, 0.0]) + + def update_fn(v, value): + v.assign_add(value) + moving_averages.assign_moving_average(var2, [2.0, 4.0], decay=0.25) + moving_averages.assign_moving_average( + var3, [2.0, 4.0], decay=0.25, zero_debias=False) + + distribution.extended.update(var1, update_fn, ([1.0, 1.0],)) + + self.assertAllClose([2.0, 4.0], var2.read_value()) + self.assertAllClose([1.5, 3.0], var3.read_value()) + @combinations.generate(all_combinations) def testAssignVariable(self, distribution): diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py index 612215328c6..b95e366aa38 100644 --- a/tensorflow/python/training/moving_averages.py +++ b/tensorflow/python/training/moving_averages.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import reduce_util as ds_reduce_util from tensorflow.python.framework import dtypes @@ -83,7 +84,6 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None): [Kingma et al., 2015](https://arxiv.org/abs/1412.6980) ([pdf](https://arxiv.org/pdf/1412.6980.pdf)) """ - with ops.name_scope(name, "AssignMovingAvg", [variable, value, decay]) as scope: decay = ops.convert_to_tensor(1.0 - decay, name="decay") @@ -97,7 +97,7 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None): if zero_debias: return _zero_debias(strategy, v, value, decay) else: - return strategy.extended.update(v, update_fn, args=(value,)) + return _update(strategy, v, update_fn, args=(value,)) replica_context = distribution_strategy_context.get_replica_context() if replica_context: @@ -178,6 +178,20 @@ def weighted_moving_average(value, return math_ops.divide(numerator, denominator, name=scope.name) +def _update(strategy, var, update_fn, args): + """Applies updates depending on the context.""" + assert distribution_strategy_context.in_cross_replica_context(), ( + "_update can only be called in cross-replica context") + if distribute_lib.get_update_replica_id() is not None: + # Call update_fn on var to delegate the implementation. We expect `var` will + # do the right thing in update context, e.g, if `var` is a MirroredVariable, + # it should pick its component variable based on `update_replica_id` and + # only update that. + return update_fn(var, *args) + else: + return strategy.extended.update(var, update_fn, args) + + def _zero_debias(strategy, unbiased_var, value, decay): """Compute the delta required for a debiased Variable. @@ -263,8 +277,8 @@ def _zero_debias(strategy, unbiased_var, value, decay): return state_ops.assign( v, update_biased / bias_factor, name=ops.get_name_scope() + "/") - return strategy.extended.update( - unbiased_var, update_fn, args=(value, biased_var, local_step)) + return _update( + strategy, unbiased_var, update_fn, args=(value, biased_var, local_step)) @tf_export("train.ExponentialMovingAverage")