Another round of refactoring of values.py to split utility functions that use distributed Variable types defined in values.py.

PiperOrigin-RevId: 316147517
Change-Id: I72e17b02e8f41c9cee40f4ec7f56fec2f7d860a9
This commit is contained in:
Anjali Sridhar 2020-06-12 11:54:05 -07:00 committed by TensorFlower Gardener
parent 1cd053d409
commit 0f3562ba77
22 changed files with 526 additions and 432 deletions

View File

@ -286,6 +286,36 @@ py_library(
],
)
py_library(
name = "distribute_utils",
srcs = ["distribute_utils.py"],
deps = [
":device_util",
":distribute_lib",
":reduce_util",
":shared_variable_creator",
":values",
"//tensorflow/python:array_ops",
"//tensorflow/python:config",
"//tensorflow/python:constant_op",
"//tensorflow/python:device",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:platform",
"//tensorflow/python:pywrap_tfe",
"//tensorflow/python:summary_ops_v2",
"//tensorflow/python:tensor_util",
"//tensorflow/python:tf_export",
"//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python/autograph/core",
"//tensorflow/python/autograph/impl",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function",
],
)
py_library(
name = "mirrored_strategy",
srcs = ["mirrored_strategy.py"],
@ -293,6 +323,7 @@ py_library(
":cross_device_ops",
":device_util",
":distribute_lib",
":distribute_utils",
":input_lib",
":mirrored_run",
":multi_worker_util",
@ -320,6 +351,7 @@ py_library(
":cross_device_ops",
":device_util",
":distribute_lib",
":distribute_utils",
":input_lib",
":mirrored_run",
":multi_worker_util",
@ -508,6 +540,7 @@ py_library(
deps = [
":device_util",
":distribute_lib",
":distribute_utils",
":input_ops",
":values",
"//tensorflow/python:framework_ops",

View File

@ -28,11 +28,11 @@ from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import cross_device_utils
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import strategy_test_lib
from tensorflow.python.distribute import values
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
@ -362,8 +362,8 @@ class CollectiveAllReduceStrategyTestBase(
for expected_value in expected_values:
next_element = iterator.get_next()
computed_value = sess.run([values.select_replica(r, next_element)
for r in range(len(devices))])
computed_value = sess.run([distribute_utils.select_replica(
r, next_element) for r in range(len(devices))])
if ignore_order:
self.assertCountEqual(list(expected_value), list(computed_value))
else:
@ -371,7 +371,7 @@ class CollectiveAllReduceStrategyTestBase(
with self.assertRaises(errors.OutOfRangeError):
next_element = iterator.get_next()
sess.run([values.select_replica(r, next_element)
sess.run([distribute_utils.select_replica(r, next_element)
for r in range(len(devices))])
# After re-initializing the iterator, should be able to iterate again.
@ -380,8 +380,9 @@ class CollectiveAllReduceStrategyTestBase(
for expected_value in expected_values:
next_element = iterator.get_next()
computed_value = sess.run([values.select_replica(r, next_element)
for r in range(len(devices))])
computed_value = sess.run([
distribute_utils.select_replica(r, next_element)
for r in range(len(devices))])
if ignore_order:
self.assertCountEqual(list(expected_value), list(computed_value))
else:

View File

@ -28,6 +28,7 @@ from tensorflow.python.client import device_lib
from tensorflow.python.distribute import collective_util
from tensorflow.python.distribute import cross_device_utils
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import ps_values
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import tpu_values
@ -187,7 +188,8 @@ def simple_broadcast(value, destinations, always_mirrored=False):
for d in devices:
value_updates.append(
cross_device_utils.copy_tensor_or_indexed_slices_to_device(value, d))
return value_lib.regroup(value_updates, wrap_class=value_lib.Mirrored)
return distribute_utils.regroup(value_updates,
wrap_class=value_lib.Mirrored)
def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn,
@ -259,7 +261,7 @@ class CrossDeviceOps(object):
per_replica_value, destinations):
with ops.device(per_replica_value.values[0].device):
v = array_ops.identity(per_replica_value.values[0])
return value_lib.regroup((v,), wrap_class=value_lib.Mirrored)
return distribute_utils.regroup((v,), wrap_class=value_lib.Mirrored)
if experimental_hints is None:
experimental_hints = collective_util.Hints()
@ -309,7 +311,7 @@ class CrossDeviceOps(object):
value_destination_pairs) and len(
value_destination_pairs[0][0].values) == 1:
return [
value_lib.regroup(v.values, wrap_class=value_lib.Mirrored)
distribute_utils.regroup(v.values, wrap_class=value_lib.Mirrored)
for v, _ in value_destination_pairs
]
@ -510,7 +512,8 @@ def _ungroup_and_make_mirrored(grouped_reduced,
index[i].append(v / num_replicas)
else:
index[i].append(v)
return [value_lib.regroup(v, wrap_class=value_lib.Mirrored) for v in index]
return [distribute_utils.regroup(
v, wrap_class=value_lib.Mirrored) for v in index]
class _ConcatAndSplitPacker(object):
@ -1000,7 +1003,7 @@ class CollectiveAllReduce(CrossDeviceOps):
# TODO(josh11b): Once we add support for model parallelism, get the
# copy from the corresponding replica instead of the primary.
index.append(array_ops.identity(all_reduced._primary)) # pylint: disable=protected-access
return value_lib.regroup(index, wrap_class=value_lib.Mirrored)
return distribute_utils.regroup(index, wrap_class=value_lib.Mirrored)
def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
experimental_hints):
@ -1104,7 +1107,8 @@ class CollectiveAllReduce(CrossDeviceOps):
for i, v in enumerate(value):
with ops.device(v.device):
value[i] = v / num_replicas
mirrored.append(value_lib.regroup(value, wrap_class=value_lib.Mirrored))
mirrored.append(distribute_utils.regroup(value,
wrap_class=value_lib.Mirrored))
return mirrored
def _do_batch_all_reduce_sparse(self, reduce_op, per_replica_values):
@ -1140,7 +1144,8 @@ class CollectiveAllReduce(CrossDeviceOps):
for i, v in enumerate(value):
with ops.device(v.device):
value[i].values = value[i].values / num_replicas
mirrored.append(value_lib.regroup(value, wrap_class=value_lib.Mirrored))
mirrored.append(distribute_utils.regroup(value,
wrap_class=value_lib.Mirrored))
return mirrored
def __deepcopy__(self, memo):

View File

@ -32,6 +32,7 @@ from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
from tensorflow.python.distribute import cross_device_utils
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import strategy_combinations
@ -73,7 +74,7 @@ def _make_per_replica(values, devices, regroup=False):
with ops.device(d):
placed_v = array_ops.identity(v)
index.append(placed_v)
return value_lib.regroup(index)
return distribute_utils.regroup(index)
# pylint: disable=g-doc-args,g-doc-return-or-yield
@ -88,7 +89,7 @@ def _fake_mirrored(value, devices):
for d in devices:
with ops.device(d):
values.append(array_ops.identity(value))
return value_lib.regroup(
return distribute_utils.regroup(
values,
wrap_class=value_lib.Mirrored)
@ -105,7 +106,7 @@ def _make_indexed_slices(values, indices, dense_shape, device):
def _make_mirrored_indexed_slices(devices, values, indices, dense_shape):
values = [_make_indexed_slices(values, indices, dense_shape, d)
for d in devices]
return value_lib.regroup(
return distribute_utils.regroup(
values,
wrap_class=value_lib.Mirrored)

View File

@ -0,0 +1,322 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Class implementing utilities used by tf.distribute.Strategy."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.distribute import values as values_lib
from tensorflow.python.eager import context
from tensorflow.python.eager import tape
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.util import nest
def regroup(values, wrap_class=values_lib.PerReplica, always_wrap=False):
"""Makes a nest per-replica into a nest of PerReplica/Mirrored values.
Args:
values: Values to regroup
wrap_class: Class that `values` be wrapped in.
always_wrap: Always wrap the `values` in `wrap_class` even if the values
are the same except for DistributeVariable.
Returns:
Wrapped `values`.
"""
v0 = values[0]
if isinstance(v0, list):
for v in values[1:]:
assert isinstance(v, list)
assert len(v) == len(v0), ("len(v) == %d, len(v0) == %d, v: %s, v0: %s" %
(len(v), len(v0), v, v0))
return [
regroup(tuple(v[i] for v in values), wrap_class)
for i in range(len(v0))
]
if isinstance(v0, tuple):
for v in values[1:]:
assert isinstance(v, tuple)
assert len(v) == len(v0)
regrouped_tuple = tuple(
regroup(tuple(v[i] for v in values), wrap_class)
for i in range(len(v0)))
if hasattr(v0, "_fields"):
# This tuple is in fact a namedtuple! Create a new namedtuple instance
# and initialize it with the regrouped values:
assert hasattr(type(v0), "_make")
return type(v0)._make(regrouped_tuple)
else:
return regrouped_tuple
if isinstance(v0, dict):
v0keys = v0.keys()
for v in values[1:]:
assert isinstance(v, dict), ("v[0]: %r v[i]: %r" % (v0, v))
assert set(v.keys()) == set(v0keys), ("v[0].keys: %s v[i].keys: %s" %
(set(v0keys), set(v.keys())))
# Use the actual type in case it is a class inherited from a dict.
return type(v0)({
key: regroup(tuple(v[key] for v in values), wrap_class)
for key in v0keys
})
# If exactly the same object across all devices, return it unwrapped.
same_id = True
for v in values[1:]:
if v is not v0:
same_id = False
break
# Consider three cases where same_id is true:
# * If v0 is a DistributedVariable (a MirroredVariable or
# SyncOnReadVariable, and same_id means it is the same across all
# devices), we want to return it. We check DistributedVariable
# specifically since it can look like it has a
# _distributed_container member since its members do.
if same_id and isinstance(v0, values_lib.DistributedVariable):
return v0
# * If v0 is a member of a distributed variable, in which case
# hasattr(v0, "_distributed_container") is true, we want to
# return the DistributedVariable that contains it using the
# _distributed_container logic below. This case can trigger
# same_id when there is only one device.
# * In any other situation, same_id means we return v0 unless `always_wrap` is
# true.
if same_id and not always_wrap and not hasattr(v0, "_distributed_container"):
return v0
# Detect the case where each device has a parallel component of the
# same MirroredVariable (or SyncOnReadVariable). In this case we
# want to return the containing MirroredVariable, after a bunch of
# sanity checking. In particular, each component should have the
# same container, and the devices of the variables should match the
# keys of the per-replica dictionary.
if hasattr(v0, "_distributed_container"):
# pylint: disable=protected-access
assert not isinstance(v0, values_lib.MirroredVariable), (
"ids = %s, values = %s" % ([id(v) for v in values], values))
distributed_container = v0._distributed_container
assert distributed_container is not None
for v in values[1:]:
assert distributed_container is v._distributed_container
return distributed_container
# pylint: enable=protected-access
return wrap_class(values)
def select_replica(replica_id, structured):
"""Specialize a nest of regular & per-replica values for one replica."""
def _get(x):
# `DistributedValues` would be sliced according to replica unless it is a
# `DistributedVariable` because `DistributedVariable` can be handled
# directly in the replica context.
if (isinstance(x, values_lib.DistributedVariable) or
not isinstance(x, values_lib.DistributedValues)):
return x
else:
return x.values[replica_id]
return nest.map_structure(_get, structured)
def select_replica_mirrored(replica_id, structured):
"""Specialize a nest of regular & mirrored values for one replica."""
def _get_mirrored(x):
if isinstance(x, values_lib.DistributedValues):
if not isinstance(x, values_lib.Mirrored):
raise TypeError(
"Expected value to be mirrored across replicas: %s in %s." %
(x, structured))
return x.values[replica_id]
else:
return x
return nest.map_structure(_get_mirrored, structured)
def update_regroup(extended, updates, group):
"""Regroup for an update, with dependencies to ensure all updates execute."""
if not group:
regrouped = regroup(updates, values_lib.Mirrored)
return nest.map_structure(extended._local_results, regrouped) # pylint: disable=protected-access
def _make_grouped_mirrored(values):
"""Convert per-replica list `values` into Mirrored type with grouping."""
if len(values) == 1:
return values_lib.Mirrored(values)
# Make sure we run all updates. Without this, something like
# session.run(extended.update(...)) may only update one replica.
g = control_flow_ops.group(values)
# If values is just ops, the grouping is enough. Everything in values
# should have the same type, since we expect every replica to be performing
# the same computation.
if not all(tensor_util.is_tensor(v) for v in values):
return g
# Otherwise we need tensors with the same values as `values`, but
# that have a dependency on `g`.
with_dep = []
for v in values:
with ops.device(v.device), ops.control_dependencies([g]):
with_dep.append(array_ops.identity(v))
return values_lib.Mirrored(with_dep)
return regroup(updates, _make_grouped_mirrored)
def value_container(val):
"""Returns the container that this per-replica `value` belongs to.
Args:
val: A value returned by `call_for_each_replica()` or a variable created in
`scope()`.
Returns:
A container that `value` belongs to.
If value does not belong to any container (including the case of
container having been destroyed), returns the value itself.
"""
if (hasattr(val, "_distributed_container") and
# DistributedVariable has _distributed_container defined
# but we don't want to return it.
not isinstance(val, values_lib.DistributedVariable)):
container = val._distributed_container # pylint: disable=protected-access
if container is not None:
return container
return val
def is_distributed_variable(v):
"""Determine if a variable is ds variable or TPU mirrored variable."""
return isinstance(v, values_lib.DistributedVariable)
def _validate_colocate_extended(v, extended):
variable_strategy = v._distribute_strategy # pylint: disable=protected-access
if variable_strategy.extended is not extended:
raise ValueError(
"`colocate_vars_with` must only be passed a variable created in this "
"tf.distribute.Strategy.scope(), not %s created in scope: %s" %
(v, variable_strategy))
def validate_colocate_distributed_variable(v, extended):
if not isinstance(v, values_lib.DistributedVariable):
raise ValueError(
"`colocate_vars_with` must only be passed a variable created in this "
"tf.distribute.Strategy.scope(), not: %r" % (v,))
_validate_colocate_extended(v, extended)
def validate_colocate(v, extended):
if not hasattr(v, "_distribute_strategy"):
raise ValueError(
"`colocate_vars_with` must only be passed a variable created in this "
"tf.distribute.Strategy.scope(), not: %r" % (v,))
_validate_colocate_extended(v, extended)
# Variable creation function for sync strategies.
def create_mirrored_variable( # pylint: disable=missing-docstring
strategy, real_mirrored_creator, mirrored_cls, sync_on_read_cls, **kwargs):
# Figure out what collections this variable should be added to.
# We'll add the MirroredVariable to those collections instead.
var_collections = kwargs.pop("collections", None)
if var_collections is None:
var_collections = [ops.GraphKeys.GLOBAL_VARIABLES]
kwargs["collections"] = []
synchronization = kwargs.get("synchronization",
vs.VariableSynchronization.ON_WRITE)
if synchronization == vs.VariableSynchronization.NONE:
raise ValueError(
"`NONE` variable synchronization mode is not supported with `Mirrored` "
"distribution strategy. Please change the `synchronization` for "
"variable: " + str(kwargs["name"]))
elif synchronization == vs.VariableSynchronization.ON_READ:
is_sync_on_read = True
elif synchronization in (vs.VariableSynchronization.ON_WRITE,
vs.VariableSynchronization.AUTO):
# `AUTO` synchronization defaults to `ON_WRITE`.
is_sync_on_read = False
else:
raise ValueError(
"Invalid variable synchronization mode: %s for variable: %s" %
(synchronization, kwargs["name"]))
aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)
if aggregation not in (vs.VariableAggregation.NONE,
vs.VariableAggregation.SUM,
vs.VariableAggregation.MEAN,
vs.VariableAggregation.ONLY_FIRST_REPLICA):
raise ValueError("Invalid variable aggregation mode: %s for variable: %s" %
(aggregation, kwargs["name"]))
# Ignore user-specified caching device, not needed for mirrored variables.
kwargs.pop("caching_device", None)
# TODO(josh11b,apassos): It would be better if variable initialization
# was never recorded on the tape instead of having to do this manually
# here.
with tape.stop_recording():
value_list = real_mirrored_creator(**kwargs)
var_cls = sync_on_read_cls if is_sync_on_read else mirrored_cls
result = var_cls(strategy, value_list, aggregation)
# Install the created DistributedVariable as _distributed_container property
# of the underlying variables, to make it easy to map back to the container.
for v in result.values:
# Hold a strong reference to avoid the container from being GC-ed. After
# v = v.assign(), the user code may no longer holds references to the
# original container, since v.assign() returns a new DistributedVariable.
v._distributed_container = result # pylint: disable=protected-access
# Add the wrapped variable to the requested collections.
# The handling of eager mode and the global step matches
# ResourceVariable._init_from_args().
if not context.executing_eagerly():
g = ops.get_default_graph()
# If "trainable" is True, next_creator() will add the member variables
# to the TRAINABLE_VARIABLES collection, so we manually remove
# them and replace with the MirroredVariable. We can't set
# "trainable" to False for next_creator() since that causes functions
# like implicit_gradients to skip those variables.
if kwargs.get("trainable", True):
var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
for value in value_list:
for i, trainable_variable in enumerate(l):
if value is trainable_variable:
del l[i]
break
g.add_to_collections(var_collections, result)
elif ops.GraphKeys.GLOBAL_STEP in var_collections:
ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)
return result

View File

@ -29,6 +29,7 @@ from tensorflow.python.data.experimental.ops import distribute
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import multi_device_iterator_ops
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import input_ops
from tensorflow.python.distribute import reduce_util
@ -316,7 +317,7 @@ class DistributedIteratorBase(distribute_types.Iterator):
# Make `replicas` a flat list of values across all replicas.
replicas.extend(
self._iterators[i].get_next_as_list_static_shapes(new_name))
return values.regroup(replicas)
return distribute_utils.regroup(replicas)
out_of_range_replicas = []
def out_of_range_fn(worker_index, device):
@ -349,7 +350,7 @@ class DistributedIteratorBase(distribute_types.Iterator):
results.append(result)
replicas = results
return values.regroup(replicas)
return distribute_utils.regroup(replicas)
class DistributedIteratorV1(DistributedIteratorBase):
@ -577,7 +578,7 @@ class _IterableInput(distribute_types.Iterable):
else:
raise ValueError("Dataset iteration within a tf.function is"
" not supported for multiple workers.")
state = reduce_fn(state, values.regroup(data))
state = reduce_fn(state, distribute_utils.regroup(data))
has_data, data = _get_next_as_optional(iterator, self._strategy)
return has_data, data, state

View File

@ -34,13 +34,13 @@ from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import parameter_server_strategy
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import values
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
@ -188,8 +188,8 @@ class DistributedIteratorTestBase(test.TestCase):
for expected_value in expected_values:
next_element = iterator.get_next()
computed_value = evaluate(
[values.select_replica(r,
next_element) for r in range(len(devices))])
[distribute_utils.select_replica(r, next_element)
for r in range(len(devices))])
self.assertEqual(len(expected_value), len(computed_value))
for i in range(len(expected_value)):
self.assertAllEqual(expected_value[i], computed_value[i])
@ -197,8 +197,8 @@ class DistributedIteratorTestBase(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
next_element = iterator.get_next()
evaluate(
[values.select_replica(r,
next_element) for r in range(len(devices))])
[distribute_utils.select_replica(r, next_element)
for r in range(len(devices))])
# After re-initializing the iterator, should be able to iterate again.
if not ops.executing_eagerly_outside_functions():
@ -212,8 +212,8 @@ class DistributedIteratorTestBase(test.TestCase):
for expected_value in expected_values:
next_element = iterator.get_next()
computed_value = evaluate(
[values.select_replica(r,
next_element) for r in range(len(devices))])
[distribute_utils.select_replica(r, next_element)
for r in range(len(devices))])
self.assertEqual(len(expected_value), len(computed_value))
for i in range(len(expected_value)):
self.assertAllEqual(expected_value[i], computed_value[i])
@ -222,7 +222,8 @@ class DistributedIteratorTestBase(test.TestCase):
actual_values = []
for x in dataset:
computed_value = self.evaluate(
[values.select_replica(r, x) for r in range(len(devices))])
[distribute_utils.select_replica(r, x)
for r in range(len(devices))])
actual_values.append(computed_value)
for i, expected_value in enumerate(expected_values):
self.assertEqual(len(expected_value), len(actual_values[i]))
@ -699,24 +700,29 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
# Assert that the tensors are rebatched and sparsity is preserved.
per_replica_batch = defun(lambda x: next(iter(x)))(dataset)
self.assertAllEqual(
values.select_replica(0, per_replica_batch["dense"]),
distribute_utils.select_replica(0, per_replica_batch["dense"]),
[[0., 0., 0.], [1., 0., 0.], [2., 2., 0.], [3., 3., 3.]])
self.assertAllEqual(
values.select_replica(1, per_replica_batch["dense"]),
distribute_utils.select_replica(1, per_replica_batch["dense"]),
[[0., 0., 0.], [5., 0., 0.], [6., 6., 0.], [7., 7., 7.]])
# Transitively check the ragged and sparse tensors by densification.
for i in range(2):
self.assertLen(
values.select_replica(i, per_replica_batch["ragged"]).values, 6)
distribute_utils.select_replica(i,
per_replica_batch["ragged"]).values,
6)
self.assertAllEqual(
values.select_replica(i, per_replica_batch["ragged"]).to_tensor(),
values.select_replica(i, per_replica_batch["dense"]))
distribute_utils.select_replica(
i, per_replica_batch["ragged"]).to_tensor(),
distribute_utils.select_replica(i, per_replica_batch["dense"]))
self.assertLen(
values.select_replica(i, per_replica_batch["sparse"]).indices, 6)
distribute_utils.select_replica(i,
per_replica_batch["sparse"]).indices,
6)
self.assertAllEqual(
sparse_ops.sparse_tensor_to_dense(
values.select_replica(i, per_replica_batch["sparse"])),
values.select_replica(i, per_replica_batch["dense"]))
distribute_utils.select_replica(i, per_replica_batch["sparse"])),
distribute_utils.select_replica(i, per_replica_batch["dense"]))
# Iterate through all the batches and sum them up.
def sum_batch(per_replica_features):
"""Sums the `PerReplica` values in the `per_replica_features` map."""

