Add VariablePolicy field to the DistributedVariable class as part of an internal refactor that will allow us to attach different policies to a DistributedVariable.
PiperOrigin-RevId: 316629999 Change-Id: I20160480b0678657198112adaa61ad7a47823cbd
This commit is contained in:
parent
b03a4dbbd6
commit
5107743c47
@ -41,6 +41,36 @@ from tensorflow.python.types import core
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
def _on_write_update_replica(var, update_fn, value, **kwargs):
|
||||
"""Updates variables with ON_WRITE synchronization in replica context."""
|
||||
if var.aggregation == vs.VariableAggregation.NONE:
|
||||
return update_fn(var._get_on_device_or_primary(), value, **kwargs) # pylint: disable=protected-access
|
||||
|
||||
def merge_fn(strategy, value, **kwargs):
|
||||
"""Aggregate values and update all variables in cross replica context."""
|
||||
# Don't allow MEAN with non float dtype, since it may cause unexpected
|
||||
# precision loss. Python3 and NumPy automatically upcast integers to
|
||||
# float in division, but we should always preserve the type.
|
||||
#
|
||||
# Note that to be backward compatible we allow the case when the value
|
||||
# is *always* the same on each replica. I.E. value is not a
|
||||
# PerReplica. Refer to regroup() to see how values are grouped.
|
||||
if var.aggregation == vs.VariableAggregation.MEAN and (
|
||||
not var.dtype.is_floating) and isinstance(value, PerReplica):
|
||||
raise ValueError(
|
||||
"Cannot update non-float variables with "
|
||||
"tf.VariableAggregation.MEAN aggregation in replica context. "
|
||||
"Either change the variable dtype to float or update it in "
|
||||
"cross-replica context.")
|
||||
|
||||
assert strategy == var.distribute_strategy
|
||||
v = values_util.apply_aggregation(strategy, value, var.aggregation, var)
|
||||
return var._update_cross_replica(update_fn, v, **kwargs) # pylint: disable=protected-access
|
||||
|
||||
return ds_context.get_replica_context().merge_call(
|
||||
merge_fn, args=(value,), kwargs=kwargs)
|
||||
|
||||
|
||||
@tf_export("distribute.DistributedValues", v1=[])
|
||||
class DistributedValues(object):
|
||||
"""Base class for representing distributed values.
|
||||
@ -409,10 +439,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
|
||||
core.Tensor):
|
||||
"""Holds a map from replica to variables."""
|
||||
|
||||
# TODO(josh11b): Support changing the set of variables if e.g. if new
|
||||
# devices are joining or a device is to leave.
|
||||
|
||||
def __init__(self, strategy, values, aggregation):
|
||||
def __init__(self, strategy, values, aggregation, var_policy=None):
|
||||
self._distribute_strategy = strategy
|
||||
self._aggregation = aggregation
|
||||
super(DistributedVariable, self).__init__(values)
|
||||
@ -439,6 +466,9 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
|
||||
# when restoring from a checkpoint, we may set the _initializer_op
|
||||
# property on the entire `DistributedVariable`.
|
||||
self._initializer_op = None
|
||||
# Set a VariablePolicy which decides how we replicate/aggregate the given
|
||||
# variable.
|
||||
self._var_policy = var_policy
|
||||
|
||||
def is_initialized(self, name=None):
|
||||
"""Identifies if all the component variables are initialized.
|
||||
@ -580,6 +610,8 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
|
||||
return array_ops.identity(self._get())
|
||||
|
||||
def value(self):
|
||||
if self._var_policy:
|
||||
return self._var_policy.value(self)
|
||||
return self._get_on_device_or_primary().value()
|
||||
|
||||
def numpy(self):
|
||||
@ -590,87 +622,104 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
|
||||
"numpy() is only available when eager execution is enabled.")
|
||||
|
||||
def assign_sub(self, value, use_locking=False, name=None, read_value=True):
|
||||
assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=assign_sub_fn,
|
||||
value=value,
|
||||
use_locking=use_locking,
|
||||
name=name,
|
||||
read_value=read_value)
|
||||
if self._var_policy:
|
||||
return self._var_policy.assign_sub(self, value, use_locking=use_locking,
|
||||
name=name, read_value=read_value)
|
||||
return values_util.on_write_assign_sub(self, value, use_locking=use_locking,
|
||||
name=name, read_value=read_value)
|
||||
|
||||
def assign_add(self, value, use_locking=False, name=None, read_value=True):
|
||||
assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=assign_add_fn,
|
||||
value=value,
|
||||
use_locking=use_locking,
|
||||
name=name,
|
||||
read_value=read_value)
|
||||
if self._var_policy:
|
||||
return self._var_policy.assign_add(self, value, use_locking=use_locking,
|
||||
name=name, read_value=read_value)
|
||||
return values_util.on_write_assign_add(self, value, use_locking=use_locking,
|
||||
name=name, read_value=read_value)
|
||||
|
||||
def assign(self, value, use_locking=False, name=None, read_value=True):
|
||||
assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=assign_fn,
|
||||
value=value,
|
||||
use_locking=use_locking,
|
||||
name=name,
|
||||
read_value=read_value)
|
||||
if self._var_policy:
|
||||
return self._var_policy.assign(self, value, use_locking=use_locking,
|
||||
name=name, read_value=read_value)
|
||||
return values_util.on_write_assign(self, value, use_locking=use_locking,
|
||||
name=name, read_value=read_value)
|
||||
|
||||
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
|
||||
scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=scatter_sub_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
if self._var_policy:
|
||||
self._var_policy.scatter_sub(self, sparse_delta, use_locking=use_locking,
|
||||
name=name)
|
||||
return values_util.scatter_sub(self, sparse_delta, use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def scatter_add(self, sparse_delta, use_locking=False, name=None):
|
||||
scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=scatter_add_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
if self._var_policy:
|
||||
self._var_policy.scatter_add(self, sparse_delta, use_locking=use_locking,
|
||||
name=name)
|
||||
return values_util.scatter_add(self, sparse_delta, use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def scatter_mul(self, sparse_delta, use_locking=False, name=None):
|
||||
scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=scatter_mul_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
if self._var_policy:
|
||||
self._var_policy.scatter_mul(self, sparse_delta, use_locking=use_locking,
|
||||
name=name)
|
||||
return values_util.scatter_mul(self, sparse_delta, use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def scatter_div(self, sparse_delta, use_locking=False, name=None):
|
||||
scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=scatter_div_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
if self._var_policy:
|
||||
self._var_policy.scatter_div(self, sparse_delta, use_locking=use_locking,
|
||||
name=name)
|
||||
return values_util.scatter_div(self, sparse_delta, use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def scatter_min(self, sparse_delta, use_locking=False, name=None):
|
||||
scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=scatter_min_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
if self._var_policy:
|
||||
self._var_policy.scatter_min(self, sparse_delta, use_locking=use_locking,
|
||||
name=name)
|
||||
return values_util.scatter_min(self, sparse_delta, use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def scatter_max(self, sparse_delta, use_locking=False, name=None):
|
||||
scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=scatter_max_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
if self._var_policy:
|
||||
self._var_policy.scatter_max(self, sparse_delta, use_locking=use_locking,
|
||||
name=name)
|
||||
return values_util.scatter_max(self, sparse_delta, use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def scatter_update(self, sparse_delta, use_locking=False, name=None):
|
||||
scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw)
|
||||
return self._update(
|
||||
update_fn=scatter_update_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
if self._var_policy:
|
||||
self._var_policy.scatter_update(self, sparse_delta,
|
||||
use_locking=use_locking, name=name)
|
||||
return values_util.scatter_update(self, sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def _gather_saveables_for_checkpoint(self):
|
||||
"""Overrides Trackable method.
|
||||
|
||||
This allows both name-based and object-based save and restore of
|
||||
DistributedVariables.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping attribute names to `SaveableObject` factories.
|
||||
"""
|
||||
|
||||
def _saveable_factory(name=self._common_name):
|
||||
return _DistributedVariableSaveable(self, self._primary, name)
|
||||
|
||||
return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
|
||||
|
||||
def _as_graph_element(self):
|
||||
if self._var_policy:
|
||||
return self._var_policy._as_graph_element(self) # pylint: disable=protected-access
|
||||
|
||||
raise NotImplementedError("No policy set for calling _as_graph_element.")
|
||||
|
||||
def _get_cross_replica(self):
|
||||
if self._var_policy:
|
||||
return self._var_policy._get_cross_replica(self) # pylint: disable=protected-access
|
||||
|
||||
raise NotImplementedError(
|
||||
"This method should be overridden by sub-classes which support cross-"
|
||||
"replica accesses.")
|
||||
|
||||
def _update_cross_replica(self, update_fn, value, **kwargs):
|
||||
"""Applies updates across replicas.
|
||||
@ -699,6 +748,8 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
|
||||
Returns:
|
||||
Updated variable or `tf.Operation`.
|
||||
"""
|
||||
if self._var_policy:
|
||||
return self._var_policy._update_replica(self, update_fn, value, **kwargs) # pylint: disable=protected-access
|
||||
raise NotImplementedError("should be implemented by subclass.")
|
||||
|
||||
def _update(self, update_fn, value, **kwargs):
|
||||
@ -735,6 +786,31 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
|
||||
"""Pass resource_variable_ops.is_resource_variable check."""
|
||||
pass
|
||||
|
||||
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
|
||||
"""Converts a variable to a tensor."""
|
||||
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||
return ops.convert_to_tensor(
|
||||
self._get(), dtype=dtype, name=name, as_ref=as_ref)
|
||||
|
||||
|
||||
class _DistributedVariableSaveable(saveable_object.SaveableObject):
|
||||
"""Class for defining how to restore a DistributedVariable."""
|
||||
|
||||
def __init__(self, distributed_variable, primary_variable, name):
|
||||
self._distributed_variable = distributed_variable
|
||||
if not self._distributed_variable._var_policy:
|
||||
raise ValueError("VariablePolicy has not been set for the distributed "
|
||||
"variable.")
|
||||
tensor, spec = distributed_variable._var_policy.get_saveable(
|
||||
distributed_variable, primary_variable, name)
|
||||
super(_DistributedVariableSaveable, self).__init__(tensor, spec, name)
|
||||
|
||||
def restore(self, restored_tensors, restored_shapes):
|
||||
"""Restore the same value into all variables."""
|
||||
tensor, = restored_tensors
|
||||
return self._distributed_variable._var_policy.get_restore_ops( # pylint: disable=protected-access
|
||||
self._distributed_variable, tensor)
|
||||
|
||||
|
||||
class _MirroredSaveable(saveable_object_util.ResourceVariableSaveable):
|
||||
"""Class for defining how to restore a MirroredVariable."""
|
||||
@ -756,61 +832,27 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
||||
"""Holds a map from replica to variables whose values are kept in sync."""
|
||||
|
||||
def _update_replica(self, update_fn, value, **kwargs):
|
||||
if self.aggregation == vs.VariableAggregation.NONE:
|
||||
return update_fn(self._get_on_device_or_primary(), value, **kwargs)
|
||||
|
||||
def merge_fn(strategy, value, **kwargs):
|
||||
"""Aggregate values and update all variables in cross replica context."""
|
||||
# Don't allow MEAN with non float dtype, since it may cause unexpected
|
||||
# precision loss. Python3 and NumPy automatically upcast integers to
|
||||
# float in division, but we should always preserve the type.
|
||||
#
|
||||
# Note that to be backward compatible we allow the case when the value
|
||||
# is *always* the same on each replica. I.E. value is not a
|
||||
# PerReplica. Refer to regroup() to see how values are grouped.
|
||||
if self._aggregation == vs.VariableAggregation.MEAN and (
|
||||
not self.dtype.is_floating) and isinstance(value, PerReplica):
|
||||
raise ValueError(
|
||||
"Cannot update non-float variables with "
|
||||
"tf.VariableAggregation.MEAN aggregation in replica context. "
|
||||
"Either change the variable dtype to float or update it in "
|
||||
"cross-replica context.")
|
||||
|
||||
assert strategy == self.distribute_strategy
|
||||
v = values_util.apply_aggregation(strategy, value, self.aggregation, self)
|
||||
return self._update_cross_replica(update_fn, v, **kwargs)
|
||||
|
||||
return ds_context.get_replica_context().merge_call(
|
||||
merge_fn, args=(value,), kwargs=kwargs)
|
||||
return _on_write_update_replica(self, update_fn, value, **kwargs)
|
||||
|
||||
def scatter_min(self, *args, **kwargs):
|
||||
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
|
||||
self._aggregation != vs.VariableAggregation.NONE):
|
||||
raise NotImplementedError("scatter_min is only supported for mirrored "
|
||||
"variable (variable created within certain "
|
||||
"`tf.distribute.Strategy` scope) with NONE or "
|
||||
"`ONLY_FIRST_REPLICA` aggregation, got: %s" %
|
||||
self._aggregation)
|
||||
raise NotImplementedError(values_util.scatter_error_msg.format(
|
||||
op_name="scatter_min", aggregation=self._aggregation))
|
||||
return super(MirroredVariable, self).scatter_min(*args, **kwargs)
|
||||
|
||||
def scatter_max(self, *args, **kwargs):
|
||||
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
|
||||
self._aggregation != vs.VariableAggregation.NONE):
|
||||
raise NotImplementedError("scatter_max is only supported for mirrored "
|
||||
"variable (variable created within certain "
|
||||
"`tf.distribute.Strategy` scope) with NONE or "
|
||||
"`ONLY_FIRST_REPLICA` aggregation, got: %s" %
|
||||
self._aggregation)
|
||||
raise NotImplementedError(values_util.scatter_error_msg.format(
|
||||
op_name="scatter_min", aggregation=self._aggregation))
|
||||
return super(MirroredVariable, self).scatter_max(*args, **kwargs)
|
||||
|
||||
def scatter_update(self, *args, **kwargs):
|
||||
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
|
||||
self._aggregation != vs.VariableAggregation.NONE):
|
||||
raise NotImplementedError("scatter_update is only supported for mirrored "
|
||||
"variable (variable created within certain "
|
||||
"`tf.distribute.Strategy` scope) with NONE or "
|
||||
"`ONLY_FIRST_REPLICA` aggregation, got: %s" %
|
||||
self._aggregation)
|
||||
raise NotImplementedError(values_util.scatter_error_msg.format(
|
||||
op_name="scatter_min", aggregation=self._aggregation))
|
||||
return super(MirroredVariable, self).scatter_update(*args, **kwargs)
|
||||
|
||||
def _get_cross_replica(self):
|
||||
@ -893,28 +935,13 @@ class SyncOnReadVariable(DistributedVariable):
|
||||
def _update_replica(self, update_fn, value, **kwargs):
|
||||
return update_fn(self._get_on_device_or_primary(), value, **kwargs)
|
||||
|
||||
def _assign_on_each_device(self, assign_func, value, read_value):
|
||||
update = control_flow_ops.group(
|
||||
tuple(
|
||||
assign_func(v.device, v, value)
|
||||
for v in self._values))
|
||||
if not read_value:
|
||||
return update
|
||||
with ops.control_dependencies([update] if update else []):
|
||||
return self.read_value()
|
||||
|
||||
# TODO(b/154017756): Make assign behaivor in cross replica context consistent
|
||||
# with MirroredVariable.
|
||||
def assign_sub(self, value, use_locking=False, name=None, read_value=True):
|
||||
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||
if ds_context.in_cross_replica_context():
|
||||
if self._aggregation == vs.VariableAggregation.SUM:
|
||||
raise ValueError(
|
||||
"SyncOnReadVariable does not support `assign_sub` in "
|
||||
"cross-replica context when aggregation is set to "
|
||||
"`tf.VariableAggregation.SUM`.")
|
||||
return self._assign_on_each_device(values_util.assign_sub_on_device,
|
||||
value, read_value)
|
||||
return values_util.on_read_assign_sub_cross_replica(
|
||||
self, value, read_value=read_value)
|
||||
else:
|
||||
return super(SyncOnReadVariable,
|
||||
self).assign_sub(value, use_locking, name, read_value)
|
||||
@ -922,13 +949,8 @@ class SyncOnReadVariable(DistributedVariable):
|
||||
def assign_add(self, value, use_locking=False, name=None, read_value=True):
|
||||
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||
if ds_context.in_cross_replica_context():
|
||||
if self._aggregation == vs.VariableAggregation.SUM:
|
||||
raise ValueError(
|
||||
"SyncOnReadVariable does not support `assign_add` in "
|
||||
"cross-replica context when aggregation is set to "
|
||||
"`tf.VariableAggregation.SUM`.")
|
||||
return self._assign_on_each_device(values_util.assign_add_on_device,
|
||||
value, read_value)
|
||||
return values_util.on_read_assign_add_cross_replica(
|
||||
self, value, read_value=read_value)
|
||||
else:
|
||||
return super(SyncOnReadVariable,
|
||||
self).assign_add(value, use_locking, name, read_value)
|
||||
@ -936,13 +958,8 @@ class SyncOnReadVariable(DistributedVariable):
|
||||
def assign(self, value, use_locking=False, name=None, read_value=True):
|
||||
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||
if ds_context.in_cross_replica_context():
|
||||
# To preserve the sum across save and restore, we have to divide the
|
||||
# total across all devices when restoring a variable that was summed
|
||||
# when saving.
|
||||
if self._aggregation == vs.VariableAggregation.SUM:
|
||||
value = math_ops.cast(value / len(self._values), self.dtype)
|
||||
return self._assign_on_each_device(values_util.assign_on_device, value,
|
||||
read_value)
|
||||
return values_util.on_read_assign_cross_replica(
|
||||
self, value, read_value=read_value)
|
||||
else:
|
||||
return super(SyncOnReadVariable,
|
||||
self).assign(value, use_locking, name, read_value)
|
||||
@ -987,7 +1004,7 @@ class SyncOnReadVariable(DistributedVariable):
|
||||
|
||||
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||
return self._distribute_strategy.reduce(
|
||||
reduce_util.ReduceOp.from_variable_aggregation(self.aggregation),
|
||||
reduce_util.ReduceOp.from_variable_aggregation(self._aggregation),
|
||||
self,
|
||||
axis=None)
|
||||
|
||||
@ -1022,6 +1039,16 @@ class SyncOnReadVariable(DistributedVariable):
|
||||
|
||||
# Register a conversion functions which reads the value of the variable,
|
||||
# allowing instances of the class to be used as tensors.
|
||||
# DistributedVariable
|
||||
def _tensor_conversion_distributed_var(var, dtype=None, name=None,
|
||||
as_ref=False):
|
||||
return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access
|
||||
|
||||
|
||||
ops.register_tensor_conversion_function(DistributedVariable,
|
||||
_tensor_conversion_distributed_var)
|
||||
|
||||
|
||||
# MirroredVariables
|
||||
def _tensor_conversion_mirrored(var, dtype=None, name=None, as_ref=False):
|
||||
return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access
|
||||
@ -1048,3 +1075,299 @@ def _tensor_conversion_sync_on_read(var, dtype=None, name=None, as_ref=False):
|
||||
|
||||
ops.register_tensor_conversion_function(SyncOnReadVariable,
|
||||
_tensor_conversion_sync_on_read)
|
||||
|
||||
|
||||
class VariablePolicy(object):
|
||||
"""Policy defining synchronization and aggregation of a distributed variable.
|
||||
|
||||
Given `synchronization` and `aggregation` parameters set on a `tf.Variable`
|
||||
during variable creation within `tf.distribute` scope, `tf.distribute` creates
|
||||
an appropriate policy object and assigns it to the distributed variable. All
|
||||
variable operations are delegated to the respective policy object.
|
||||
"""
|
||||
|
||||
def __init__(self, aggregation):
|
||||
self._aggregation = aggregation
|
||||
|
||||
def value(self):
|
||||
raise NotImplementedError(
|
||||
"This method should be overridden by sub-classes.")
|
||||
|
||||
def _is_mirrored(self):
|
||||
raise NotImplementedError(
|
||||
"This method should be overridden by sub-classes.")
|
||||
|
||||
def _as_graph_element(self, _):
|
||||
raise NotImplementedError(
|
||||
"This method should be overridden by sub-classes.")
|
||||
|
||||
def _get_cross_replica(self, var):
|
||||
raise NotImplementedError(
|
||||
"This method should be overridden by sub-classes.")
|
||||
|
||||
def _update_replica(self, var, update_fn, value, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"This method should be overridden by sub-classes.")
|
||||
|
||||
|
||||
class OnReadPolicy(VariablePolicy):
|
||||
"""Policy defined for `tf.VariableSynchronization.ON_READ` synchronization.
|
||||
|
||||
This policy is created when `synchronization` is set to
|
||||
`tf.VariableSynchronization.ON_READ` and `aggregation` is set to any of the
|
||||
values allowed by the `tf.VariableAggregation` enum such as `NONE`, `SUM`,
|
||||
`MEAN` or `ONLY_FIRST_REPLICA`when creating a `tf.Variable` in `tf.distribute`
|
||||
scope.
|
||||
"""
|
||||
|
||||
def _is_mirrored(self):
|
||||
return False
|
||||
|
||||
def value(self, var):
|
||||
with ds_context.enter_or_assert_strategy(var.distribute_strategy):
|
||||
if ds_context.in_cross_replica_context():
|
||||
return var._get_cross_replica() # pylint: disable=protected-access
|
||||
else:
|
||||
return var._get_on_device_or_primary().value() # pylint: disable=protected-access
|
||||
|
||||
def _as_graph_element(self, var):
|
||||
with ds_context.enter_or_assert_strategy(var.distribute_strategy):
|
||||
if ds_context.in_cross_replica_context():
|
||||
return ops.convert_to_tensor(var._get_cross_replica()) # pylint: disable=protected-access
|
||||
return var._get()._as_graph_element() # pylint: disable=protected-access
|
||||
|
||||
def _get_cross_replica(self, var):
|
||||
if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
|
||||
return var._primary # pylint: disable=protected-access
|
||||
|
||||
with ds_context.enter_or_assert_strategy(var.distribute_strategy):
|
||||
return var.distribute_strategy.reduce(
|
||||
reduce_util.ReduceOp.from_variable_aggregation(self._aggregation),
|
||||
var,
|
||||
axis=None)
|
||||
|
||||
def _update_replica(self, var, update_fn, value, **kwargs):
|
||||
return update_fn(var._get_on_device_or_primary(), value, **kwargs) # pylint: disable=protected-access
|
||||
|
||||
def _scatter_not_implemented(self, method):
|
||||
raise NotImplementedError(
|
||||
"ON_READ variables doesn't support `%s` in cross replica context" %
|
||||
method)
|
||||
|
||||
def assign_sub(self, var, value, use_locking=False, name=None,
|
||||
read_value=True):
|
||||
with ds_context.enter_or_assert_strategy(var.distribute_strategy):
|
||||
if ds_context.in_cross_replica_context():
|
||||
return values_util.on_read_assign_sub_cross_replica(
|
||||
var, value, read_value=read_value)
|
||||
else:
|
||||
return values_util.on_write_assign_sub(
|
||||
var, value, use_locking=use_locking, name=name,
|
||||
read_value=read_value)
|
||||
|
||||
def assign_add(self, var, value, use_locking=False, name=None,
|
||||
read_value=True):
|
||||
with ds_context.enter_or_assert_strategy(var.distribute_strategy):
|
||||
if ds_context.in_cross_replica_context():
|
||||
return values_util.on_read_assign_add_cross_replica(
|
||||
var, value, read_value=read_value)
|
||||
else:
|
||||
return values_util.on_write_assign_add(
|
||||
var, value, use_locking=use_locking, name=name,
|
||||
read_value=read_value)
|
||||
|
||||
def assign(self, var, value, use_locking=False, name=None, read_value=True):
|
||||
with ds_context.enter_or_assert_strategy(var.distribute_strategy):
|
||||
if ds_context.in_cross_replica_context():
|
||||
return values_util.on_read_assign_cross_replica(var, value,
|
||||
read_value=read_value)
|
||||
else:
|
||||
return values_util.on_write_assign(var, value,
|
||||
use_locking=use_locking,
|
||||
name=name,
|
||||
read_value=read_value)
|
||||
|
||||
def scatter_sub(self, *args, **kwargs):
|
||||
del args, kwargs
|
||||
self._scatter_not_implemented("scatter_sub")
|
||||
|
||||
def scatter_add(self, *args, **kwargs):
|
||||
del args, kwargs
|
||||
self._scatter_not_implemented("scatter_add")
|
||||
|
||||
def scatter_mul(self, *args, **kwargs):
|
||||
del args, kwargs
|
||||
self._scatter_not_implemented("scatter_mul")
|
||||
|
||||
def scatter_div(self, *args, **kwargs):
|
||||
del args, kwargs
|
||||
self._scatter_not_implemented("scatter_div")
|
||||
|
||||
def scatter_min(self, *args, **kwargs):
|
||||
del args, kwargs
|
||||
self._scatter_not_implemented("scatter_min")
|
||||
|
||||
def scatter_max(self, *args, **kwargs):
|
||||
del args, kwargs
|
||||
self._scatter_not_implemented("scatter_max")
|
||||
|
||||
def scatter_update(self, *args, **kwargs):
|
||||
del args, kwargs
|
||||
self._scatter_not_implemented("scatter_update")
|
||||
|
||||
def get_saveable(self, var, primary_var, name):
|
||||
"""Create a saveable object for the given variable."""
|
||||
# We use a callable so that we don't have to evaluate this expression
|
||||
# in the case where we are trying to restore instead of save.
|
||||
def tensor():
|
||||
strategy = var.distribute_strategy
|
||||
return strategy.extended.read_var(var)
|
||||
|
||||
spec = saveable_object.SaveSpec(
|
||||
tensor=tensor,
|
||||
slice_spec="",
|
||||
name=name,
|
||||
dtype=var.dtype,
|
||||
device=primary_var.device)
|
||||
|
||||
return tensor, [spec]
|
||||
|
||||
def get_restore_ops(self, var, tensor):
|
||||
"""Restore the same value into all variables."""
|
||||
# To preserve the sum across save and restore, we have to divide the
|
||||
# total across all devices when restoring a variable that was summed
|
||||
# when saving.
|
||||
if self._aggregation == vs.VariableAggregation.SUM:
|
||||
tensor = math_ops.cast(tensor / len(var._devices), # pylint: disable=protected-access
|
||||
var.dtype)
|
||||
return control_flow_ops.group(
|
||||
tuple(
|
||||
values_util.assign_on_device(v.device, v, tensor)
|
||||
for v in var.values))
|
||||
|
||||
|
||||
class AutoPolicy(VariablePolicy):
|
||||
"""Policy defined for `tf.VariableSynchronization.AUTO` synchronization.
|
||||
|
||||
This policy is created when `synchronization` is set to
|
||||
`tf.VariableSynchronization.AUTO` and `aggregation` is set to
|
||||
`tf.VariableAggregation.NONE` when creating a `tf.Variable` in `tf.distribute`
|
||||
scope.
|
||||
"""
|
||||
|
||||
def _is_mirrored(self):
|
||||
return True
|
||||
|
||||
def value(self, var):
|
||||
return var._get_on_device_or_primary().value() # pylint: disable=protected-access
|
||||
|
||||
def _as_graph_element(self, var):
|
||||
return var._get_on_device_or_primary()._as_graph_element() # pylint: disable=protected-access
|
||||
|
||||
def _get_cross_replica(self, var):
|
||||
# Return identity, to avoid directly exposing the variable to the user and
|
||||
# allowing it to be modified by mistake.
|
||||
return array_ops.identity(Mirrored._get_cross_replica(var)) # pylint: disable=protected-access
|
||||
|
||||
def _update_replica(self, var, update_fn, value, **kwargs):
|
||||
return update_fn(var._get_on_device_or_primary(), value, **kwargs) # pylint: disable=protected-access
|
||||
|
||||
def assign(self, var, value, use_locking=False, name=None, read_value=True):
|
||||
return values_util.on_write_assign(var, value, use_locking=use_locking,
|
||||
name=name, read_value=read_value)
|
||||
|
||||
def assign_add(self, var, value, use_locking=False, name=None,
|
||||
read_value=True):
|
||||
return values_util.on_write_assign_add(var, value, use_locking=use_locking,
|
||||
name=name, read_value=read_value)
|
||||
|
||||
def assign_sub(self, var, value, use_locking=False, name=None,
|
||||
read_value=True):
|
||||
return values_util.on_write_assign_sub(var, value, use_locking=use_locking,
|
||||
name=name, read_value=read_value)
|
||||
|
||||
def scatter_sub(self, var, sparse_delta, use_locking=False, name=None):
|
||||
return values_util.scatter_sub(var, sparse_delta, use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def scatter_add(self, var, sparse_delta, use_locking=False, name=None):
|
||||
return values_util.scatter_add(var, sparse_delta, use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def scatter_mul(self, var, sparse_delta, use_locking=False, name=None):
|
||||
return values_util.scatter_mul(var, sparse_delta, use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def scatter_div(self, var, sparse_delta, use_locking=False, name=None):
|
||||
return values_util.scatter_div(var, sparse_delta, use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def scatter_min(self, var, sparse_delta, use_locking=False, name=None):
|
||||
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
|
||||
self._aggregation != vs.VariableAggregation.NONE):
|
||||
raise NotImplementedError(values_util.scatter_error_msg.format(
|
||||
op_name="scatter_min", aggregation=self._aggregation))
|
||||
return values_util.scatter_min(var, sparse_delta, use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def scatter_max(self, var, sparse_delta, use_locking=False, name=None):
|
||||
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
|
||||
self._aggregation != vs.VariableAggregation.NONE):
|
||||
raise NotImplementedError(values_util.scatter_error_msg.format(
|
||||
op_name="scatter_max", aggregation=self._aggregation))
|
||||
return values_util.scatter_max(var, sparse_delta, use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def scatter_update(self, var, sparse_delta, use_locking=False, name=None):
|
||||
if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
|
||||
self._aggregation != vs.VariableAggregation.NONE):
|
||||
raise NotImplementedError(values_util.scatter_error_msg.format(
|
||||
op_name="scatter_update", aggregation=self._aggregation))
|
||||
return values_util.scatter_update(var, sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
def get_saveable(self, var, primary_var, name):
|
||||
del var, name
|
||||
return primary_var, ""
|
||||
|
||||
def get_restore_ops(self, var, tensor):
|
||||
return control_flow_ops.group(
|
||||
tuple(
|
||||
values_util.assign_on_device(v.device, v, tensor)
|
||||
for v in var.values))
|
||||
|
||||
|
||||
class OnWritePolicy(AutoPolicy):
|
||||
"""Policy defined for `tf.VariableSynchronization.ON_WRITE` synchronization.
|
||||
|
||||
This policy is created when the following `synchronization` and
|
||||
`aggregation` parameters are specified when creating a `tf.Variable` in
|
||||
`tf.distribute` scope:
|
||||
* `synchronization` is equal to `tf.VariableSynchronization.AUTO` and
|
||||
aggregation can be any of the following `tf.VariableAggregation` enum
|
||||
values such as `SUM`, `MEAN` or `ONLY_FIRST_REPLICA`.
|
||||
* `synchronization` is equal to `tf.VariableSynchronization.ON_WRITE` and
|
||||
aggregation can be any of the following `tf.VariableAggregation` enum
|
||||
values such as `NONE`, `SUM`, `MEAN` or `ONLY_FIRST_REPLICA`.
|
||||
"""
|
||||
|
||||
def _update_replica(self, var, update_fn, value, **kwargs):
|
||||
return _on_write_update_replica(var, update_fn, value, **kwargs)
|
||||
|
||||
|
||||
# Utility functions
|
||||
# Return True if the Value is Mirrored or the Variable is replicated and kept in
|
||||
# sync.
|
||||
def _is_mirrored(val):
|
||||
if isinstance(val, DistributedVariable):
|
||||
if val._var_policy: # pylint: disable=protected-access
|
||||
return val._var_policy._is_mirrored() # pylint: disable=protected-access
|
||||
return isinstance(val, Mirrored)
|
||||
|
||||
|
||||
def _is_sync_on_read(val):
|
||||
if isinstance(val, DistributedVariable):
|
||||
if val._var_policy: # pylint: disable=protected-access
|
||||
return not val._var_policy._is_mirrored() # pylint: disable=protected-access
|
||||
return not isinstance(val, Mirrored)
|
||||
|
@ -23,9 +23,155 @@ from tensorflow.python.distribute import distribution_strategy_context as ds_con
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
|
||||
|
||||
def on_write_assign(var, value, use_locking=False, name=None, read_value=True):
|
||||
assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
|
||||
return var._update( # pylint: disable=protected-access
|
||||
update_fn=assign_fn,
|
||||
value=value,
|
||||
use_locking=use_locking,
|
||||
name=name,
|
||||
read_value=read_value)
|
||||
|
||||
|
||||
def on_write_assign_add(var, value, use_locking=False, name=None,
|
||||
read_value=True):
|
||||
assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
|
||||
return var._update( # pylint: disable=protected-access
|
||||
update_fn=assign_add_fn,
|
||||
value=value,
|
||||
use_locking=use_locking,
|
||||
name=name,
|
||||
read_value=read_value)
|
||||
|
||||
|
||||
def on_write_assign_sub(var, value, use_locking=False, name=None,
|
||||
read_value=True):
|
||||
assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
|
||||
return var._update( # pylint: disable=protected-access
|
||||
update_fn=assign_sub_fn,
|
||||
value=value,
|
||||
use_locking=use_locking,
|
||||
name=name,
|
||||
read_value=read_value)
|
||||
|
||||
|
||||
def assign_on_each_device(var, assign_func, value, read_value):
|
||||
update = control_flow_ops.group(
|
||||
tuple(assign_func(v.device, v, value) for v in var._values)) # pylint: disable=protected-access
|
||||
if not read_value:
|
||||
return update
|
||||
with ops.control_dependencies([update] if update else []):
|
||||
return var.read_value()
|
||||
|
||||
|
||||
def on_read_assign_sub_cross_replica(var, value, read_value=True):
|
||||
with ds_context.enter_or_assert_strategy(var.distribute_strategy):
|
||||
if ds_context.in_cross_replica_context():
|
||||
if var.aggregation == vs.VariableAggregation.SUM:
|
||||
raise ValueError(
|
||||
"SyncOnReadVariable does not support `assign_sub` in "
|
||||
"cross-replica context when aggregation is set to "
|
||||
"`tf.VariableAggregation.SUM`.")
|
||||
return assign_on_each_device(var, assign_sub_on_device,
|
||||
value, read_value)
|
||||
|
||||
|
||||
def on_read_assign_add_cross_replica(var, value, read_value=True):
|
||||
with ds_context.enter_or_assert_strategy(var.distribute_strategy):
|
||||
if ds_context.in_cross_replica_context():
|
||||
if var.aggregation == vs.VariableAggregation.SUM:
|
||||
raise ValueError(
|
||||
"SyncOnReadVariable does not support `assign_add` in "
|
||||
"cross-replica context when aggregation is set to "
|
||||
"`tf.VariableAggregation.SUM`.")
|
||||
return assign_on_each_device(var, assign_add_on_device,
|
||||
value, read_value)
|
||||
|
||||
|
||||
def on_read_assign_cross_replica(var, value, read_value=True):
|
||||
"""Return the value of the variable in cross replica context."""
|
||||
with ds_context.enter_or_assert_strategy(var.distribute_strategy):
|
||||
if ds_context.in_cross_replica_context():
|
||||
# To preserve the sum across save and restore, we have to divide the
|
||||
# total across all devices when restoring a variable that was summed
|
||||
# when saving.
|
||||
tensor = value
|
||||
# TODO(anjs): Should this be over all the replicas in sync since we
|
||||
# call `reduce` on the variable during read?
|
||||
if var.aggregation == vs.VariableAggregation.SUM:
|
||||
tensor = math_ops.cast(tensor / len(var._values), var.dtype) # pylint: disable=protected-access
|
||||
return assign_on_each_device(var, assign_on_device, tensor,
|
||||
read_value)
|
||||
|
||||
|
||||
def scatter_sub(var, sparse_delta, use_locking=False, name=None):
|
||||
scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw)
|
||||
return var._update( # pylint: disable=protected-access
|
||||
update_fn=scatter_sub_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
|
||||
def scatter_add(var, sparse_delta, use_locking=False, name=None):
|
||||
scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw)
|
||||
return var._update( # pylint: disable=protected-access
|
||||
update_fn=scatter_add_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
|
||||
def scatter_mul(var, sparse_delta, use_locking=False, name=None):
|
||||
scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw)
|
||||
return var._update( # pylint: disable=protected-access
|
||||
update_fn=scatter_mul_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
|
||||
def scatter_div(var, sparse_delta, use_locking=False, name=None):
|
||||
scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw)
|
||||
return var._update( # pylint: disable=protected-access
|
||||
update_fn=scatter_div_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
|
||||
def scatter_min(var, sparse_delta, use_locking=False, name=None):
|
||||
scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw)
|
||||
return var._update( # pylint: disable=protected-access
|
||||
update_fn=scatter_min_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
|
||||
def scatter_max(var, sparse_delta, use_locking=False, name=None):
|
||||
scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw)
|
||||
return var._update( # pylint: disable=protected-access
|
||||
update_fn=scatter_max_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
|
||||
def scatter_update(var, sparse_delta, use_locking=False, name=None):
|
||||
scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw)
|
||||
return var._update( # pylint: disable=protected-access
|
||||
update_fn=scatter_update_fn,
|
||||
value=sparse_delta,
|
||||
use_locking=use_locking,
|
||||
name=name)
|
||||
|
||||
|
||||
def get_current_replica_id_as_int():
|
||||
"""Returns the current replica ID as an integer, or `None`."""
|
||||
replica_context = ds_context.get_replica_context()
|
||||
@ -89,3 +235,9 @@ aggregation_error_msg = (
|
||||
"`tf.distribute.get_replica_context().merge_call(merge_fn, ..)`."
|
||||
"Inside `merge_fn`, you can then update the {variable_type} "
|
||||
"using `tf.distribute.StrategyExtended.update()`.")
|
||||
|
||||
|
||||
scatter_error_msg = ("{op_name} is only supported for mirrored "
|
||||
"variable (variable created within certain "
|
||||
"`tf.distribute.Strategy` scope) with NONE or "
|
||||
"`ONLY_FIRST_REPLICA` aggregation, got: {aggregation}.")
|
||||
|
Loading…
Reference in New Issue
Block a user