Refactor SyncOnReadVariable update methods to be consistent with MirroredVariable.
This brings the behaivor of SyncOnReadVariable consistent with MirroredVariable. E.g. previously it always returns a tf.Operation in cross replica context. After the last refactoring SyncOnReadVariable only needs to override _update_replica to have the desired behavior. This change also moves assign* and scatter* override to DistributedVariable level. This is part of the effort to make DistributedVariable.assign* returns a variable. PiperOrigin-RevId: 308894562 Change-Id: I1a58352e2af2ff57402d8fc744fcfc9610a48d8b
This commit is contained in:
parent
47766b3087
commit
f7f727388c
@ -393,16 +393,6 @@ def _assign_on_device(device, variable, tensor):
|
||||
return variable.assign(tensor)
|
||||
|
||||
|
||||
def _assign_add_on_device(device, variable, tensor):
|
||||
with ops.device(device):
|
||||
return variable.assign_add(tensor)
|
||||
|
||||
|
||||
def _assign_sub_on_device(device, variable, tensor):
|
||||
with ops.device(device):
|
||||
return variable.assign_sub(tensor)
|
||||
|
||||
|
||||
class DistributedVarOp(object):
|
||||
"""A class that looks like `tf.Operation`."""
|
||||
|
||||
@ -587,6 +577,89 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
|
||||
def value(self):
|
||||
return self._get_closest().value()
|
||||
|
||||
def assign_sub(self, value, use_locking=False, name=None, read_value=True):
|
||||
assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=assign_sub_fn,
|
||||
value=value,
|
||||
use_locking=use_locking,
|
||||
name=name,
|
||||
read_value=read_value)
|
||||
|
||||
def assign_add(self, value, use_locking=False, name=None, read_value=True):
|
||||
assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=assign_add_fn,
|
||||
value=value,
|
||||
use_locking=use_locking,
|
||||
name=name,
|
||||
read_value=read_value)
|
||||
|
||||
def assign(self, value, use_locking=False, name=None, read_value=True):
|
||||
assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=assign_fn,
|
||||
value=value,
|
||||
use_locking=use_locking,
|
||||
name=name,
|
||||
read_value=read_value)
|
||||
|
||||
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
|
||||
scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=scatter_sub_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def scatter_add(self, sparse_delta, use_locking=False, name=None):
|
||||
scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=scatter_add_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def scatter_mul(self, sparse_delta, use_locking=False, name=None):
|
||||
scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=scatter_mul_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def scatter_div(self, sparse_delta, use_locking=False, name=None):
|
||||
scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=scatter_div_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def scatter_min(self, sparse_delta, use_locking=False, name=None):
|
||||
scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=scatter_min_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def scatter_max(self, sparse_delta, use_locking=False, name=None):
|
||||
scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=scatter_max_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def scatter_update(self, sparse_delta, use_locking=False, name=None):
|
||||
scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=scatter_update_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def _update_cross_replica(self, update_fn, value, **kwargs):
|
||||
"""Applies updates across replicas.
|
||||
|
||||
@ -836,66 +909,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
||||
return ds_context.get_replica_context().merge_call(
|
||||
merge_fn, args=(value,), kwargs=kwargs)
|
||||
|
||||
def assign_sub(self, value, use_locking=False, name=None, read_value=True):
|
||||
assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=assign_sub_fn,
|
||||
value=value,
|
||||
use_locking=use_locking,
|
||||
name=name,
|
||||
read_value=read_value)
|
||||
|
||||
def assign_add(self, value, use_locking=False, name=None, read_value=True):
|
||||
assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=assign_add_fn,
|
||||
value=value,
|
||||
use_locking=use_locking,
|
||||
name=name,
|
||||
read_value=read_value)
|
||||
|
||||
def assign(self, value, use_locking=False, name=None, read_value=True):
|
||||
assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=assign_fn,
|
||||
value=value,
|
||||
use_locking=use_locking,
|
||||
name=name,
|
||||
read_value=read_value)
|
||||
|
||||
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
|
||||
scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=scatter_sub_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def scatter_add(self, sparse_delta, use_locking=False, name=None):
|
||||
scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=scatter_add_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def scatter_mul(self, sparse_delta, use_locking=False, name=None):
|
||||
scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=scatter_mul_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def scatter_div(self, sparse_delta, use_locking=False, name=None):
|
||||
scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=scatter_div_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def scatter_min(self, sparse_delta, use_locking=False, name=None):
|
||||
def scatter_min(self, *args, **kwargs):
|
||||
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
|
||||
self._aggregation != vs.VariableAggregation.NONE):
|
||||
raise NotImplementedError("scatter_min is only supported for mirrored "
|
||||
@ -903,14 +917,9 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
||||
"`tf.distribute.Strategy` scope) with NONE or "
|
||||
"`ONLY_FIRST_REPLICA` aggregation, got: %s" %
|
||||
self._aggregation)
|
||||
scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=scatter_min_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
return super(MirroredVariable, self).scatter_min(*args, **kwargs)
|
||||
|
||||
def scatter_max(self, sparse_delta, use_locking=False, name=None):
|
||||
def scatter_max(self, *args, **kwargs):
|
||||
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
|
||||
self._aggregation != vs.VariableAggregation.NONE):
|
||||
raise NotImplementedError("scatter_max is only supported for mirrored "
|
||||
@ -918,14 +927,9 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
||||
"`tf.distribute.Strategy` scope) with NONE or "
|
||||
"`ONLY_FIRST_REPLICA` aggregation, got: %s" %
|
||||
self._aggregation)
|
||||
scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=scatter_max_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
return super(MirroredVariable, self).scatter_max(*args, **kwargs)
|
||||
|
||||
def scatter_update(self, sparse_delta, use_locking=False, name=None):
|
||||
def scatter_update(self, *args, **kwargs):
|
||||
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
|
||||
self._aggregation != vs.VariableAggregation.NONE):
|
||||
raise NotImplementedError("scatter_update is only supported for mirrored "
|
||||
@ -933,12 +937,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
||||
"`tf.distribute.Strategy` scope) with NONE or "
|
||||
"`ONLY_FIRST_REPLICA` aggregation, got: %s" %
|
||||
self._aggregation)
|
||||
scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=scatter_update_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
return super(MirroredVariable, self).scatter_update(*args, **kwargs)
|
||||
|
||||
def _get_cross_replica(self):
|
||||
# Return identity, to avoid directly exposing the variable to the user and
|
||||
@ -1051,49 +1050,67 @@ def _assert_replica_context(strategy):
|
||||
class SyncOnReadVariable(DistributedVariable):
|
||||
"""Holds a map from replica to variables whose values are reduced on save."""
|
||||
|
||||
def assign_sub(self, *args, **kwargs):
|
||||
def _update_cross_replica(self, update_fn, value, **kwargs):
|
||||
# TODO(b/155139086): respect name argument.
|
||||
kwargs["name"] = None
|
||||
return super(SyncOnReadVariable,
|
||||
self)._update_cross_replica(update_fn, value, **kwargs)
|
||||
|
||||
def _update_replica(self, update_fn, value, **kwargs):
|
||||
return update_fn(self._get_closest(), value, **kwargs)
|
||||
|
||||
def _aggregation_sum_assign_not_supported_cross_replica(self, method):
|
||||
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||
if ds_context.in_cross_replica_context():
|
||||
if self._aggregation == vs.VariableAggregation.SUM:
|
||||
if self.aggregation == vs.VariableAggregation.SUM:
|
||||
raise ValueError(
|
||||
"SyncOnReadVariable does not support `assign_sub` in "
|
||||
"cross-replica context when aggregation is set to "
|
||||
"`tf.VariableAggregation.SUM`.")
|
||||
return control_flow_ops.group(
|
||||
tuple(
|
||||
_assign_sub_on_device(v.device, v, args[0])
|
||||
for v in self._values))
|
||||
else:
|
||||
return self._get().assign_sub(*args, **kwargs)
|
||||
"Variable with `synchronization=ON_READ` does not support `%s` "
|
||||
"in cross-replica context when aggregation is set to "
|
||||
"`tf.VariableAggregation.SUM`." % method)
|
||||
|
||||
def assign_sub(self, *args, **kwargs):
|
||||
self._aggregation_sum_assign_not_supported_cross_replica("assign_sub")
|
||||
return super(SyncOnReadVariable, self).assign_sub(*args, **kwargs)
|
||||
|
||||
def assign_add(self, *args, **kwargs):
|
||||
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||
if ds_context.in_cross_replica_context():
|
||||
if self._aggregation == vs.VariableAggregation.SUM:
|
||||
raise ValueError(
|
||||
"SyncOnReadVariable does not support `assign_add` in "
|
||||
"cross-replica context when aggregation is set to "
|
||||
"`tf.VariableAggregation.SUM`.")
|
||||
return control_flow_ops.group(
|
||||
tuple(
|
||||
_assign_add_on_device(v.device, v, args[0])
|
||||
for v in self._values))
|
||||
else:
|
||||
return self._get().assign_add(*args, **kwargs)
|
||||
self._aggregation_sum_assign_not_supported_cross_replica("assign_add")
|
||||
return super(SyncOnReadVariable, self).assign_add(*args, **kwargs)
|
||||
|
||||
def assign(self, *args, **kwargs):
|
||||
def assign(self, value, *args, **kwargs):
|
||||
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||
if ds_context.in_cross_replica_context():
|
||||
if (ds_context.in_cross_replica_context() and
|
||||
self._aggregation == vs.VariableAggregation.SUM):
|
||||
# To preserve the sum across save and restore, we have to divide the
|
||||
# total across all devices when restoring a variable that was summed
|
||||
# when saving.
|
||||
tensor = args[0]
|
||||
if self._aggregation == vs.VariableAggregation.SUM:
|
||||
tensor = math_ops.cast(tensor / len(self._values), self.dtype)
|
||||
return control_flow_ops.group(
|
||||
tuple(_assign_on_device(v.device, v, tensor) for v in self._values))
|
||||
else:
|
||||
return self._get().assign(*args, **kwargs)
|
||||
value = math_ops.cast(value / len(self._values), self.dtype)
|
||||
return super(SyncOnReadVariable, self).assign(value, *args, **kwargs)
|
||||
|
||||
def _scatter_not_implemented(self, method):
|
||||
raise NotImplementedError(
|
||||
"Variables with `synchronization=ON_READ` doesn't support `%s`" %
|
||||
method)
|
||||
|
||||
def scatter_sub(self, *args, **kwargs):
|
||||
self._scatter_not_implemented("scatter_sub")
|
||||
|
||||
def scatter_add(self, *args, **kwargs):
|
||||
self._scatter_not_implemented("scatter_add")
|
||||
|
||||
def scatter_mul(self, *args, **kwargs):
|
||||
self._scatter_not_implemented("scatter_mul")
|
||||
|
||||
def scatter_div(self, *args, **kwargs):
|
||||
self._scatter_not_implemented("scatter_div")
|
||||
|
||||
def scatter_min(self, *args, **kwargs):
|
||||
self._scatter_not_implemented("scatter_min")
|
||||
|
||||
def scatter_max(self, *args, **kwargs):
|
||||
self._scatter_not_implemented("scatter_max")
|
||||
|
||||
def scatter_update(self, *args, **kwargs):
|
||||
self._scatter_not_implemented("scatter_update")
|
||||
|
||||
def value(self):
|
||||
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||
|
@ -1684,10 +1684,10 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
|
||||
aggregation=variables_lib.VariableAggregation.SUM)
|
||||
self.evaluate(variables_lib.global_variables_initializer())
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "SyncOnReadVariable does not support "):
|
||||
ValueError, "Variable with `synchronization=ON_READ` does not support"):
|
||||
self.evaluate(v.assign_add(1.))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "SyncOnReadVariable does not support "):
|
||||
ValueError, "Variable with `synchronization=ON_READ` does not support"):
|
||||
self.evaluate(v.assign_sub(1.))
|
||||
|
||||
@combinations.generate(strategy_and_run_tf_function_combinations())
|
||||
|
Loading…
Reference in New Issue
Block a user