View File

@ -22,6 +22,7 @@ import threading
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.distribute import values
@ -123,16 +124,18 @@ class MirroredFunctionExtended(distribute_lib.StrategyExtendedV1):
# use a collective op. This is a particular concern with eager
# execution.
with context.execution_mode(context.ASYNC):
return_values.append(fn(*values.select_replica(index, args),
**values.select_replica(index, kwargs)))
return_values.append(
fn(*distribute_utils.select_replica(index, args),
**distribute_utils.select_replica(index, kwargs)))
else:
return_values.append(fn(*values.select_replica(index, args),
**values.select_replica(index, kwargs)))
return_values.append(
fn(*distribute_utils.select_replica(index, args),
**distribute_utils.select_replica(index, kwargs)))
finally:
_replica_index.graph_outside_run = None
_replica_index.current = None
return values.regroup(return_values)
return distribute_utils.regroup(return_values)
def _local_results(self, val):
if isinstance(val, values.DistributedValues):

View File

@ -27,8 +27,8 @@ from tensorflow.python import pywrap_tfe
from tensorflow.python.autograph.core import ag_ctx as autograph_ctx
from tensorflow.python.autograph.impl import api as autograph
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import shared_variable_creator
from tensorflow.python.distribute import values
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
@ -160,8 +160,8 @@ def _call_for_each_replica(distribution, fn, args, kwargs):
shared_variable_store, index)
t = _MirroredReplicaThread(
distribution, coord, index, devices, variable_creator_fn, fn,
values.select_replica(index, args),
values.select_replica(index, kwargs))
distribute_utils.select_replica(index, args),
distribute_utils.select_replica(index, kwargs))
threads.append(t)
for t in threads:
@ -209,8 +209,10 @@ def _call_for_each_replica(distribution, fn, args, kwargs):
raise RuntimeError("Some replicas made a different number of "
"replica_context().merge_call() calls.")
# get_replica_context().merge_call() case
merge_args = values.regroup(tuple(t.merge_args for t in threads))
merge_kwargs = values.regroup(tuple(t.merge_kwargs for t in threads))
merge_args = distribute_utils.regroup(
tuple(t.merge_args for t in threads))
merge_kwargs = distribute_utils.regroup(
tuple(t.merge_kwargs for t in threads))
# We capture the name_scope of the MRT when we call merge_fn
# to ensure that if we have opened a name scope in the MRT,
# it will be respected when executing the merge function. We only
@ -228,13 +230,13 @@ def _call_for_each_replica(distribution, fn, args, kwargs):
merge_result = threads[0].merge_fn(distribution, *merge_args,
**merge_kwargs)
for r, t in enumerate(threads):
t.merge_result = values.select_replica(r, merge_result)
t.merge_result = distribute_utils.select_replica(r, merge_result)
finally:
for t in threads:
t.should_run.set()
coord.join(threads)
return values.regroup(tuple(t.main_result for t in threads))
return distribute_utils.regroup(tuple(t.main_result for t in threads))
class _MirroredReplicaThread(threading.Thread):

