diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py index 5c038c01999..c6e0eb34a7b 100644 --- a/tensorflow/python/distribute/values.py +++ b/tensorflow/python/distribute/values.py @@ -41,6 +41,36 @@ from tensorflow.python.types import core from tensorflow.python.util.tf_export import tf_export +def _on_write_update_replica(var, update_fn, value, **kwargs): + """Updates variables with ON_WRITE synchronization in replica context.""" + if var.aggregation == vs.VariableAggregation.NONE: + return update_fn(var._get_on_device_or_primary(), value, **kwargs) # pylint: disable=protected-access + + def merge_fn(strategy, value, **kwargs): + """Aggregate values and update all variables in cross replica context.""" + # Don't allow MEAN with non float dtype, since it may cause unexpected + # precision loss. Python3 and NumPy automatically upcast integers to + # float in division, but we should always preserve the type. + # + # Note that to be backward compatible we allow the case when the value + # is *always* the same on each replica. I.E. value is not a + # PerReplica. Refer to regroup() to see how values are grouped. + if var.aggregation == vs.VariableAggregation.MEAN and ( + not var.dtype.is_floating) and isinstance(value, PerReplica): + raise ValueError( + "Cannot update non-float variables with " + "tf.VariableAggregation.MEAN aggregation in replica context. " + "Either change the variable dtype to float or update it in " + "cross-replica context.") + + assert strategy == var.distribute_strategy + v = values_util.apply_aggregation(strategy, value, var.aggregation, var) + return var._update_cross_replica(update_fn, v, **kwargs) # pylint: disable=protected-access + + return ds_context.get_replica_context().merge_call( + merge_fn, args=(value,), kwargs=kwargs) + + @tf_export("distribute.DistributedValues", v1=[]) class DistributedValues(object): """Base class for representing distributed values. @@ -409,10 +439,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable, core.Tensor): """Holds a map from replica to variables.""" - # TODO(josh11b): Support changing the set of variables if e.g. if new - # devices are joining or a device is to leave. - - def __init__(self, strategy, values, aggregation): + def __init__(self, strategy, values, aggregation, var_policy=None): self._distribute_strategy = strategy self._aggregation = aggregation super(DistributedVariable, self).__init__(values) @@ -439,6 +466,9 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable, # when restoring from a checkpoint, we may set the _initializer_op # property on the entire `DistributedVariable`. self._initializer_op = None + # Set a VariablePolicy which decides how we replicate/aggregate the given + # variable. + self._var_policy = var_policy def is_initialized(self, name=None): """Identifies if all the component variables are initialized. @@ -580,6 +610,8 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable, return array_ops.identity(self._get()) def value(self): + if self._var_policy: + return self._var_policy.value(self) return self._get_on_device_or_primary().value() def numpy(self): @@ -590,87 +622,104 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable, "numpy() is only available when eager execution is enabled.") def assign_sub(self, value, use_locking=False, name=None, read_value=True): - assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw) - return self._update( - update_fn=assign_sub_fn, - value=value, - use_locking=use_locking, - name=name, - read_value=read_value) + if self._var_policy: + return self._var_policy.assign_sub(self, value, use_locking=use_locking, + name=name, read_value=read_value) + return values_util.on_write_assign_sub(self, value, use_locking=use_locking, + name=name, read_value=read_value) def assign_add(self, value, use_locking=False, name=None, read_value=True): - assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw) - return self._update( - update_fn=assign_add_fn, - value=value, - use_locking=use_locking, - name=name, - read_value=read_value) + if self._var_policy: + return self._var_policy.assign_add(self, value, use_locking=use_locking, + name=name, read_value=read_value) + return values_util.on_write_assign_add(self, value, use_locking=use_locking, + name=name, read_value=read_value) def assign(self, value, use_locking=False, name=None, read_value=True): - assign_fn = lambda var, *a, **kw: var.assign(*a, **kw) - return self._update( - update_fn=assign_fn, - value=value, - use_locking=use_locking, - name=name, - read_value=read_value) + if self._var_policy: + return self._var_policy.assign(self, value, use_locking=use_locking, + name=name, read_value=read_value) + return values_util.on_write_assign(self, value, use_locking=use_locking, + name=name, read_value=read_value) def scatter_sub(self, sparse_delta, use_locking=False, name=None): - scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw) - return self._update( - update_fn=scatter_sub_fn, - value=sparse_delta, - use_locking=use_locking, - name=name) + if self._var_policy: + self._var_policy.scatter_sub(self, sparse_delta, use_locking=use_locking, + name=name) + return values_util.scatter_sub(self, sparse_delta, use_locking=use_locking, + name=name) def scatter_add(self, sparse_delta, use_locking=False, name=None): - scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw) - return self._update( - update_fn=scatter_add_fn, - value=sparse_delta, - use_locking=use_locking, - name=name) + if self._var_policy: + self._var_policy.scatter_add(self, sparse_delta, use_locking=use_locking, + name=name) + return values_util.scatter_add(self, sparse_delta, use_locking=use_locking, + name=name) def scatter_mul(self, sparse_delta, use_locking=False, name=None): - scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw) - return self._update( - update_fn=scatter_mul_fn, - value=sparse_delta, - use_locking=use_locking, - name=name) + if self._var_policy: + self._var_policy.scatter_mul(self, sparse_delta, use_locking=use_locking, + name=name) + return values_util.scatter_mul(self, sparse_delta, use_locking=use_locking, + name=name) def scatter_div(self, sparse_delta, use_locking=False, name=None): - scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw) - return self._update( - update_fn=scatter_div_fn, - value=sparse_delta, - use_locking=use_locking, - name=name) + if self._var_policy: + self._var_policy.scatter_div(self, sparse_delta, use_locking=use_locking, + name=name) + return values_util.scatter_div(self, sparse_delta, use_locking=use_locking, + name=name) def scatter_min(self, sparse_delta, use_locking=False, name=None): - scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw) - return self._update( - update_fn=scatter_min_fn, - value=sparse_delta, - use_locking=use_locking, - name=name) + if self._var_policy: + self._var_policy.scatter_min(self, sparse_delta, use_locking=use_locking, + name=name) + return values_util.scatter_min(self, sparse_delta, use_locking=use_locking, + name=name) def scatter_max(self, sparse_delta, use_locking=False, name=None): - scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw) - return self._update( - update_fn=scatter_max_fn, - value=sparse_delta, - use_locking=use_locking, - name=name) + if self._var_policy: + self._var_policy.scatter_max(self, sparse_delta, use_locking=use_locking, + name=name) + return values_util.scatter_max(self, sparse_delta, use_locking=use_locking, + name=name) def scatter_update(self, sparse_delta, use_locking=False, name=None): - scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw) - return self._update( - update_fn=scatter_update_fn, - value=sparse_delta, - use_locking=use_locking, - name=name) + if self._var_policy: + self._var_policy.scatter_update(self, sparse_delta, + use_locking=use_locking, name=name) + return values_util.scatter_update(self, sparse_delta, + use_locking=use_locking, + name=name) + + def _gather_saveables_for_checkpoint(self): + """Overrides Trackable method. + + This allows both name-based and object-based save and restore of + DistributedVariables. + + Returns: + A dictionary mapping attribute names to `SaveableObject` factories. + """ + + def _saveable_factory(name=self._common_name): + return _DistributedVariableSaveable(self, self._primary, name) + + return {trackable.VARIABLE_VALUE_KEY: _saveable_factory} + + def _as_graph_element(self): + if self._var_policy: + return self._var_policy._as_graph_element(self) # pylint: disable=protected-access + + raise NotImplementedError("No policy set for calling _as_graph_element.") + + def _get_cross_replica(self): + if self._var_policy: + return self._var_policy._get_cross_replica(self) # pylint: disable=protected-access + + raise NotImplementedError( + "This method should be overridden by sub-classes which support cross-" + "replica accesses.") def _update_cross_replica(self, update_fn, value, **kwargs): """Applies updates across replicas. @@ -699,6 +748,8 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable, Returns: Updated variable or `tf.Operation`. """ + if self._var_policy: + return self._var_policy._update_replica(self, update_fn, value, **kwargs) # pylint: disable=protected-access raise NotImplementedError("should be implemented by subclass.") def _update(self, update_fn, value, **kwargs): @@ -735,6 +786,31 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable, """Pass resource_variable_ops.is_resource_variable check.""" pass + def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): + """Converts a variable to a tensor.""" + with ds_context.enter_or_assert_strategy(self._distribute_strategy): + return ops.convert_to_tensor( + self._get(), dtype=dtype, name=name, as_ref=as_ref) + + +class _DistributedVariableSaveable(saveable_object.SaveableObject): + """Class for defining how to restore a DistributedVariable.""" + + def __init__(self, distributed_variable, primary_variable, name): + self._distributed_variable = distributed_variable + if not self._distributed_variable._var_policy: + raise ValueError("VariablePolicy has not been set for the distributed " + "variable.") + tensor, spec = distributed_variable._var_policy.get_saveable( + distributed_variable, primary_variable, name) + super(_DistributedVariableSaveable, self).__init__(tensor, spec, name) + + def restore(self, restored_tensors, restored_shapes): + """Restore the same value into all variables.""" + tensor, = restored_tensors + return self._distributed_variable._var_policy.get_restore_ops( # pylint: disable=protected-access + self._distributed_variable, tensor) + class _MirroredSaveable(saveable_object_util.ResourceVariableSaveable): """Class for defining how to restore a MirroredVariable.""" @@ -756,61 +832,27 @@ class MirroredVariable(DistributedVariable, Mirrored): """Holds a map from replica to variables whose values are kept in sync.""" def _update_replica(self, update_fn, value, **kwargs): - if self.aggregation == vs.VariableAggregation.NONE: - return update_fn(self._get_on_device_or_primary(), value, **kwargs) - - def merge_fn(strategy, value, **kwargs): - """Aggregate values and update all variables in cross replica context.""" - # Don't allow MEAN with non float dtype, since it may cause unexpected - # precision loss. Python3 and NumPy automatically upcast integers to - # float in division, but we should always preserve the type. - # - # Note that to be backward compatible we allow the case when the value - # is *always* the same on each replica. I.E. value is not a - # PerReplica. Refer to regroup() to see how values are grouped. - if self._aggregation == vs.VariableAggregation.MEAN and ( - not self.dtype.is_floating) and isinstance(value, PerReplica): - raise ValueError( - "Cannot update non-float variables with " - "tf.VariableAggregation.MEAN aggregation in replica context. " - "Either change the variable dtype to float or update it in " - "cross-replica context.") - - assert strategy == self.distribute_strategy - v = values_util.apply_aggregation(strategy, value, self.aggregation, self) - return self._update_cross_replica(update_fn, v, **kwargs) - - return ds_context.get_replica_context().merge_call( - merge_fn, args=(value,), kwargs=kwargs) + return _on_write_update_replica(self, update_fn, value, **kwargs) def scatter_min(self, *args, **kwargs): if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and self._aggregation != vs.VariableAggregation.NONE): - raise NotImplementedError("scatter_min is only supported for mirrored " - "variable (variable created within certain " - "`tf.distribute.Strategy` scope) with NONE or " - "`ONLY_FIRST_REPLICA` aggregation, got: %s" % - self._aggregation) + raise NotImplementedError(values_util.scatter_error_msg.format( + op_name="scatter_min", aggregation=self._aggregation)) return super(MirroredVariable, self).scatter_min(*args, **kwargs) def scatter_max(self, *args, **kwargs): if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and self._aggregation != vs.VariableAggregation.NONE): - raise NotImplementedError("scatter_max is only supported for mirrored " - "variable (variable created within certain " - "`tf.distribute.Strategy` scope) with NONE or " - "`ONLY_FIRST_REPLICA` aggregation, got: %s" % - self._aggregation) + raise NotImplementedError(values_util.scatter_error_msg.format( + op_name="scatter_min", aggregation=self._aggregation)) return super(MirroredVariable, self).scatter_max(*args, **kwargs) def scatter_update(self, *args, **kwargs): if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and self._aggregation != vs.VariableAggregation.NONE): - raise NotImplementedError("scatter_update is only supported for mirrored " - "variable (variable created within certain " - "`tf.distribute.Strategy` scope) with NONE or " - "`ONLY_FIRST_REPLICA` aggregation, got: %s" % - self._aggregation) + raise NotImplementedError(values_util.scatter_error_msg.format( + op_name="scatter_min", aggregation=self._aggregation)) return super(MirroredVariable, self).scatter_update(*args, **kwargs) def _get_cross_replica(self): @@ -893,28 +935,13 @@ class SyncOnReadVariable(DistributedVariable): def _update_replica(self, update_fn, value, **kwargs): return update_fn(self._get_on_device_or_primary(), value, **kwargs) - def _assign_on_each_device(self, assign_func, value, read_value): - update = control_flow_ops.group( - tuple( - assign_func(v.device, v, value) - for v in self._values)) - if not read_value: - return update - with ops.control_dependencies([update] if update else []): - return self.read_value() - # TODO(b/154017756): Make assign behaivor in cross replica context consistent # with MirroredVariable. def assign_sub(self, value, use_locking=False, name=None, read_value=True): with ds_context.enter_or_assert_strategy(self._distribute_strategy): if ds_context.in_cross_replica_context(): - if self._aggregation == vs.VariableAggregation.SUM: - raise ValueError( - "SyncOnReadVariable does not support `assign_sub` in " - "cross-replica context when aggregation is set to " - "`tf.VariableAggregation.SUM`.") - return self._assign_on_each_device(values_util.assign_sub_on_device, - value, read_value) + return values_util.on_read_assign_sub_cross_replica( + self, value, read_value=read_value) else: return super(SyncOnReadVariable, self).assign_sub(value, use_locking, name, read_value) @@ -922,13 +949,8 @@ class SyncOnReadVariable(DistributedVariable): def assign_add(self, value, use_locking=False, name=None, read_value=True): with ds_context.enter_or_assert_strategy(self._distribute_strategy): if ds_context.in_cross_replica_context(): - if self._aggregation == vs.VariableAggregation.SUM: - raise ValueError( - "SyncOnReadVariable does not support `assign_add` in " - "cross-replica context when aggregation is set to " - "`tf.VariableAggregation.SUM`.") - return self._assign_on_each_device(values_util.assign_add_on_device, - value, read_value) + return values_util.on_read_assign_add_cross_replica( + self, value, read_value=read_value) else: return super(SyncOnReadVariable, self).assign_add(value, use_locking, name, read_value) @@ -936,13 +958,8 @@ class SyncOnReadVariable(DistributedVariable): def assign(self, value, use_locking=False, name=None, read_value=True): with ds_context.enter_or_assert_strategy(self._distribute_strategy): if ds_context.in_cross_replica_context(): - # To preserve the sum across save and restore, we have to divide the - # total across all devices when restoring a variable that was summed - # when saving. - if self._aggregation == vs.VariableAggregation.SUM: - value = math_ops.cast(value / len(self._values), self.dtype) - return self._assign_on_each_device(values_util.assign_on_device, value, - read_value) + return values_util.on_read_assign_cross_replica( + self, value, read_value=read_value) else: return super(SyncOnReadVariable, self).assign(value, use_locking, name, read_value) @@ -987,7 +1004,7 @@ class SyncOnReadVariable(DistributedVariable): with ds_context.enter_or_assert_strategy(self._distribute_strategy): return self._distribute_strategy.reduce( - reduce_util.ReduceOp.from_variable_aggregation(self.aggregation), + reduce_util.ReduceOp.from_variable_aggregation(self._aggregation), self, axis=None) @@ -1022,6 +1039,16 @@ class SyncOnReadVariable(DistributedVariable): # Register a conversion functions which reads the value of the variable, # allowing instances of the class to be used as tensors. +# DistributedVariable +def _tensor_conversion_distributed_var(var, dtype=None, name=None, + as_ref=False): + return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access + + +ops.register_tensor_conversion_function(DistributedVariable, + _tensor_conversion_distributed_var) + + # MirroredVariables def _tensor_conversion_mirrored(var, dtype=None, name=None, as_ref=False): return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access @@ -1048,3 +1075,299 @@ def _tensor_conversion_sync_on_read(var, dtype=None, name=None, as_ref=False): ops.register_tensor_conversion_function(SyncOnReadVariable, _tensor_conversion_sync_on_read) + + +class VariablePolicy(object): + """Policy defining synchronization and aggregation of a distributed variable. + + Given `synchronization` and `aggregation` parameters set on a `tf.Variable` + during variable creation within `tf.distribute` scope, `tf.distribute` creates + an appropriate policy object and assigns it to the distributed variable. All + variable operations are delegated to the respective policy object. + """ + + def __init__(self, aggregation): + self._aggregation = aggregation + + def value(self): + raise NotImplementedError( + "This method should be overridden by sub-classes.") + + def _is_mirrored(self): + raise NotImplementedError( + "This method should be overridden by sub-classes.") + + def _as_graph_element(self, _): + raise NotImplementedError( + "This method should be overridden by sub-classes.") + + def _get_cross_replica(self, var): + raise NotImplementedError( + "This method should be overridden by sub-classes.") + + def _update_replica(self, var, update_fn, value, **kwargs): + raise NotImplementedError( + "This method should be overridden by sub-classes.") + + +class OnReadPolicy(VariablePolicy): + """Policy defined for `tf.VariableSynchronization.ON_READ` synchronization. + + This policy is created when `synchronization` is set to + `tf.VariableSynchronization.ON_READ` and `aggregation` is set to any of the + values allowed by the `tf.VariableAggregation` enum such as `NONE`, `SUM`, + `MEAN` or `ONLY_FIRST_REPLICA`when creating a `tf.Variable` in `tf.distribute` + scope. + """ + + def _is_mirrored(self): + return False + + def value(self, var): + with ds_context.enter_or_assert_strategy(var.distribute_strategy): + if ds_context.in_cross_replica_context(): + return var._get_cross_replica() # pylint: disable=protected-access + else: + return var._get_on_device_or_primary().value() # pylint: disable=protected-access + + def _as_graph_element(self, var): + with ds_context.enter_or_assert_strategy(var.distribute_strategy): + if ds_context.in_cross_replica_context(): + return ops.convert_to_tensor(var._get_cross_replica()) # pylint: disable=protected-access + return var._get()._as_graph_element() # pylint: disable=protected-access + + def _get_cross_replica(self, var): + if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: + return var._primary # pylint: disable=protected-access + + with ds_context.enter_or_assert_strategy(var.distribute_strategy): + return var.distribute_strategy.reduce( + reduce_util.ReduceOp.from_variable_aggregation(self._aggregation), + var, + axis=None) + + def _update_replica(self, var, update_fn, value, **kwargs): + return update_fn(var._get_on_device_or_primary(), value, **kwargs) # pylint: disable=protected-access + + def _scatter_not_implemented(self, method): + raise NotImplementedError( + "ON_READ variables doesn't support `%s` in cross replica context" % + method) + + def assign_sub(self, var, value, use_locking=False, name=None, + read_value=True): + with ds_context.enter_or_assert_strategy(var.distribute_strategy): + if ds_context.in_cross_replica_context(): + return values_util.on_read_assign_sub_cross_replica( + var, value, read_value=read_value) + else: + return values_util.on_write_assign_sub( + var, value, use_locking=use_locking, name=name, + read_value=read_value) + + def assign_add(self, var, value, use_locking=False, name=None, + read_value=True): + with ds_context.enter_or_assert_strategy(var.distribute_strategy): + if ds_context.in_cross_replica_context(): + return values_util.on_read_assign_add_cross_replica( + var, value, read_value=read_value) + else: + return values_util.on_write_assign_add( + var, value, use_locking=use_locking, name=name, + read_value=read_value) + + def assign(self, var, value, use_locking=False, name=None, read_value=True): + with ds_context.enter_or_assert_strategy(var.distribute_strategy): + if ds_context.in_cross_replica_context(): + return values_util.on_read_assign_cross_replica(var, value, + read_value=read_value) + else: + return values_util.on_write_assign(var, value, + use_locking=use_locking, + name=name, + read_value=read_value) + + def scatter_sub(self, *args, **kwargs): + del args, kwargs + self._scatter_not_implemented("scatter_sub") + + def scatter_add(self, *args, **kwargs): + del args, kwargs + self._scatter_not_implemented("scatter_add") + + def scatter_mul(self, *args, **kwargs): + del args, kwargs + self._scatter_not_implemented("scatter_mul") + + def scatter_div(self, *args, **kwargs): + del args, kwargs + self._scatter_not_implemented("scatter_div") + + def scatter_min(self, *args, **kwargs): + del args, kwargs + self._scatter_not_implemented("scatter_min") + + def scatter_max(self, *args, **kwargs): + del args, kwargs + self._scatter_not_implemented("scatter_max") + + def scatter_update(self, *args, **kwargs): + del args, kwargs + self._scatter_not_implemented("scatter_update") + + def get_saveable(self, var, primary_var, name): + """Create a saveable object for the given variable.""" + # We use a callable so that we don't have to evaluate this expression + # in the case where we are trying to restore instead of save. + def tensor(): + strategy = var.distribute_strategy + return strategy.extended.read_var(var) + + spec = saveable_object.SaveSpec( + tensor=tensor, + slice_spec="", + name=name, + dtype=var.dtype, + device=primary_var.device) + + return tensor, [spec] + + def get_restore_ops(self, var, tensor): + """Restore the same value into all variables.""" + # To preserve the sum across save and restore, we have to divide the + # total across all devices when restoring a variable that was summed + # when saving. + if self._aggregation == vs.VariableAggregation.SUM: + tensor = math_ops.cast(tensor / len(var._devices), # pylint: disable=protected-access + var.dtype) + return control_flow_ops.group( + tuple( + values_util.assign_on_device(v.device, v, tensor) + for v in var.values)) + + +class AutoPolicy(VariablePolicy): + """Policy defined for `tf.VariableSynchronization.AUTO` synchronization. + + This policy is created when `synchronization` is set to + `tf.VariableSynchronization.AUTO` and `aggregation` is set to + `tf.VariableAggregation.NONE` when creating a `tf.Variable` in `tf.distribute` + scope. + """ + + def _is_mirrored(self): + return True + + def value(self, var): + return var._get_on_device_or_primary().value() # pylint: disable=protected-access + + def _as_graph_element(self, var): + return var._get_on_device_or_primary()._as_graph_element() # pylint: disable=protected-access + + def _get_cross_replica(self, var): + # Return identity, to avoid directly exposing the variable to the user and + # allowing it to be modified by mistake. + return array_ops.identity(Mirrored._get_cross_replica(var)) # pylint: disable=protected-access + + def _update_replica(self, var, update_fn, value, **kwargs): + return update_fn(var._get_on_device_or_primary(), value, **kwargs) # pylint: disable=protected-access + + def assign(self, var, value, use_locking=False, name=None, read_value=True): + return values_util.on_write_assign(var, value, use_locking=use_locking, + name=name, read_value=read_value) + + def assign_add(self, var, value, use_locking=False, name=None, + read_value=True): + return values_util.on_write_assign_add(var, value, use_locking=use_locking, + name=name, read_value=read_value) + + def assign_sub(self, var, value, use_locking=False, name=None, + read_value=True): + return values_util.on_write_assign_sub(var, value, use_locking=use_locking, + name=name, read_value=read_value) + + def scatter_sub(self, var, sparse_delta, use_locking=False, name=None): + return values_util.scatter_sub(var, sparse_delta, use_locking=use_locking, + name=name) + + def scatter_add(self, var, sparse_delta, use_locking=False, name=None): + return values_util.scatter_add(var, sparse_delta, use_locking=use_locking, + name=name) + + def scatter_mul(self, var, sparse_delta, use_locking=False, name=None): + return values_util.scatter_mul(var, sparse_delta, use_locking=use_locking, + name=name) + + def scatter_div(self, var, sparse_delta, use_locking=False, name=None): + return values_util.scatter_div(var, sparse_delta, use_locking=use_locking, + name=name) + + def scatter_min(self, var, sparse_delta, use_locking=False, name=None): + if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and + self._aggregation != vs.VariableAggregation.NONE): + raise NotImplementedError(values_util.scatter_error_msg.format( + op_name="scatter_min", aggregation=self._aggregation)) + return values_util.scatter_min(var, sparse_delta, use_locking=use_locking, + name=name) + + def scatter_max(self, var, sparse_delta, use_locking=False, name=None): + if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and + self._aggregation != vs.VariableAggregation.NONE): + raise NotImplementedError(values_util.scatter_error_msg.format( + op_name="scatter_max", aggregation=self._aggregation)) + return values_util.scatter_max(var, sparse_delta, use_locking=use_locking, + name=name) + + def scatter_update(self, var, sparse_delta, use_locking=False, name=None): + if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and + self._aggregation != vs.VariableAggregation.NONE): + raise NotImplementedError(values_util.scatter_error_msg.format( + op_name="scatter_update", aggregation=self._aggregation)) + return values_util.scatter_update(var, sparse_delta, + use_locking=use_locking, + name=name) + + def get_saveable(self, var, primary_var, name): + del var, name + return primary_var, "" + + def get_restore_ops(self, var, tensor): + return control_flow_ops.group( + tuple( + values_util.assign_on_device(v.device, v, tensor) + for v in var.values)) + + +class OnWritePolicy(AutoPolicy): + """Policy defined for `tf.VariableSynchronization.ON_WRITE` synchronization. + + This policy is created when the following `synchronization` and + `aggregation` parameters are specified when creating a `tf.Variable` in + `tf.distribute` scope: + * `synchronization` is equal to `tf.VariableSynchronization.AUTO` and + aggregation can be any of the following `tf.VariableAggregation` enum + values such as `SUM`, `MEAN` or `ONLY_FIRST_REPLICA`. + * `synchronization` is equal to `tf.VariableSynchronization.ON_WRITE` and + aggregation can be any of the following `tf.VariableAggregation` enum + values such as `NONE`, `SUM`, `MEAN` or `ONLY_FIRST_REPLICA`. + """ + + def _update_replica(self, var, update_fn, value, **kwargs): + return _on_write_update_replica(var, update_fn, value, **kwargs) + + +# Utility functions +# Return True if the Value is Mirrored or the Variable is replicated and kept in +# sync. +def _is_mirrored(val): + if isinstance(val, DistributedVariable): + if val._var_policy: # pylint: disable=protected-access + return val._var_policy._is_mirrored() # pylint: disable=protected-access + return isinstance(val, Mirrored) + + +def _is_sync_on_read(val): + if isinstance(val, DistributedVariable): + if val._var_policy: # pylint: disable=protected-access + return not val._var_policy._is_mirrored() # pylint: disable=protected-access + return not isinstance(val, Mirrored) diff --git a/tensorflow/python/distribute/values_util.py b/tensorflow/python/distribute/values_util.py index c42ac9e4de1..ddb0d2d0401 100644 --- a/tensorflow/python/distribute/values_util.py +++ b/tensorflow/python/distribute/values_util.py @@ -23,9 +23,155 @@ from tensorflow.python.distribute import distribution_strategy_context as ds_con from tensorflow.python.distribute import reduce_util from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope as vs +def on_write_assign(var, value, use_locking=False, name=None, read_value=True): + assign_fn = lambda var, *a, **kw: var.assign(*a, **kw) + return var._update( # pylint: disable=protected-access + update_fn=assign_fn, + value=value, + use_locking=use_locking, + name=name, + read_value=read_value) + + +def on_write_assign_add(var, value, use_locking=False, name=None, + read_value=True): + assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw) + return var._update( # pylint: disable=protected-access + update_fn=assign_add_fn, + value=value, + use_locking=use_locking, + name=name, + read_value=read_value) + + +def on_write_assign_sub(var, value, use_locking=False, name=None, + read_value=True): + assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw) + return var._update( # pylint: disable=protected-access + update_fn=assign_sub_fn, + value=value, + use_locking=use_locking, + name=name, + read_value=read_value) + + +def assign_on_each_device(var, assign_func, value, read_value): + update = control_flow_ops.group( + tuple(assign_func(v.device, v, value) for v in var._values)) # pylint: disable=protected-access + if not read_value: + return update + with ops.control_dependencies([update] if update else []): + return var.read_value() + + +def on_read_assign_sub_cross_replica(var, value, read_value=True): + with ds_context.enter_or_assert_strategy(var.distribute_strategy): + if ds_context.in_cross_replica_context(): + if var.aggregation == vs.VariableAggregation.SUM: + raise ValueError( + "SyncOnReadVariable does not support `assign_sub` in " + "cross-replica context when aggregation is set to " + "`tf.VariableAggregation.SUM`.") + return assign_on_each_device(var, assign_sub_on_device, + value, read_value) + + +def on_read_assign_add_cross_replica(var, value, read_value=True): + with ds_context.enter_or_assert_strategy(var.distribute_strategy): + if ds_context.in_cross_replica_context(): + if var.aggregation == vs.VariableAggregation.SUM: + raise ValueError( + "SyncOnReadVariable does not support `assign_add` in " + "cross-replica context when aggregation is set to " + "`tf.VariableAggregation.SUM`.") + return assign_on_each_device(var, assign_add_on_device, + value, read_value) + + +def on_read_assign_cross_replica(var, value, read_value=True): + """Return the value of the variable in cross replica context.""" + with ds_context.enter_or_assert_strategy(var.distribute_strategy): + if ds_context.in_cross_replica_context(): + # To preserve the sum across save and restore, we have to divide the + # total across all devices when restoring a variable that was summed + # when saving. + tensor = value + # TODO(anjs): Should this be over all the replicas in sync since we + # call `reduce` on the variable during read? + if var.aggregation == vs.VariableAggregation.SUM: + tensor = math_ops.cast(tensor / len(var._values), var.dtype) # pylint: disable=protected-access + return assign_on_each_device(var, assign_on_device, tensor, + read_value) + + +def scatter_sub(var, sparse_delta, use_locking=False, name=None): + scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw) + return var._update( # pylint: disable=protected-access + update_fn=scatter_sub_fn, + value=sparse_delta, + use_locking=use_locking, + name=name) + + +def scatter_add(var, sparse_delta, use_locking=False, name=None): + scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw) + return var._update( # pylint: disable=protected-access + update_fn=scatter_add_fn, + value=sparse_delta, + use_locking=use_locking, + name=name) + + +def scatter_mul(var, sparse_delta, use_locking=False, name=None): + scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw) + return var._update( # pylint: disable=protected-access + update_fn=scatter_mul_fn, + value=sparse_delta, + use_locking=use_locking, + name=name) + + +def scatter_div(var, sparse_delta, use_locking=False, name=None): + scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw) + return var._update( # pylint: disable=protected-access + update_fn=scatter_div_fn, + value=sparse_delta, + use_locking=use_locking, + name=name) + + +def scatter_min(var, sparse_delta, use_locking=False, name=None): + scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw) + return var._update( # pylint: disable=protected-access + update_fn=scatter_min_fn, + value=sparse_delta, + use_locking=use_locking, + name=name) + + +def scatter_max(var, sparse_delta, use_locking=False, name=None): + scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw) + return var._update( # pylint: disable=protected-access + update_fn=scatter_max_fn, + value=sparse_delta, + use_locking=use_locking, + name=name) + + +def scatter_update(var, sparse_delta, use_locking=False, name=None): + scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw) + return var._update( # pylint: disable=protected-access + update_fn=scatter_update_fn, + value=sparse_delta, + use_locking=use_locking, + name=name) + + def get_current_replica_id_as_int(): """Returns the current replica ID as an integer, or `None`.""" replica_context = ds_context.get_replica_context() @@ -89,3 +235,9 @@ aggregation_error_msg = ( "`tf.distribute.get_replica_context().merge_call(merge_fn, ..)`." "Inside `merge_fn`, you can then update the {variable_type} " "using `tf.distribute.StrategyExtended.update()`.") + + +scatter_error_msg = ("{op_name} is only supported for mirrored " + "variable (variable created within certain " + "`tf.distribute.Strategy` scope) with NONE or " + "`ONLY_FIRST_REPLICA` aggregation, got: {aggregation}.")