Ran Chen 0ac07a2fc4 Retire AutoPolicy
This is part of the effort to refactor distributed variables. Auto is somewhat
confusing and adds additional implementation complexity.

PiperOrigin-RevId: 355733796
Change-Id: I7446c3ed706624178fcb26c9b992632a93b939f6
2021-02-04 16:16:59 -08:00

386 lines
15 KiB

# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Class implementing utilities used by tf.distribute.Strategy."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import abc
from tensorflow.python.distribute import tpu_values as tpu_values_lib
from tensorflow.python.distribute import values as values_lib
from tensorflow.python.eager import context
from tensorflow.python.eager import tape
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.util import nest
def regroup(values, wrap_class=values_lib.PerReplica, always_wrap=False):
"""Makes a nest per-replica into a nest of PerReplica/Mirrored values.
values: Values to regroup
wrap_class: Class that `values` be wrapped in.
always_wrap: Always wrap the `values` in `wrap_class` even if the values
are the same except for DistributeVariable.
Wrapped `values`.
v0 = values[0]
if isinstance(v0, list):
for v in values[1:]:
assert isinstance(v, list)
assert len(v) == len(v0), ("len(v) == %d, len(v0) == %d, v: %s, v0: %s" %
(len(v), len(v0), v, v0))
return [
regroup(tuple(v[i] for v in values), wrap_class, always_wrap)
for i in range(len(v0))
if isinstance(v0, tuple):
for v in values[1:]:
assert isinstance(v, tuple)
assert len(v) == len(v0)
regrouped_tuple = tuple(
regroup(tuple(v[i] for v in values), wrap_class, always_wrap)
for i in range(len(v0)))
if hasattr(v0, "_fields"):
# This tuple is in fact a namedtuple! Create a new namedtuple instance
# and initialize it with the regrouped values:
assert hasattr(v0, "_make")
return v0._make(regrouped_tuple)
return regrouped_tuple
if isinstance(v0, abc.Mapping):
v0keys = v0.keys()
for v in values[1:]:
assert isinstance(v, abc.Mapping), ("v[0]: %r v[i]: %r" % (v0, v))
assert set(v.keys()) == set(v0keys), ("v[0].keys: %s v[i].keys: %s" %
(set(v0keys), set(v.keys())))
# Use the actual type in case it is a class inherited from a dict.
return type(v0)({
key: regroup(tuple(v[key] for v in values),
wrap_class, always_wrap)
for key in v0keys
# If exactly the same object across all devices, return it unwrapped.
same_id = True
for v in values[1:]:
if v is not v0:
same_id = False
# Consider three cases where same_id is true:
# * If v0 is a DistributedVariable (a MirroredVariable or
# SyncOnReadVariable, and same_id means it is the same across all
# devices), we want to return it. We check DistributedVariable
# specifically since it can look like it has a
# _distributed_container member since its members do.
if same_id and isinstance(v0, values_lib.DistributedVariable):
return v0
# * If v0 is a member of a distributed variable, in which case
# hasattr(v0, "_distributed_container") is true, we want to
# return the DistributedVariable that contains it using the
# _distributed_container logic below. This case can trigger
# same_id when there is only one device.
# * In any other situation, same_id means we return v0 unless `always_wrap` is
# true.
if same_id and not always_wrap and not hasattr(v0, "_distributed_container"):
return v0
# Detect the case where each device has a parallel component of the
# same MirroredVariable (or SyncOnReadVariable). In this case we
# want to return the containing MirroredVariable, after a bunch of
# sanity checking. In particular, each component should have the
# same container, and the devices of the variables should match the
# keys of the per-replica dictionary.
if hasattr(v0, "_distributed_container"):
# pylint: disable=protected-access
assert not isinstance(v0, values_lib.MirroredVariable), (
"ids = %s, values = %s" % ([id(v) for v in values], values))
distributed_container = v0._distributed_container()
assert distributed_container is not None
for v in values[1:]:
assert distributed_container is v._distributed_container()
return distributed_container
# pylint: enable=protected-access
return wrap_class(values)
def select_replica(replica_id, structured):
"""Specialize a nest of regular & per-replica values for one replica."""
def _get(x):
# `DistributedValues` would be sliced according to replica unless it is a
# `DistributedVariable` because `DistributedVariable` can be handled
# directly in the replica context.
if (isinstance(x, values_lib.DistributedVariable) or
not isinstance(x, values_lib.DistributedValues)):
return x
return x.values[replica_id]
return nest.map_structure(_get, structured)
def select_replica_mirrored(replica_id, structured):
"""Specialize a nest of regular & mirrored values for one replica."""
return select_replica(replica_id, structured)
def assert_mirrored(structured):
"""Raises if the structured is not composed of mirrored or regular values."""
def _assert_mirrored(x):
if isinstance(x, values_lib.DistributedValues) and not is_mirrored(x):
raise TypeError(
"Expected value to be mirrored across replicas: %s in %s." %
(x, structured))
nest.map_structure(_assert_mirrored, structured)
def update_regroup(extended, updates, group):
"""Regroup for an update, with dependencies to ensure all updates execute."""
if not group:
regrouped = regroup(updates, values_lib.Mirrored)
return nest.map_structure(extended._local_results, regrouped) # pylint: disable=protected-access
def _make_grouped_mirrored(values):
"""Convert per-replica list `values` into Mirrored type with grouping."""
if len(values) == 1:
return values_lib.Mirrored(values)
# Make sure we run all updates. Without this, something like
# may only update one replica.
g =
# If values is just ops, the grouping is enough. Everything in values
# should have the same type, since we expect every replica to be performing
# the same computation.
if not all(tensor_util.is_tf_type(v) for v in values):
return g
# Otherwise we need tensors with the same values as `values`, but
# that have a dependency on `g`.
with_dep = []
for v in values:
with ops.device(v.device), ops.control_dependencies([g]):
return values_lib.Mirrored(with_dep)
return regroup(updates, _make_grouped_mirrored)
def value_container(val):
"""Returns the container that this per-replica `value` belongs to.
val: A value returned by `call_for_each_replica()` or a variable created in
A container that `value` belongs to.
If value does not belong to any container (including the case of
container having been destroyed), returns the value itself.
if (hasattr(val, "_distributed_container") and
# DistributedVariable has _distributed_container defined
# but we don't want to return it.
not isinstance(val, values_lib.DistributedVariable)):
container = val._distributed_container() # pylint: disable=protected-access
if container is not None:
return container
return val
def is_distributed_variable(v):
"""Determine if a variable is ds variable or TPU mirrored variable."""
return isinstance(v, values_lib.DistributedVariable)
def _validate_colocate_extended(v, extended):
variable_strategy = v._distribute_strategy # pylint: disable=protected-access
if variable_strategy.extended is not extended:
raise ValueError(
"`colocate_vars_with` must only be passed a variable created in this "
"tf.distribute.Strategy.scope(), not %s created in scope: %s" %
(v, variable_strategy))
def validate_colocate_distributed_variable(v, extended):
if not isinstance(v, values_lib.DistributedVariable):
raise ValueError(
"`colocate_vars_with` must only be passed a variable created in this "
"tf.distribute.Strategy.scope(), not: %r" % (v,))
_validate_colocate_extended(v, extended)
def validate_colocate(v, extended):
if not hasattr(v, "_distribute_strategy"):
raise ValueError(
"`colocate_vars_with` must only be passed a variable created in this "
"tf.distribute.Strategy.scope(), not: %r" % (v,))
_validate_colocate_extended(v, extended)
# Variable creation function for sync strategies.
def _validate_synchronization(kwargs):
"""Validate that given synchronization value is valid."""
synchronization = kwargs.get("synchronization",
if synchronization == vs.VariableSynchronization.NONE:
raise ValueError(
"`NONE` variable synchronization mode is not supported with "
"tf.distribute strategy. Please change the `synchronization` for "
"variable: " + str(kwargs["name"]))
if synchronization not in (vs.VariableSynchronization.ON_READ,
raise ValueError(
"Invalid variable synchronization mode: %s for variable: %s" %
(synchronization, kwargs["name"]))
if synchronization == vs.VariableSynchronization.AUTO:
return vs.VariableSynchronization.ON_WRITE
return synchronization
def _validate_aggregation(kwargs):
aggregation = kwargs.get("aggregation", vs.VariableAggregation.NONE)
if aggregation not in (vs.VariableAggregation.NONE,
raise ValueError("Invalid variable aggregation mode: %s for variable: %s" %
(aggregation, kwargs["name"]))
return aggregation
def create_mirrored_variable(strategy, real_mirrored_creator, class_mapping,
policy_mapping, **kwargs):
"""Create distributed variables with given synchronization and aggregation."""
# Figure out what collections this variable should be added to.
# We'll add the MirroredVariable to those collections instead.
var_collections = kwargs.pop("collections", None)
if var_collections is None:
var_collections = [ops.GraphKeys.GLOBAL_VARIABLES]
kwargs["collections"] = []
synchronization = _validate_synchronization(kwargs)
# Update synchronization in kwargs in case it's AUTO, which is converted to
kwargs["synchronization"] = synchronization
aggregation = _validate_aggregation(kwargs)
use_var_policy = getattr(strategy.extended, "_use_var_policy", False)
# Ignore user-specified caching device, not needed for mirrored variables.
kwargs.pop("caching_device", None)
# TODO(josh11b,apassos): It would be better if variable initialization
# was never recorded on the tape instead of having to do this manually
# here.
with tape.stop_recording():
value_list = real_mirrored_creator(**kwargs)
if use_var_policy:
var_policy_cls = policy_mapping.get(synchronization)
var_policy = var_policy_cls(aggregation=aggregation)
var_cls = class_mapping.get("VariableClass")
result = var_cls(strategy, value_list, aggregation, var_policy=var_policy)
var_cls = class_mapping.get(synchronization)
result = var_cls(strategy, value_list, aggregation)
# Add the wrapped variable to the requested collections.
# The handling of eager mode and the global step matches
# ResourceVariable._init_from_args().
if not context.executing_eagerly():
g = ops.get_default_graph()
# If "trainable" is True, next_creator() will add the member variables
# to the TRAINABLE_VARIABLES collection, so we manually remove
# them and replace with the MirroredVariable. We can't set
# "trainable" to False for next_creator() since that causes functions
# like implicit_gradients to skip those variables.
if kwargs.get("trainable", True):
l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
for value in value_list:
for i, trainable_variable in enumerate(l):
if value is trainable_variable:
del l[i]
g.add_to_collections(var_collections, result)
elif ops.GraphKeys.GLOBAL_STEP in var_collections:
ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)
return result
# Utility functions
# Return True if the Value is Mirrored or the Variable is replicated and kept in
# sync.
def is_mirrored(val):
if isinstance(val, values_lib.DistributedVariable):
if val._policy: # pylint: disable=protected-access
return val._policy._is_mirrored() # pylint: disable=protected-access
return isinstance(val, values_lib.Mirrored)
def is_sync_on_read(val):
if isinstance(val, values_lib.DistributedVariable):
if val._policy: # pylint: disable=protected-access
return not val._policy._is_mirrored() # pylint: disable=protected-access
return not isinstance(val, values_lib.Mirrored)
# The following mapping indicates the policy that you must use for a given
# variable `synchronization` and `aggregation` pair.
# OnWritePolicy is used for:
# (synchronization=Auto, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA)
# (synchronization=ON_WRITE, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA)
# OnReadPolicy is used for:
# (synchronization=ON_READ, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA)
vs.VariableSynchronization.ON_WRITE: values_lib.OnWritePolicy,
vs.VariableSynchronization.ON_READ: values_lib.OnReadPolicy,
"VariableClass": values_lib.DistributedVariable,
vs.VariableSynchronization.ON_WRITE: values_lib.MirroredVariable,
vs.VariableSynchronization.ON_READ: values_lib.SyncOnReadVariable,
vs.VariableSynchronization.ON_WRITE: tpu_values_lib.TPUOnWritePolicy,
vs.VariableSynchronization.ON_READ: tpu_values_lib.TPUOnReadPolicy,
"VariableClass": tpu_values_lib.TPUDistributedVariable,
vs.VariableSynchronization.ON_WRITE: tpu_values_lib.TPUMirroredVariable,
vs.VariableSynchronization.ON_READ: tpu_values_lib.TPUSyncOnReadVariable,