Add VariablePolicy field to the DistributedVariable class as part of an internal refactor that will allow us to attach different policies to a DistributedVariable.

PiperOrigin-RevId: 316629999
Change-Id: I20160480b0678657198112adaa61ad7a47823cbd
This commit is contained in:
Anjali Sridhar 2020-06-16 00:34:30 -07:00 committed by TensorFlower Gardener
parent b03a4dbbd6
commit 5107743c47
2 changed files with 615 additions and 140 deletions

View File

@ -41,6 +41,36 @@ from tensorflow.python.types import core
from tensorflow.python.util.tf_export import tf_export
def _on_write_update_replica(var, update_fn, value, **kwargs):
"""Updates variables with ON_WRITE synchronization in replica context."""
if var.aggregation == vs.VariableAggregation.NONE:
return update_fn(var._get_on_device_or_primary(), value, **kwargs) # pylint: disable=protected-access
def merge_fn(strategy, value, **kwargs):
"""Aggregate values and update all variables in cross replica context."""
# Don't allow MEAN with non float dtype, since it may cause unexpected
# precision loss. Python3 and NumPy automatically upcast integers to
# float in division, but we should always preserve the type.
#
# Note that to be backward compatible we allow the case when the value
# is *always* the same on each replica. I.E. value is not a
# PerReplica. Refer to regroup() to see how values are grouped.
if var.aggregation == vs.VariableAggregation.MEAN and (
not var.dtype.is_floating) and isinstance(value, PerReplica):
raise ValueError(
"Cannot update non-float variables with "
"tf.VariableAggregation.MEAN aggregation in replica context. "
"Either change the variable dtype to float or update it in "
"cross-replica context.")
assert strategy == var.distribute_strategy
v = values_util.apply_aggregation(strategy, value, var.aggregation, var)
return var._update_cross_replica(update_fn, v, **kwargs) # pylint: disable=protected-access
return ds_context.get_replica_context().merge_call(
merge_fn, args=(value,), kwargs=kwargs)
@tf_export("distribute.DistributedValues", v1=[])
class DistributedValues(object):
"""Base class for representing distributed values.
@ -409,10 +439,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
core.Tensor):
"""Holds a map from replica to variables."""
# TODO(josh11b): Support changing the set of variables if e.g. if new
# devices are joining or a device is to leave.
def __init__(self, strategy, values, aggregation):
def __init__(self, strategy, values, aggregation, var_policy=None):
self._distribute_strategy = strategy
self._aggregation = aggregation
super(DistributedVariable, self).__init__(values)
@ -439,6 +466,9 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
# when restoring from a checkpoint, we may set the _initializer_op
# property on the entire `DistributedVariable`.
self._initializer_op = None
# Set a VariablePolicy which decides how we replicate/aggregate the given
# variable.
self._var_policy = var_policy
def is_initialized(self, name=None):
"""Identifies if all the component variables are initialized.
@ -580,6 +610,8 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
return array_ops.identity(self._get())
def value(self):
if self._var_policy:
return self._var_policy.value(self)
return self._get_on_device_or_primary().value()
def numpy(self):
@ -590,87 +622,104 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
"numpy() is only available when eager execution is enabled.")
def assign_sub(self, value, use_locking=False, name=None, read_value=True):
assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
return self._update(
update_fn=assign_sub_fn,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
if self._var_policy:
return self._var_policy.assign_sub(self, value, use_locking=use_locking,
name=name, read_value=read_value)
return values_util.on_write_assign_sub(self, value, use_locking=use_locking,
name=name, read_value=read_value)
def assign_add(self, value, use_locking=False, name=None, read_value=True):
assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
return self._update(
update_fn=assign_add_fn,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
if self._var_policy:
return self._var_policy.assign_add(self, value, use_locking=use_locking,
name=name, read_value=read_value)
return values_util.on_write_assign_add(self, value, use_locking=use_locking,
name=name, read_value=read_value)
def assign(self, value, use_locking=False, name=None, read_value=True):
assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
return self._update(
update_fn=assign_fn,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
if self._var_policy:
return self._var_policy.assign(self, value, use_locking=use_locking,
name=name, read_value=read_value)
return values_util.on_write_assign(self, value, use_locking=use_locking,
name=name, read_value=read_value)
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw)
return self._update(
update_fn=scatter_sub_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
if self._var_policy:
self._var_policy.scatter_sub(self, sparse_delta, use_locking=use_locking,
name=name)
return values_util.scatter_sub(self, sparse_delta, use_locking=use_locking,
name=name)
def scatter_add(self, sparse_delta, use_locking=False, name=None):
scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw)
return self._update(
update_fn=scatter_add_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
if self._var_policy:
self._var_policy.scatter_add(self, sparse_delta, use_locking=use_locking,
name=name)
return values_util.scatter_add(self, sparse_delta, use_locking=use_locking,
name=name)
def scatter_mul(self, sparse_delta, use_locking=False, name=None):
scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw)
return self._update(
update_fn=scatter_mul_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
if self._var_policy:
self._var_policy.scatter_mul(self, sparse_delta, use_locking=use_locking,
name=name)
return values_util.scatter_mul(self, sparse_delta, use_locking=use_locking,
name=name)
def scatter_div(self, sparse_delta, use_locking=False, name=None):
scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw)
return self._update(
update_fn=scatter_div_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
if self._var_policy:
self._var_policy.scatter_div(self, sparse_delta, use_locking=use_locking,
name=name)
return values_util.scatter_div(self, sparse_delta, use_locking=use_locking,
name=name)
def scatter_min(self, sparse_delta, use_locking=False, name=None):
scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw)
return self._update(
update_fn=scatter_min_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
if self._var_policy:
self._var_policy.scatter_min(self, sparse_delta, use_locking=use_locking,
name=name)
return values_util.scatter_min(self, sparse_delta, use_locking=use_locking,
name=name)
def scatter_max(self, sparse_delta, use_locking=False, name=None):
scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw)
return self._update(
update_fn=scatter_max_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
if self._var_policy:
self._var_policy.scatter_max(self, sparse_delta, use_locking=use_locking,
name=name)
return values_util.scatter_max(self, sparse_delta, use_locking=use_locking,
name=name)
def scatter_update(self, sparse_delta, use_locking=False, name=None):
scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw)
return self._update(
update_fn=scatter_update_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
if self._var_policy:
self._var_policy.scatter_update(self, sparse_delta,
use_locking=use_locking, name=name)
return values_util.scatter_update(self, sparse_delta,
use_locking=use_locking,
name=name)
def _gather_saveables_for_checkpoint(self):
"""Overrides Trackable method.
This allows both name-based and object-based save and restore of
DistributedVariables.
Returns:
A dictionary mapping attribute names to `SaveableObject` factories.
"""
def _saveable_factory(name=self._common_name):
return _DistributedVariableSaveable(self, self._primary, name)
return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
def _as_graph_element(self):
if self._var_policy:
return self._var_policy._as_graph_element(self) # pylint: disable=protected-access
raise NotImplementedError("No policy set for calling _as_graph_element.")
def _get_cross_replica(self):
if self._var_policy:
return self._var_policy._get_cross_replica(self) # pylint: disable=protected-access
raise NotImplementedError(
"This method should be overridden by sub-classes which support cross-"
"replica accesses.")
def _update_cross_replica(self, update_fn, value, **kwargs):
"""Applies updates across replicas.
@ -699,6 +748,8 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
Returns:
Updated variable or `tf.Operation`.
"""
if self._var_policy:
return self._var_policy._update_replica(self, update_fn, value, **kwargs) # pylint: disable=protected-access
raise NotImplementedError("should be implemented by subclass.")
def _update(self, update_fn, value, **kwargs):
@ -735,6 +786,31 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
"""Pass resource_variable_ops.is_resource_variable check."""
pass
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
"""Converts a variable to a tensor."""
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
return ops.convert_to_tensor(
self._get(), dtype=dtype, name=name, as_ref=as_ref)
class _DistributedVariableSaveable(saveable_object.SaveableObject):
"""Class for defining how to restore a DistributedVariable."""
def __init__(self, distributed_variable, primary_variable, name):
self._distributed_variable = distributed_variable
if not self._distributed_variable._var_policy:
raise ValueError("VariablePolicy has not been set for the distributed "
"variable.")
tensor, spec = distributed_variable._var_policy.get_saveable(
distributed_variable, primary_variable, name)
super(_DistributedVariableSaveable, self).__init__(tensor, spec, name)
def restore(self, restored_tensors, restored_shapes):
"""Restore the same value into all variables."""
tensor, = restored_tensors
return self._distributed_variable._var_policy.get_restore_ops( # pylint: disable=protected-access
self._distributed_variable, tensor)
class _MirroredSaveable(saveable_object_util.ResourceVariableSaveable):
"""Class for defining how to restore a MirroredVariable."""
@ -756,61 +832,27 @@ class MirroredVariable(DistributedVariable, Mirrored):
"""Holds a map from replica to variables whose values are kept in sync."""
def _update_replica(self, update_fn, value, **kwargs):
if self.aggregation == vs.VariableAggregation.NONE:
return update_fn(self._get_on_device_or_primary(), value, **kwargs)
def merge_fn(strategy, value, **kwargs):
"""Aggregate values and update all variables in cross replica context."""
# Don't allow MEAN with non float dtype, since it may cause unexpected
# precision loss. Python3 and NumPy automatically upcast integers to
# float in division, but we should always preserve the type.
#
# Note that to be backward compatible we allow the case when the value
# is *always* the same on each replica. I.E. value is not a
# PerReplica. Refer to regroup() to see how values are grouped.
if self._aggregation == vs.VariableAggregation.MEAN and (
not self.dtype.is_floating) and isinstance(value, PerReplica):
raise ValueError(
"Cannot update non-float variables with "
"tf.VariableAggregation.MEAN aggregation in replica context. "
"Either change the variable dtype to float or update it in "
"cross-replica context.")
assert strategy == self.distribute_strategy
v = values_util.apply_aggregation(strategy, value, self.aggregation, self)
return self._update_cross_replica(update_fn, v, **kwargs)
return ds_context.get_replica_context().merge_call(
merge_fn, args=(value,), kwargs=kwargs)
return _on_write_update_replica(self, update_fn, value, **kwargs)
def scatter_min(self, *args, **kwargs):
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
self._aggregation != vs.VariableAggregation.NONE):
raise NotImplementedError("scatter_min is only supported for mirrored "
"variable (variable created within certain "
"`tf.distribute.Strategy` scope) with NONE or "
"`ONLY_FIRST_REPLICA` aggregation, got: %s" %
self._aggregation)
raise NotImplementedError(values_util.scatter_error_msg.format(
op_name="scatter_min", aggregation=self._aggregation))
return super(MirroredVariable, self).scatter_min(*args, **kwargs)
def scatter_max(self, *args, **kwargs):
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
self._aggregation != vs.VariableAggregation.NONE):
raise NotImplementedError("scatter_max is only supported for mirrored "
"variable (variable created within certain "
"`tf.distribute.Strategy` scope) with NONE or "
"`ONLY_FIRST_REPLICA` aggregation, got: %s" %
self._aggregation)
raise NotImplementedError(values_util.scatter_error_msg.format(
op_name="scatter_min", aggregation=self._aggregation))
return super(MirroredVariable, self).scatter_max(*args, **kwargs)
def scatter_update(self, *args, **kwargs):
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
self._aggregation != vs.VariableAggregation.NONE):
raise NotImplementedError("scatter_update is only supported for mirrored "
"variable (variable created within certain "
"`tf.distribute.Strategy` scope) with NONE or "
"`ONLY_FIRST_REPLICA` aggregation, got: %s" %
self._aggregation)
raise NotImplementedError(values_util.scatter_error_msg.format(
op_name="scatter_min", aggregation=self._aggregation))
return super(MirroredVariable, self).scatter_update(*args, **kwargs)
def _get_cross_replica(self):
@ -893,28 +935,13 @@ class SyncOnReadVariable(DistributedVariable):
def _update_replica(self, update_fn, value, **kwargs):
return update_fn(self._get_on_device_or_primary(), value, **kwargs)
def _assign_on_each_device(self, assign_func, value, read_value):
update = control_flow_ops.group(
tuple(
assign_func(v.device, v, value)
for v in self._values))
if not read_value:
return update
with ops.control_dependencies([update] if update else []):
return self.read_value()
# TODO(b/154017756): Make assign behaivor in cross replica context consistent
# with MirroredVariable.
def assign_sub(self, value, use_locking=False, name=None, read_value=True):
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
if ds_context.in_cross_replica_context():
if self._aggregation == vs.VariableAggregation.SUM:
raise ValueError(
"SyncOnReadVariable does not support `assign_sub` in "
"cross-replica context when aggregation is set to "
"`tf.VariableAggregation.SUM`.")
return self._assign_on_each_device(values_util.assign_sub_on_device,
value, read_value)
return values_util.on_read_assign_sub_cross_replica(
self, value, read_value=read_value)
else:
return super(SyncOnReadVariable,
self).assign_sub(value, use_locking, name, read_value)
@ -922,13 +949,8 @@ class SyncOnReadVariable(DistributedVariable):
def assign_add(self, value, use_locking=False, name=None, read_value=True):
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
if ds_context.in_cross_replica_context():
if self._aggregation == vs.VariableAggregation.SUM:
raise ValueError(
"SyncOnReadVariable does not support `assign_add` in "
"cross-replica context when aggregation is set to "
"`tf.VariableAggregation.SUM`.")
return self._assign_on_each_device(values_util.assign_add_on_device,
value, read_value)
return values_util.on_read_assign_add_cross_replica(
self, value, read_value=read_value)
else:
return super(SyncOnReadVariable,
self).assign_add(value, use_locking, name, read_value)
@ -936,13 +958,8 @@ class SyncOnReadVariable(DistributedVariable):
def assign(self, value, use_locking=False, name=None, read_value=True):
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
if ds_context.in_cross_replica_context():
# To preserve the sum across save and restore, we have to divide the
# total across all devices when restoring a variable that was summed
# when saving.
if self._aggregation == vs.VariableAggregation.SUM:
value = math_ops.cast(value / len(self._values), self.dtype)
return self._assign_on_each_device(values_util.assign_on_device, value,
read_value)
return values_util.on_read_assign_cross_replica(
self, value, read_value=read_value)
else:
return super(SyncOnReadVariable,
self).assign(value, use_locking, name, read_value)
@ -987,7 +1004,7 @@ class SyncOnReadVariable(DistributedVariable):
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
return self._distribute_strategy.reduce(
reduce_util.ReduceOp.from_variable_aggregation(self.aggregation),
reduce_util.ReduceOp.from_variable_aggregation(self._aggregation),
self,
axis=None)
@ -1022,6 +1039,16 @@ class SyncOnReadVariable(DistributedVariable):
# Register a conversion functions which reads the value of the variable,
# allowing instances of the class to be used as tensors.
# DistributedVariable
def _tensor_conversion_distributed_var(var, dtype=None, name=None,
as_ref=False):
return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access
ops.register_tensor_conversion_function(DistributedVariable,
_tensor_conversion_distributed_var)
# MirroredVariables
def _tensor_conversion_mirrored(var, dtype=None, name=None, as_ref=False):
return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access
@ -1048,3 +1075,299 @@ def _tensor_conversion_sync_on_read(var, dtype=None, name=None, as_ref=False):
ops.register_tensor_conversion_function(SyncOnReadVariable,
_tensor_conversion_sync_on_read)
class VariablePolicy(object):
"""Policy defining synchronization and aggregation of a distributed variable.
Given `synchronization` and `aggregation` parameters set on a `tf.Variable`
during variable creation within `tf.distribute` scope, `tf.distribute` creates
an appropriate policy object and assigns it to the distributed variable. All
variable operations are delegated to the respective policy object.
"""
def __init__(self, aggregation):
self._aggregation = aggregation
def value(self):
raise NotImplementedError(
"This method should be overridden by sub-classes.")
def _is_mirrored(self):
raise NotImplementedError(
"This method should be overridden by sub-classes.")
def _as_graph_element(self, _):
raise NotImplementedError(
"This method should be overridden by sub-classes.")
def _get_cross_replica(self, var):
raise NotImplementedError(
"This method should be overridden by sub-classes.")
def _update_replica(self, var, update_fn, value, **kwargs):
raise NotImplementedError(
"This method should be overridden by sub-classes.")
class OnReadPolicy(VariablePolicy):
"""Policy defined for `tf.VariableSynchronization.ON_READ` synchronization.
This policy is created when `synchronization` is set to
`tf.VariableSynchronization.ON_READ` and `aggregation` is set to any of the
values allowed by the `tf.VariableAggregation` enum such as `NONE`, `SUM`,
`MEAN` or `ONLY_FIRST_REPLICA`when creating a `tf.Variable` in `tf.distribute`
scope.
"""
def _is_mirrored(self):
return False
def value(self, var):
with ds_context.enter_or_assert_strategy(var.distribute_strategy):
if ds_context.in_cross_replica_context():
return var._get_cross_replica() # pylint: disable=protected-access
else:
return var._get_on_device_or_primary().value() # pylint: disable=protected-access
def _as_graph_element(self, var):
with ds_context.enter_or_assert_strategy(var.distribute_strategy):
if ds_context.in_cross_replica_context():
return ops.convert_to_tensor(var._get_cross_replica()) # pylint: disable=protected-access
return var._get()._as_graph_element() # pylint: disable=protected-access
def _get_cross_replica(self, var):
if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
return var._primary # pylint: disable=protected-access
with ds_context.enter_or_assert_strategy(var.distribute_strategy):
return var.distribute_strategy.reduce(
reduce_util.ReduceOp.from_variable_aggregation(self._aggregation),
var,
axis=None)
def _update_replica(self, var, update_fn, value, **kwargs):
return update_fn(var._get_on_device_or_primary(), value, **kwargs) # pylint: disable=protected-access
def _scatter_not_implemented(self, method):
raise NotImplementedError(
"ON_READ variables doesn't support `%s` in cross replica context" %
method)
def assign_sub(self, var, value, use_locking=False, name=None,
read_value=True):
with ds_context.enter_or_assert_strategy(var.distribute_strategy):
if ds_context.in_cross_replica_context():
return values_util.on_read_assign_sub_cross_replica(
var, value, read_value=read_value)
else:
return values_util.on_write_assign_sub(
var, value, use_locking=use_locking, name=name,
read_value=read_value)
def assign_add(self, var, value, use_locking=False, name=None,
read_value=True):
with ds_context.enter_or_assert_strategy(var.distribute_strategy):
if ds_context.in_cross_replica_context():
return values_util.on_read_assign_add_cross_replica(
var, value, read_value=read_value)
else:
return values_util.on_write_assign_add(
var, value, use_locking=use_locking, name=name,
read_value=read_value)
def assign(self, var, value, use_locking=False, name=None, read_value=True):
with ds_context.enter_or_assert_strategy(var.distribute_strategy):
if ds_context.in_cross_replica_context():
return values_util.on_read_assign_cross_replica(var, value,
read_value=read_value)
else:
return values_util.on_write_assign(var, value,
use_locking=use_locking,
name=name,
read_value=read_value)
def scatter_sub(self, *args, **kwargs):
del args, kwargs
self._scatter_not_implemented("scatter_sub")
def scatter_add(self, *args, **kwargs):
del args, kwargs
self._scatter_not_implemented("scatter_add")
def scatter_mul(self, *args, **kwargs):
del args, kwargs
self._scatter_not_implemented("scatter_mul")
def scatter_div(self, *args, **kwargs):
del args, kwargs
self._scatter_not_implemented("scatter_div")
def scatter_min(self, *args, **kwargs):
del args, kwargs
self._scatter_not_implemented("scatter_min")
def scatter_max(self, *args, **kwargs):
del args, kwargs
self._scatter_not_implemented("scatter_max")
def scatter_update(self, *args, **kwargs):
del args, kwargs
self._scatter_not_implemented("scatter_update")
def get_saveable(self, var, primary_var, name):
"""Create a saveable object for the given variable."""
# We use a callable so that we don't have to evaluate this expression
# in the case where we are trying to restore instead of save.
def tensor():
strategy = var.distribute_strategy
return strategy.extended.read_var(var)
spec = saveable_object.SaveSpec(
tensor=tensor,
slice_spec="",
name=name,
dtype=var.dtype,
device=primary_var.device)
return tensor, [spec]
def get_restore_ops(self, var, tensor):
"""Restore the same value into all variables."""
# To preserve the sum across save and restore, we have to divide the
# total across all devices when restoring a variable that was summed
# when saving.
if self._aggregation == vs.VariableAggregation.SUM:
tensor = math_ops.cast(tensor / len(var._devices), # pylint: disable=protected-access
var.dtype)
return control_flow_ops.group(
tuple(
values_util.assign_on_device(v.device, v, tensor)
for v in var.values))
class AutoPolicy(VariablePolicy):
"""Policy defined for `tf.VariableSynchronization.AUTO` synchronization.
This policy is created when `synchronization` is set to
`tf.VariableSynchronization.AUTO` and `aggregation` is set to
`tf.VariableAggregation.NONE` when creating a `tf.Variable` in `tf.distribute`
scope.
"""
def _is_mirrored(self):
return True
def value(self, var):
return var._get_on_device_or_primary().value() # pylint: disable=protected-access
def _as_graph_element(self, var):
return var._get_on_device_or_primary()._as_graph_element() # pylint: disable=protected-access
def _get_cross_replica(self, var):
# Return identity, to avoid directly exposing the variable to the user and
# allowing it to be modified by mistake.
return array_ops.identity(Mirrored._get_cross_replica(var)) # pylint: disable=protected-access
def _update_replica(self, var, update_fn, value, **kwargs):
return update_fn(var._get_on_device_or_primary(), value, **kwargs) # pylint: disable=protected-access
def assign(self, var, value, use_locking=False, name=None, read_value=True):
return values_util.on_write_assign(var, value, use_locking=use_locking,
name=name, read_value=read_value)
def assign_add(self, var, value, use_locking=False, name=None,
read_value=True):
return values_util.on_write_assign_add(var, value, use_locking=use_locking,
name=name, read_value=read_value)
def assign_sub(self, var, value, use_locking=False, name=None,
read_value=True):
return values_util.on_write_assign_sub(var, value, use_locking=use_locking,
name=name, read_value=read_value)
def scatter_sub(self, var, sparse_delta, use_locking=False, name=None):
return values_util.scatter_sub(var, sparse_delta, use_locking=use_locking,
name=name)
def scatter_add(self, var, sparse_delta, use_locking=False, name=None):
return values_util.scatter_add(var, sparse_delta, use_locking=use_locking,
name=name)
def scatter_mul(self, var, sparse_delta, use_locking=False, name=None):
return values_util.scatter_mul(var, sparse_delta, use_locking=use_locking,
name=name)
def scatter_div(self, var, sparse_delta, use_locking=False, name=None):
return values_util.scatter_div(var, sparse_delta, use_locking=use_locking,
name=name)
def scatter_min(self, var, sparse_delta, use_locking=False, name=None):
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
self._aggregation != vs.VariableAggregation.NONE):
raise NotImplementedError(values_util.scatter_error_msg.format(
op_name="scatter_min", aggregation=self._aggregation))
return values_util.scatter_min(var, sparse_delta, use_locking=use_locking,
name=name)
def scatter_max(self, var, sparse_delta, use_locking=False, name=None):
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
self._aggregation != vs.VariableAggregation.NONE):
raise NotImplementedError(values_util.scatter_error_msg.format(
op_name="scatter_max", aggregation=self._aggregation))
return values_util.scatter_max(var, sparse_delta, use_locking=use_locking,
name=name)
def scatter_update(self, var, sparse_delta, use_locking=False, name=None):
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
self._aggregation != vs.VariableAggregation.NONE):
raise NotImplementedError(values_util.scatter_error_msg.format(
op_name="scatter_update", aggregation=self._aggregation))
return values_util.scatter_update(var, sparse_delta,
use_locking=use_locking,
name=name)
def get_saveable(self, var, primary_var, name):
del var, name
return primary_var, ""
def get_restore_ops(self, var, tensor):
return control_flow_ops.group(
tuple(
values_util.assign_on_device(v.device, v, tensor)
for v in var.values))
class OnWritePolicy(AutoPolicy):
"""Policy defined for `tf.VariableSynchronization.ON_WRITE` synchronization.
This policy is created when the following `synchronization` and
`aggregation` parameters are specified when creating a `tf.Variable` in
`tf.distribute` scope:
* `synchronization` is equal to `tf.VariableSynchronization.AUTO` and
aggregation can be any of the following `tf.VariableAggregation` enum
values such as `SUM`, `MEAN` or `ONLY_FIRST_REPLICA`.
* `synchronization` is equal to `tf.VariableSynchronization.ON_WRITE` and
aggregation can be any of the following `tf.VariableAggregation` enum
values such as `NONE`, `SUM`, `MEAN` or `ONLY_FIRST_REPLICA`.
"""
def _update_replica(self, var, update_fn, value, **kwargs):
return _on_write_update_replica(var, update_fn, value, **kwargs)
# Utility functions
# Return True if the Value is Mirrored or the Variable is replicated and kept in
# sync.
def _is_mirrored(val):
if isinstance(val, DistributedVariable):
if val._var_policy: # pylint: disable=protected-access
return val._var_policy._is_mirrored() # pylint: disable=protected-access
return isinstance(val, Mirrored)
def _is_sync_on_read(val):
if isinstance(val, DistributedVariable):
if val._var_policy: # pylint: disable=protected-access
return not val._var_policy._is_mirrored() # pylint: disable=protected-access
return not isinstance(val, Mirrored)

