Refactor SyncOnReadVariable update methods to be consistent with MirroredVariable.

This brings the behaivor of SyncOnReadVariable consistent with MirroredVariable.
E.g. previously it always returns a tf.Operation in cross replica context.

After the last refactoring SyncOnReadVariable only needs to override _update_replica to have the desired behavior.

This change also moves assign* and scatter* override to DistributedVariable
level.

This is part of the effort to make DistributedVariable.assign* returns a variable.

PiperOrigin-RevId: 308894562
Change-Id: I1a58352e2af2ff57402d8fc744fcfc9610a48d8b
This commit is contained in:
Ran Chen 2020-04-28 14:11:46 -07:00 committed by TensorFlower Gardener
parent 47766b3087
commit f7f727388c
2 changed files with 142 additions and 125 deletions

View File

@ -393,16 +393,6 @@ def _assign_on_device(device, variable, tensor):
return variable.assign(tensor)
def _assign_add_on_device(device, variable, tensor):
with ops.device(device):
return variable.assign_add(tensor)
def _assign_sub_on_device(device, variable, tensor):
with ops.device(device):
return variable.assign_sub(tensor)
class DistributedVarOp(object):
"""A class that looks like `tf.Operation`."""
@ -587,6 +577,89 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
def value(self):
return self._get_closest().value()
def assign_sub(self, value, use_locking=False, name=None, read_value=True):
assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
return self._update(
update_fn=assign_sub_fn,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
def assign_add(self, value, use_locking=False, name=None, read_value=True):
assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
return self._update(
update_fn=assign_add_fn,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
def assign(self, value, use_locking=False, name=None, read_value=True):
assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
return self._update(
update_fn=assign_fn,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw)
return self._update(
update_fn=scatter_sub_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
def scatter_add(self, sparse_delta, use_locking=False, name=None):
scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw)
return self._update(
update_fn=scatter_add_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
def scatter_mul(self, sparse_delta, use_locking=False, name=None):
scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw)
return self._update(
update_fn=scatter_mul_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
def scatter_div(self, sparse_delta, use_locking=False, name=None):
scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw)
return self._update(
update_fn=scatter_div_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
def scatter_min(self, sparse_delta, use_locking=False, name=None):
scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw)
return self._update(
update_fn=scatter_min_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
def scatter_max(self, sparse_delta, use_locking=False, name=None):
scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw)
return self._update(
update_fn=scatter_max_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
def scatter_update(self, sparse_delta, use_locking=False, name=None):
scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw)
return self._update(
update_fn=scatter_update_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
def _update_cross_replica(self, update_fn, value, **kwargs):
"""Applies updates across replicas.
@ -836,66 +909,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
return ds_context.get_replica_context().merge_call(
merge_fn, args=(value,), kwargs=kwargs)
def assign_sub(self, value, use_locking=False, name=None, read_value=True):
assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
return self._update(
update_fn=assign_sub_fn,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
def assign_add(self, value, use_locking=False, name=None, read_value=True):
assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
return self._update(
update_fn=assign_add_fn,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
def assign(self, value, use_locking=False, name=None, read_value=True):
assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
return self._update(
update_fn=assign_fn,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw)
return self._update(
update_fn=scatter_sub_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
def scatter_add(self, sparse_delta, use_locking=False, name=None):
scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw)
return self._update(
update_fn=scatter_add_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
def scatter_mul(self, sparse_delta, use_locking=False, name=None):
scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw)
return self._update(
update_fn=scatter_mul_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
def scatter_div(self, sparse_delta, use_locking=False, name=None):
scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw)
return self._update(
update_fn=scatter_div_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
def scatter_min(self, sparse_delta, use_locking=False, name=None):
def scatter_min(self, *args, **kwargs):
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
self._aggregation != vs.VariableAggregation.NONE):
raise NotImplementedError("scatter_min is only supported for mirrored "
@ -903,14 +917,9 @@ class MirroredVariable(DistributedVariable, Mirrored):
"`tf.distribute.Strategy` scope) with NONE or "
"`ONLY_FIRST_REPLICA` aggregation, got: %s" %
self._aggregation)
scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw)
return self._update(
update_fn=scatter_min_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
return super(MirroredVariable, self).scatter_min(*args, **kwargs)
def scatter_max(self, sparse_delta, use_locking=False, name=None):
def scatter_max(self, *args, **kwargs):
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
self._aggregation != vs.VariableAggregation.NONE):
raise NotImplementedError("scatter_max is only supported for mirrored "
@ -918,14 +927,9 @@ class MirroredVariable(DistributedVariable, Mirrored):
"`tf.distribute.Strategy` scope) with NONE or "
"`ONLY_FIRST_REPLICA` aggregation, got: %s" %
self._aggregation)
scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw)
return self._update(
update_fn=scatter_max_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
return super(MirroredVariable, self).scatter_max(*args, **kwargs)
def scatter_update(self, sparse_delta, use_locking=False, name=None):
def scatter_update(self, *args, **kwargs):
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
self._aggregation != vs.VariableAggregation.NONE):
raise NotImplementedError("scatter_update is only supported for mirrored "
@ -933,12 +937,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
"`tf.distribute.Strategy` scope) with NONE or "
"`ONLY_FIRST_REPLICA` aggregation, got: %s" %
self._aggregation)
scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw)
return self._update(
update_fn=scatter_update_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
return super(MirroredVariable, self).scatter_update(*args, **kwargs)
def _get_cross_replica(self):
# Return identity, to avoid directly exposing the variable to the user and
@ -1051,49 +1050,67 @@ def _assert_replica_context(strategy):
class SyncOnReadVariable(DistributedVariable):
"""Holds a map from replica to variables whose values are reduced on save."""
def assign_sub(self, *args, **kwargs):
def _update_cross_replica(self, update_fn, value, **kwargs):
# TODO(b/155139086): respect name argument.
kwargs["name"] = None
return super(SyncOnReadVariable,
self)._update_cross_replica(update_fn, value, **kwargs)
def _update_replica(self, update_fn, value, **kwargs):
return update_fn(self._get_closest(), value, **kwargs)
def _aggregation_sum_assign_not_supported_cross_replica(self, method):
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
if ds_context.in_cross_replica_context():
if self._aggregation == vs.VariableAggregation.SUM:
if self.aggregation == vs.VariableAggregation.SUM:
raise ValueError(
"SyncOnReadVariable does not support `assign_sub` in "
"cross-replica context when aggregation is set to "
"`tf.VariableAggregation.SUM`.")
return control_flow_ops.group(
tuple(
_assign_sub_on_device(v.device, v, args[0])
for v in self._values))
else:
return self._get().assign_sub(*args, **kwargs)
"Variable with `synchronization=ON_READ` does not support `%s` "
"in cross-replica context when aggregation is set to "
"`tf.VariableAggregation.SUM`." % method)
def assign_sub(self, *args, **kwargs):
self._aggregation_sum_assign_not_supported_cross_replica("assign_sub")
return super(SyncOnReadVariable, self).assign_sub(*args, **kwargs)
def assign_add(self, *args, **kwargs):
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
if ds_context.in_cross_replica_context():
if self._aggregation == vs.VariableAggregation.SUM:
raise ValueError(
"SyncOnReadVariable does not support `assign_add` in "
"cross-replica context when aggregation is set to "
"`tf.VariableAggregation.SUM`.")
return control_flow_ops.group(
tuple(
_assign_add_on_device(v.device, v, args[0])
for v in self._values))
else:
return self._get().assign_add(*args, **kwargs)
self._aggregation_sum_assign_not_supported_cross_replica("assign_add")
return super(SyncOnReadVariable, self).assign_add(*args, **kwargs)
def assign(self, *args, **kwargs):
def assign(self, value, *args, **kwargs):
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
if ds_context.in_cross_replica_context():
if (ds_context.in_cross_replica_context() and
self._aggregation == vs.VariableAggregation.SUM):
# To preserve the sum across save and restore, we have to divide the
# total across all devices when restoring a variable that was summed
# when saving.
tensor = args[0]
if self._aggregation == vs.VariableAggregation.SUM:
tensor = math_ops.cast(tensor / len(self._values), self.dtype)
return control_flow_ops.group(
tuple(_assign_on_device(v.device, v, tensor) for v in self._values))
else:
return self._get().assign(*args, **kwargs)
value = math_ops.cast(value / len(self._values), self.dtype)
return super(SyncOnReadVariable, self).assign(value, *args, **kwargs)
def _scatter_not_implemented(self, method):
raise NotImplementedError(
"Variables with `synchronization=ON_READ` doesn't support `%s`" %
method)
def scatter_sub(self, *args, **kwargs):
self._scatter_not_implemented("scatter_sub")
def scatter_add(self, *args, **kwargs):
self._scatter_not_implemented("scatter_add")
def scatter_mul(self, *args, **kwargs):
self._scatter_not_implemented("scatter_mul")
def scatter_div(self, *args, **kwargs):
self._scatter_not_implemented("scatter_div")
def scatter_min(self, *args, **kwargs):
self._scatter_not_implemented("scatter_min")
def scatter_max(self, *args, **kwargs):
self._scatter_not_implemented("scatter_max")
def scatter_update(self, *args, **kwargs):
self._scatter_not_implemented("scatter_update")
def value(self):
with ds_context.enter_or_assert_strategy(self._distribute_strategy):

View File

@ -1684,10 +1684,10 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
aggregation=variables_lib.VariableAggregation.SUM)
self.evaluate(variables_lib.global_variables_initializer())
with self.assertRaisesRegex(
ValueError, "SyncOnReadVariable does not support "):
ValueError, "Variable with `synchronization=ON_READ` does not support"):
self.evaluate(v.assign_add(1.))
with self.assertRaisesRegex(
ValueError, "SyncOnReadVariable does not support "):
ValueError, "Variable with `synchronization=ON_READ` does not support"):
self.evaluate(v.assign_sub(1.))
@combinations.generate(strategy_and_run_tf_function_combinations())