View File

@ -23,6 +23,7 @@ import copy
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import mirrored_run
from tensorflow.python.distribute import multi_worker_util
@ -445,13 +446,13 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
value_list.append(v)
return value_list
return values.create_mirrored_variable(self._container_strategy(),
_real_mirrored_creator,
values.MirroredVariable,
values.SyncOnReadVariable, **kwargs)
return distribute_utils.create_mirrored_variable(
self._container_strategy(), _real_mirrored_creator,
values.MirroredVariable, values.SyncOnReadVariable, **kwargs)
def _validate_colocate_with_variable(self, colocate_with_variable):
values.validate_colocate_distributed_variable(colocate_with_variable, self)
distribute_utils.validate_colocate_distributed_variable(
colocate_with_variable, self)
def _make_dataset_iterator(self, dataset):
return input_lib.DatasetIterator(
@ -507,7 +508,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
per_replica_values.append(value_fn(
distribute_lib.ValueContext(replica_id,
self._num_replicas_in_sync)))
return values.regroup(per_replica_values, always_wrap=True)
return distribute_utils.regroup(per_replica_values, always_wrap=True)
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
@ -557,7 +558,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
# For outputs that have already been reduced, wrap them in a Mirrored
# container, else in a PerReplica container.
if reduce_op is None:
last_step_tensor_outputs_dict[name] = values.regroup(output)
last_step_tensor_outputs_dict[name] = distribute_utils.regroup(output)
else:
assert len(output) == 1
last_step_tensor_outputs_dict[name] = output[0]
@ -580,8 +581,8 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
return self._get_cross_device_ops().broadcast(tensor, destinations)
def _call_for_each_replica(self, fn, args, kwargs):
return mirrored_run.call_for_each_replica(self._container_strategy(), fn,
args, kwargs)
return mirrored_run.call_for_each_replica(
self._container_strategy(), fn, args, kwargs)
def _configure(self,
session_config=None,
@ -643,10 +644,10 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
distribute_lib.UpdateContext(i), \
ops.name_scope(name):
# If args and kwargs are not mirrored, the value is returned as is.
updates.append(fn(v,
*values.select_replica_mirrored(i, args),
**values.select_replica_mirrored(i, kwargs)))
return values.update_regroup(self, updates, group)
updates.append(
fn(v, *distribute_utils.select_replica_mirrored(i, args),
**distribute_utils.select_replica_mirrored(i, kwargs)))
return distribute_utils.update_regroup(self, updates, group)
def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
assert isinstance(colocate_with, tuple)
@ -655,9 +656,10 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
for i, d in enumerate(colocate_with):
name = "update_%d" % i
with ops.device(d), distribute_lib.UpdateContext(i), ops.name_scope(name):
updates.append(fn(*values.select_replica_mirrored(i, args),
**values.select_replica_mirrored(i, kwargs)))
return values.update_regroup(self, updates, group)
updates.append(
fn(*distribute_utils.select_replica_mirrored(i, args),
**distribute_utils.select_replica_mirrored(i, kwargs)))
return distribute_utils.update_regroup(self, updates, group)
def read_var(self, replica_local_var):
"""Read the aggregate value of a replica-local variable."""
@ -672,7 +674,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
return (val,)
def value_container(self, val):
return values.value_container(val)
return distribute_utils.value_container(val)
@property
def _num_replicas_in_sync(self):

View File

@ -31,6 +31,7 @@ from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.distribute import multi_worker_test_base
@ -1026,8 +1027,9 @@ class MirroredStrategyDefunTest(test.TestCase):
result = distribution.extended.call_for_each_replica(
model_fn, args=[mock_model] + inputs)
for r in range(len(devices)):
device_result = values.select_replica(r, result)
device_expected_result = values.select_replica(r, expected_result)
device_result = distribute_utils.select_replica(r, result)
device_expected_result = distribute_utils.select_replica(
r, expected_result)
self.assertAllClose(device_expected_result,
self.evaluate(device_result))

View File

@ -20,9 +20,9 @@ from __future__ import print_function
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import numpy_dataset
from tensorflow.python.distribute import values
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -233,7 +233,7 @@ class OneDeviceStrategy(distribute_lib.Strategy):
return super(OneDeviceStrategy, self).scope()
@tf_export(v1=["distribute.OneDeviceStrategy"]) # pylint: disable=missing-docstring
@tf_export(v1=["distribute.OneDeviceStrategy"]) # pylint: disable=empty-docstring
class OneDeviceStrategyV1(distribute_lib.StrategyV1):
__doc__ = OneDeviceStrategy.__doc__.replace(
@ -272,7 +272,7 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
return next_creator(**kwargs)
def _validate_colocate_with_variable(self, colocate_with_variable):
values.validate_colocate(colocate_with_variable, self)
distribute_utils.validate_colocate(colocate_with_variable, self)
def _make_dataset_iterator(self, dataset):
"""Make iterator from dataset without splitting the batch."""

View File

@ -24,6 +24,7 @@ import copy
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import mirrored_run
from tensorflow.python.distribute import multi_worker_util
@ -334,7 +335,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
compute_devices, self._variable_device)
def _validate_colocate_with_variable(self, colocate_with_variable):
values.validate_colocate(colocate_with_variable, self)
distribute_utils.validate_colocate(colocate_with_variable, self)
def _experimental_distribute_dataset(self, dataset):
return input_lib.get_distributed_dataset(

View File

@ -27,6 +27,7 @@ from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import central_storage_strategy
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import multi_worker_util
@ -34,7 +35,6 @@ from tensorflow.python.distribute import parameter_server_strategy
from tensorflow.python.distribute import ps_values
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import strategy_test_lib
from tensorflow.python.distribute import values
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
@ -534,8 +534,8 @@ class ParameterServerStrategyTestBase(
for expected_value in expected_values:
next_element = iterator.get_next()
computed_value = sess.run([values.select_replica(r, next_element)
for r in range(len(devices))])
computed_value = sess.run([distribute_utils.select_replica(
r, next_element) for r in range(len(devices))])
if ignore_order:
self.assertCountEqual(expected_value, computed_value)
else:
@ -543,7 +543,7 @@ class ParameterServerStrategyTestBase(
with self.assertRaises(errors.OutOfRangeError):
next_element = iterator.get_next()
sess.run([values.select_replica(r, next_element)
sess.run([distribute_utils.select_replica(r, next_element)
for r in range(len(devices))])
# After re-initializing the iterator, should be able to iterate again.
@ -552,8 +552,8 @@ class ParameterServerStrategyTestBase(
for expected_value in expected_values:
next_element = iterator.get_next()
computed_value = sess.run([values.select_replica(r, next_element)
for r in range(len(devices))])
computed_value = sess.run([distribute_utils.select_replica(
r, next_element) for r in range(len(devices))])
if ignore_order:
self.assertCountEqual(expected_value, computed_value)
else:

View File

@ -27,9 +27,9 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.core.util import event_pb2
from tensorflow.python.client import session as session_lib
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import values
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
@ -348,7 +348,8 @@ class DistributionTestBase(test.TestCase):
for expected_value in expected_values:
next_element = iterator.get_next()
computed_value = evaluate(
[values.select_replica(r, next_element) for r in range(len(devices))])
[distribute_utils.select_replica(r, next_element) for r in
range(len(devices))])
if ignore_order:
self.assertCountEqual(expected_value, computed_value)
else:
@ -357,7 +358,8 @@ class DistributionTestBase(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
next_element = iterator.get_next()
evaluate(
[values.select_replica(r, next_element) for r in range(len(devices))])
[distribute_utils.select_replica(r, next_element) for r in
range(len(devices))])
# After re-initializing the iterator, should be able to iterate again.
if test_reinitialize:
@ -366,7 +368,8 @@ class DistributionTestBase(test.TestCase):
for expected_value in expected_values:
next_element = iterator.get_next()
computed_value = evaluate([
values.select_replica(r, next_element) for r in range(len(devices))
distribute_utils.select_replica(r, next_element) for r in
range(len(devices))
])
if ignore_order:
self.assertCountEqual(expected_value, computed_value)

View File

@ -32,6 +32,7 @@ from tensorflow.python.autograph.impl import api as autograph
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import numpy_dataset
from tensorflow.python.distribute import reduce_util
@ -363,7 +364,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
return self._input_workers_obj
def _validate_colocate_with_variable(self, colocate_with_variable):
values.validate_colocate(colocate_with_variable, self)
distribute_utils. validate_colocate(colocate_with_variable, self)
def _make_dataset_iterator(self, dataset):
"""Make iterators for each of the TPU hosts."""
@ -423,7 +424,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
per_replica_values.append(
value_fn(distribute_lib.ValueContext(replica_id,
self._num_replicas_in_sync)))
return values.regroup(per_replica_values, always_wrap=True)
return distribute_utils.regroup(per_replica_values, always_wrap=True)
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
# TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have
@ -462,7 +463,8 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
per_replica_inputs = multi_worker_iterator.get_next()
replicate_inputs = []
for replica_id in range(self._num_replicas_in_sync):
select_replica = lambda x: values.select_replica(replica_id, x) # pylint: disable=cell-var-from-loop
select_replica = lambda x: distribute_utils.select_replica( # pylint: disable=g-long-lambda
replica_id, x) # pylint: disable=cell-var-from-loop
replicate_inputs.append((nest.map_structure(
select_replica, per_replica_inputs),))
@ -648,11 +650,10 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
value_list.append(v)
return value_list
return values.create_mirrored_variable(self._container_strategy(),
_real_mirrored_creator,
tpu_values.TPUMirroredVariable,
tpu_values.TPUSyncOnReadVariable,
**kwargs)
return distribute_utils.create_mirrored_variable(
self._container_strategy(), _real_mirrored_creator,
tpu_values.TPUMirroredVariable, tpu_values.TPUSyncOnReadVariable,
**kwargs)
def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
if (isinstance(value, values.DistributedValues) or
@ -722,10 +723,10 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
distribute_lib.UpdateContext(i), \
ops.name_scope(name):
# If args and kwargs are not mirrored, the value is returned as is.
updates.append(fn(v,
*values.select_replica_mirrored(i, args),
**values.select_replica_mirrored(i, kwargs)))
return values.update_regroup(self, updates, group)
updates.append(
fn(v, *distribute_utils.select_replica_mirrored(i, args),
**distribute_utils.select_replica_mirrored(i, kwargs)))
return distribute_utils.update_regroup(self, updates, group)
def read_var(self, var):
assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance(
@ -893,8 +894,8 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
for i in range(strategy.num_replicas_in_sync):
replicate_inputs.append(
[constant_op.constant(i, dtype=dtypes.int32),
values.select_replica(i, args),
values.select_replica(i, kwargs)])
distribute_utils.select_replica(i, args),
distribute_utils.select_replica(i, kwargs)])
# Construct and pass `maximum_shapes` so that we could support dynamic
# shapes using dynamic padder.
@ -941,7 +942,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
nest.pack_sequence_as(result[0], nest.flatten(replica_output))
for replica_output in replicate_outputs
]
return values.regroup(replicate_outputs)
return distribute_utils.regroup(replicate_outputs)
if context.executing_eagerly():
tpu_function = def_function.function(tpu_function)

View File

@ -26,10 +26,8 @@ from tensorflow.python.distribute import packed_distributed_variable as packed
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import values_util
from tensorflow.python.eager import context
from tensorflow.python.eager import tape
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@ -40,7 +38,6 @@ from tensorflow.python.training.saving import saveable_object
from tensorflow.python.training.saving import saveable_object_util
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.types import core
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@ -1023,88 +1020,6 @@ class SyncOnReadVariable(DistributedVariable):
self._get(), dtype=dtype, name=name, as_ref=as_ref)
# Variable creation function for sync strategies.
def create_mirrored_variable( # pylint: disable=missing-docstring
strategy, real_mirrored_creator, mirrored_cls, sync_on_read_cls, **kwargs):
# Figure out what collections this variable should be added to.
# We'll add the MirroredVariable to those collections instead.
var_collections = kwargs.pop("collections", None)
if var_collections is None:
var_collections = [ops.GraphKeys.GLOBAL_VARIABLES]
kwargs["collections"] = []
synchronization = kwargs.get("synchronization",
vs.VariableSynchronization.ON_WRITE)
if synchronization == vs.VariableSynchronization.NONE:
raise ValueError(
"`NONE` variable synchronization mode is not supported with `Mirrored` "
"distribution strategy. Please change the `synchronization` for "
"variable: " + str(kwargs["name"]))
elif synchronization == vs.VariableSynchronization.ON_READ:
is_sync_on_read = True
elif synchronization in (vs.VariableSynchronization.ON_WRITE,
vs.VariableSynchronization.AUTO):
# `AUTO` synchronization defaults to `ON_WRITE`.
is_sync_on_read = False
else:
raise ValueError(
"Invalid variable synchronization mode: %s for variable: %s" %
(synchronization, kwargs["name"]))
aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)
if aggregation not in (vs.VariableAggregation.NONE,
vs.VariableAggregation.SUM,
vs.VariableAggregation.MEAN,
vs.VariableAggregation.ONLY_FIRST_REPLICA):
raise ValueError("Invalid variable aggregation mode: %s for variable: %s" %
(aggregation, kwargs["name"]))
# Ignore user-specified caching device, not needed for mirrored variables.
kwargs.pop("caching_device", None)
# TODO(josh11b,apassos): It would be better if variable initialization
# was never recorded on the tape instead of having to do this manually
# here.
with tape.stop_recording():
value_list = real_mirrored_creator(**kwargs)
var_cls = sync_on_read_cls if is_sync_on_read else mirrored_cls
result = var_cls(strategy, value_list, aggregation)
# Install the created DistributedVariable as _distributed_container property
# of the underlying variables, to make it easy to map back to the container.
for v in result.values:
# Hold a strong reference to avoid the container from being GC-ed. After
# v = v.assign(), the user code may no longer holds references to the
# original container, since v.assign() returns a new DistributedVariable.
v._distributed_container = result # pylint: disable=protected-access
# Add the wrapped variable to the requested collections.
# The handling of eager mode and the global step matches
# ResourceVariable._init_from_args().
if not context.executing_eagerly():
g = ops.get_default_graph()
# If "trainable" is True, next_creator() will add the member variables
# to the TRAINABLE_VARIABLES collection, so we manually remove
# them and replace with the MirroredVariable. We can't set
# "trainable" to False for next_creator() since that causes functions
# like implicit_gradients to skip those variables.
if kwargs.get("trainable", True):
var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
for value in value_list:
for i, trainable_variable in enumerate(l):
if value is trainable_variable:
del l[i]
break
g.add_to_collections(var_collections, result)
elif ops.GraphKeys.GLOBAL_STEP in var_collections:
ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)
return result
# Register a conversion functions which reads the value of the variable,
# allowing instances of the class to be used as tensors.
# MirroredVariables
@ -1133,214 +1048,3 @@ def _tensor_conversion_sync_on_read(var, dtype=None, name=None, as_ref=False):
ops.register_tensor_conversion_function(SyncOnReadVariable,
_tensor_conversion_sync_on_read)
def regroup(values, wrap_class=PerReplica, always_wrap=False):
"""Makes a nest per-replica into a nest of PerReplica/Mirrored values.
Args:
values: Values to regroup
wrap_class: Class that `values` be wrapped in.
always_wrap: Always wrap the `values` in `wrap_class` even if the values
are the same except for DistributeVariable.
Returns:
Wrapped `values`.
"""
v0 = values[0]
if isinstance(v0, list):
for v in values[1:]:
assert isinstance(v, list)
assert len(v) == len(v0), ("len(v) == %d, len(v0) == %d, v: %s, v0: %s" %
(len(v), len(v0), v, v0))
return [
regroup(tuple(v[i] for v in values), wrap_class)
for i in range(len(v0))
]
if isinstance(v0, tuple):
for v in values[1:]:
assert isinstance(v, tuple)
assert len(v) == len(v0)
regrouped_tuple = tuple(
regroup(tuple(v[i] for v in values), wrap_class)
for i in range(len(v0)))
if hasattr(v0, "_fields"):
# This tuple is in fact a namedtuple! Create a new namedtuple instance
# and initialize it with the regrouped values:
assert hasattr(type(v0), "_make")
return type(v0)._make(regrouped_tuple)
else:
return regrouped_tuple
if isinstance(v0, dict):
v0keys = v0.keys()
for v in values[1:]:
assert isinstance(v, dict), ("v[0]: %r v[i]: %r" % (v0, v))
assert set(v.keys()) == set(v0keys), ("v[0].keys: %s v[i].keys: %s" %
(set(v0keys), set(v.keys())))
# Use the actual type in case it is a class inherited from a dict.
return type(v0)({
key: regroup(tuple(v[key] for v in values), wrap_class)
for key in v0keys
})
# If exactly the same object across all devices, return it unwrapped.
same_id = True
for v in values[1:]:
if v is not v0:
same_id = False
break
# Consider three cases where same_id is true:
# * If v0 is a DistributedVariable (a MirroredVariable or
# SyncOnReadVariable, and same_id means it is the same across all
# devices), we want to return it. We check DistributedVariable
# specifically since it can look like it has a
# _distributed_container member since its members do.
if same_id and isinstance(v0, DistributedVariable):
return v0
# * If v0 is a member of a distributed variable, in which case
# hasattr(v0, "_distributed_container") is true, we want to
# return the DistributedVariable that contains it using the
# _distributed_container logic below. This case can trigger
# same_id when there is only one device.
# * In any other situation, same_id means we return v0 unless `always_wrap` is
# true.
if same_id and not always_wrap and not hasattr(v0, "_distributed_container"):
return v0
# Detect the case where each device has a parallel component of the
# same MirroredVariable (or SyncOnReadVariable). In this case we
# want to return the containing MirroredVariable, after a bunch of
# sanity checking. In particular, each component should have the
# same container, and the devices of the variables should match the
# keys of the per-replica dictionary.
if hasattr(v0, "_distributed_container"):
# pylint: disable=protected-access
assert not isinstance(v0, MirroredVariable), (
"ids = %s, values = %s" % ([id(v) for v in values], values))
distributed_container = v0._distributed_container
assert distributed_container is not None
for v in values[1:]:
assert distributed_container is v._distributed_container
return distributed_container
# pylint: enable=protected-access
return wrap_class(values)
def select_replica(replica_id, structured):
"""Specialize a nest of regular & per-replica values for one replica."""
def _get(x):
# `DistributedValues` would be sliced according to replica unless it is a
# `DistributedVariable` because `DistributedVariable` can be handled
# directly in the replica context.
if (isinstance(x, DistributedVariable) or
not isinstance(x, DistributedValues)):
return x
else:
return x.values[replica_id]
return nest.map_structure(_get, structured)
def select_replica_mirrored(replica_id, structured):
"""Specialize a nest of regular & mirrored values for one replica."""
def _get_mirrored(x):
if isinstance(x, DistributedValues):
if not isinstance(x, Mirrored):
raise TypeError(
"Expected value to be mirrored across replicas: %s in %s." %
(x, structured))
return x.values[replica_id]
else:
return x
return nest.map_structure(_get_mirrored, structured)
def update_regroup(extended, updates, group):
"""Regroup for an update, with dependencies to ensure all updates execute."""
if not group:
regrouped = regroup(updates, Mirrored)
return nest.map_structure(extended._local_results, regrouped) # pylint: disable=protected-access
def _make_grouped_mirrored(values):
"""Convert per-replica list `values` into Mirrored type with grouping."""
if len(values) == 1:
return Mirrored(values)
# Make sure we run all updates. Without this, something like
# session.run(extended.update(...)) may only update one replica.
g = control_flow_ops.group(values)
# If values is just ops, the grouping is enough. Everything in values
# should have the same type, since we expect every replica to be performing
# the same computation.
if not all(tensor_util.is_tensor(v) for v in values):
return g
# Otherwise we need tensors with the same values as `values`, but
# that have a dependency on `g`.
with_dep = []
for v in values:
with ops.device(v.device), ops.control_dependencies([g]):
with_dep.append(array_ops.identity(v))
return Mirrored(with_dep)
return regroup(updates, _make_grouped_mirrored)
def value_container(val):
"""Returns the container that this per-replica `value` belongs to.
Args:
val: A value returned by `call_for_each_replica()` or a variable created in
`scope()`.
Returns:
A container that `value` belongs to.
If value does not belong to any container (including the case of
container having been destroyed), returns the value itself.
"""
if (hasattr(val, "_distributed_container") and
# DistributedVariable has _distributed_container defined
# but we don't want to return it.
not isinstance(val, DistributedVariable)):
container = val._distributed_container # pylint: disable=protected-access
if container is not None:
return container
return val
def is_distributed_variable(v):
"""Determine if a variable is ds variable or TPU mirrored variable."""
return isinstance(v, DistributedVariable)
def _validate_colocate_extended(v, extended):
variable_strategy = v._distribute_strategy # pylint: disable=protected-access
if variable_strategy.extended is not extended:
raise ValueError(
"`colocate_vars_with` must only be passed a variable created in this "
"tf.distribute.Strategy.scope(), not %s created in scope: %s" %
(v, variable_strategy))
def validate_colocate_distributed_variable(v, extended):
if not isinstance(v, DistributedVariable):
raise ValueError(
"`colocate_vars_with` must only be passed a variable created in this "
"tf.distribute.Strategy.scope(), not: %r" % (v,))
_validate_colocate_extended(v, extended)
def validate_colocate(v, extended):
if not hasattr(v, "_distribute_strategy"):
raise ValueError(
"`colocate_vars_with` must only be passed a variable created in this "
"tf.distribute.Strategy.scope(), not: %r" % (v,))
_validate_colocate_extended(v, extended)

View File

@ -30,6 +30,7 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.python import tf2
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import packed_distributed_variable as packed
from tensorflow.python.distribute import strategy_combinations
@ -157,7 +158,7 @@ class DistributedValuesTest(test.TestCase, parameterized.TestCase):
distribution.experimental_distribute_values_from_function(value_fn))
for i in range(distribution.num_replicas_in_sync):
self.assertAllEqual(
values.select_replica(i, distributed_values),
distribute_utils.select_replica(i, distributed_values),
(1. * i, 2. * i, 3. * i))
@combinations.generate(
@ -391,7 +392,7 @@ class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase):
self.assertEqual(exp, result.values[i])
def testNested(self):
result = values.regroup((_nested_value("1"), _nested_value("2")))
result = distribute_utils.regroup((_nested_value("1"), _nested_value("2")))
self.assertIsInstance(result, tuple)
self.assertLen(result, 3)
self._is_per_replica(result[0], ["a1", "a2"])
@ -409,20 +410,20 @@ class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase):
# Also test that we can undo the merge using select_replica()
self.assertEqual(_nested_value("1"),
values.select_replica(0, result))
distribute_utils.select_replica(0, result))
self.assertEqual(_nested_value("2"),
values.select_replica(1, result))
distribute_utils.select_replica(1, result))
# select_device_mirrored() should fail due to non-mirrored values
with self.assertRaises(TypeError):
values.select_replica_mirrored(0, result)
distribute_utils.select_replica_mirrored(0, result)
with self.assertRaises(TypeError):
values.select_replica_mirrored(1, result)
distribute_utils.select_replica_mirrored(1, result)
def testRegroupKeepsDictBasedClass(self):
class DictBasedClass(dict):
"""Dummy class inherited from a dict."""
result = values.regroup(
result = distribute_utils.regroup(
(DictBasedClass(a="a1", b="b1"), DictBasedClass(a="a2", b="b2")))
self.assertIsInstance(result, DictBasedClass)
self._is_per_replica(result["a"], ["a1", "a2"])
@ -431,8 +432,8 @@ class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase):
def testWrapClass(self):
# Normally a mirrored value would be the same across devices, but
# for a test it is convenient to be able to tell the values apart.
result = values.regroup((_nested_value("1"), _nested_value("2")),
values.Mirrored)
result = distribute_utils.regroup((_nested_value("1"), _nested_value("2")),
values.Mirrored)
self.assertIsInstance(result, tuple)
self.assertLen(result, 3)
self._is_per_replica(result[0], ["a1", "a2"], values.Mirrored)
@ -450,17 +451,17 @@ class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase):
# Also test that we can undo the merge using select_replica()
self.assertEqual(_nested_value("1"),
values.select_replica(0, result))
distribute_utils.select_replica(0, result))
self.assertEqual(_nested_value("2"),
values.select_replica(1, result))
distribute_utils.select_replica(1, result))
# Values are marked as mirrored, so select_device_mirrored() is allowed.
self.assertEqual(_nested_value("1"),
values.select_replica_mirrored(0, result))
distribute_utils.select_replica_mirrored(0, result))
self.assertEqual(_nested_value("2"),
values.select_replica_mirrored(1, result))
distribute_utils.select_replica_mirrored(1, result))
def testWrapAListOfTwoTuples(self):
result = values.regroup([("1", "2"), ("3", "4")])
result = distribute_utils.regroup([("1", "2"), ("3", "4")])
self.assertIsInstance(result, tuple)
self.assertLen(result, 2)
self._is_per_replica(result[0], ("1", "3"), values.PerReplica)
@ -478,34 +479,36 @@ class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase):
with distribution.scope():
v = variable_scope.variable(
1., aggregation=variable_scope.VariableAggregation.SUM)
self.assertTrue(values.is_distributed_variable(v))
self.assertTrue(values.is_distributed_variable(values.regroup(v.values)))
self.assertTrue(distribute_utils.is_distributed_variable(v))
self.assertTrue(distribute_utils.is_distributed_variable(
distribute_utils.regroup(v.values)))
def testSameId(self):
foo = object()
result = values.regroup((("a", foo), ("b", foo)))
result = distribute_utils.regroup((("a", foo), ("b", foo)))
self.assertIsInstance(result, tuple)
self.assertLen(result, 2)
self._is_per_replica(result[0], ["a", "b"])
self.assertIs(foo, result[1])
# Test select_replica(), should undo the merge done by regroup().
result_0 = values.select_replica(0, result)
result_0 = distribute_utils.select_replica(0, result)
self.assertIsInstance(result_0, tuple)
self.assertLen(result_0, 2)
self.assertEqual("a", result_0[0])
self.assertIs(foo, result_0[1])
result_1 = values.select_replica(1, result)
result_1 = distribute_utils.select_replica(1, result)
self.assertIsInstance(result_1, tuple)
self.assertLen(result_1, 2)
self.assertEqual("b", result_1[0])
self.assertIs(foo, result_1[1])
def testOneDevice(self):
result = values.regroup((_nested_value("1"),))
result = distribute_utils.regroup((_nested_value("1"),))
# On one device regroup() and select_replica() are basically identity.
self.assertEqual(_nested_value("1"), result)
self.assertEqual(_nested_value("1"), values.select_replica(0, result))
self.assertEqual(_nested_value("1"),
distribute_utils.select_replica(0, result))
def testNamedTuple(self):
@ -533,7 +536,7 @@ class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase):
train_op=array_ops.identity(constant_op.constant(device_id)))
created_estimator_specs.append(spec)
merged_estimator_spec = values.regroup(created_estimator_specs)
merged_estimator_spec = distribute_utils.regroup(created_estimator_specs)
self.assertIsInstance(merged_estimator_spec, EstimatorSpec)
self.assertEqual(mode_keys.EstimatorModeKeys.TRAIN,
@ -550,8 +553,8 @@ class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase):
Scaffold)
# Also test that we can undo the merge using select_replica()
self.assertEqual(created_estimator_specs[device_id],
values.select_replica(device_id,
merged_estimator_spec))
distribute_utils.select_replica(
device_id, merged_estimator_spec))
@combinations.generate(
@ -630,7 +633,7 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase):
with distribution.scope():
v = variables_lib.Variable(
1., synchronization=synchronization, aggregation=aggregation)
self.assertIs(v, values.select_replica(0, v))
self.assertIs(v, distribute_utils.select_replica(0, v))
def testIsTensorLike(self, distribution, synchronization, aggregation):
if isinstance(distribution.extended,
@ -973,8 +976,8 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
return v.assign(ctx.replica_id_in_sync_group)
# disallow assign() with distributed value in replica context.
with self.assertRaisesRegexp(ValueError,
"Cannot update non-float variables"):
with self.assertRaisesRegex(ValueError,
"Cannot update non-float variables"):
self.evaluate(
distribution.experimental_local_results(
distribution.run(assign)))
@ -2174,11 +2177,11 @@ class PerReplicaTest(test.TestCase, parameterized.TestCase):
condition, lambda: per_replica_1, lambda: per_replica_2)
def _make_index_slices(values, indices, dense_shape=None):
def _make_index_slices(vals, indices, dense_shape=None):
if dense_shape:
dense_shape = array_ops.identity(dense_shape)
return indexed_slices.IndexedSlices(
array_ops.identity(values), array_ops.identity(indices), dense_shape)
array_ops.identity(vals), array_ops.identity(indices), dense_shape)
if __name__ == "__main__":

