Refactor DistributedVariable update logic

The change separates _update_cross_replica and _update_replica.
Subclasses(SyncOnReadVariable and MirroredVariable) can then override
accordingly. _update() serves as an entry point and dispatch to the
aforementioned two methods accordingly. Common logic can be added to _update().

This is part of the work to make assign() return a variable.

PiperOrigin-RevId: 307676699
Change-Id: If23372f784dd874d951a4fe5cbb1bfedb057d69d
This commit is contained in:
Ran Chen 2020-04-21 14:06:10 -07:00 committed by TensorFlower Gardener
parent 04bee75e68
commit 2a6eb169c8
2 changed files with 97 additions and 77 deletions

View File

@ -24,7 +24,6 @@ from __future__ import print_function
import contextlib import contextlib
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import values from tensorflow.python.distribute import values
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import tape from tensorflow.python.eager import tape
@ -186,20 +185,10 @@ def enclosing_tpu_context():
class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable): class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable):
"""Holds a map from replica to TPU variables whose values are kept in sync.""" """Holds a map from replica to TPU variables whose values are kept in sync."""
def _mirrored_update(self, update_fn, **kwargs):
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
if (ds_context.in_cross_replica_context() and
(enclosing_tpu_context() is not None)):
return self._distribute_strategy.extended.update(
self, update_fn, kwargs=kwargs)
else:
return values.MirroredVariable._mirrored_update(self, update_fn,
**kwargs)
def assign_sub(self, value, use_locking=False, name=None, read_value=True): def assign_sub(self, value, use_locking=False, name=None, read_value=True):
assign_sub_fn = _make_raw_assign_fn( assign_sub_fn = _make_raw_assign_fn(
gen_resource_variable_ops.assign_sub_variable_op) gen_resource_variable_ops.assign_sub_variable_op)
return self._mirrored_update( return self._update(
update_fn=assign_sub_fn, update_fn=assign_sub_fn,
value=value, value=value,
use_locking=use_locking, use_locking=use_locking,
@ -209,7 +198,7 @@ class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable):
def assign_add(self, value, use_locking=False, name=None, read_value=True): def assign_add(self, value, use_locking=False, name=None, read_value=True):
assign_add_fn = _make_raw_assign_fn( assign_add_fn = _make_raw_assign_fn(
gen_resource_variable_ops.assign_add_variable_op) gen_resource_variable_ops.assign_add_variable_op)
return self._mirrored_update( return self._update(
update_fn=assign_add_fn, update_fn=assign_add_fn,
value=value, value=value,
use_locking=use_locking, use_locking=use_locking,
@ -219,7 +208,7 @@ class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable):
def assign(self, value, use_locking=False, name=None, read_value=True): def assign(self, value, use_locking=False, name=None, read_value=True):
assign_fn = _make_raw_assign_fn( assign_fn = _make_raw_assign_fn(
gen_resource_variable_ops.assign_variable_op) gen_resource_variable_ops.assign_variable_op)
return self._mirrored_update( return self._update(
update_fn=assign_fn, update_fn=assign_fn,
value=value, value=value,
use_locking=use_locking, use_locking=use_locking,

View File

@ -591,6 +591,65 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
def value(self): def value(self):
return self._get_closest().value() return self._get_closest().value()
def _update_cross_replica(self, update_fn, value, **kwargs):
"""Applies updates across replicas.
Args:
update_fn: A callable to pass to `strategy.extended.update` to update the
variable. It should has the same signature as `Variable.assign()`.
value: value to be passed to `update_fn`.
**kwargs: remaining arguments to `update_fn`.
Returns:
Updated variable or `tf.Operation`.
"""
return self.distribute_strategy.extended.update(
self, update_fn, args=(value,), kwargs=kwargs, group=True)
def _update_replica(self, update_fn, value, **kwargs):
"""Applies updates in one replica.
Args:
update_fn: A callable to update the variable. It should has the same
signature as `Variable.assign()`.
value: value to be passed to `update_fn`.
**kwargs: remaining arguments to `update_fn`.
Returns:
Updated variable or `tf.Operation`.
"""
raise NotImplementedError("should be implemented by subclass.")
def _update(self, update_fn, value, **kwargs):
"""Applies updates depending on the context.
The method calls `_update_replica` in replica context,
`_update_cross_replica` in cross replica context, and `update_fn` in update
context.
If `read_value` is True, the method returns the updated Variable. If
`read_value` is False, the method returns the update `tf.Operation`.
Args:
update_fn: A callable to pass to `strategy.extended.update` to update the
variable. It should have the same signature as `Variable.assign()`.
value: value to be passed to `update_fn`.
**kwargs: keyword arguments to `update_fn`.
Returns:
Updated variable or `tf.Operation`.
"""
with ds_context.enter_or_assert_strategy(self.distribute_strategy):
if ds_context.in_cross_replica_context():
update_replica_id = distribute_lib.get_update_replica_id()
if update_replica_id is not None:
return update_fn(self._values[update_replica_id], value, **kwargs)
return self._update_cross_replica(update_fn, value, **kwargs)
else:
_assert_replica_context(self.distribute_strategy)
return self._update_replica(update_fn, value, **kwargs)
def _should_act_as_resource_variable(self): def _should_act_as_resource_variable(self):
"""Pass resource_variable_ops.is_resource_variable check.""" """Pass resource_variable_ops.is_resource_variable check."""
pass pass
@ -745,66 +804,38 @@ def create_mirrored_variable( # pylint: disable=missing-docstring
class MirroredVariable(DistributedVariable, Mirrored): class MirroredVariable(DistributedVariable, Mirrored):
"""Holds a map from replica to variables whose values are kept in sync.""" """Holds a map from replica to variables whose values are kept in sync."""
def _mirrored_update(self, update_fn, value, **kwargs): def _update_replica(self, update_fn, value, **kwargs):
"""Apply identical updates using `update_fn` to variables on each replica.""" if self.aggregation == vs.VariableAggregation.NONE:
with ds_context.enter_or_assert_strategy(self._distribute_strategy): raise ValueError(
if ds_context.in_cross_replica_context(): _aggregation_error_msg.format(variable_type="MirroredVariable"))
update_replica_id = distribute_lib.get_update_replica_id()
if update_replica_id is not None:
# We are calling an update function on the mirrored variable in an
# update context.
#
# The arguments to update() are automatically unwrapped so the
# update() function would normally see regular variables, not
# MirroredVariables. However, the update function can still operate on
# wrapped MirroredVariables through object members, captured arguments
# , etc. This is more likely in an update_non_slot() function
# , which can update several non-slot variables in one call.
return update_fn(self._values[update_replica_id], value, **kwargs)
# We are calling update on the mirrored variable in cross replica def merge_fn(strategy, value, **kwargs):
# context, use `strategy.extended.update()` to update the variable. """Aggregate values and update all variables in cross replica context."""
return self._distribute_strategy.extended.update( # Don't allow MEAN with non float dtype, since it may cause unexpected
self, update_fn, args=(value,), kwargs=kwargs) # precision loss. Python3 and NumPy automatically upcast integers to
else: # float in division, but we should always preserve the type.
_assert_replica_context(self._distribute_strategy) #
# We are calling an update function on the mirrored variable in replica # Note that to be backward compatible we allow the case when the value
# context. # is *always* the same on each replica. I.E. value is not a
# We reduce the value we want to update. More details about how # PerReplica. Refer to regroup() to see how values are grouped.
# we handle the different use cases can be found in the _reduce method. if self._aggregation == vs.VariableAggregation.MEAN and (
# We call the function on each of the mirrored variables with the not self.dtype.is_floating) and isinstance(value, PerReplica):
# reduced value. raise ValueError(
if self._aggregation == vs.VariableAggregation.NONE: "Cannot update non-float variables with "
raise ValueError( "tf.VariableAggregation.MEAN aggregation in replica context. "
_aggregation_error_msg.format(variable_type="MirroredVariable")) "Either change the variable dtype to float or update it in "
"cross-replica context.")
def merge_fn(strategy, value, **other_kwargs): assert strategy == self.distribute_strategy
"""Aggregate across replicas and update MV with aggregated value.""" v = _apply_aggregation(strategy, value, self.aggregation, self)
# Don't allow MEAN with non float dtype, since it may cause unexpected return self._update_cross_replica(update_fn, v, **kwargs)
# precision loss. Python3 and NumPy automatically upcast integers to
# float in division, but we should always preserve the type.
#
# Note that to be backward compatible we allow the case when the value
# is *always* the same on each replica. I.E. value is not a
# PerReplica. Refer to regroup() to see how values are grouped.
if self._aggregation == vs.VariableAggregation.MEAN and (
not self.dtype.is_floating) and isinstance(value, PerReplica):
raise ValueError(
"Cannot update non-float variables with "
"tf.VariableAggregation.MEAN aggregation in replica context. "
"Either change the variable dtype to float or update it in "
"cross-replica context.")
v = _apply_aggregation(strategy, value, self._aggregation, self) return ds_context.get_replica_context().merge_call(
return strategy.extended.update( merge_fn, args=(value,), kwargs=kwargs)
self, update_fn, args=(v,), kwargs=other_kwargs)
return ds_context.get_replica_context().merge_call(
merge_fn, args=(value,), kwargs=kwargs)
def assign_sub(self, value, use_locking=False, name=None, read_value=True): def assign_sub(self, value, use_locking=False, name=None, read_value=True):
assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw) assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
return self._mirrored_update( return self._update(
update_fn=assign_sub_fn, update_fn=assign_sub_fn,
value=value, value=value,
use_locking=use_locking, use_locking=use_locking,
@ -813,7 +844,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
def assign_add(self, value, use_locking=False, name=None, read_value=True): def assign_add(self, value, use_locking=False, name=None, read_value=True):
assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw) assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
return self._mirrored_update( return self._update(
update_fn=assign_add_fn, update_fn=assign_add_fn,
value=value, value=value,
use_locking=use_locking, use_locking=use_locking,
@ -822,7 +853,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
def assign(self, value, use_locking=False, name=None, read_value=True): def assign(self, value, use_locking=False, name=None, read_value=True):
assign_fn = lambda var, *a, **kw: var.assign(*a, **kw) assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
return self._mirrored_update( return self._update(
update_fn=assign_fn, update_fn=assign_fn,
value=value, value=value,
use_locking=use_locking, use_locking=use_locking,
@ -831,7 +862,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
def scatter_sub(self, sparse_delta, use_locking=False, name=None): def scatter_sub(self, sparse_delta, use_locking=False, name=None):
scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw) scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw)
return self._mirrored_update( return self._update(
update_fn=scatter_sub_fn, update_fn=scatter_sub_fn,
value=sparse_delta, value=sparse_delta,
use_locking=use_locking, use_locking=use_locking,
@ -839,7 +870,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
def scatter_add(self, sparse_delta, use_locking=False, name=None): def scatter_add(self, sparse_delta, use_locking=False, name=None):
scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw) scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw)
return self._mirrored_update( return self._update(
update_fn=scatter_add_fn, update_fn=scatter_add_fn,
value=sparse_delta, value=sparse_delta,
use_locking=use_locking, use_locking=use_locking,
@ -847,7 +878,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
def scatter_mul(self, sparse_delta, use_locking=False, name=None): def scatter_mul(self, sparse_delta, use_locking=False, name=None):
scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw) scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw)
return self._mirrored_update( return self._update(
update_fn=scatter_mul_fn, update_fn=scatter_mul_fn,
value=sparse_delta, value=sparse_delta,
use_locking=use_locking, use_locking=use_locking,
@ -855,7 +886,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
def scatter_div(self, sparse_delta, use_locking=False, name=None): def scatter_div(self, sparse_delta, use_locking=False, name=None):
scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw) scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw)
return self._mirrored_update( return self._update(
update_fn=scatter_div_fn, update_fn=scatter_div_fn,
value=sparse_delta, value=sparse_delta,
use_locking=use_locking, use_locking=use_locking,
@ -870,7 +901,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
"`ONLY_FIRST_REPLICA` aggregation, got: %s" % "`ONLY_FIRST_REPLICA` aggregation, got: %s" %
self._aggregation) self._aggregation)
scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw) scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw)
return self._mirrored_update( return self._update(
update_fn=scatter_min_fn, update_fn=scatter_min_fn,
value=sparse_delta, value=sparse_delta,
use_locking=use_locking, use_locking=use_locking,
@ -885,7 +916,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
"`ONLY_FIRST_REPLICA` aggregation, got: %s" % "`ONLY_FIRST_REPLICA` aggregation, got: %s" %
self._aggregation) self._aggregation)
scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw) scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw)
return self._mirrored_update( return self._update(
update_fn=scatter_max_fn, update_fn=scatter_max_fn,
value=sparse_delta, value=sparse_delta,
use_locking=use_locking, use_locking=use_locking,
@ -900,7 +931,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
"`ONLY_FIRST_REPLICA` aggregation, got: %s" % "`ONLY_FIRST_REPLICA` aggregation, got: %s" %
self._aggregation) self._aggregation)
scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw) scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw)
return self._mirrored_update( return self._update(
update_fn=scatter_update_fn, update_fn=scatter_update_fn,
value=sparse_delta, value=sparse_delta,
use_locking=use_locking, use_locking=use_locking,