diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py index 15a08013d14..894b94031ac 100644 --- a/tensorflow/python/distribute/values.py +++ b/tensorflow/python/distribute/values.py @@ -745,7 +745,7 @@ def create_mirrored_variable( # pylint: disable=missing-docstring class MirroredVariable(DistributedVariable, Mirrored): """Holds a map from replica to variables whose values are kept in sync.""" - def _mirrored_update(self, update_fn, *args, **kwargs): + def _mirrored_update(self, update_fn, value, **kwargs): """Apply identical updates using `update_fn` to variables on each replica.""" with ds_context.enter_or_assert_strategy(self._distribute_strategy): if ds_context.in_cross_replica_context(): @@ -760,12 +760,12 @@ class MirroredVariable(DistributedVariable, Mirrored): # wrapped MirroredVariables through object members, captured arguments # , etc. This is more likely in an update_non_slot() function # , which can update several non-slot variables in one call. - return update_fn(self._values[update_replica_id], *args, **kwargs) + return update_fn(self._values[update_replica_id], value, **kwargs) # We are calling update on the mirrored variable in cross replica # context, use `strategy.extended.update()` to update the variable. return self._distribute_strategy.extended.update( - self, update_fn, args=args, kwargs=kwargs) + self, update_fn, args=(value,), kwargs=kwargs) else: _assert_replica_context(self._distribute_strategy) # We are calling an update function on the mirrored variable in replica @@ -778,7 +778,7 @@ class MirroredVariable(DistributedVariable, Mirrored): raise ValueError( _aggregation_error_msg.format(variable_type="MirroredVariable")) - def merge_fn(strategy, value, *other_args, **other_kwargs): + def merge_fn(strategy, value, **other_kwargs): """Aggregate across replicas and update MV with aggregated value.""" # Don't allow MEAN with non float dtype, since it may cause unexpected # precision loss. Python3 and NumPy automatically upcast integers to @@ -797,40 +797,71 @@ class MirroredVariable(DistributedVariable, Mirrored): v = _apply_aggregation(strategy, value, self._aggregation, self) return strategy.extended.update( - self, update_fn, args=(v,) + other_args, kwargs=other_kwargs) + self, update_fn, args=(v,), kwargs=other_kwargs) return ds_context.get_replica_context().merge_call( - merge_fn, args=args, kwargs=kwargs) + merge_fn, args=(value,), kwargs=kwargs) - def assign_sub(self, *args, **kwargs): + def assign_sub(self, value, use_locking=False, name=None, read_value=True): assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw) - return self._mirrored_update(assign_sub_fn, *args, **kwargs) + return self._mirrored_update( + update_fn=assign_sub_fn, + value=value, + use_locking=use_locking, + name=name, + read_value=read_value) - def assign_add(self, *args, **kwargs): + def assign_add(self, value, use_locking=False, name=None, read_value=True): assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw) - return self._mirrored_update(assign_add_fn, *args, **kwargs) + return self._mirrored_update( + update_fn=assign_add_fn, + value=value, + use_locking=use_locking, + name=name, + read_value=read_value) - def assign(self, *args, **kwargs): + def assign(self, value, use_locking=False, name=None, read_value=True): assign_fn = lambda var, *a, **kw: var.assign(*a, **kw) - return self._mirrored_update(assign_fn, *args, **kwargs) + return self._mirrored_update( + update_fn=assign_fn, + value=value, + use_locking=use_locking, + name=name, + read_value=read_value) - def scatter_sub(self, *args, **kwargs): + def scatter_sub(self, sparse_delta, use_locking=False, name=None): scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw) - return self._mirrored_update(scatter_sub_fn, *args, **kwargs) + return self._mirrored_update( + update_fn=scatter_sub_fn, + value=sparse_delta, + use_locking=use_locking, + name=name) - def scatter_add(self, *args, **kwargs): + def scatter_add(self, sparse_delta, use_locking=False, name=None): scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw) - return self._mirrored_update(scatter_add_fn, *args, **kwargs) + return self._mirrored_update( + update_fn=scatter_add_fn, + value=sparse_delta, + use_locking=use_locking, + name=name) - def scatter_mul(self, *args, **kwargs): + def scatter_mul(self, sparse_delta, use_locking=False, name=None): scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw) - return self._mirrored_update(scatter_mul_fn, *args, **kwargs) + return self._mirrored_update( + update_fn=scatter_mul_fn, + value=sparse_delta, + use_locking=use_locking, + name=name) - def scatter_div(self, *args, **kwargs): + def scatter_div(self, sparse_delta, use_locking=False, name=None): scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw) - return self._mirrored_update(scatter_div_fn, *args, **kwargs) + return self._mirrored_update( + update_fn=scatter_div_fn, + value=sparse_delta, + use_locking=use_locking, + name=name) - def scatter_min(self, *args, **kwargs): + def scatter_min(self, sparse_delta, use_locking=False, name=None): if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and self._aggregation != vs.VariableAggregation.NONE): raise NotImplementedError("scatter_min is only supported for mirrored " @@ -839,9 +870,13 @@ class MirroredVariable(DistributedVariable, Mirrored): "`ONLY_FIRST_REPLICA` aggregation, got: %s" % self._aggregation) scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw) - return self._mirrored_update(scatter_min_fn, *args, **kwargs) + return self._mirrored_update( + update_fn=scatter_min_fn, + value=sparse_delta, + use_locking=use_locking, + name=name) - def scatter_max(self, *args, **kwargs): + def scatter_max(self, sparse_delta, use_locking=False, name=None): if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and self._aggregation != vs.VariableAggregation.NONE): raise NotImplementedError("scatter_max is only supported for mirrored " @@ -850,9 +885,13 @@ class MirroredVariable(DistributedVariable, Mirrored): "`ONLY_FIRST_REPLICA` aggregation, got: %s" % self._aggregation) scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw) - return self._mirrored_update(scatter_max_fn, *args, **kwargs) + return self._mirrored_update( + update_fn=scatter_max_fn, + value=sparse_delta, + use_locking=use_locking, + name=name) - def scatter_update(self, *args, **kwargs): + def scatter_update(self, sparse_delta, use_locking=False, name=None): if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and self._aggregation != vs.VariableAggregation.NONE): raise NotImplementedError("scatter_update is only supported for mirrored " @@ -861,7 +900,11 @@ class MirroredVariable(DistributedVariable, Mirrored): "`ONLY_FIRST_REPLICA` aggregation, got: %s" % self._aggregation) scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw) - return self._mirrored_update(scatter_update_fn, *args, **kwargs) + return self._mirrored_update( + update_fn=scatter_update_fn, + value=sparse_delta, + use_locking=use_locking, + name=name) def _get_cross_replica(self): # Return identity, to avoid directly exposing the variable to the user and