DistributedVariable update methods always pass keyword arguments to
_mirrored_update In this way it's easier to modify the arugments, which is needed to make the return type another DistributedVariable. PiperOrigin-RevId: 306493210 Change-Id: I330a19ebb3045a2fc5e47b79e49c3ab45d1e8887
This commit is contained in:
parent
60140d9369
commit
85ea23af35
@ -745,7 +745,7 @@ def create_mirrored_variable( # pylint: disable=missing-docstring
|
|||||||
class MirroredVariable(DistributedVariable, Mirrored):
|
class MirroredVariable(DistributedVariable, Mirrored):
|
||||||
"""Holds a map from replica to variables whose values are kept in sync."""
|
"""Holds a map from replica to variables whose values are kept in sync."""
|
||||||
|
|
||||||
def _mirrored_update(self, update_fn, *args, **kwargs):
|
def _mirrored_update(self, update_fn, value, **kwargs):
|
||||||
"""Apply identical updates using `update_fn` to variables on each replica."""
|
"""Apply identical updates using `update_fn` to variables on each replica."""
|
||||||
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||||
if ds_context.in_cross_replica_context():
|
if ds_context.in_cross_replica_context():
|
||||||
@ -760,12 +760,12 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
|||||||
# wrapped MirroredVariables through object members, captured arguments
|
# wrapped MirroredVariables through object members, captured arguments
|
||||||
# , etc. This is more likely in an update_non_slot() function
|
# , etc. This is more likely in an update_non_slot() function
|
||||||
# , which can update several non-slot variables in one call.
|
# , which can update several non-slot variables in one call.
|
||||||
return update_fn(self._values[update_replica_id], *args, **kwargs)
|
return update_fn(self._values[update_replica_id], value, **kwargs)
|
||||||
|
|
||||||
# We are calling update on the mirrored variable in cross replica
|
# We are calling update on the mirrored variable in cross replica
|
||||||
# context, use `strategy.extended.update()` to update the variable.
|
# context, use `strategy.extended.update()` to update the variable.
|
||||||
return self._distribute_strategy.extended.update(
|
return self._distribute_strategy.extended.update(
|
||||||
self, update_fn, args=args, kwargs=kwargs)
|
self, update_fn, args=(value,), kwargs=kwargs)
|
||||||
else:
|
else:
|
||||||
_assert_replica_context(self._distribute_strategy)
|
_assert_replica_context(self._distribute_strategy)
|
||||||
# We are calling an update function on the mirrored variable in replica
|
# We are calling an update function on the mirrored variable in replica
|
||||||
@ -778,7 +778,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
_aggregation_error_msg.format(variable_type="MirroredVariable"))
|
_aggregation_error_msg.format(variable_type="MirroredVariable"))
|
||||||
|
|
||||||
def merge_fn(strategy, value, *other_args, **other_kwargs):
|
def merge_fn(strategy, value, **other_kwargs):
|
||||||
"""Aggregate across replicas and update MV with aggregated value."""
|
"""Aggregate across replicas and update MV with aggregated value."""
|
||||||
# Don't allow MEAN with non float dtype, since it may cause unexpected
|
# Don't allow MEAN with non float dtype, since it may cause unexpected
|
||||||
# precision loss. Python3 and NumPy automatically upcast integers to
|
# precision loss. Python3 and NumPy automatically upcast integers to
|
||||||
@ -797,40 +797,71 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
|||||||
|
|
||||||
v = _apply_aggregation(strategy, value, self._aggregation, self)
|
v = _apply_aggregation(strategy, value, self._aggregation, self)
|
||||||
return strategy.extended.update(
|
return strategy.extended.update(
|
||||||
self, update_fn, args=(v,) + other_args, kwargs=other_kwargs)
|
self, update_fn, args=(v,), kwargs=other_kwargs)
|
||||||
|
|
||||||
return ds_context.get_replica_context().merge_call(
|
return ds_context.get_replica_context().merge_call(
|
||||||
merge_fn, args=args, kwargs=kwargs)
|
merge_fn, args=(value,), kwargs=kwargs)
|
||||||
|
|
||||||
def assign_sub(self, *args, **kwargs):
|
def assign_sub(self, value, use_locking=False, name=None, read_value=True):
|
||||||
assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
|
assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
|
||||||
return self._mirrored_update(assign_sub_fn, *args, **kwargs)
|
return self._mirrored_update(
|
||||||
|
update_fn=assign_sub_fn,
|
||||||
|
value=value,
|
||||||
|
use_locking=use_locking,
|
||||||
|
name=name,
|
||||||
|
read_value=read_value)
|
||||||
|
|
||||||
def assign_add(self, *args, **kwargs):
|
def assign_add(self, value, use_locking=False, name=None, read_value=True):
|
||||||
assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
|
assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
|
||||||
return self._mirrored_update(assign_add_fn, *args, **kwargs)
|
return self._mirrored_update(
|
||||||
|
update_fn=assign_add_fn,
|
||||||
|
value=value,
|
||||||
|
use_locking=use_locking,
|
||||||
|
name=name,
|
||||||
|
read_value=read_value)
|
||||||
|
|
||||||
def assign(self, *args, **kwargs):
|
def assign(self, value, use_locking=False, name=None, read_value=True):
|
||||||
assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
|
assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
|
||||||
return self._mirrored_update(assign_fn, *args, **kwargs)
|
return self._mirrored_update(
|
||||||
|
update_fn=assign_fn,
|
||||||
|
value=value,
|
||||||
|
use_locking=use_locking,
|
||||||
|
name=name,
|
||||||
|
read_value=read_value)
|
||||||
|
|
||||||
def scatter_sub(self, *args, **kwargs):
|
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
|
||||||
scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw)
|
scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw)
|
||||||
return self._mirrored_update(scatter_sub_fn, *args, **kwargs)
|
return self._mirrored_update(
|
||||||
|
update_fn=scatter_sub_fn,
|
||||||
|
value=sparse_delta,
|
||||||
|
use_locking=use_locking,
|
||||||
|
name=name)
|
||||||
|
|
||||||
def scatter_add(self, *args, **kwargs):
|
def scatter_add(self, sparse_delta, use_locking=False, name=None):
|
||||||
scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw)
|
scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw)
|
||||||
return self._mirrored_update(scatter_add_fn, *args, **kwargs)
|
return self._mirrored_update(
|
||||||
|
update_fn=scatter_add_fn,
|
||||||
|
value=sparse_delta,
|
||||||
|
use_locking=use_locking,
|
||||||
|
name=name)
|
||||||
|
|
||||||
def scatter_mul(self, *args, **kwargs):
|
def scatter_mul(self, sparse_delta, use_locking=False, name=None):
|
||||||
scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw)
|
scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw)
|
||||||
return self._mirrored_update(scatter_mul_fn, *args, **kwargs)
|
return self._mirrored_update(
|
||||||
|
update_fn=scatter_mul_fn,
|
||||||
|
value=sparse_delta,
|
||||||
|
use_locking=use_locking,
|
||||||
|
name=name)
|
||||||
|
|
||||||
def scatter_div(self, *args, **kwargs):
|
def scatter_div(self, sparse_delta, use_locking=False, name=None):
|
||||||
scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw)
|
scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw)
|
||||||
return self._mirrored_update(scatter_div_fn, *args, **kwargs)
|
return self._mirrored_update(
|
||||||
|
update_fn=scatter_div_fn,
|
||||||
|
value=sparse_delta,
|
||||||
|
use_locking=use_locking,
|
||||||
|
name=name)
|
||||||
|
|
||||||
def scatter_min(self, *args, **kwargs):
|
def scatter_min(self, sparse_delta, use_locking=False, name=None):
|
||||||
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
|
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
|
||||||
self._aggregation != vs.VariableAggregation.NONE):
|
self._aggregation != vs.VariableAggregation.NONE):
|
||||||
raise NotImplementedError("scatter_min is only supported for mirrored "
|
raise NotImplementedError("scatter_min is only supported for mirrored "
|
||||||
@ -839,9 +870,13 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
|||||||
"`ONLY_FIRST_REPLICA` aggregation, got: %s" %
|
"`ONLY_FIRST_REPLICA` aggregation, got: %s" %
|
||||||
self._aggregation)
|
self._aggregation)
|
||||||
scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw)
|
scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw)
|
||||||
return self._mirrored_update(scatter_min_fn, *args, **kwargs)
|
return self._mirrored_update(
|
||||||
|
update_fn=scatter_min_fn,
|
||||||
|
value=sparse_delta,
|
||||||
|
use_locking=use_locking,
|
||||||
|
name=name)
|
||||||
|
|
||||||
def scatter_max(self, *args, **kwargs):
|
def scatter_max(self, sparse_delta, use_locking=False, name=None):
|
||||||
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
|
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
|
||||||
self._aggregation != vs.VariableAggregation.NONE):
|
self._aggregation != vs.VariableAggregation.NONE):
|
||||||
raise NotImplementedError("scatter_max is only supported for mirrored "
|
raise NotImplementedError("scatter_max is only supported for mirrored "
|
||||||
@ -850,9 +885,13 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
|||||||
"`ONLY_FIRST_REPLICA` aggregation, got: %s" %
|
"`ONLY_FIRST_REPLICA` aggregation, got: %s" %
|
||||||
self._aggregation)
|
self._aggregation)
|
||||||
scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw)
|
scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw)
|
||||||
return self._mirrored_update(scatter_max_fn, *args, **kwargs)
|
return self._mirrored_update(
|
||||||
|
update_fn=scatter_max_fn,
|
||||||
|
value=sparse_delta,
|
||||||
|
use_locking=use_locking,
|
||||||
|
name=name)
|
||||||
|
|
||||||
def scatter_update(self, *args, **kwargs):
|
def scatter_update(self, sparse_delta, use_locking=False, name=None):
|
||||||
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
|
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
|
||||||
self._aggregation != vs.VariableAggregation.NONE):
|
self._aggregation != vs.VariableAggregation.NONE):
|
||||||
raise NotImplementedError("scatter_update is only supported for mirrored "
|
raise NotImplementedError("scatter_update is only supported for mirrored "
|
||||||
@ -861,7 +900,11 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
|||||||
"`ONLY_FIRST_REPLICA` aggregation, got: %s" %
|
"`ONLY_FIRST_REPLICA` aggregation, got: %s" %
|
||||||
self._aggregation)
|
self._aggregation)
|
||||||
scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw)
|
scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw)
|
||||||
return self._mirrored_update(scatter_update_fn, *args, **kwargs)
|
return self._mirrored_update(
|
||||||
|
update_fn=scatter_update_fn,
|
||||||
|
value=sparse_delta,
|
||||||
|
use_locking=use_locking,
|
||||||
|
name=name)
|
||||||
|
|
||||||
def _get_cross_replica(self):
|
def _get_cross_replica(self):
|
||||||
# Return identity, to avoid directly exposing the variable to the user and
|
# Return identity, to avoid directly exposing the variable to the user and
|
||||||
|
Loading…
Reference in New Issue
Block a user