View File

@ -23,9 +23,155 @@ from tensorflow.python.distribute import distribution_strategy_context as ds_con
from tensorflow.python.distribute import reduce_util
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope as vs
def on_write_assign(var, value, use_locking=False, name=None, read_value=True):
assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
return var._update( # pylint: disable=protected-access
update_fn=assign_fn,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
def on_write_assign_add(var, value, use_locking=False, name=None,
read_value=True):
assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
return var._update( # pylint: disable=protected-access
update_fn=assign_add_fn,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
def on_write_assign_sub(var, value, use_locking=False, name=None,
read_value=True):
assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
return var._update( # pylint: disable=protected-access
update_fn=assign_sub_fn,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
def assign_on_each_device(var, assign_func, value, read_value):
update = control_flow_ops.group(
tuple(assign_func(v.device, v, value) for v in var._values)) # pylint: disable=protected-access
if not read_value:
return update
with ops.control_dependencies([update] if update else []):
return var.read_value()
def on_read_assign_sub_cross_replica(var, value, read_value=True):
with ds_context.enter_or_assert_strategy(var.distribute_strategy):
if ds_context.in_cross_replica_context():
if var.aggregation == vs.VariableAggregation.SUM:
raise ValueError(
"SyncOnReadVariable does not support `assign_sub` in "
"cross-replica context when aggregation is set to "
"`tf.VariableAggregation.SUM`.")
return assign_on_each_device(var, assign_sub_on_device,
value, read_value)
def on_read_assign_add_cross_replica(var, value, read_value=True):
with ds_context.enter_or_assert_strategy(var.distribute_strategy):
if ds_context.in_cross_replica_context():
if var.aggregation == vs.VariableAggregation.SUM:
raise ValueError(
"SyncOnReadVariable does not support `assign_add` in "
"cross-replica context when aggregation is set to "
"`tf.VariableAggregation.SUM`.")
return assign_on_each_device(var, assign_add_on_device,
value, read_value)
def on_read_assign_cross_replica(var, value, read_value=True):
"""Return the value of the variable in cross replica context."""
with ds_context.enter_or_assert_strategy(var.distribute_strategy):
if ds_context.in_cross_replica_context():
# To preserve the sum across save and restore, we have to divide the
# total across all devices when restoring a variable that was summed
# when saving.
tensor = value
# TODO(anjs): Should this be over all the replicas in sync since we
# call `reduce` on the variable during read?
if var.aggregation == vs.VariableAggregation.SUM:
tensor = math_ops.cast(tensor / len(var._values), var.dtype) # pylint: disable=protected-access
return assign_on_each_device(var, assign_on_device, tensor,
read_value)
def scatter_sub(var, sparse_delta, use_locking=False, name=None):
scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw)
return var._update( # pylint: disable=protected-access
update_fn=scatter_sub_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
def scatter_add(var, sparse_delta, use_locking=False, name=None):
scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw)
return var._update( # pylint: disable=protected-access
update_fn=scatter_add_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
def scatter_mul(var, sparse_delta, use_locking=False, name=None):
scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw)
return var._update( # pylint: disable=protected-access
update_fn=scatter_mul_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
def scatter_div(var, sparse_delta, use_locking=False, name=None):
scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw)
return var._update( # pylint: disable=protected-access
update_fn=scatter_div_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
def scatter_min(var, sparse_delta, use_locking=False, name=None):
scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw)
return var._update( # pylint: disable=protected-access
update_fn=scatter_min_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
def scatter_max(var, sparse_delta, use_locking=False, name=None):
scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw)
return var._update( # pylint: disable=protected-access
update_fn=scatter_max_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
def scatter_update(var, sparse_delta, use_locking=False, name=None):
scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw)
return var._update( # pylint: disable=protected-access
update_fn=scatter_update_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
def get_current_replica_id_as_int():
"""Returns the current replica ID as an integer, or `None`."""
replica_context = ds_context.get_replica_context()
@ -89,3 +235,9 @@ aggregation_error_msg = (
"`tf.distribute.get_replica_context().merge_call(merge_fn, ..)`."
"Inside `merge_fn`, you can then update the {variable_type} "
"using `tf.distribute.StrategyExtended.update()`.")
scatter_error_msg = ("{op_name} is only supported for mirrored "
"variable (variable created within certain "
"`tf.distribute.Strategy` scope) with NONE or "
"`ONLY_FIRST_REPLICA` aggregation, got: {aggregation}.")