Internal change

PiperOrigin-RevId: 308926386
Change-Id: I3185dfa09e948883d21ba7e3e16020e30639728c
This commit is contained in:
A. Unique TensorFlower 2020-04-28 17:02:12 -07:00 committed by TensorFlower Gardener
parent 8fe9e086d7
commit 62856c9366
2 changed files with 126 additions and 143 deletions

View File

@ -393,6 +393,16 @@ def _assign_on_device(device, variable, tensor):
return variable.assign(tensor)
def _assign_add_on_device(device, variable, tensor):
with ops.device(device):
return variable.assign_add(tensor)
def _assign_sub_on_device(device, variable, tensor):
with ops.device(device):
return variable.assign_sub(tensor)
class DistributedVarOp(object):
"""A class that looks like `tf.Operation`."""
@ -577,89 +587,6 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
def value(self):
return self._get_closest().value()
def assign_sub(self, value, use_locking=False, name=None, read_value=True):
assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
return self._update(
update_fn=assign_sub_fn,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
def assign_add(self, value, use_locking=False, name=None, read_value=True):
assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
return self._update(
update_fn=assign_add_fn,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
def assign(self, value, use_locking=False, name=None, read_value=True):
assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
return self._update(
update_fn=assign_fn,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw)
return self._update(
update_fn=scatter_sub_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
def scatter_add(self, sparse_delta, use_locking=False, name=None):
scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw)
return self._update(
update_fn=scatter_add_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
def scatter_mul(self, sparse_delta, use_locking=False, name=None):
scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw)
return self._update(
update_fn=scatter_mul_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
def scatter_div(self, sparse_delta, use_locking=False, name=None):
scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw)
return self._update(
update_fn=scatter_div_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
def scatter_min(self, sparse_delta, use_locking=False, name=None):
scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw)
return self._update(
update_fn=scatter_min_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
def scatter_max(self, sparse_delta, use_locking=False, name=None):
scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw)
return self._update(
update_fn=scatter_max_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
def scatter_update(self, sparse_delta, use_locking=False, name=None):
scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw)
return self._update(
update_fn=scatter_update_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
def _update_cross_replica(self, update_fn, value, **kwargs):
"""Applies updates across replicas.
@ -909,7 +836,66 @@ class MirroredVariable(DistributedVariable, Mirrored):
return ds_context.get_replica_context().merge_call(
merge_fn, args=(value,), kwargs=kwargs)
def scatter_min(self, *args, **kwargs):
def assign_sub(self, value, use_locking=False, name=None, read_value=True):
assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
return self._update(
update_fn=assign_sub_fn,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
def assign_add(self, value, use_locking=False, name=None, read_value=True):
assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
return self._update(
update_fn=assign_add_fn,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
def assign(self, value, use_locking=False, name=None, read_value=True):
assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
return self._update(
update_fn=assign_fn,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw)
return self._update(
update_fn=scatter_sub_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
def scatter_add(self, sparse_delta, use_locking=False, name=None):
scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw)
return self._update(
update_fn=scatter_add_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
def scatter_mul(self, sparse_delta, use_locking=False, name=None):
scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw)
return self._update(
update_fn=scatter_mul_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
def scatter_div(self, sparse_delta, use_locking=False, name=None):
scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw)
return self._update(
update_fn=scatter_div_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
def scatter_min(self, sparse_delta, use_locking=False, name=None):
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
self._aggregation != vs.VariableAggregation.NONE):
raise NotImplementedError("scatter_min is only supported for mirrored "
@ -917,9 +903,14 @@ class MirroredVariable(DistributedVariable, Mirrored):
"`tf.distribute.Strategy` scope) with NONE or "
"`ONLY_FIRST_REPLICA` aggregation, got: %s" %
self._aggregation)
return super(MirroredVariable, self).scatter_min(*args, **kwargs)
scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw)
return self._update(
update_fn=scatter_min_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
def scatter_max(self, *args, **kwargs):
def scatter_max(self, sparse_delta, use_locking=False, name=None):
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
self._aggregation != vs.VariableAggregation.NONE):
raise NotImplementedError("scatter_max is only supported for mirrored "
@ -927,9 +918,14 @@ class MirroredVariable(DistributedVariable, Mirrored):
"`tf.distribute.Strategy` scope) with NONE or "
"`ONLY_FIRST_REPLICA` aggregation, got: %s" %
self._aggregation)
return super(MirroredVariable, self).scatter_max(*args, **kwargs)
scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw)
return self._update(
update_fn=scatter_max_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
def scatter_update(self, *args, **kwargs):
def scatter_update(self, sparse_delta, use_locking=False, name=None):
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
self._aggregation != vs.VariableAggregation.NONE):
raise NotImplementedError("scatter_update is only supported for mirrored "
@ -937,7 +933,12 @@ class MirroredVariable(DistributedVariable, Mirrored):
"`tf.distribute.Strategy` scope) with NONE or "
"`ONLY_FIRST_REPLICA` aggregation, got: %s" %
self._aggregation)
return super(MirroredVariable, self).scatter_update(*args, **kwargs)
scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw)
return self._update(
update_fn=scatter_update_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
def _get_cross_replica(self):
# Return identity, to avoid directly exposing the variable to the user and
@ -1050,67 +1051,49 @@ def _assert_replica_context(strategy):
class SyncOnReadVariable(DistributedVariable):
"""Holds a map from replica to variables whose values are reduced on save."""
def _update_cross_replica(self, update_fn, value, **kwargs):
# TODO(b/155139086): respect name argument.
kwargs["name"] = None
return super(SyncOnReadVariable,
self)._update_cross_replica(update_fn, value, **kwargs)
def _update_replica(self, update_fn, value, **kwargs):
return update_fn(self._get_closest(), value, **kwargs)
def _aggregation_sum_assign_not_supported_cross_replica(self, method):
def assign_sub(self, *args, **kwargs):
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
if ds_context.in_cross_replica_context():
if self.aggregation == vs.VariableAggregation.SUM:
if self._aggregation == vs.VariableAggregation.SUM:
raise ValueError(
"Variable with `synchronization=ON_READ` does not support `%s` "
"in cross-replica context when aggregation is set to "
"`tf.VariableAggregation.SUM`." % method)
def assign_sub(self, *args, **kwargs):
self._aggregation_sum_assign_not_supported_cross_replica("assign_sub")
return super(SyncOnReadVariable, self).assign_sub(*args, **kwargs)
"SyncOnReadVariable does not support `assign_sub` in "
"cross-replica context when aggregation is set to "
"`tf.VariableAggregation.SUM`.")
return control_flow_ops.group(
tuple(
_assign_sub_on_device(v.device, v, args[0])
for v in self._values))
else:
return self._get().assign_sub(*args, **kwargs)
def assign_add(self, *args, **kwargs):
self._aggregation_sum_assign_not_supported_cross_replica("assign_add")
return super(SyncOnReadVariable, self).assign_add(*args, **kwargs)
def assign(self, value, *args, **kwargs):
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
if (ds_context.in_cross_replica_context() and
self._aggregation == vs.VariableAggregation.SUM):
if ds_context.in_cross_replica_context():
if self._aggregation == vs.VariableAggregation.SUM:
raise ValueError(
"SyncOnReadVariable does not support `assign_add` in "
"cross-replica context when aggregation is set to "
"`tf.VariableAggregation.SUM`.")
return control_flow_ops.group(
tuple(
_assign_add_on_device(v.device, v, args[0])
for v in self._values))
else:
return self._get().assign_add(*args, **kwargs)
def assign(self, *args, **kwargs):
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
if ds_context.in_cross_replica_context():
# To preserve the sum across save and restore, we have to divide the
# total across all devices when restoring a variable that was summed
# when saving.
value = math_ops.cast(value / len(self._values), self.dtype)
return super(SyncOnReadVariable, self).assign(value, *args, **kwargs)
def _scatter_not_implemented(self, method):
raise NotImplementedError(
"Variables with `synchronization=ON_READ` doesn't support `%s`" %
method)
def scatter_sub(self, *args, **kwargs):
self._scatter_not_implemented("scatter_sub")
def scatter_add(self, *args, **kwargs):
self._scatter_not_implemented("scatter_add")
def scatter_mul(self, *args, **kwargs):
self._scatter_not_implemented("scatter_mul")
def scatter_div(self, *args, **kwargs):
self._scatter_not_implemented("scatter_div")
def scatter_min(self, *args, **kwargs):
self._scatter_not_implemented("scatter_min")
def scatter_max(self, *args, **kwargs):
self._scatter_not_implemented("scatter_max")
def scatter_update(self, *args, **kwargs):
self._scatter_not_implemented("scatter_update")
tensor = args[0]
if self._aggregation == vs.VariableAggregation.SUM:
tensor = math_ops.cast(tensor / len(self._values), self.dtype)
return control_flow_ops.group(
tuple(_assign_on_device(v.device, v, tensor) for v in self._values))
else:
return self._get().assign(*args, **kwargs)
def value(self):
with ds_context.enter_or_assert_strategy(self._distribute_strategy):

View File

@ -1684,10 +1684,10 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
aggregation=variables_lib.VariableAggregation.SUM)
self.evaluate(variables_lib.global_variables_initializer())
with self.assertRaisesRegex(
ValueError, "Variable with `synchronization=ON_READ` does not support"):
ValueError, "SyncOnReadVariable does not support "):
self.evaluate(v.assign_add(1.))
with self.assertRaisesRegex(
ValueError, "Variable with `synchronization=ON_READ` does not support"):
ValueError, "SyncOnReadVariable does not support "):
self.evaluate(v.assign_sub(1.))
@combinations.generate(strategy_and_run_tf_function_combinations())