View File

@ -363,6 +363,7 @@ py_library(
"//tensorflow/python:util",
"//tensorflow/python:variables",
"//tensorflow/python/distribute:distribute_lib",
"//tensorflow/python/distribute:distribute_utils",
"//tensorflow/python/distribute:values",
"//tensorflow/python/training/tracking",
"//tensorflow/python/training/tracking:base",

View File

@ -22,8 +22,8 @@ import functools
import os
from tensorflow.core.protobuf import graph_debug_info_pb2
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import values as ds_values
from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
@ -87,10 +87,11 @@ class _WrapperFunction(function.ConcreteFunction):
def _call_flat(self, args, captured_inputs, cancellation_manager=None):
def get_in_replica_handle(x):
return x.handle if ds_values.is_distributed_variable(x) else x
return x.handle if distribute_utils.is_distributed_variable(x) else x
def get_cross_replica_handle(x):
return _unused_handle() if ds_values.is_distributed_variable(x) else x
return _unused_handle() if distribute_utils.is_distributed_variable(x) \
else x
if ds_context.get_replica_context() is not None: # in-replica context
captured_inputs = list(map(get_in_replica_handle, captured_inputs))
@ -201,7 +202,7 @@ class Loader(object):
if bound_inputs:
for bound_input, internal_capture in zip(
bound_inputs, concrete_function.inputs[-len(bound_inputs):]):
if ds_values.is_distributed_variable(bound_input):
if distribute_utils.is_distributed_variable(bound_input):
concrete_function.graph.capture_distributed_variable(
bound_input, internal_capture)
else:
@ -227,7 +228,7 @@ class Loader(object):
"""Resolves a node id into a tensor to be captured for a function."""
with ops.init_scope():
obj = self._nodes[node_id]
if ds_values.is_distributed_variable(obj):
if distribute_utils.is_distributed_variable(obj):
return obj
elif resource_variable_ops.is_resource_variable(obj):
return obj.handle

View File

@ -26,7 +26,7 @@ from tensorflow.core.framework import versions_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.core.protobuf import saved_object_graph_pb2
from tensorflow.python.distribute import values as ds_values
from tensorflow.python.distribute import distribute_utils as ds_utils
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function as defun
@ -273,13 +273,13 @@ class _SaveableView(object):
# pylint: enable=protected-access
resource_map[obj.resource_handle] = new_resource
self.captured_tensor_node_ids[obj.resource_handle] = node_id
elif (ds_values.is_distributed_variable(obj) or
elif (ds_utils.is_distributed_variable(obj) or
resource_variable_ops.is_resource_variable(obj)):
obj_to_copy = obj._primary if ds_values.is_distributed_variable( # pylint: disable=protected-access
obj_to_copy = obj._primary if ds_utils.is_distributed_variable( # pylint: disable=protected-access
obj) else obj
new_variable = resource_variable_ops.copy_to_graph_uninitialized(
obj_to_copy)
if ds_values.is_distributed_variable(obj):
if ds_utils.is_distributed_variable(obj):
self.captured_tensor_node_ids[obj] = node_id
for v in obj.values:
object_map[v] = new_variable

View File

@ -25,10 +25,10 @@ from absl import logging
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import sharded_variable
from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.distribute import values as tf_values
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -1131,8 +1131,10 @@ class TPUEmbedding(tracking.AutoTrackable):
# in the same (standard) order as self._strategy.extended.worker_devices.
enqueue_ops = []
for replica_id in range(self._strategy.num_replicas_in_sync):
replica_inputs = tf_values.select_replica(replica_id, flat_inputs)
replica_weights = tf_values.select_replica(replica_id, flat_weights)
replica_inputs = distribute_utils.select_replica(replica_id,
flat_inputs)
replica_weights = distribute_utils.select_replica(replica_id,
flat_weights)
tpu_device = self._strategy.extended.worker_devices[replica_id]
# TPU devices string are like /job:worker/replica:0/task:0/device:TPU:0
# the device ordinal is the last number