Another round of refactoring of values.py to split utility functions that use distributed Variable types defined in values.py.
PiperOrigin-RevId: 316147517 Change-Id: I72e17b02e8f41c9cee40f4ec7f56fec2f7d860a9
This commit is contained in:
parent
1cd053d409
commit
0f3562ba77
@ -286,6 +286,36 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "distribute_utils",
|
||||
srcs = ["distribute_utils.py"],
|
||||
deps = [
|
||||
":device_util",
|
||||
":distribute_lib",
|
||||
":reduce_util",
|
||||
":shared_variable_creator",
|
||||
":values",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:config",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:device",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:pywrap_tfe",
|
||||
"//tensorflow/python:summary_ops_v2",
|
||||
"//tensorflow/python:tensor_util",
|
||||
"//tensorflow/python:tf_export",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/autograph/core",
|
||||
"//tensorflow/python/autograph/impl",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "mirrored_strategy",
|
||||
srcs = ["mirrored_strategy.py"],
|
||||
@ -293,6 +323,7 @@ py_library(
|
||||
":cross_device_ops",
|
||||
":device_util",
|
||||
":distribute_lib",
|
||||
":distribute_utils",
|
||||
":input_lib",
|
||||
":mirrored_run",
|
||||
":multi_worker_util",
|
||||
@ -320,6 +351,7 @@ py_library(
|
||||
":cross_device_ops",
|
||||
":device_util",
|
||||
":distribute_lib",
|
||||
":distribute_utils",
|
||||
":input_lib",
|
||||
":mirrored_run",
|
||||
":multi_worker_util",
|
||||
@ -508,6 +540,7 @@ py_library(
|
||||
deps = [
|
||||
":device_util",
|
||||
":distribute_lib",
|
||||
":distribute_utils",
|
||||
":input_ops",
|
||||
":values",
|
||||
"//tensorflow/python:framework_ops",
|
||||
|
@ -28,11 +28,11 @@ from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import collective_all_reduce_strategy
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import cross_device_utils
|
||||
from tensorflow.python.distribute import distribute_utils
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.distribute import multi_worker_util
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import strategy_test_lib
|
||||
from tensorflow.python.distribute import values
|
||||
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -362,8 +362,8 @@ class CollectiveAllReduceStrategyTestBase(
|
||||
|
||||
for expected_value in expected_values:
|
||||
next_element = iterator.get_next()
|
||||
computed_value = sess.run([values.select_replica(r, next_element)
|
||||
for r in range(len(devices))])
|
||||
computed_value = sess.run([distribute_utils.select_replica(
|
||||
r, next_element) for r in range(len(devices))])
|
||||
if ignore_order:
|
||||
self.assertCountEqual(list(expected_value), list(computed_value))
|
||||
else:
|
||||
@ -371,7 +371,7 @@ class CollectiveAllReduceStrategyTestBase(
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
next_element = iterator.get_next()
|
||||
sess.run([values.select_replica(r, next_element)
|
||||
sess.run([distribute_utils.select_replica(r, next_element)
|
||||
for r in range(len(devices))])
|
||||
|
||||
# After re-initializing the iterator, should be able to iterate again.
|
||||
@ -380,8 +380,9 @@ class CollectiveAllReduceStrategyTestBase(
|
||||
|
||||
for expected_value in expected_values:
|
||||
next_element = iterator.get_next()
|
||||
computed_value = sess.run([values.select_replica(r, next_element)
|
||||
for r in range(len(devices))])
|
||||
computed_value = sess.run([
|
||||
distribute_utils.select_replica(r, next_element)
|
||||
for r in range(len(devices))])
|
||||
if ignore_order:
|
||||
self.assertCountEqual(list(expected_value), list(computed_value))
|
||||
else:
|
||||
|
@ -28,6 +28,7 @@ from tensorflow.python.client import device_lib
|
||||
from tensorflow.python.distribute import collective_util
|
||||
from tensorflow.python.distribute import cross_device_utils
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import distribute_utils
|
||||
from tensorflow.python.distribute import ps_values
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import tpu_values
|
||||
@ -187,7 +188,8 @@ def simple_broadcast(value, destinations, always_mirrored=False):
|
||||
for d in devices:
|
||||
value_updates.append(
|
||||
cross_device_utils.copy_tensor_or_indexed_slices_to_device(value, d))
|
||||
return value_lib.regroup(value_updates, wrap_class=value_lib.Mirrored)
|
||||
return distribute_utils.regroup(value_updates,
|
||||
wrap_class=value_lib.Mirrored)
|
||||
|
||||
|
||||
def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn,
|
||||
@ -259,7 +261,7 @@ class CrossDeviceOps(object):
|
||||
per_replica_value, destinations):
|
||||
with ops.device(per_replica_value.values[0].device):
|
||||
v = array_ops.identity(per_replica_value.values[0])
|
||||
return value_lib.regroup((v,), wrap_class=value_lib.Mirrored)
|
||||
return distribute_utils.regroup((v,), wrap_class=value_lib.Mirrored)
|
||||
|
||||
if experimental_hints is None:
|
||||
experimental_hints = collective_util.Hints()
|
||||
@ -309,7 +311,7 @@ class CrossDeviceOps(object):
|
||||
value_destination_pairs) and len(
|
||||
value_destination_pairs[0][0].values) == 1:
|
||||
return [
|
||||
value_lib.regroup(v.values, wrap_class=value_lib.Mirrored)
|
||||
distribute_utils.regroup(v.values, wrap_class=value_lib.Mirrored)
|
||||
for v, _ in value_destination_pairs
|
||||
]
|
||||
|
||||
@ -510,7 +512,8 @@ def _ungroup_and_make_mirrored(grouped_reduced,
|
||||
index[i].append(v / num_replicas)
|
||||
else:
|
||||
index[i].append(v)
|
||||
return [value_lib.regroup(v, wrap_class=value_lib.Mirrored) for v in index]
|
||||
return [distribute_utils.regroup(
|
||||
v, wrap_class=value_lib.Mirrored) for v in index]
|
||||
|
||||
|
||||
class _ConcatAndSplitPacker(object):
|
||||
@ -1000,7 +1003,7 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
# TODO(josh11b): Once we add support for model parallelism, get the
|
||||
# copy from the corresponding replica instead of the primary.
|
||||
index.append(array_ops.identity(all_reduced._primary)) # pylint: disable=protected-access
|
||||
return value_lib.regroup(index, wrap_class=value_lib.Mirrored)
|
||||
return distribute_utils.regroup(index, wrap_class=value_lib.Mirrored)
|
||||
|
||||
def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
|
||||
experimental_hints):
|
||||
@ -1104,7 +1107,8 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
for i, v in enumerate(value):
|
||||
with ops.device(v.device):
|
||||
value[i] = v / num_replicas
|
||||
mirrored.append(value_lib.regroup(value, wrap_class=value_lib.Mirrored))
|
||||
mirrored.append(distribute_utils.regroup(value,
|
||||
wrap_class=value_lib.Mirrored))
|
||||
return mirrored
|
||||
|
||||
def _do_batch_all_reduce_sparse(self, reduce_op, per_replica_values):
|
||||
@ -1140,7 +1144,8 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
for i, v in enumerate(value):
|
||||
with ops.device(v.device):
|
||||
value[i].values = value[i].values / num_replicas
|
||||
mirrored.append(value_lib.regroup(value, wrap_class=value_lib.Mirrored))
|
||||
mirrored.append(distribute_utils.regroup(value,
|
||||
wrap_class=value_lib.Mirrored))
|
||||
return mirrored
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
|
@ -32,6 +32,7 @@ from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
|
||||
from tensorflow.python.distribute import cross_device_utils
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import distribute_utils
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
@ -73,7 +74,7 @@ def _make_per_replica(values, devices, regroup=False):
|
||||
with ops.device(d):
|
||||
placed_v = array_ops.identity(v)
|
||||
index.append(placed_v)
|
||||
return value_lib.regroup(index)
|
||||
return distribute_utils.regroup(index)
|
||||
|
||||
|
||||
# pylint: disable=g-doc-args,g-doc-return-or-yield
|
||||
@ -88,7 +89,7 @@ def _fake_mirrored(value, devices):
|
||||
for d in devices:
|
||||
with ops.device(d):
|
||||
values.append(array_ops.identity(value))
|
||||
return value_lib.regroup(
|
||||
return distribute_utils.regroup(
|
||||
values,
|
||||
wrap_class=value_lib.Mirrored)
|
||||
|
||||
@ -105,7 +106,7 @@ def _make_indexed_slices(values, indices, dense_shape, device):
|
||||
def _make_mirrored_indexed_slices(devices, values, indices, dense_shape):
|
||||
values = [_make_indexed_slices(values, indices, dense_shape, d)
|
||||
for d in devices]
|
||||
return value_lib.regroup(
|
||||
return distribute_utils.regroup(
|
||||
values,
|
||||
wrap_class=value_lib.Mirrored)
|
||||
|
||||
|
322
tensorflow/python/distribute/distribute_utils.py
Normal file
322
tensorflow/python/distribute/distribute_utils.py
Normal file
@ -0,0 +1,322 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Class implementing utilities used by tf.distribute.Strategy."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.distribute import values as values_lib
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import tape
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
def regroup(values, wrap_class=values_lib.PerReplica, always_wrap=False):
|
||||
"""Makes a nest per-replica into a nest of PerReplica/Mirrored values.
|
||||
|
||||
Args:
|
||||
values: Values to regroup
|
||||
wrap_class: Class that `values` be wrapped in.
|
||||
always_wrap: Always wrap the `values` in `wrap_class` even if the values
|
||||
are the same except for DistributeVariable.
|
||||
Returns:
|
||||
Wrapped `values`.
|
||||
"""
|
||||
v0 = values[0]
|
||||
|
||||
if isinstance(v0, list):
|
||||
for v in values[1:]:
|
||||
assert isinstance(v, list)
|
||||
assert len(v) == len(v0), ("len(v) == %d, len(v0) == %d, v: %s, v0: %s" %
|
||||
(len(v), len(v0), v, v0))
|
||||
return [
|
||||
regroup(tuple(v[i] for v in values), wrap_class)
|
||||
for i in range(len(v0))
|
||||
]
|
||||
|
||||
if isinstance(v0, tuple):
|
||||
for v in values[1:]:
|
||||
assert isinstance(v, tuple)
|
||||
assert len(v) == len(v0)
|
||||
regrouped_tuple = tuple(
|
||||
regroup(tuple(v[i] for v in values), wrap_class)
|
||||
for i in range(len(v0)))
|
||||
if hasattr(v0, "_fields"):
|
||||
# This tuple is in fact a namedtuple! Create a new namedtuple instance
|
||||
# and initialize it with the regrouped values:
|
||||
assert hasattr(type(v0), "_make")
|
||||
return type(v0)._make(regrouped_tuple)
|
||||
else:
|
||||
return regrouped_tuple
|
||||
|
||||
if isinstance(v0, dict):
|
||||
v0keys = v0.keys()
|
||||
for v in values[1:]:
|
||||
assert isinstance(v, dict), ("v[0]: %r v[i]: %r" % (v0, v))
|
||||
assert set(v.keys()) == set(v0keys), ("v[0].keys: %s v[i].keys: %s" %
|
||||
(set(v0keys), set(v.keys())))
|
||||
# Use the actual type in case it is a class inherited from a dict.
|
||||
return type(v0)({
|
||||
key: regroup(tuple(v[key] for v in values), wrap_class)
|
||||
for key in v0keys
|
||||
})
|
||||
|
||||
# If exactly the same object across all devices, return it unwrapped.
|
||||
same_id = True
|
||||
for v in values[1:]:
|
||||
if v is not v0:
|
||||
same_id = False
|
||||
break
|
||||
# Consider three cases where same_id is true:
|
||||
# * If v0 is a DistributedVariable (a MirroredVariable or
|
||||
# SyncOnReadVariable, and same_id means it is the same across all
|
||||
# devices), we want to return it. We check DistributedVariable
|
||||
# specifically since it can look like it has a
|
||||
# _distributed_container member since its members do.
|
||||
if same_id and isinstance(v0, values_lib.DistributedVariable):
|
||||
return v0
|
||||
# * If v0 is a member of a distributed variable, in which case
|
||||
# hasattr(v0, "_distributed_container") is true, we want to
|
||||
# return the DistributedVariable that contains it using the
|
||||
# _distributed_container logic below. This case can trigger
|
||||
# same_id when there is only one device.
|
||||
# * In any other situation, same_id means we return v0 unless `always_wrap` is
|
||||
# true.
|
||||
if same_id and not always_wrap and not hasattr(v0, "_distributed_container"):
|
||||
return v0
|
||||
|
||||
# Detect the case where each device has a parallel component of the
|
||||
# same MirroredVariable (or SyncOnReadVariable). In this case we
|
||||
# want to return the containing MirroredVariable, after a bunch of
|
||||
# sanity checking. In particular, each component should have the
|
||||
# same container, and the devices of the variables should match the
|
||||
# keys of the per-replica dictionary.
|
||||
if hasattr(v0, "_distributed_container"):
|
||||
# pylint: disable=protected-access
|
||||
assert not isinstance(v0, values_lib.MirroredVariable), (
|
||||
"ids = %s, values = %s" % ([id(v) for v in values], values))
|
||||
distributed_container = v0._distributed_container
|
||||
assert distributed_container is not None
|
||||
for v in values[1:]:
|
||||
assert distributed_container is v._distributed_container
|
||||
return distributed_container
|
||||
# pylint: enable=protected-access
|
||||
|
||||
return wrap_class(values)
|
||||
|
||||
|
||||
def select_replica(replica_id, structured):
|
||||
"""Specialize a nest of regular & per-replica values for one replica."""
|
||||
|
||||
def _get(x):
|
||||
# `DistributedValues` would be sliced according to replica unless it is a
|
||||
# `DistributedVariable` because `DistributedVariable` can be handled
|
||||
# directly in the replica context.
|
||||
if (isinstance(x, values_lib.DistributedVariable) or
|
||||
not isinstance(x, values_lib.DistributedValues)):
|
||||
return x
|
||||
else:
|
||||
return x.values[replica_id]
|
||||
|
||||
return nest.map_structure(_get, structured)
|
||||
|
||||
|
||||
def select_replica_mirrored(replica_id, structured):
|
||||
"""Specialize a nest of regular & mirrored values for one replica."""
|
||||
|
||||
def _get_mirrored(x):
|
||||
if isinstance(x, values_lib.DistributedValues):
|
||||
if not isinstance(x, values_lib.Mirrored):
|
||||
raise TypeError(
|
||||
"Expected value to be mirrored across replicas: %s in %s." %
|
||||
(x, structured))
|
||||
return x.values[replica_id]
|
||||
else:
|
||||
return x
|
||||
|
||||
return nest.map_structure(_get_mirrored, structured)
|
||||
|
||||
|
||||
def update_regroup(extended, updates, group):
|
||||
"""Regroup for an update, with dependencies to ensure all updates execute."""
|
||||
if not group:
|
||||
regrouped = regroup(updates, values_lib.Mirrored)
|
||||
return nest.map_structure(extended._local_results, regrouped) # pylint: disable=protected-access
|
||||
|
||||
def _make_grouped_mirrored(values):
|
||||
"""Convert per-replica list `values` into Mirrored type with grouping."""
|
||||
if len(values) == 1:
|
||||
return values_lib.Mirrored(values)
|
||||
|
||||
# Make sure we run all updates. Without this, something like
|
||||
# session.run(extended.update(...)) may only update one replica.
|
||||
g = control_flow_ops.group(values)
|
||||
|
||||
# If values is just ops, the grouping is enough. Everything in values
|
||||
# should have the same type, since we expect every replica to be performing
|
||||
# the same computation.
|
||||
if not all(tensor_util.is_tensor(v) for v in values):
|
||||
return g
|
||||
|
||||
# Otherwise we need tensors with the same values as `values`, but
|
||||
# that have a dependency on `g`.
|
||||
with_dep = []
|
||||
for v in values:
|
||||
with ops.device(v.device), ops.control_dependencies([g]):
|
||||
with_dep.append(array_ops.identity(v))
|
||||
|
||||
return values_lib.Mirrored(with_dep)
|
||||
|
||||
return regroup(updates, _make_grouped_mirrored)
|
||||
|
||||
|
||||
def value_container(val):
|
||||
"""Returns the container that this per-replica `value` belongs to.
|
||||
|
||||
Args:
|
||||
val: A value returned by `call_for_each_replica()` or a variable created in
|
||||
`scope()`.
|
||||
|
||||
Returns:
|
||||
A container that `value` belongs to.
|
||||
If value does not belong to any container (including the case of
|
||||
container having been destroyed), returns the value itself.
|
||||
"""
|
||||
if (hasattr(val, "_distributed_container") and
|
||||
# DistributedVariable has _distributed_container defined
|
||||
# but we don't want to return it.
|
||||
not isinstance(val, values_lib.DistributedVariable)):
|
||||
container = val._distributed_container # pylint: disable=protected-access
|
||||
if container is not None:
|
||||
return container
|
||||
return val
|
||||
|
||||
|
||||
def is_distributed_variable(v):
|
||||
"""Determine if a variable is ds variable or TPU mirrored variable."""
|
||||
return isinstance(v, values_lib.DistributedVariable)
|
||||
|
||||
|
||||
def _validate_colocate_extended(v, extended):
|
||||
variable_strategy = v._distribute_strategy # pylint: disable=protected-access
|
||||
if variable_strategy.extended is not extended:
|
||||
raise ValueError(
|
||||
"`colocate_vars_with` must only be passed a variable created in this "
|
||||
"tf.distribute.Strategy.scope(), not %s created in scope: %s" %
|
||||
(v, variable_strategy))
|
||||
|
||||
|
||||
def validate_colocate_distributed_variable(v, extended):
|
||||
if not isinstance(v, values_lib.DistributedVariable):
|
||||
raise ValueError(
|
||||
"`colocate_vars_with` must only be passed a variable created in this "
|
||||
"tf.distribute.Strategy.scope(), not: %r" % (v,))
|
||||
_validate_colocate_extended(v, extended)
|
||||
|
||||
|
||||
def validate_colocate(v, extended):
|
||||
if not hasattr(v, "_distribute_strategy"):
|
||||
raise ValueError(
|
||||
"`colocate_vars_with` must only be passed a variable created in this "
|
||||
"tf.distribute.Strategy.scope(), not: %r" % (v,))
|
||||
_validate_colocate_extended(v, extended)
|
||||
|
||||
|
||||
# Variable creation function for sync strategies.
|
||||
def create_mirrored_variable( # pylint: disable=missing-docstring
|
||||
strategy, real_mirrored_creator, mirrored_cls, sync_on_read_cls, **kwargs):
|
||||
# Figure out what collections this variable should be added to.
|
||||
# We'll add the MirroredVariable to those collections instead.
|
||||
var_collections = kwargs.pop("collections", None)
|
||||
if var_collections is None:
|
||||
var_collections = [ops.GraphKeys.GLOBAL_VARIABLES]
|
||||
kwargs["collections"] = []
|
||||
|
||||
synchronization = kwargs.get("synchronization",
|
||||
vs.VariableSynchronization.ON_WRITE)
|
||||
|
||||
if synchronization == vs.VariableSynchronization.NONE:
|
||||
raise ValueError(
|
||||
"`NONE` variable synchronization mode is not supported with `Mirrored` "
|
||||
"distribution strategy. Please change the `synchronization` for "
|
||||
"variable: " + str(kwargs["name"]))
|
||||
elif synchronization == vs.VariableSynchronization.ON_READ:
|
||||
is_sync_on_read = True
|
||||
elif synchronization in (vs.VariableSynchronization.ON_WRITE,
|
||||
vs.VariableSynchronization.AUTO):
|
||||
# `AUTO` synchronization defaults to `ON_WRITE`.
|
||||
is_sync_on_read = False
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid variable synchronization mode: %s for variable: %s" %
|
||||
(synchronization, kwargs["name"]))
|
||||
|
||||
aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)
|
||||
|
||||
if aggregation not in (vs.VariableAggregation.NONE,
|
||||
vs.VariableAggregation.SUM,
|
||||
vs.VariableAggregation.MEAN,
|
||||
vs.VariableAggregation.ONLY_FIRST_REPLICA):
|
||||
raise ValueError("Invalid variable aggregation mode: %s for variable: %s" %
|
||||
(aggregation, kwargs["name"]))
|
||||
|
||||
# Ignore user-specified caching device, not needed for mirrored variables.
|
||||
kwargs.pop("caching_device", None)
|
||||
|
||||
# TODO(josh11b,apassos): It would be better if variable initialization
|
||||
# was never recorded on the tape instead of having to do this manually
|
||||
# here.
|
||||
with tape.stop_recording():
|
||||
value_list = real_mirrored_creator(**kwargs)
|
||||
var_cls = sync_on_read_cls if is_sync_on_read else mirrored_cls
|
||||
result = var_cls(strategy, value_list, aggregation)
|
||||
# Install the created DistributedVariable as _distributed_container property
|
||||
# of the underlying variables, to make it easy to map back to the container.
|
||||
for v in result.values:
|
||||
# Hold a strong reference to avoid the container from being GC-ed. After
|
||||
# v = v.assign(), the user code may no longer holds references to the
|
||||
# original container, since v.assign() returns a new DistributedVariable.
|
||||
v._distributed_container = result # pylint: disable=protected-access
|
||||
|
||||
# Add the wrapped variable to the requested collections.
|
||||
# The handling of eager mode and the global step matches
|
||||
# ResourceVariable._init_from_args().
|
||||
if not context.executing_eagerly():
|
||||
g = ops.get_default_graph()
|
||||
# If "trainable" is True, next_creator() will add the member variables
|
||||
# to the TRAINABLE_VARIABLES collection, so we manually remove
|
||||
# them and replace with the MirroredVariable. We can't set
|
||||
# "trainable" to False for next_creator() since that causes functions
|
||||
# like implicit_gradients to skip those variables.
|
||||
if kwargs.get("trainable", True):
|
||||
var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
|
||||
l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
|
||||
for value in value_list:
|
||||
for i, trainable_variable in enumerate(l):
|
||||
if value is trainable_variable:
|
||||
del l[i]
|
||||
break
|
||||
|
||||
g.add_to_collections(var_collections, result)
|
||||
elif ops.GraphKeys.GLOBAL_STEP in var_collections:
|
||||
ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)
|
||||
|
||||
return result
|
@ -29,6 +29,7 @@ from tensorflow.python.data.experimental.ops import distribute
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import multi_device_iterator_ops
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import distribute_utils
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.distribute import input_ops
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
@ -316,7 +317,7 @@ class DistributedIteratorBase(distribute_types.Iterator):
|
||||
# Make `replicas` a flat list of values across all replicas.
|
||||
replicas.extend(
|
||||
self._iterators[i].get_next_as_list_static_shapes(new_name))
|
||||
return values.regroup(replicas)
|
||||
return distribute_utils.regroup(replicas)
|
||||
|
||||
out_of_range_replicas = []
|
||||
def out_of_range_fn(worker_index, device):
|
||||
@ -349,7 +350,7 @@ class DistributedIteratorBase(distribute_types.Iterator):
|
||||
results.append(result)
|
||||
replicas = results
|
||||
|
||||
return values.regroup(replicas)
|
||||
return distribute_utils.regroup(replicas)
|
||||
|
||||
|
||||
class DistributedIteratorV1(DistributedIteratorBase):
|
||||
@ -577,7 +578,7 @@ class _IterableInput(distribute_types.Iterable):
|
||||
else:
|
||||
raise ValueError("Dataset iteration within a tf.function is"
|
||||
" not supported for multiple workers.")
|
||||
state = reduce_fn(state, values.regroup(data))
|
||||
state = reduce_fn(state, distribute_utils.regroup(data))
|
||||
has_data, data = _get_next_as_optional(iterator, self._strategy)
|
||||
return has_data, data, state
|
||||
|
||||
|
@ -34,13 +34,13 @@ from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import distribute_utils
|
||||
from tensorflow.python.distribute import input_lib
|
||||
from tensorflow.python.distribute import mirrored_strategy
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.distribute import parameter_server_strategy
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.distribute import values
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
@ -188,8 +188,8 @@ class DistributedIteratorTestBase(test.TestCase):
|
||||
for expected_value in expected_values:
|
||||
next_element = iterator.get_next()
|
||||
computed_value = evaluate(
|
||||
[values.select_replica(r,
|
||||
next_element) for r in range(len(devices))])
|
||||
[distribute_utils.select_replica(r, next_element)
|
||||
for r in range(len(devices))])
|
||||
self.assertEqual(len(expected_value), len(computed_value))
|
||||
for i in range(len(expected_value)):
|
||||
self.assertAllEqual(expected_value[i], computed_value[i])
|
||||
@ -197,8 +197,8 @@ class DistributedIteratorTestBase(test.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
next_element = iterator.get_next()
|
||||
evaluate(
|
||||
[values.select_replica(r,
|
||||
next_element) for r in range(len(devices))])
|
||||
[distribute_utils.select_replica(r, next_element)
|
||||
for r in range(len(devices))])
|
||||
|
||||
# After re-initializing the iterator, should be able to iterate again.
|
||||
if not ops.executing_eagerly_outside_functions():
|
||||
@ -212,8 +212,8 @@ class DistributedIteratorTestBase(test.TestCase):
|
||||
for expected_value in expected_values:
|
||||
next_element = iterator.get_next()
|
||||
computed_value = evaluate(
|
||||
[values.select_replica(r,
|
||||
next_element) for r in range(len(devices))])
|
||||
[distribute_utils.select_replica(r, next_element)
|
||||
for r in range(len(devices))])
|
||||
self.assertEqual(len(expected_value), len(computed_value))
|
||||
for i in range(len(expected_value)):
|
||||
self.assertAllEqual(expected_value[i], computed_value[i])
|
||||
@ -222,7 +222,8 @@ class DistributedIteratorTestBase(test.TestCase):
|
||||
actual_values = []
|
||||
for x in dataset:
|
||||
computed_value = self.evaluate(
|
||||
[values.select_replica(r, x) for r in range(len(devices))])
|
||||
[distribute_utils.select_replica(r, x)
|
||||
for r in range(len(devices))])
|
||||
actual_values.append(computed_value)
|
||||
for i, expected_value in enumerate(expected_values):
|
||||
self.assertEqual(len(expected_value), len(actual_values[i]))
|
||||
@ -699,24 +700,29 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
|
||||
# Assert that the tensors are rebatched and sparsity is preserved.
|
||||
per_replica_batch = defun(lambda x: next(iter(x)))(dataset)
|
||||
self.assertAllEqual(
|
||||
values.select_replica(0, per_replica_batch["dense"]),
|
||||
distribute_utils.select_replica(0, per_replica_batch["dense"]),
|
||||
[[0., 0., 0.], [1., 0., 0.], [2., 2., 0.], [3., 3., 3.]])
|
||||
self.assertAllEqual(
|
||||
values.select_replica(1, per_replica_batch["dense"]),
|
||||
distribute_utils.select_replica(1, per_replica_batch["dense"]),
|
||||
[[0., 0., 0.], [5., 0., 0.], [6., 6., 0.], [7., 7., 7.]])
|
||||
# Transitively check the ragged and sparse tensors by densification.
|
||||
for i in range(2):
|
||||
self.assertLen(
|
||||
values.select_replica(i, per_replica_batch["ragged"]).values, 6)
|
||||
distribute_utils.select_replica(i,
|
||||
per_replica_batch["ragged"]).values,
|
||||
6)
|
||||
self.assertAllEqual(
|
||||
values.select_replica(i, per_replica_batch["ragged"]).to_tensor(),
|
||||
values.select_replica(i, per_replica_batch["dense"]))
|
||||
distribute_utils.select_replica(
|
||||
i, per_replica_batch["ragged"]).to_tensor(),
|
||||
distribute_utils.select_replica(i, per_replica_batch["dense"]))
|
||||
self.assertLen(
|
||||
values.select_replica(i, per_replica_batch["sparse"]).indices, 6)
|
||||
distribute_utils.select_replica(i,
|
||||
per_replica_batch["sparse"]).indices,
|
||||
6)
|
||||
self.assertAllEqual(
|
||||
sparse_ops.sparse_tensor_to_dense(
|
||||
values.select_replica(i, per_replica_batch["sparse"])),
|
||||
values.select_replica(i, per_replica_batch["dense"]))
|
||||
distribute_utils.select_replica(i, per_replica_batch["sparse"])),
|
||||
distribute_utils.select_replica(i, per_replica_batch["dense"]))
|
||||
# Iterate through all the batches and sum them up.
|
||||
def sum_batch(per_replica_features):
|
||||
"""Sums the `PerReplica` values in the `per_replica_features` map."""
|
||||
|
@ -22,6 +22,7 @@ import threading
|
||||
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import distribute_utils
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.distribute import mirrored_strategy
|
||||
from tensorflow.python.distribute import values
|
||||
@ -123,16 +124,18 @@ class MirroredFunctionExtended(distribute_lib.StrategyExtendedV1):
|
||||
# use a collective op. This is a particular concern with eager
|
||||
# execution.
|
||||
with context.execution_mode(context.ASYNC):
|
||||
return_values.append(fn(*values.select_replica(index, args),
|
||||
**values.select_replica(index, kwargs)))
|
||||
return_values.append(
|
||||
fn(*distribute_utils.select_replica(index, args),
|
||||
**distribute_utils.select_replica(index, kwargs)))
|
||||
else:
|
||||
return_values.append(fn(*values.select_replica(index, args),
|
||||
**values.select_replica(index, kwargs)))
|
||||
return_values.append(
|
||||
fn(*distribute_utils.select_replica(index, args),
|
||||
**distribute_utils.select_replica(index, kwargs)))
|
||||
finally:
|
||||
_replica_index.graph_outside_run = None
|
||||
_replica_index.current = None
|
||||
|
||||
return values.regroup(return_values)
|
||||
return distribute_utils.regroup(return_values)
|
||||
|
||||
def _local_results(self, val):
|
||||
if isinstance(val, values.DistributedValues):
|
||||
|
@ -27,8 +27,8 @@ from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python.autograph.core import ag_ctx as autograph_ctx
|
||||
from tensorflow.python.autograph.impl import api as autograph
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import distribute_utils
|
||||
from tensorflow.python.distribute import shared_variable_creator
|
||||
from tensorflow.python.distribute import values
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -160,8 +160,8 @@ def _call_for_each_replica(distribution, fn, args, kwargs):
|
||||
shared_variable_store, index)
|
||||
t = _MirroredReplicaThread(
|
||||
distribution, coord, index, devices, variable_creator_fn, fn,
|
||||
values.select_replica(index, args),
|
||||
values.select_replica(index, kwargs))
|
||||
distribute_utils.select_replica(index, args),
|
||||
distribute_utils.select_replica(index, kwargs))
|
||||
threads.append(t)
|
||||
|
||||
for t in threads:
|
||||
@ -209,8 +209,10 @@ def _call_for_each_replica(distribution, fn, args, kwargs):
|
||||
raise RuntimeError("Some replicas made a different number of "
|
||||
"replica_context().merge_call() calls.")
|
||||
# get_replica_context().merge_call() case
|
||||
merge_args = values.regroup(tuple(t.merge_args for t in threads))
|
||||
merge_kwargs = values.regroup(tuple(t.merge_kwargs for t in threads))
|
||||
merge_args = distribute_utils.regroup(
|
||||
tuple(t.merge_args for t in threads))
|
||||
merge_kwargs = distribute_utils.regroup(
|
||||
tuple(t.merge_kwargs for t in threads))
|
||||
# We capture the name_scope of the MRT when we call merge_fn
|
||||
# to ensure that if we have opened a name scope in the MRT,
|
||||
# it will be respected when executing the merge function. We only
|
||||
@ -228,13 +230,13 @@ def _call_for_each_replica(distribution, fn, args, kwargs):
|
||||
merge_result = threads[0].merge_fn(distribution, *merge_args,
|
||||
**merge_kwargs)
|
||||
for r, t in enumerate(threads):
|
||||
t.merge_result = values.select_replica(r, merge_result)
|
||||
t.merge_result = distribute_utils.select_replica(r, merge_result)
|
||||
finally:
|
||||
for t in threads:
|
||||
t.should_run.set()
|
||||
coord.join(threads)
|
||||
|
||||
return values.regroup(tuple(t.main_result for t in threads))
|
||||
return distribute_utils.regroup(tuple(t.main_result for t in threads))
|
||||
|
||||
|
||||
class _MirroredReplicaThread(threading.Thread):
|
||||
|
@ -23,6 +23,7 @@ import copy
|
||||
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import distribute_utils
|
||||
from tensorflow.python.distribute import input_lib
|
||||
from tensorflow.python.distribute import mirrored_run
|
||||
from tensorflow.python.distribute import multi_worker_util
|
||||
@ -445,13 +446,13 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
value_list.append(v)
|
||||
return value_list
|
||||
|
||||
return values.create_mirrored_variable(self._container_strategy(),
|
||||
_real_mirrored_creator,
|
||||
values.MirroredVariable,
|
||||
values.SyncOnReadVariable, **kwargs)
|
||||
return distribute_utils.create_mirrored_variable(
|
||||
self._container_strategy(), _real_mirrored_creator,
|
||||
values.MirroredVariable, values.SyncOnReadVariable, **kwargs)
|
||||
|
||||
def _validate_colocate_with_variable(self, colocate_with_variable):
|
||||
values.validate_colocate_distributed_variable(colocate_with_variable, self)
|
||||
distribute_utils.validate_colocate_distributed_variable(
|
||||
colocate_with_variable, self)
|
||||
|
||||
def _make_dataset_iterator(self, dataset):
|
||||
return input_lib.DatasetIterator(
|
||||
@ -507,7 +508,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
per_replica_values.append(value_fn(
|
||||
distribute_lib.ValueContext(replica_id,
|
||||
self._num_replicas_in_sync)))
|
||||
return values.regroup(per_replica_values, always_wrap=True)
|
||||
return distribute_utils.regroup(per_replica_values, always_wrap=True)
|
||||
|
||||
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
|
||||
def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
|
||||
@ -557,7 +558,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
# For outputs that have already been reduced, wrap them in a Mirrored
|
||||
# container, else in a PerReplica container.
|
||||
if reduce_op is None:
|
||||
last_step_tensor_outputs_dict[name] = values.regroup(output)
|
||||
last_step_tensor_outputs_dict[name] = distribute_utils.regroup(output)
|
||||
else:
|
||||
assert len(output) == 1
|
||||
last_step_tensor_outputs_dict[name] = output[0]
|
||||
@ -580,8 +581,8 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
return self._get_cross_device_ops().broadcast(tensor, destinations)
|
||||
|
||||
def _call_for_each_replica(self, fn, args, kwargs):
|
||||
return mirrored_run.call_for_each_replica(self._container_strategy(), fn,
|
||||
args, kwargs)
|
||||
return mirrored_run.call_for_each_replica(
|
||||
self._container_strategy(), fn, args, kwargs)
|
||||
|
||||
def _configure(self,
|
||||
session_config=None,
|
||||
@ -643,10 +644,10 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
distribute_lib.UpdateContext(i), \
|
||||
ops.name_scope(name):
|
||||
# If args and kwargs are not mirrored, the value is returned as is.
|
||||
updates.append(fn(v,
|
||||
*values.select_replica_mirrored(i, args),
|
||||
**values.select_replica_mirrored(i, kwargs)))
|
||||
return values.update_regroup(self, updates, group)
|
||||
updates.append(
|
||||
fn(v, *distribute_utils.select_replica_mirrored(i, args),
|
||||
**distribute_utils.select_replica_mirrored(i, kwargs)))
|
||||
return distribute_utils.update_regroup(self, updates, group)
|
||||
|
||||
def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
|
||||
assert isinstance(colocate_with, tuple)
|
||||
@ -655,9 +656,10 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
for i, d in enumerate(colocate_with):
|
||||
name = "update_%d" % i
|
||||
with ops.device(d), distribute_lib.UpdateContext(i), ops.name_scope(name):
|
||||
updates.append(fn(*values.select_replica_mirrored(i, args),
|
||||
**values.select_replica_mirrored(i, kwargs)))
|
||||
return values.update_regroup(self, updates, group)
|
||||
updates.append(
|
||||
fn(*distribute_utils.select_replica_mirrored(i, args),
|
||||
**distribute_utils.select_replica_mirrored(i, kwargs)))
|
||||
return distribute_utils.update_regroup(self, updates, group)
|
||||
|
||||
def read_var(self, replica_local_var):
|
||||
"""Read the aggregate value of a replica-local variable."""
|
||||
@ -672,7 +674,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
return (val,)
|
||||
|
||||
def value_container(self, val):
|
||||
return values.value_container(val)
|
||||
return distribute_utils.value_container(val)
|
||||
|
||||
@property
|
||||
def _num_replicas_in_sync(self):
|
||||
|
@ -31,6 +31,7 @@ from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import distribute_utils
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.distribute import mirrored_strategy
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
@ -1026,8 +1027,9 @@ class MirroredStrategyDefunTest(test.TestCase):
|
||||
result = distribution.extended.call_for_each_replica(
|
||||
model_fn, args=[mock_model] + inputs)
|
||||
for r in range(len(devices)):
|
||||
device_result = values.select_replica(r, result)
|
||||
device_expected_result = values.select_replica(r, expected_result)
|
||||
device_result = distribute_utils.select_replica(r, result)
|
||||
device_expected_result = distribute_utils.select_replica(
|
||||
r, expected_result)
|
||||
self.assertAllClose(device_expected_result,
|
||||
self.evaluate(device_result))
|
||||
|
||||
|
@ -20,9 +20,9 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import distribute_utils
|
||||
from tensorflow.python.distribute import input_lib
|
||||
from tensorflow.python.distribute import numpy_dataset
|
||||
from tensorflow.python.distribute import values
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -233,7 +233,7 @@ class OneDeviceStrategy(distribute_lib.Strategy):
|
||||
return super(OneDeviceStrategy, self).scope()
|
||||
|
||||
|
||||
@tf_export(v1=["distribute.OneDeviceStrategy"]) # pylint: disable=missing-docstring
|
||||
@tf_export(v1=["distribute.OneDeviceStrategy"]) # pylint: disable=empty-docstring
|
||||
class OneDeviceStrategyV1(distribute_lib.StrategyV1):
|
||||
|
||||
__doc__ = OneDeviceStrategy.__doc__.replace(
|
||||
@ -272,7 +272,7 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
|
||||
return next_creator(**kwargs)
|
||||
|
||||
def _validate_colocate_with_variable(self, colocate_with_variable):
|
||||
values.validate_colocate(colocate_with_variable, self)
|
||||
distribute_utils.validate_colocate(colocate_with_variable, self)
|
||||
|
||||
def _make_dataset_iterator(self, dataset):
|
||||
"""Make iterator from dataset without splitting the batch."""
|
||||
|
@ -24,6 +24,7 @@ import copy
|
||||
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import distribute_utils
|
||||
from tensorflow.python.distribute import input_lib
|
||||
from tensorflow.python.distribute import mirrored_run
|
||||
from tensorflow.python.distribute import multi_worker_util
|
||||
@ -334,7 +335,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
|
||||
compute_devices, self._variable_device)
|
||||
|
||||
def _validate_colocate_with_variable(self, colocate_with_variable):
|
||||
values.validate_colocate(colocate_with_variable, self)
|
||||
distribute_utils.validate_colocate(colocate_with_variable, self)
|
||||
|
||||
def _experimental_distribute_dataset(self, dataset):
|
||||
return input_lib.get_distributed_dataset(
|
||||
|
@ -27,6 +27,7 @@ from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import central_storage_strategy
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import distribute_utils
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.distribute import multi_worker_util
|
||||
@ -34,7 +35,6 @@ from tensorflow.python.distribute import parameter_server_strategy
|
||||
from tensorflow.python.distribute import ps_values
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import strategy_test_lib
|
||||
from tensorflow.python.distribute import values
|
||||
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
@ -534,8 +534,8 @@ class ParameterServerStrategyTestBase(
|
||||
|
||||
for expected_value in expected_values:
|
||||
next_element = iterator.get_next()
|
||||
computed_value = sess.run([values.select_replica(r, next_element)
|
||||
for r in range(len(devices))])
|
||||
computed_value = sess.run([distribute_utils.select_replica(
|
||||
r, next_element) for r in range(len(devices))])
|
||||
if ignore_order:
|
||||
self.assertCountEqual(expected_value, computed_value)
|
||||
else:
|
||||
@ -543,7 +543,7 @@ class ParameterServerStrategyTestBase(
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
next_element = iterator.get_next()
|
||||
sess.run([values.select_replica(r, next_element)
|
||||
sess.run([distribute_utils.select_replica(r, next_element)
|
||||
for r in range(len(devices))])
|
||||
|
||||
# After re-initializing the iterator, should be able to iterate again.
|
||||
@ -552,8 +552,8 @@ class ParameterServerStrategyTestBase(
|
||||
|
||||
for expected_value in expected_values:
|
||||
next_element = iterator.get_next()
|
||||
computed_value = sess.run([values.select_replica(r, next_element)
|
||||
for r in range(len(devices))])
|
||||
computed_value = sess.run([distribute_utils.select_replica(
|
||||
r, next_element) for r in range(len(devices))])
|
||||
if ignore_order:
|
||||
self.assertCountEqual(expected_value, computed_value)
|
||||
else:
|
||||
|
@ -27,9 +27,9 @@ from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.util import event_pb2
|
||||
from tensorflow.python.client import session as session_lib
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import distribute_utils
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import values
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
@ -348,7 +348,8 @@ class DistributionTestBase(test.TestCase):
|
||||
for expected_value in expected_values:
|
||||
next_element = iterator.get_next()
|
||||
computed_value = evaluate(
|
||||
[values.select_replica(r, next_element) for r in range(len(devices))])
|
||||
[distribute_utils.select_replica(r, next_element) for r in
|
||||
range(len(devices))])
|
||||
if ignore_order:
|
||||
self.assertCountEqual(expected_value, computed_value)
|
||||
else:
|
||||
@ -357,7 +358,8 @@ class DistributionTestBase(test.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
next_element = iterator.get_next()
|
||||
evaluate(
|
||||
[values.select_replica(r, next_element) for r in range(len(devices))])
|
||||
[distribute_utils.select_replica(r, next_element) for r in
|
||||
range(len(devices))])
|
||||
|
||||
# After re-initializing the iterator, should be able to iterate again.
|
||||
if test_reinitialize:
|
||||
@ -366,7 +368,8 @@ class DistributionTestBase(test.TestCase):
|
||||
for expected_value in expected_values:
|
||||
next_element = iterator.get_next()
|
||||
computed_value = evaluate([
|
||||
values.select_replica(r, next_element) for r in range(len(devices))
|
||||
distribute_utils.select_replica(r, next_element) for r in
|
||||
range(len(devices))
|
||||
])
|
||||
if ignore_order:
|
||||
self.assertCountEqual(expected_value, computed_value)
|
||||
|
@ -32,6 +32,7 @@ from tensorflow.python.autograph.impl import api as autograph
|
||||
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import distribute_utils
|
||||
from tensorflow.python.distribute import input_lib
|
||||
from tensorflow.python.distribute import numpy_dataset
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
@ -363,7 +364,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
return self._input_workers_obj
|
||||
|
||||
def _validate_colocate_with_variable(self, colocate_with_variable):
|
||||
values.validate_colocate(colocate_with_variable, self)
|
||||
distribute_utils. validate_colocate(colocate_with_variable, self)
|
||||
|
||||
def _make_dataset_iterator(self, dataset):
|
||||
"""Make iterators for each of the TPU hosts."""
|
||||
@ -423,7 +424,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
per_replica_values.append(
|
||||
value_fn(distribute_lib.ValueContext(replica_id,
|
||||
self._num_replicas_in_sync)))
|
||||
return values.regroup(per_replica_values, always_wrap=True)
|
||||
return distribute_utils.regroup(per_replica_values, always_wrap=True)
|
||||
|
||||
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
|
||||
# TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have
|
||||
@ -462,7 +463,8 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
per_replica_inputs = multi_worker_iterator.get_next()
|
||||
replicate_inputs = []
|
||||
for replica_id in range(self._num_replicas_in_sync):
|
||||
select_replica = lambda x: values.select_replica(replica_id, x) # pylint: disable=cell-var-from-loop
|
||||
select_replica = lambda x: distribute_utils.select_replica( # pylint: disable=g-long-lambda
|
||||
replica_id, x) # pylint: disable=cell-var-from-loop
|
||||
replicate_inputs.append((nest.map_structure(
|
||||
select_replica, per_replica_inputs),))
|
||||
|
||||
@ -648,11 +650,10 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
value_list.append(v)
|
||||
return value_list
|
||||
|
||||
return values.create_mirrored_variable(self._container_strategy(),
|
||||
_real_mirrored_creator,
|
||||
tpu_values.TPUMirroredVariable,
|
||||
tpu_values.TPUSyncOnReadVariable,
|
||||
**kwargs)
|
||||
return distribute_utils.create_mirrored_variable(
|
||||
self._container_strategy(), _real_mirrored_creator,
|
||||
tpu_values.TPUMirroredVariable, tpu_values.TPUSyncOnReadVariable,
|
||||
**kwargs)
|
||||
|
||||
def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
|
||||
if (isinstance(value, values.DistributedValues) or
|
||||
@ -722,10 +723,10 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
distribute_lib.UpdateContext(i), \
|
||||
ops.name_scope(name):
|
||||
# If args and kwargs are not mirrored, the value is returned as is.
|
||||
updates.append(fn(v,
|
||||
*values.select_replica_mirrored(i, args),
|
||||
**values.select_replica_mirrored(i, kwargs)))
|
||||
return values.update_regroup(self, updates, group)
|
||||
updates.append(
|
||||
fn(v, *distribute_utils.select_replica_mirrored(i, args),
|
||||
**distribute_utils.select_replica_mirrored(i, kwargs)))
|
||||
return distribute_utils.update_regroup(self, updates, group)
|
||||
|
||||
def read_var(self, var):
|
||||
assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance(
|
||||
@ -893,8 +894,8 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
for i in range(strategy.num_replicas_in_sync):
|
||||
replicate_inputs.append(
|
||||
[constant_op.constant(i, dtype=dtypes.int32),
|
||||
values.select_replica(i, args),
|
||||
values.select_replica(i, kwargs)])
|
||||
distribute_utils.select_replica(i, args),
|
||||
distribute_utils.select_replica(i, kwargs)])
|
||||
|
||||
# Construct and pass `maximum_shapes` so that we could support dynamic
|
||||
# shapes using dynamic padder.
|
||||
@ -941,7 +942,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
nest.pack_sequence_as(result[0], nest.flatten(replica_output))
|
||||
for replica_output in replicate_outputs
|
||||
]
|
||||
return values.regroup(replicate_outputs)
|
||||
return distribute_utils.regroup(replicate_outputs)
|
||||
|
||||
if context.executing_eagerly():
|
||||
tpu_function = def_function.function(tpu_function)
|
||||
|
@ -26,10 +26,8 @@ from tensorflow.python.distribute import packed_distributed_variable as packed
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import values_util
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import tape
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -40,7 +38,6 @@ from tensorflow.python.training.saving import saveable_object
|
||||
from tensorflow.python.training.saving import saveable_object_util
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.types import core
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@ -1023,88 +1020,6 @@ class SyncOnReadVariable(DistributedVariable):
|
||||
self._get(), dtype=dtype, name=name, as_ref=as_ref)
|
||||
|
||||
|
||||
# Variable creation function for sync strategies.
|
||||
def create_mirrored_variable( # pylint: disable=missing-docstring
|
||||
strategy, real_mirrored_creator, mirrored_cls, sync_on_read_cls, **kwargs):
|
||||
# Figure out what collections this variable should be added to.
|
||||
# We'll add the MirroredVariable to those collections instead.
|
||||
var_collections = kwargs.pop("collections", None)
|
||||
if var_collections is None:
|
||||
var_collections = [ops.GraphKeys.GLOBAL_VARIABLES]
|
||||
kwargs["collections"] = []
|
||||
|
||||
synchronization = kwargs.get("synchronization",
|
||||
vs.VariableSynchronization.ON_WRITE)
|
||||
|
||||
if synchronization == vs.VariableSynchronization.NONE:
|
||||
raise ValueError(
|
||||
"`NONE` variable synchronization mode is not supported with `Mirrored` "
|
||||
"distribution strategy. Please change the `synchronization` for "
|
||||
"variable: " + str(kwargs["name"]))
|
||||
elif synchronization == vs.VariableSynchronization.ON_READ:
|
||||
is_sync_on_read = True
|
||||
elif synchronization in (vs.VariableSynchronization.ON_WRITE,
|
||||
vs.VariableSynchronization.AUTO):
|
||||
# `AUTO` synchronization defaults to `ON_WRITE`.
|
||||
is_sync_on_read = False
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid variable synchronization mode: %s for variable: %s" %
|
||||
(synchronization, kwargs["name"]))
|
||||
|
||||
aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)
|
||||
|
||||
if aggregation not in (vs.VariableAggregation.NONE,
|
||||
vs.VariableAggregation.SUM,
|
||||
vs.VariableAggregation.MEAN,
|
||||
vs.VariableAggregation.ONLY_FIRST_REPLICA):
|
||||
raise ValueError("Invalid variable aggregation mode: %s for variable: %s" %
|
||||
(aggregation, kwargs["name"]))
|
||||
|
||||
# Ignore user-specified caching device, not needed for mirrored variables.
|
||||
kwargs.pop("caching_device", None)
|
||||
|
||||
# TODO(josh11b,apassos): It would be better if variable initialization
|
||||
# was never recorded on the tape instead of having to do this manually
|
||||
# here.
|
||||
with tape.stop_recording():
|
||||
value_list = real_mirrored_creator(**kwargs)
|
||||
var_cls = sync_on_read_cls if is_sync_on_read else mirrored_cls
|
||||
result = var_cls(strategy, value_list, aggregation)
|
||||
# Install the created DistributedVariable as _distributed_container property
|
||||
# of the underlying variables, to make it easy to map back to the container.
|
||||
for v in result.values:
|
||||
# Hold a strong reference to avoid the container from being GC-ed. After
|
||||
# v = v.assign(), the user code may no longer holds references to the
|
||||
# original container, since v.assign() returns a new DistributedVariable.
|
||||
v._distributed_container = result # pylint: disable=protected-access
|
||||
|
||||
# Add the wrapped variable to the requested collections.
|
||||
# The handling of eager mode and the global step matches
|
||||
# ResourceVariable._init_from_args().
|
||||
if not context.executing_eagerly():
|
||||
g = ops.get_default_graph()
|
||||
# If "trainable" is True, next_creator() will add the member variables
|
||||
# to the TRAINABLE_VARIABLES collection, so we manually remove
|
||||
# them and replace with the MirroredVariable. We can't set
|
||||
# "trainable" to False for next_creator() since that causes functions
|
||||
# like implicit_gradients to skip those variables.
|
||||
if kwargs.get("trainable", True):
|
||||
var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
|
||||
l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
|
||||
for value in value_list:
|
||||
for i, trainable_variable in enumerate(l):
|
||||
if value is trainable_variable:
|
||||
del l[i]
|
||||
break
|
||||
|
||||
g.add_to_collections(var_collections, result)
|
||||
elif ops.GraphKeys.GLOBAL_STEP in var_collections:
|
||||
ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# Register a conversion functions which reads the value of the variable,
|
||||
# allowing instances of the class to be used as tensors.
|
||||
# MirroredVariables
|
||||
@ -1133,214 +1048,3 @@ def _tensor_conversion_sync_on_read(var, dtype=None, name=None, as_ref=False):
|
||||
|
||||
ops.register_tensor_conversion_function(SyncOnReadVariable,
|
||||
_tensor_conversion_sync_on_read)
|
||||
|
||||
|
||||
def regroup(values, wrap_class=PerReplica, always_wrap=False):
|
||||
"""Makes a nest per-replica into a nest of PerReplica/Mirrored values.
|
||||
|
||||
Args:
|
||||
values: Values to regroup
|
||||
wrap_class: Class that `values` be wrapped in.
|
||||
always_wrap: Always wrap the `values` in `wrap_class` even if the values
|
||||
are the same except for DistributeVariable.
|
||||
Returns:
|
||||
Wrapped `values`.
|
||||
"""
|
||||
v0 = values[0]
|
||||
|
||||
if isinstance(v0, list):
|
||||
for v in values[1:]:
|
||||
assert isinstance(v, list)
|
||||
assert len(v) == len(v0), ("len(v) == %d, len(v0) == %d, v: %s, v0: %s" %
|
||||
(len(v), len(v0), v, v0))
|
||||
return [
|
||||
regroup(tuple(v[i] for v in values), wrap_class)
|
||||
for i in range(len(v0))
|
||||
]
|
||||
|
||||
if isinstance(v0, tuple):
|
||||
for v in values[1:]:
|
||||
assert isinstance(v, tuple)
|
||||
assert len(v) == len(v0)
|
||||
regrouped_tuple = tuple(
|
||||
regroup(tuple(v[i] for v in values), wrap_class)
|
||||
for i in range(len(v0)))
|
||||
if hasattr(v0, "_fields"):
|
||||
# This tuple is in fact a namedtuple! Create a new namedtuple instance
|
||||
# and initialize it with the regrouped values:
|
||||
assert hasattr(type(v0), "_make")
|
||||
return type(v0)._make(regrouped_tuple)
|
||||
else:
|
||||
return regrouped_tuple
|
||||
|
||||
if isinstance(v0, dict):
|
||||
v0keys = v0.keys()
|
||||
for v in values[1:]:
|
||||
assert isinstance(v, dict), ("v[0]: %r v[i]: %r" % (v0, v))
|
||||
assert set(v.keys()) == set(v0keys), ("v[0].keys: %s v[i].keys: %s" %
|
||||
(set(v0keys), set(v.keys())))
|
||||
# Use the actual type in case it is a class inherited from a dict.
|
||||
return type(v0)({
|
||||
key: regroup(tuple(v[key] for v in values), wrap_class)
|
||||
for key in v0keys
|
||||
})
|
||||
|
||||
# If exactly the same object across all devices, return it unwrapped.
|
||||
same_id = True
|
||||
for v in values[1:]:
|
||||
if v is not v0:
|
||||
same_id = False
|
||||
break
|
||||
# Consider three cases where same_id is true:
|
||||
# * If v0 is a DistributedVariable (a MirroredVariable or
|
||||
# SyncOnReadVariable, and same_id means it is the same across all
|
||||
# devices), we want to return it. We check DistributedVariable
|
||||
# specifically since it can look like it has a
|
||||
# _distributed_container member since its members do.
|
||||
if same_id and isinstance(v0, DistributedVariable):
|
||||
return v0
|
||||
# * If v0 is a member of a distributed variable, in which case
|
||||
# hasattr(v0, "_distributed_container") is true, we want to
|
||||
# return the DistributedVariable that contains it using the
|
||||
# _distributed_container logic below. This case can trigger
|
||||
# same_id when there is only one device.
|
||||
# * In any other situation, same_id means we return v0 unless `always_wrap` is
|
||||
# true.
|
||||
if same_id and not always_wrap and not hasattr(v0, "_distributed_container"):
|
||||
return v0
|
||||
|
||||
# Detect the case where each device has a parallel component of the
|
||||
# same MirroredVariable (or SyncOnReadVariable). In this case we
|
||||
# want to return the containing MirroredVariable, after a bunch of
|
||||
# sanity checking. In particular, each component should have the
|
||||
# same container, and the devices of the variables should match the
|
||||
# keys of the per-replica dictionary.
|
||||
if hasattr(v0, "_distributed_container"):
|
||||
# pylint: disable=protected-access
|
||||
assert not isinstance(v0, MirroredVariable), (
|
||||
"ids = %s, values = %s" % ([id(v) for v in values], values))
|
||||
distributed_container = v0._distributed_container
|
||||
assert distributed_container is not None
|
||||
for v in values[1:]:
|
||||
assert distributed_container is v._distributed_container
|
||||
return distributed_container
|
||||
# pylint: enable=protected-access
|
||||
|
||||
return wrap_class(values)
|
||||
|
||||
|
||||
def select_replica(replica_id, structured):
|
||||
"""Specialize a nest of regular & per-replica values for one replica."""
|
||||
|
||||
def _get(x):
|
||||
# `DistributedValues` would be sliced according to replica unless it is a
|
||||
# `DistributedVariable` because `DistributedVariable` can be handled
|
||||
# directly in the replica context.
|
||||
if (isinstance(x, DistributedVariable) or
|
||||
not isinstance(x, DistributedValues)):
|
||||
return x
|
||||
else:
|
||||
return x.values[replica_id]
|
||||
|
||||
return nest.map_structure(_get, structured)
|
||||
|
||||
|
||||
def select_replica_mirrored(replica_id, structured):
|
||||
"""Specialize a nest of regular & mirrored values for one replica."""
|
||||
|
||||
def _get_mirrored(x):
|
||||
if isinstance(x, DistributedValues):
|
||||
if not isinstance(x, Mirrored):
|
||||
raise TypeError(
|
||||
"Expected value to be mirrored across replicas: %s in %s." %
|
||||
(x, structured))
|
||||
return x.values[replica_id]
|
||||
else:
|
||||
return x
|
||||
|
||||
return nest.map_structure(_get_mirrored, structured)
|
||||
|
||||
|
||||
def update_regroup(extended, updates, group):
|
||||
"""Regroup for an update, with dependencies to ensure all updates execute."""
|
||||
if not group:
|
||||
regrouped = regroup(updates, Mirrored)
|
||||
return nest.map_structure(extended._local_results, regrouped) # pylint: disable=protected-access
|
||||
|
||||
def _make_grouped_mirrored(values):
|
||||
"""Convert per-replica list `values` into Mirrored type with grouping."""
|
||||
if len(values) == 1:
|
||||
return Mirrored(values)
|
||||
|
||||
# Make sure we run all updates. Without this, something like
|
||||
# session.run(extended.update(...)) may only update one replica.
|
||||
g = control_flow_ops.group(values)
|
||||
|
||||
# If values is just ops, the grouping is enough. Everything in values
|
||||
# should have the same type, since we expect every replica to be performing
|
||||
# the same computation.
|
||||
if not all(tensor_util.is_tensor(v) for v in values):
|
||||
return g
|
||||
|
||||
# Otherwise we need tensors with the same values as `values`, but
|
||||
# that have a dependency on `g`.
|
||||
with_dep = []
|
||||
for v in values:
|
||||
with ops.device(v.device), ops.control_dependencies([g]):
|
||||
with_dep.append(array_ops.identity(v))
|
||||
|
||||
return Mirrored(with_dep)
|
||||
|
||||
return regroup(updates, _make_grouped_mirrored)
|
||||
|
||||
|
||||
def value_container(val):
|
||||
"""Returns the container that this per-replica `value` belongs to.
|
||||
|
||||
Args:
|
||||
val: A value returned by `call_for_each_replica()` or a variable created in
|
||||
`scope()`.
|
||||
|
||||
Returns:
|
||||
A container that `value` belongs to.
|
||||
If value does not belong to any container (including the case of
|
||||
container having been destroyed), returns the value itself.
|
||||
"""
|
||||
if (hasattr(val, "_distributed_container") and
|
||||
# DistributedVariable has _distributed_container defined
|
||||
# but we don't want to return it.
|
||||
not isinstance(val, DistributedVariable)):
|
||||
container = val._distributed_container # pylint: disable=protected-access
|
||||
if container is not None:
|
||||
return container
|
||||
return val
|
||||
|
||||
|
||||
def is_distributed_variable(v):
|
||||
"""Determine if a variable is ds variable or TPU mirrored variable."""
|
||||
return isinstance(v, DistributedVariable)
|
||||
|
||||
|
||||
def _validate_colocate_extended(v, extended):
|
||||
variable_strategy = v._distribute_strategy # pylint: disable=protected-access
|
||||
if variable_strategy.extended is not extended:
|
||||
raise ValueError(
|
||||
"`colocate_vars_with` must only be passed a variable created in this "
|
||||
"tf.distribute.Strategy.scope(), not %s created in scope: %s" %
|
||||
(v, variable_strategy))
|
||||
|
||||
|
||||
def validate_colocate_distributed_variable(v, extended):
|
||||
if not isinstance(v, DistributedVariable):
|
||||
raise ValueError(
|
||||
"`colocate_vars_with` must only be passed a variable created in this "
|
||||
"tf.distribute.Strategy.scope(), not: %r" % (v,))
|
||||
_validate_colocate_extended(v, extended)
|
||||
|
||||
|
||||
def validate_colocate(v, extended):
|
||||
if not hasattr(v, "_distribute_strategy"):
|
||||
raise ValueError(
|
||||
"`colocate_vars_with` must only be passed a variable created in this "
|
||||
"tf.distribute.Strategy.scope(), not: %r" % (v,))
|
||||
_validate_colocate_extended(v, extended)
|
||||
|
@ -30,6 +30,7 @@ from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import distribute_utils
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.distribute import packed_distributed_variable as packed
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
@ -157,7 +158,7 @@ class DistributedValuesTest(test.TestCase, parameterized.TestCase):
|
||||
distribution.experimental_distribute_values_from_function(value_fn))
|
||||
for i in range(distribution.num_replicas_in_sync):
|
||||
self.assertAllEqual(
|
||||
values.select_replica(i, distributed_values),
|
||||
distribute_utils.select_replica(i, distributed_values),
|
||||
(1. * i, 2. * i, 3. * i))
|
||||
|
||||
@combinations.generate(
|
||||
@ -391,7 +392,7 @@ class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(exp, result.values[i])
|
||||
|
||||
def testNested(self):
|
||||
result = values.regroup((_nested_value("1"), _nested_value("2")))
|
||||
result = distribute_utils.regroup((_nested_value("1"), _nested_value("2")))
|
||||
self.assertIsInstance(result, tuple)
|
||||
self.assertLen(result, 3)
|
||||
self._is_per_replica(result[0], ["a1", "a2"])
|
||||
@ -409,20 +410,20 @@ class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
# Also test that we can undo the merge using select_replica()
|
||||
self.assertEqual(_nested_value("1"),
|
||||
values.select_replica(0, result))
|
||||
distribute_utils.select_replica(0, result))
|
||||
self.assertEqual(_nested_value("2"),
|
||||
values.select_replica(1, result))
|
||||
distribute_utils.select_replica(1, result))
|
||||
# select_device_mirrored() should fail due to non-mirrored values
|
||||
with self.assertRaises(TypeError):
|
||||
values.select_replica_mirrored(0, result)
|
||||
distribute_utils.select_replica_mirrored(0, result)
|
||||
with self.assertRaises(TypeError):
|
||||
values.select_replica_mirrored(1, result)
|
||||
distribute_utils.select_replica_mirrored(1, result)
|
||||
|
||||
def testRegroupKeepsDictBasedClass(self):
|
||||
class DictBasedClass(dict):
|
||||
"""Dummy class inherited from a dict."""
|
||||
|
||||
result = values.regroup(
|
||||
result = distribute_utils.regroup(
|
||||
(DictBasedClass(a="a1", b="b1"), DictBasedClass(a="a2", b="b2")))
|
||||
self.assertIsInstance(result, DictBasedClass)
|
||||
self._is_per_replica(result["a"], ["a1", "a2"])
|
||||
@ -431,8 +432,8 @@ class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase):
|
||||
def testWrapClass(self):
|
||||
# Normally a mirrored value would be the same across devices, but
|
||||
# for a test it is convenient to be able to tell the values apart.
|
||||
result = values.regroup((_nested_value("1"), _nested_value("2")),
|
||||
values.Mirrored)
|
||||
result = distribute_utils.regroup((_nested_value("1"), _nested_value("2")),
|
||||
values.Mirrored)
|
||||
self.assertIsInstance(result, tuple)
|
||||
self.assertLen(result, 3)
|
||||
self._is_per_replica(result[0], ["a1", "a2"], values.Mirrored)
|
||||
@ -450,17 +451,17 @@ class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
# Also test that we can undo the merge using select_replica()
|
||||
self.assertEqual(_nested_value("1"),
|
||||
values.select_replica(0, result))
|
||||
distribute_utils.select_replica(0, result))
|
||||
self.assertEqual(_nested_value("2"),
|
||||
values.select_replica(1, result))
|
||||
distribute_utils.select_replica(1, result))
|
||||
# Values are marked as mirrored, so select_device_mirrored() is allowed.
|
||||
self.assertEqual(_nested_value("1"),
|
||||
values.select_replica_mirrored(0, result))
|
||||
distribute_utils.select_replica_mirrored(0, result))
|
||||
self.assertEqual(_nested_value("2"),
|
||||
values.select_replica_mirrored(1, result))
|
||||
distribute_utils.select_replica_mirrored(1, result))
|
||||
|
||||
def testWrapAListOfTwoTuples(self):
|
||||
result = values.regroup([("1", "2"), ("3", "4")])
|
||||
result = distribute_utils.regroup([("1", "2"), ("3", "4")])
|
||||
self.assertIsInstance(result, tuple)
|
||||
self.assertLen(result, 2)
|
||||
self._is_per_replica(result[0], ("1", "3"), values.PerReplica)
|
||||
@ -478,34 +479,36 @@ class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase):
|
||||
with distribution.scope():
|
||||
v = variable_scope.variable(
|
||||
1., aggregation=variable_scope.VariableAggregation.SUM)
|
||||
self.assertTrue(values.is_distributed_variable(v))
|
||||
self.assertTrue(values.is_distributed_variable(values.regroup(v.values)))
|
||||
self.assertTrue(distribute_utils.is_distributed_variable(v))
|
||||
self.assertTrue(distribute_utils.is_distributed_variable(
|
||||
distribute_utils.regroup(v.values)))
|
||||
|
||||
def testSameId(self):
|
||||
foo = object()
|
||||
result = values.regroup((("a", foo), ("b", foo)))
|
||||
result = distribute_utils.regroup((("a", foo), ("b", foo)))
|
||||
self.assertIsInstance(result, tuple)
|
||||
self.assertLen(result, 2)
|
||||
self._is_per_replica(result[0], ["a", "b"])
|
||||
self.assertIs(foo, result[1])
|
||||
|
||||
# Test select_replica(), should undo the merge done by regroup().
|
||||
result_0 = values.select_replica(0, result)
|
||||
result_0 = distribute_utils.select_replica(0, result)
|
||||
self.assertIsInstance(result_0, tuple)
|
||||
self.assertLen(result_0, 2)
|
||||
self.assertEqual("a", result_0[0])
|
||||
self.assertIs(foo, result_0[1])
|
||||
result_1 = values.select_replica(1, result)
|
||||
result_1 = distribute_utils.select_replica(1, result)
|
||||
self.assertIsInstance(result_1, tuple)
|
||||
self.assertLen(result_1, 2)
|
||||
self.assertEqual("b", result_1[0])
|
||||
self.assertIs(foo, result_1[1])
|
||||
|
||||
def testOneDevice(self):
|
||||
result = values.regroup((_nested_value("1"),))
|
||||
result = distribute_utils.regroup((_nested_value("1"),))
|
||||
# On one device regroup() and select_replica() are basically identity.
|
||||
self.assertEqual(_nested_value("1"), result)
|
||||
self.assertEqual(_nested_value("1"), values.select_replica(0, result))
|
||||
self.assertEqual(_nested_value("1"),
|
||||
distribute_utils.select_replica(0, result))
|
||||
|
||||
def testNamedTuple(self):
|
||||
|
||||
@ -533,7 +536,7 @@ class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase):
|
||||
train_op=array_ops.identity(constant_op.constant(device_id)))
|
||||
created_estimator_specs.append(spec)
|
||||
|
||||
merged_estimator_spec = values.regroup(created_estimator_specs)
|
||||
merged_estimator_spec = distribute_utils.regroup(created_estimator_specs)
|
||||
|
||||
self.assertIsInstance(merged_estimator_spec, EstimatorSpec)
|
||||
self.assertEqual(mode_keys.EstimatorModeKeys.TRAIN,
|
||||
@ -550,8 +553,8 @@ class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase):
|
||||
Scaffold)
|
||||
# Also test that we can undo the merge using select_replica()
|
||||
self.assertEqual(created_estimator_specs[device_id],
|
||||
values.select_replica(device_id,
|
||||
merged_estimator_spec))
|
||||
distribute_utils.select_replica(
|
||||
device_id, merged_estimator_spec))
|
||||
|
||||
|
||||
@combinations.generate(
|
||||
@ -630,7 +633,7 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase):
|
||||
with distribution.scope():
|
||||
v = variables_lib.Variable(
|
||||
1., synchronization=synchronization, aggregation=aggregation)
|
||||
self.assertIs(v, values.select_replica(0, v))
|
||||
self.assertIs(v, distribute_utils.select_replica(0, v))
|
||||
|
||||
def testIsTensorLike(self, distribution, synchronization, aggregation):
|
||||
if isinstance(distribution.extended,
|
||||
@ -973,8 +976,8 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
|
||||
return v.assign(ctx.replica_id_in_sync_group)
|
||||
|
||||
# disallow assign() with distributed value in replica context.
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"Cannot update non-float variables"):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Cannot update non-float variables"):
|
||||
self.evaluate(
|
||||
distribution.experimental_local_results(
|
||||
distribution.run(assign)))
|
||||
@ -2174,11 +2177,11 @@ class PerReplicaTest(test.TestCase, parameterized.TestCase):
|
||||
condition, lambda: per_replica_1, lambda: per_replica_2)
|
||||
|
||||
|
||||
def _make_index_slices(values, indices, dense_shape=None):
|
||||
def _make_index_slices(vals, indices, dense_shape=None):
|
||||
if dense_shape:
|
||||
dense_shape = array_ops.identity(dense_shape)
|
||||
return indexed_slices.IndexedSlices(
|
||||
array_ops.identity(values), array_ops.identity(indices), dense_shape)
|
||||
array_ops.identity(vals), array_ops.identity(indices), dense_shape)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -363,6 +363,7 @@ py_library(
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/distribute:distribute_lib",
|
||||
"//tensorflow/python/distribute:distribute_utils",
|
||||
"//tensorflow/python/distribute:values",
|
||||
"//tensorflow/python/training/tracking",
|
||||
"//tensorflow/python/training/tracking:base",
|
||||
|
@ -22,8 +22,8 @@ import functools
|
||||
import os
|
||||
|
||||
from tensorflow.core.protobuf import graph_debug_info_pb2
|
||||
from tensorflow.python.distribute import distribute_utils
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.distribute import values as ds_values
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -87,10 +87,11 @@ class _WrapperFunction(function.ConcreteFunction):
|
||||
def _call_flat(self, args, captured_inputs, cancellation_manager=None):
|
||||
|
||||
def get_in_replica_handle(x):
|
||||
return x.handle if ds_values.is_distributed_variable(x) else x
|
||||
return x.handle if distribute_utils.is_distributed_variable(x) else x
|
||||
|
||||
def get_cross_replica_handle(x):
|
||||
return _unused_handle() if ds_values.is_distributed_variable(x) else x
|
||||
return _unused_handle() if distribute_utils.is_distributed_variable(x) \
|
||||
else x
|
||||
|
||||
if ds_context.get_replica_context() is not None: # in-replica context
|
||||
captured_inputs = list(map(get_in_replica_handle, captured_inputs))
|
||||
@ -201,7 +202,7 @@ class Loader(object):
|
||||
if bound_inputs:
|
||||
for bound_input, internal_capture in zip(
|
||||
bound_inputs, concrete_function.inputs[-len(bound_inputs):]):
|
||||
if ds_values.is_distributed_variable(bound_input):
|
||||
if distribute_utils.is_distributed_variable(bound_input):
|
||||
concrete_function.graph.capture_distributed_variable(
|
||||
bound_input, internal_capture)
|
||||
else:
|
||||
@ -227,7 +228,7 @@ class Loader(object):
|
||||
"""Resolves a node id into a tensor to be captured for a function."""
|
||||
with ops.init_scope():
|
||||
obj = self._nodes[node_id]
|
||||
if ds_values.is_distributed_variable(obj):
|
||||
if distribute_utils.is_distributed_variable(obj):
|
||||
return obj
|
||||
elif resource_variable_ops.is_resource_variable(obj):
|
||||
return obj.handle
|
||||
|
@ -26,7 +26,7 @@ from tensorflow.core.framework import versions_pb2
|
||||
from tensorflow.core.protobuf import meta_graph_pb2
|
||||
from tensorflow.core.protobuf import saved_model_pb2
|
||||
from tensorflow.core.protobuf import saved_object_graph_pb2
|
||||
from tensorflow.python.distribute import values as ds_values
|
||||
from tensorflow.python.distribute import distribute_utils as ds_utils
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import function as defun
|
||||
@ -273,13 +273,13 @@ class _SaveableView(object):
|
||||
# pylint: enable=protected-access
|
||||
resource_map[obj.resource_handle] = new_resource
|
||||
self.captured_tensor_node_ids[obj.resource_handle] = node_id
|
||||
elif (ds_values.is_distributed_variable(obj) or
|
||||
elif (ds_utils.is_distributed_variable(obj) or
|
||||
resource_variable_ops.is_resource_variable(obj)):
|
||||
obj_to_copy = obj._primary if ds_values.is_distributed_variable( # pylint: disable=protected-access
|
||||
obj_to_copy = obj._primary if ds_utils.is_distributed_variable( # pylint: disable=protected-access
|
||||
obj) else obj
|
||||
new_variable = resource_variable_ops.copy_to_graph_uninitialized(
|
||||
obj_to_copy)
|
||||
if ds_values.is_distributed_variable(obj):
|
||||
if ds_utils.is_distributed_variable(obj):
|
||||
self.captured_tensor_node_ids[obj] = node_id
|
||||
for v in obj.values:
|
||||
object_map[v] = new_variable
|
||||
|
@ -25,10 +25,10 @@ from absl import logging
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import distribute_utils
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.distribute import sharded_variable
|
||||
from tensorflow.python.distribute import tpu_strategy
|
||||
from tensorflow.python.distribute import values as tf_values
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -1131,8 +1131,10 @@ class TPUEmbedding(tracking.AutoTrackable):
|
||||
# in the same (standard) order as self._strategy.extended.worker_devices.
|
||||
enqueue_ops = []
|
||||
for replica_id in range(self._strategy.num_replicas_in_sync):
|
||||
replica_inputs = tf_values.select_replica(replica_id, flat_inputs)
|
||||
replica_weights = tf_values.select_replica(replica_id, flat_weights)
|
||||
replica_inputs = distribute_utils.select_replica(replica_id,
|
||||
flat_inputs)
|
||||
replica_weights = distribute_utils.select_replica(replica_id,
|
||||
flat_weights)
|
||||
tpu_device = self._strategy.extended.worker_devices[replica_id]
|
||||
# TPU devices string are like /job:worker/replica:0/task:0/device:TPU:0
|
||||
# the device ordinal is the last number
|
||||
|
Loading…
Reference in New Issue
Block a user