DistributedVariable update methods always pass keyword arguments to
_mirrored_update In this way it's easier to modify the arugments, which is needed to make the return type another DistributedVariable. PiperOrigin-RevId: 306493210 Change-Id: I330a19ebb3045a2fc5e47b79e49c3ab45d1e8887
This commit is contained in:
parent
60140d9369
commit
85ea23af35
@ -745,7 +745,7 @@ def create_mirrored_variable( # pylint: disable=missing-docstring
|
||||
class MirroredVariable(DistributedVariable, Mirrored):
|
||||
"""Holds a map from replica to variables whose values are kept in sync."""
|
||||
|
||||
def _mirrored_update(self, update_fn, *args, **kwargs):
|
||||
def _mirrored_update(self, update_fn, value, **kwargs):
|
||||
"""Apply identical updates using `update_fn` to variables on each replica."""
|
||||
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||
if ds_context.in_cross_replica_context():
|
||||
@ -760,12 +760,12 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
||||
# wrapped MirroredVariables through object members, captured arguments
|
||||
# , etc. This is more likely in an update_non_slot() function
|
||||
# , which can update several non-slot variables in one call.
|
||||
return update_fn(self._values[update_replica_id], *args, **kwargs)
|
||||
return update_fn(self._values[update_replica_id], value, **kwargs)
|
||||
|
||||
# We are calling update on the mirrored variable in cross replica
|
||||
# context, use `strategy.extended.update()` to update the variable.
|
||||
return self._distribute_strategy.extended.update(
|
||||
self, update_fn, args=args, kwargs=kwargs)
|
||||
self, update_fn, args=(value,), kwargs=kwargs)
|
||||
else:
|
||||
_assert_replica_context(self._distribute_strategy)
|
||||
# We are calling an update function on the mirrored variable in replica
|
||||
@ -778,7 +778,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
||||
raise ValueError(
|
||||
_aggregation_error_msg.format(variable_type="MirroredVariable"))
|
||||
|
||||
def merge_fn(strategy, value, *other_args, **other_kwargs):
|
||||
def merge_fn(strategy, value, **other_kwargs):
|
||||
"""Aggregate across replicas and update MV with aggregated value."""
|
||||
# Don't allow MEAN with non float dtype, since it may cause unexpected
|
||||
# precision loss. Python3 and NumPy automatically upcast integers to
|
||||
@ -797,40 +797,71 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
||||
|
||||
v = _apply_aggregation(strategy, value, self._aggregation, self)
|
||||
return strategy.extended.update(
|
||||
self, update_fn, args=(v,) + other_args, kwargs=other_kwargs)
|
||||
self, update_fn, args=(v,), kwargs=other_kwargs)
|
||||
|
||||
return ds_context.get_replica_context().merge_call(
|
||||
merge_fn, args=args, kwargs=kwargs)
|
||||
merge_fn, args=(value,), kwargs=kwargs)
|
||||
|
||||
def assign_sub(self, *args, **kwargs):
|
||||
def assign_sub(self, value, use_locking=False, name=None, read_value=True):
|
||||
assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
|
||||
return self._mirrored_update(assign_sub_fn, *args, **kwargs)
|
||||
return self._mirrored_update(
|
||||
update_fn=assign_sub_fn,
|
||||
value=value,
|
||||
use_locking=use_locking,
|
||||
name=name,
|
||||
read_value=read_value)
|
||||
|
||||
def assign_add(self, *args, **kwargs):
|
||||
def assign_add(self, value, use_locking=False, name=None, read_value=True):
|
||||
assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
|
||||
return self._mirrored_update(assign_add_fn, *args, **kwargs)
|
||||
return self._mirrored_update(
|
||||
update_fn=assign_add_fn,
|
||||
value=value,
|
||||
use_locking=use_locking,
|
||||
name=name,
|
||||
read_value=read_value)
|
||||
|
||||
def assign(self, *args, **kwargs):
|
||||
def assign(self, value, use_locking=False, name=None, read_value=True):
|
||||
assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
|
||||
return self._mirrored_update(assign_fn, *args, **kwargs)
|
||||
return self._mirrored_update(
|
||||
update_fn=assign_fn,
|
||||
value=value,
|
||||
use_locking=use_locking,
|
||||
name=name,
|
||||
read_value=read_value)
|
||||
|
||||
def scatter_sub(self, *args, **kwargs):
|
||||
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
|
||||
scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw)
|
||||
return self._mirrored_update(scatter_sub_fn, *args, **kwargs)
|
||||
return self._mirrored_update(
|
||||
update_fn=scatter_sub_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def scatter_add(self, *args, **kwargs):
|
||||
def scatter_add(self, sparse_delta, use_locking=False, name=None):
|
||||
scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw)
|
||||
return self._mirrored_update(scatter_add_fn, *args, **kwargs)
|
||||
return self._mirrored_update(
|
||||
update_fn=scatter_add_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def scatter_mul(self, *args, **kwargs):
|
||||
def scatter_mul(self, sparse_delta, use_locking=False, name=None):
|
||||
scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw)
|
||||
return self._mirrored_update(scatter_mul_fn, *args, **kwargs)
|
||||
return self._mirrored_update(
|
||||
update_fn=scatter_mul_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def scatter_div(self, *args, **kwargs):
|
||||
def scatter_div(self, sparse_delta, use_locking=False, name=None):
|
||||
scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw)
|
||||
return self._mirrored_update(scatter_div_fn, *args, **kwargs)
|
||||
return self._mirrored_update(
|
||||
update_fn=scatter_div_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def scatter_min(self, *args, **kwargs):
|
||||
def scatter_min(self, sparse_delta, use_locking=False, name=None):
|
||||
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
|
||||
self._aggregation != vs.VariableAggregation.NONE):
|
||||
raise NotImplementedError("scatter_min is only supported for mirrored "
|
||||
@ -839,9 +870,13 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
||||
"`ONLY_FIRST_REPLICA` aggregation, got: %s" %
|
||||
self._aggregation)
|
||||
scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw)
|
||||
return self._mirrored_update(scatter_min_fn, *args, **kwargs)
|
||||
return self._mirrored_update(
|
||||
update_fn=scatter_min_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def scatter_max(self, *args, **kwargs):
|
||||
def scatter_max(self, sparse_delta, use_locking=False, name=None):
|
||||
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
|
||||
self._aggregation != vs.VariableAggregation.NONE):
|
||||
raise NotImplementedError("scatter_max is only supported for mirrored "
|
||||
@ -850,9 +885,13 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
||||
"`ONLY_FIRST_REPLICA` aggregation, got: %s" %
|
||||
self._aggregation)
|
||||
scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw)
|
||||
return self._mirrored_update(scatter_max_fn, *args, **kwargs)
|
||||
return self._mirrored_update(
|
||||
update_fn=scatter_max_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def scatter_update(self, *args, **kwargs):
|
||||
def scatter_update(self, sparse_delta, use_locking=False, name=None):
|
||||
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
|
||||
self._aggregation != vs.VariableAggregation.NONE):
|
||||
raise NotImplementedError("scatter_update is only supported for mirrored "
|
||||
@ -861,7 +900,11 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
||||
"`ONLY_FIRST_REPLICA` aggregation, got: %s" %
|
||||
self._aggregation)
|
||||
scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw)
|
||||
return self._mirrored_update(scatter_update_fn, *args, **kwargs)
|
||||
return self._mirrored_update(
|
||||
update_fn=scatter_update_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def _get_cross_replica(self):
|
||||
# Return identity, to avoid directly exposing the variable to the user and
|
||||
|
Loading…
Reference in New Issue
Block a user