Refactor DistributedVariable update logic
The change separates _update_cross_replica and _update_replica. Subclasses(SyncOnReadVariable and MirroredVariable) can then override accordingly. _update() serves as an entry point and dispatch to the aforementioned two methods accordingly. Common logic can be added to _update(). This is part of the work to make assign() return a variable. PiperOrigin-RevId: 307676699 Change-Id: If23372f784dd874d951a4fe5cbb1bfedb057d69d
This commit is contained in:
parent
04bee75e68
commit
2a6eb169c8
@ -24,7 +24,6 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
|
|
||||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
|
||||||
from tensorflow.python.distribute import values
|
from tensorflow.python.distribute import values
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import tape
|
from tensorflow.python.eager import tape
|
||||||
@ -186,20 +185,10 @@ def enclosing_tpu_context():
|
|||||||
class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable):
|
class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable):
|
||||||
"""Holds a map from replica to TPU variables whose values are kept in sync."""
|
"""Holds a map from replica to TPU variables whose values are kept in sync."""
|
||||||
|
|
||||||
def _mirrored_update(self, update_fn, **kwargs):
|
|
||||||
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
|
||||||
if (ds_context.in_cross_replica_context() and
|
|
||||||
(enclosing_tpu_context() is not None)):
|
|
||||||
return self._distribute_strategy.extended.update(
|
|
||||||
self, update_fn, kwargs=kwargs)
|
|
||||||
else:
|
|
||||||
return values.MirroredVariable._mirrored_update(self, update_fn,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
def assign_sub(self, value, use_locking=False, name=None, read_value=True):
|
def assign_sub(self, value, use_locking=False, name=None, read_value=True):
|
||||||
assign_sub_fn = _make_raw_assign_fn(
|
assign_sub_fn = _make_raw_assign_fn(
|
||||||
gen_resource_variable_ops.assign_sub_variable_op)
|
gen_resource_variable_ops.assign_sub_variable_op)
|
||||||
return self._mirrored_update(
|
return self._update(
|
||||||
update_fn=assign_sub_fn,
|
update_fn=assign_sub_fn,
|
||||||
value=value,
|
value=value,
|
||||||
use_locking=use_locking,
|
use_locking=use_locking,
|
||||||
@ -209,7 +198,7 @@ class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable):
|
|||||||
def assign_add(self, value, use_locking=False, name=None, read_value=True):
|
def assign_add(self, value, use_locking=False, name=None, read_value=True):
|
||||||
assign_add_fn = _make_raw_assign_fn(
|
assign_add_fn = _make_raw_assign_fn(
|
||||||
gen_resource_variable_ops.assign_add_variable_op)
|
gen_resource_variable_ops.assign_add_variable_op)
|
||||||
return self._mirrored_update(
|
return self._update(
|
||||||
update_fn=assign_add_fn,
|
update_fn=assign_add_fn,
|
||||||
value=value,
|
value=value,
|
||||||
use_locking=use_locking,
|
use_locking=use_locking,
|
||||||
@ -219,7 +208,7 @@ class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable):
|
|||||||
def assign(self, value, use_locking=False, name=None, read_value=True):
|
def assign(self, value, use_locking=False, name=None, read_value=True):
|
||||||
assign_fn = _make_raw_assign_fn(
|
assign_fn = _make_raw_assign_fn(
|
||||||
gen_resource_variable_ops.assign_variable_op)
|
gen_resource_variable_ops.assign_variable_op)
|
||||||
return self._mirrored_update(
|
return self._update(
|
||||||
update_fn=assign_fn,
|
update_fn=assign_fn,
|
||||||
value=value,
|
value=value,
|
||||||
use_locking=use_locking,
|
use_locking=use_locking,
|
||||||
|
@ -591,6 +591,65 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
|
|||||||
def value(self):
|
def value(self):
|
||||||
return self._get_closest().value()
|
return self._get_closest().value()
|
||||||
|
|
||||||
|
def _update_cross_replica(self, update_fn, value, **kwargs):
|
||||||
|
"""Applies updates across replicas.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
update_fn: A callable to pass to `strategy.extended.update` to update the
|
||||||
|
variable. It should has the same signature as `Variable.assign()`.
|
||||||
|
value: value to be passed to `update_fn`.
|
||||||
|
**kwargs: remaining arguments to `update_fn`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated variable or `tf.Operation`.
|
||||||
|
"""
|
||||||
|
return self.distribute_strategy.extended.update(
|
||||||
|
self, update_fn, args=(value,), kwargs=kwargs, group=True)
|
||||||
|
|
||||||
|
def _update_replica(self, update_fn, value, **kwargs):
|
||||||
|
"""Applies updates in one replica.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
update_fn: A callable to update the variable. It should has the same
|
||||||
|
signature as `Variable.assign()`.
|
||||||
|
value: value to be passed to `update_fn`.
|
||||||
|
**kwargs: remaining arguments to `update_fn`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated variable or `tf.Operation`.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("should be implemented by subclass.")
|
||||||
|
|
||||||
|
def _update(self, update_fn, value, **kwargs):
|
||||||
|
"""Applies updates depending on the context.
|
||||||
|
|
||||||
|
The method calls `_update_replica` in replica context,
|
||||||
|
`_update_cross_replica` in cross replica context, and `update_fn` in update
|
||||||
|
context.
|
||||||
|
|
||||||
|
If `read_value` is True, the method returns the updated Variable. If
|
||||||
|
`read_value` is False, the method returns the update `tf.Operation`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
update_fn: A callable to pass to `strategy.extended.update` to update the
|
||||||
|
variable. It should have the same signature as `Variable.assign()`.
|
||||||
|
value: value to be passed to `update_fn`.
|
||||||
|
**kwargs: keyword arguments to `update_fn`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated variable or `tf.Operation`.
|
||||||
|
|
||||||
|
"""
|
||||||
|
with ds_context.enter_or_assert_strategy(self.distribute_strategy):
|
||||||
|
if ds_context.in_cross_replica_context():
|
||||||
|
update_replica_id = distribute_lib.get_update_replica_id()
|
||||||
|
if update_replica_id is not None:
|
||||||
|
return update_fn(self._values[update_replica_id], value, **kwargs)
|
||||||
|
return self._update_cross_replica(update_fn, value, **kwargs)
|
||||||
|
else:
|
||||||
|
_assert_replica_context(self.distribute_strategy)
|
||||||
|
return self._update_replica(update_fn, value, **kwargs)
|
||||||
|
|
||||||
def _should_act_as_resource_variable(self):
|
def _should_act_as_resource_variable(self):
|
||||||
"""Pass resource_variable_ops.is_resource_variable check."""
|
"""Pass resource_variable_ops.is_resource_variable check."""
|
||||||
pass
|
pass
|
||||||
@ -745,66 +804,38 @@ def create_mirrored_variable( # pylint: disable=missing-docstring
|
|||||||
class MirroredVariable(DistributedVariable, Mirrored):
|
class MirroredVariable(DistributedVariable, Mirrored):
|
||||||
"""Holds a map from replica to variables whose values are kept in sync."""
|
"""Holds a map from replica to variables whose values are kept in sync."""
|
||||||
|
|
||||||
def _mirrored_update(self, update_fn, value, **kwargs):
|
def _update_replica(self, update_fn, value, **kwargs):
|
||||||
"""Apply identical updates using `update_fn` to variables on each replica."""
|
if self.aggregation == vs.VariableAggregation.NONE:
|
||||||
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
raise ValueError(
|
||||||
if ds_context.in_cross_replica_context():
|
_aggregation_error_msg.format(variable_type="MirroredVariable"))
|
||||||
update_replica_id = distribute_lib.get_update_replica_id()
|
|
||||||
if update_replica_id is not None:
|
|
||||||
# We are calling an update function on the mirrored variable in an
|
|
||||||
# update context.
|
|
||||||
#
|
|
||||||
# The arguments to update() are automatically unwrapped so the
|
|
||||||
# update() function would normally see regular variables, not
|
|
||||||
# MirroredVariables. However, the update function can still operate on
|
|
||||||
# wrapped MirroredVariables through object members, captured arguments
|
|
||||||
# , etc. This is more likely in an update_non_slot() function
|
|
||||||
# , which can update several non-slot variables in one call.
|
|
||||||
return update_fn(self._values[update_replica_id], value, **kwargs)
|
|
||||||
|
|
||||||
# We are calling update on the mirrored variable in cross replica
|
def merge_fn(strategy, value, **kwargs):
|
||||||
# context, use `strategy.extended.update()` to update the variable.
|
"""Aggregate values and update all variables in cross replica context."""
|
||||||
return self._distribute_strategy.extended.update(
|
# Don't allow MEAN with non float dtype, since it may cause unexpected
|
||||||
self, update_fn, args=(value,), kwargs=kwargs)
|
# precision loss. Python3 and NumPy automatically upcast integers to
|
||||||
else:
|
# float in division, but we should always preserve the type.
|
||||||
_assert_replica_context(self._distribute_strategy)
|
#
|
||||||
# We are calling an update function on the mirrored variable in replica
|
# Note that to be backward compatible we allow the case when the value
|
||||||
# context.
|
# is *always* the same on each replica. I.E. value is not a
|
||||||
# We reduce the value we want to update. More details about how
|
# PerReplica. Refer to regroup() to see how values are grouped.
|
||||||
# we handle the different use cases can be found in the _reduce method.
|
if self._aggregation == vs.VariableAggregation.MEAN and (
|
||||||
# We call the function on each of the mirrored variables with the
|
not self.dtype.is_floating) and isinstance(value, PerReplica):
|
||||||
# reduced value.
|
raise ValueError(
|
||||||
if self._aggregation == vs.VariableAggregation.NONE:
|
"Cannot update non-float variables with "
|
||||||
raise ValueError(
|
"tf.VariableAggregation.MEAN aggregation in replica context. "
|
||||||
_aggregation_error_msg.format(variable_type="MirroredVariable"))
|
"Either change the variable dtype to float or update it in "
|
||||||
|
"cross-replica context.")
|
||||||
|
|
||||||
def merge_fn(strategy, value, **other_kwargs):
|
assert strategy == self.distribute_strategy
|
||||||
"""Aggregate across replicas and update MV with aggregated value."""
|
v = _apply_aggregation(strategy, value, self.aggregation, self)
|
||||||
# Don't allow MEAN with non float dtype, since it may cause unexpected
|
return self._update_cross_replica(update_fn, v, **kwargs)
|
||||||
# precision loss. Python3 and NumPy automatically upcast integers to
|
|
||||||
# float in division, but we should always preserve the type.
|
|
||||||
#
|
|
||||||
# Note that to be backward compatible we allow the case when the value
|
|
||||||
# is *always* the same on each replica. I.E. value is not a
|
|
||||||
# PerReplica. Refer to regroup() to see how values are grouped.
|
|
||||||
if self._aggregation == vs.VariableAggregation.MEAN and (
|
|
||||||
not self.dtype.is_floating) and isinstance(value, PerReplica):
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot update non-float variables with "
|
|
||||||
"tf.VariableAggregation.MEAN aggregation in replica context. "
|
|
||||||
"Either change the variable dtype to float or update it in "
|
|
||||||
"cross-replica context.")
|
|
||||||
|
|
||||||
v = _apply_aggregation(strategy, value, self._aggregation, self)
|
return ds_context.get_replica_context().merge_call(
|
||||||
return strategy.extended.update(
|
merge_fn, args=(value,), kwargs=kwargs)
|
||||||
self, update_fn, args=(v,), kwargs=other_kwargs)
|
|
||||||
|
|
||||||
return ds_context.get_replica_context().merge_call(
|
|
||||||
merge_fn, args=(value,), kwargs=kwargs)
|
|
||||||
|
|
||||||
def assign_sub(self, value, use_locking=False, name=None, read_value=True):
|
def assign_sub(self, value, use_locking=False, name=None, read_value=True):
|
||||||
assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
|
assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
|
||||||
return self._mirrored_update(
|
return self._update(
|
||||||
update_fn=assign_sub_fn,
|
update_fn=assign_sub_fn,
|
||||||
value=value,
|
value=value,
|
||||||
use_locking=use_locking,
|
use_locking=use_locking,
|
||||||
@ -813,7 +844,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
|||||||
|
|
||||||
def assign_add(self, value, use_locking=False, name=None, read_value=True):
|
def assign_add(self, value, use_locking=False, name=None, read_value=True):
|
||||||
assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
|
assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
|
||||||
return self._mirrored_update(
|
return self._update(
|
||||||
update_fn=assign_add_fn,
|
update_fn=assign_add_fn,
|
||||||
value=value,
|
value=value,
|
||||||
use_locking=use_locking,
|
use_locking=use_locking,
|
||||||
@ -822,7 +853,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
|||||||
|
|
||||||
def assign(self, value, use_locking=False, name=None, read_value=True):
|
def assign(self, value, use_locking=False, name=None, read_value=True):
|
||||||
assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
|
assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
|
||||||
return self._mirrored_update(
|
return self._update(
|
||||||
update_fn=assign_fn,
|
update_fn=assign_fn,
|
||||||
value=value,
|
value=value,
|
||||||
use_locking=use_locking,
|
use_locking=use_locking,
|
||||||
@ -831,7 +862,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
|||||||
|
|
||||||
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
|
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
|
||||||
scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw)
|
scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw)
|
||||||
return self._mirrored_update(
|
return self._update(
|
||||||
update_fn=scatter_sub_fn,
|
update_fn=scatter_sub_fn,
|
||||||
value=sparse_delta,
|
value=sparse_delta,
|
||||||
use_locking=use_locking,
|
use_locking=use_locking,
|
||||||
@ -839,7 +870,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
|||||||
|
|
||||||
def scatter_add(self, sparse_delta, use_locking=False, name=None):
|
def scatter_add(self, sparse_delta, use_locking=False, name=None):
|
||||||
scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw)
|
scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw)
|
||||||
return self._mirrored_update(
|
return self._update(
|
||||||
update_fn=scatter_add_fn,
|
update_fn=scatter_add_fn,
|
||||||
value=sparse_delta,
|
value=sparse_delta,
|
||||||
use_locking=use_locking,
|
use_locking=use_locking,
|
||||||
@ -847,7 +878,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
|||||||
|
|
||||||
def scatter_mul(self, sparse_delta, use_locking=False, name=None):
|
def scatter_mul(self, sparse_delta, use_locking=False, name=None):
|
||||||
scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw)
|
scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw)
|
||||||
return self._mirrored_update(
|
return self._update(
|
||||||
update_fn=scatter_mul_fn,
|
update_fn=scatter_mul_fn,
|
||||||
value=sparse_delta,
|
value=sparse_delta,
|
||||||
use_locking=use_locking,
|
use_locking=use_locking,
|
||||||
@ -855,7 +886,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
|||||||
|
|
||||||
def scatter_div(self, sparse_delta, use_locking=False, name=None):
|
def scatter_div(self, sparse_delta, use_locking=False, name=None):
|
||||||
scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw)
|
scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw)
|
||||||
return self._mirrored_update(
|
return self._update(
|
||||||
update_fn=scatter_div_fn,
|
update_fn=scatter_div_fn,
|
||||||
value=sparse_delta,
|
value=sparse_delta,
|
||||||
use_locking=use_locking,
|
use_locking=use_locking,
|
||||||
@ -870,7 +901,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
|||||||
"`ONLY_FIRST_REPLICA` aggregation, got: %s" %
|
"`ONLY_FIRST_REPLICA` aggregation, got: %s" %
|
||||||
self._aggregation)
|
self._aggregation)
|
||||||
scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw)
|
scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw)
|
||||||
return self._mirrored_update(
|
return self._update(
|
||||||
update_fn=scatter_min_fn,
|
update_fn=scatter_min_fn,
|
||||||
value=sparse_delta,
|
value=sparse_delta,
|
||||||
use_locking=use_locking,
|
use_locking=use_locking,
|
||||||
@ -885,7 +916,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
|||||||
"`ONLY_FIRST_REPLICA` aggregation, got: %s" %
|
"`ONLY_FIRST_REPLICA` aggregation, got: %s" %
|
||||||
self._aggregation)
|
self._aggregation)
|
||||||
scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw)
|
scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw)
|
||||||
return self._mirrored_update(
|
return self._update(
|
||||||
update_fn=scatter_max_fn,
|
update_fn=scatter_max_fn,
|
||||||
value=sparse_delta,
|
value=sparse_delta,
|
||||||
use_locking=use_locking,
|
use_locking=use_locking,
|
||||||
@ -900,7 +931,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
|||||||
"`ONLY_FIRST_REPLICA` aggregation, got: %s" %
|
"`ONLY_FIRST_REPLICA` aggregation, got: %s" %
|
||||||
self._aggregation)
|
self._aggregation)
|
||||||
scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw)
|
scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw)
|
||||||
return self._mirrored_update(
|
return self._update(
|
||||||
update_fn=scatter_update_fn,
|
update_fn=scatter_update_fn,
|
||||||
value=sparse_delta,
|
value=sparse_delta,
|
||||||
use_locking=use_locking,
|
use_locking=use_locking,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user