diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 9ce72479f82..9673d57a4aa 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -286,6 +286,36 @@ py_library( ], ) +py_library( + name = "distribute_utils", + srcs = ["distribute_utils.py"], + deps = [ + ":device_util", + ":distribute_lib", + ":reduce_util", + ":shared_variable_creator", + ":values", + "//tensorflow/python:array_ops", + "//tensorflow/python:config", + "//tensorflow/python:constant_op", + "//tensorflow/python:device", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + "//tensorflow/python:pywrap_tfe", + "//tensorflow/python:summary_ops_v2", + "//tensorflow/python:tensor_util", + "//tensorflow/python:tf_export", + "//tensorflow/python:training", + "//tensorflow/python:util", + "//tensorflow/python:variable_scope", + "//tensorflow/python/autograph/core", + "//tensorflow/python/autograph/impl", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:def_function", + ], +) + py_library( name = "mirrored_strategy", srcs = ["mirrored_strategy.py"], @@ -293,6 +323,7 @@ py_library( ":cross_device_ops", ":device_util", ":distribute_lib", + ":distribute_utils", ":input_lib", ":mirrored_run", ":multi_worker_util", @@ -320,6 +351,7 @@ py_library( ":cross_device_ops", ":device_util", ":distribute_lib", + ":distribute_utils", ":input_lib", ":mirrored_run", ":multi_worker_util", @@ -508,6 +540,7 @@ py_library( deps = [ ":device_util", ":distribute_lib", + ":distribute_utils", ":input_ops", ":values", "//tensorflow/python:framework_ops", diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy_test.py b/tensorflow/python/distribute/collective_all_reduce_strategy_test.py index ea7a90504d2..63c0db4c3b3 100644 --- a/tensorflow/python/distribute/collective_all_reduce_strategy_test.py +++ b/tensorflow/python/distribute/collective_all_reduce_strategy_test.py @@ -28,11 +28,11 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import collective_all_reduce_strategy from tensorflow.python.distribute import combinations from tensorflow.python.distribute import cross_device_utils +from tensorflow.python.distribute import distribute_utils from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import multi_worker_util from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import strategy_test_lib -from tensorflow.python.distribute import values from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver from tensorflow.python.eager import context from tensorflow.python.framework import constant_op @@ -362,8 +362,8 @@ class CollectiveAllReduceStrategyTestBase( for expected_value in expected_values: next_element = iterator.get_next() - computed_value = sess.run([values.select_replica(r, next_element) - for r in range(len(devices))]) + computed_value = sess.run([distribute_utils.select_replica( + r, next_element) for r in range(len(devices))]) if ignore_order: self.assertCountEqual(list(expected_value), list(computed_value)) else: @@ -371,7 +371,7 @@ class CollectiveAllReduceStrategyTestBase( with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() - sess.run([values.select_replica(r, next_element) + sess.run([distribute_utils.select_replica(r, next_element) for r in range(len(devices))]) # After re-initializing the iterator, should be able to iterate again. @@ -380,8 +380,9 @@ class CollectiveAllReduceStrategyTestBase( for expected_value in expected_values: next_element = iterator.get_next() - computed_value = sess.run([values.select_replica(r, next_element) - for r in range(len(devices))]) + computed_value = sess.run([ + distribute_utils.select_replica(r, next_element) + for r in range(len(devices))]) if ignore_order: self.assertCountEqual(list(expected_value), list(computed_value)) else: diff --git a/tensorflow/python/distribute/cross_device_ops.py b/tensorflow/python/distribute/cross_device_ops.py index aaca66833e0..b88357e0ea6 100644 --- a/tensorflow/python/distribute/cross_device_ops.py +++ b/tensorflow/python/distribute/cross_device_ops.py @@ -28,6 +28,7 @@ from tensorflow.python.client import device_lib from tensorflow.python.distribute import collective_util from tensorflow.python.distribute import cross_device_utils from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import distribute_utils from tensorflow.python.distribute import ps_values from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import tpu_values @@ -187,7 +188,8 @@ def simple_broadcast(value, destinations, always_mirrored=False): for d in devices: value_updates.append( cross_device_utils.copy_tensor_or_indexed_slices_to_device(value, d)) - return value_lib.regroup(value_updates, wrap_class=value_lib.Mirrored) + return distribute_utils.regroup(value_updates, + wrap_class=value_lib.Mirrored) def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn, @@ -259,7 +261,7 @@ class CrossDeviceOps(object): per_replica_value, destinations): with ops.device(per_replica_value.values[0].device): v = array_ops.identity(per_replica_value.values[0]) - return value_lib.regroup((v,), wrap_class=value_lib.Mirrored) + return distribute_utils.regroup((v,), wrap_class=value_lib.Mirrored) if experimental_hints is None: experimental_hints = collective_util.Hints() @@ -309,7 +311,7 @@ class CrossDeviceOps(object): value_destination_pairs) and len( value_destination_pairs[0][0].values) == 1: return [ - value_lib.regroup(v.values, wrap_class=value_lib.Mirrored) + distribute_utils.regroup(v.values, wrap_class=value_lib.Mirrored) for v, _ in value_destination_pairs ] @@ -510,7 +512,8 @@ def _ungroup_and_make_mirrored(grouped_reduced, index[i].append(v / num_replicas) else: index[i].append(v) - return [value_lib.regroup(v, wrap_class=value_lib.Mirrored) for v in index] + return [distribute_utils.regroup( + v, wrap_class=value_lib.Mirrored) for v in index] class _ConcatAndSplitPacker(object): @@ -1000,7 +1003,7 @@ class CollectiveAllReduce(CrossDeviceOps): # TODO(josh11b): Once we add support for model parallelism, get the # copy from the corresponding replica instead of the primary. index.append(array_ops.identity(all_reduced._primary)) # pylint: disable=protected-access - return value_lib.regroup(index, wrap_class=value_lib.Mirrored) + return distribute_utils.regroup(index, wrap_class=value_lib.Mirrored) def batch_reduce_implementation(self, reduce_op, value_destination_pairs, experimental_hints): @@ -1104,7 +1107,8 @@ class CollectiveAllReduce(CrossDeviceOps): for i, v in enumerate(value): with ops.device(v.device): value[i] = v / num_replicas - mirrored.append(value_lib.regroup(value, wrap_class=value_lib.Mirrored)) + mirrored.append(distribute_utils.regroup(value, + wrap_class=value_lib.Mirrored)) return mirrored def _do_batch_all_reduce_sparse(self, reduce_op, per_replica_values): @@ -1140,7 +1144,8 @@ class CollectiveAllReduce(CrossDeviceOps): for i, v in enumerate(value): with ops.device(v.device): value[i].values = value[i].values / num_replicas - mirrored.append(value_lib.regroup(value, wrap_class=value_lib.Mirrored)) + mirrored.append(distribute_utils.regroup(value, + wrap_class=value_lib.Mirrored)) return mirrored def __deepcopy__(self, memo): diff --git a/tensorflow/python/distribute/cross_device_ops_test.py b/tensorflow/python/distribute/cross_device_ops_test.py index 09de4306199..9554de41a6e 100644 --- a/tensorflow/python/distribute/cross_device_ops_test.py +++ b/tensorflow/python/distribute/cross_device_ops_test.py @@ -32,6 +32,7 @@ from tensorflow.python.distribute import combinations from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import cross_device_utils from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import distribute_utils from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import strategy_combinations @@ -73,7 +74,7 @@ def _make_per_replica(values, devices, regroup=False): with ops.device(d): placed_v = array_ops.identity(v) index.append(placed_v) - return value_lib.regroup(index) + return distribute_utils.regroup(index) # pylint: disable=g-doc-args,g-doc-return-or-yield @@ -88,7 +89,7 @@ def _fake_mirrored(value, devices): for d in devices: with ops.device(d): values.append(array_ops.identity(value)) - return value_lib.regroup( + return distribute_utils.regroup( values, wrap_class=value_lib.Mirrored) @@ -105,7 +106,7 @@ def _make_indexed_slices(values, indices, dense_shape, device): def _make_mirrored_indexed_slices(devices, values, indices, dense_shape): values = [_make_indexed_slices(values, indices, dense_shape, d) for d in devices] - return value_lib.regroup( + return distribute_utils.regroup( values, wrap_class=value_lib.Mirrored) diff --git a/tensorflow/python/distribute/distribute_utils.py b/tensorflow/python/distribute/distribute_utils.py new file mode 100644 index 00000000000..ccf19521718 --- /dev/null +++ b/tensorflow/python/distribute/distribute_utils.py @@ -0,0 +1,322 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Class implementing utilities used by tf.distribute.Strategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.distribute import values as values_lib +from tensorflow.python.eager import context +from tensorflow.python.eager import tape +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.util import nest + + +def regroup(values, wrap_class=values_lib.PerReplica, always_wrap=False): + """Makes a nest per-replica into a nest of PerReplica/Mirrored values. + + Args: + values: Values to regroup + wrap_class: Class that `values` be wrapped in. + always_wrap: Always wrap the `values` in `wrap_class` even if the values + are the same except for DistributeVariable. + Returns: + Wrapped `values`. + """ + v0 = values[0] + + if isinstance(v0, list): + for v in values[1:]: + assert isinstance(v, list) + assert len(v) == len(v0), ("len(v) == %d, len(v0) == %d, v: %s, v0: %s" % + (len(v), len(v0), v, v0)) + return [ + regroup(tuple(v[i] for v in values), wrap_class) + for i in range(len(v0)) + ] + + if isinstance(v0, tuple): + for v in values[1:]: + assert isinstance(v, tuple) + assert len(v) == len(v0) + regrouped_tuple = tuple( + regroup(tuple(v[i] for v in values), wrap_class) + for i in range(len(v0))) + if hasattr(v0, "_fields"): + # This tuple is in fact a namedtuple! Create a new namedtuple instance + # and initialize it with the regrouped values: + assert hasattr(type(v0), "_make") + return type(v0)._make(regrouped_tuple) + else: + return regrouped_tuple + + if isinstance(v0, dict): + v0keys = v0.keys() + for v in values[1:]: + assert isinstance(v, dict), ("v[0]: %r v[i]: %r" % (v0, v)) + assert set(v.keys()) == set(v0keys), ("v[0].keys: %s v[i].keys: %s" % + (set(v0keys), set(v.keys()))) + # Use the actual type in case it is a class inherited from a dict. + return type(v0)({ + key: regroup(tuple(v[key] for v in values), wrap_class) + for key in v0keys + }) + + # If exactly the same object across all devices, return it unwrapped. + same_id = True + for v in values[1:]: + if v is not v0: + same_id = False + break + # Consider three cases where same_id is true: + # * If v0 is a DistributedVariable (a MirroredVariable or + # SyncOnReadVariable, and same_id means it is the same across all + # devices), we want to return it. We check DistributedVariable + # specifically since it can look like it has a + # _distributed_container member since its members do. + if same_id and isinstance(v0, values_lib.DistributedVariable): + return v0 + # * If v0 is a member of a distributed variable, in which case + # hasattr(v0, "_distributed_container") is true, we want to + # return the DistributedVariable that contains it using the + # _distributed_container logic below. This case can trigger + # same_id when there is only one device. + # * In any other situation, same_id means we return v0 unless `always_wrap` is + # true. + if same_id and not always_wrap and not hasattr(v0, "_distributed_container"): + return v0 + + # Detect the case where each device has a parallel component of the + # same MirroredVariable (or SyncOnReadVariable). In this case we + # want to return the containing MirroredVariable, after a bunch of + # sanity checking. In particular, each component should have the + # same container, and the devices of the variables should match the + # keys of the per-replica dictionary. + if hasattr(v0, "_distributed_container"): + # pylint: disable=protected-access + assert not isinstance(v0, values_lib.MirroredVariable), ( + "ids = %s, values = %s" % ([id(v) for v in values], values)) + distributed_container = v0._distributed_container + assert distributed_container is not None + for v in values[1:]: + assert distributed_container is v._distributed_container + return distributed_container + # pylint: enable=protected-access + + return wrap_class(values) + + +def select_replica(replica_id, structured): + """Specialize a nest of regular & per-replica values for one replica.""" + + def _get(x): + # `DistributedValues` would be sliced according to replica unless it is a + # `DistributedVariable` because `DistributedVariable` can be handled + # directly in the replica context. + if (isinstance(x, values_lib.DistributedVariable) or + not isinstance(x, values_lib.DistributedValues)): + return x + else: + return x.values[replica_id] + + return nest.map_structure(_get, structured) + + +def select_replica_mirrored(replica_id, structured): + """Specialize a nest of regular & mirrored values for one replica.""" + + def _get_mirrored(x): + if isinstance(x, values_lib.DistributedValues): + if not isinstance(x, values_lib.Mirrored): + raise TypeError( + "Expected value to be mirrored across replicas: %s in %s." % + (x, structured)) + return x.values[replica_id] + else: + return x + + return nest.map_structure(_get_mirrored, structured) + + +def update_regroup(extended, updates, group): + """Regroup for an update, with dependencies to ensure all updates execute.""" + if not group: + regrouped = regroup(updates, values_lib.Mirrored) + return nest.map_structure(extended._local_results, regrouped) # pylint: disable=protected-access + + def _make_grouped_mirrored(values): + """Convert per-replica list `values` into Mirrored type with grouping.""" + if len(values) == 1: + return values_lib.Mirrored(values) + + # Make sure we run all updates. Without this, something like + # session.run(extended.update(...)) may only update one replica. + g = control_flow_ops.group(values) + + # If values is just ops, the grouping is enough. Everything in values + # should have the same type, since we expect every replica to be performing + # the same computation. + if not all(tensor_util.is_tensor(v) for v in values): + return g + + # Otherwise we need tensors with the same values as `values`, but + # that have a dependency on `g`. + with_dep = [] + for v in values: + with ops.device(v.device), ops.control_dependencies([g]): + with_dep.append(array_ops.identity(v)) + + return values_lib.Mirrored(with_dep) + + return regroup(updates, _make_grouped_mirrored) + + +def value_container(val): + """Returns the container that this per-replica `value` belongs to. + + Args: + val: A value returned by `call_for_each_replica()` or a variable created in + `scope()`. + + Returns: + A container that `value` belongs to. + If value does not belong to any container (including the case of + container having been destroyed), returns the value itself. + """ + if (hasattr(val, "_distributed_container") and + # DistributedVariable has _distributed_container defined + # but we don't want to return it. + not isinstance(val, values_lib.DistributedVariable)): + container = val._distributed_container # pylint: disable=protected-access + if container is not None: + return container + return val + + +def is_distributed_variable(v): + """Determine if a variable is ds variable or TPU mirrored variable.""" + return isinstance(v, values_lib.DistributedVariable) + + +def _validate_colocate_extended(v, extended): + variable_strategy = v._distribute_strategy # pylint: disable=protected-access + if variable_strategy.extended is not extended: + raise ValueError( + "`colocate_vars_with` must only be passed a variable created in this " + "tf.distribute.Strategy.scope(), not %s created in scope: %s" % + (v, variable_strategy)) + + +def validate_colocate_distributed_variable(v, extended): + if not isinstance(v, values_lib.DistributedVariable): + raise ValueError( + "`colocate_vars_with` must only be passed a variable created in this " + "tf.distribute.Strategy.scope(), not: %r" % (v,)) + _validate_colocate_extended(v, extended) + + +def validate_colocate(v, extended): + if not hasattr(v, "_distribute_strategy"): + raise ValueError( + "`colocate_vars_with` must only be passed a variable created in this " + "tf.distribute.Strategy.scope(), not: %r" % (v,)) + _validate_colocate_extended(v, extended) + + +# Variable creation function for sync strategies. +def create_mirrored_variable( # pylint: disable=missing-docstring + strategy, real_mirrored_creator, mirrored_cls, sync_on_read_cls, **kwargs): + # Figure out what collections this variable should be added to. + # We'll add the MirroredVariable to those collections instead. + var_collections = kwargs.pop("collections", None) + if var_collections is None: + var_collections = [ops.GraphKeys.GLOBAL_VARIABLES] + kwargs["collections"] = [] + + synchronization = kwargs.get("synchronization", + vs.VariableSynchronization.ON_WRITE) + + if synchronization == vs.VariableSynchronization.NONE: + raise ValueError( + "`NONE` variable synchronization mode is not supported with `Mirrored` " + "distribution strategy. Please change the `synchronization` for " + "variable: " + str(kwargs["name"])) + elif synchronization == vs.VariableSynchronization.ON_READ: + is_sync_on_read = True + elif synchronization in (vs.VariableSynchronization.ON_WRITE, + vs.VariableSynchronization.AUTO): + # `AUTO` synchronization defaults to `ON_WRITE`. + is_sync_on_read = False + else: + raise ValueError( + "Invalid variable synchronization mode: %s for variable: %s" % + (synchronization, kwargs["name"])) + + aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE) + + if aggregation not in (vs.VariableAggregation.NONE, + vs.VariableAggregation.SUM, + vs.VariableAggregation.MEAN, + vs.VariableAggregation.ONLY_FIRST_REPLICA): + raise ValueError("Invalid variable aggregation mode: %s for variable: %s" % + (aggregation, kwargs["name"])) + + # Ignore user-specified caching device, not needed for mirrored variables. + kwargs.pop("caching_device", None) + + # TODO(josh11b,apassos): It would be better if variable initialization + # was never recorded on the tape instead of having to do this manually + # here. + with tape.stop_recording(): + value_list = real_mirrored_creator(**kwargs) + var_cls = sync_on_read_cls if is_sync_on_read else mirrored_cls + result = var_cls(strategy, value_list, aggregation) + # Install the created DistributedVariable as _distributed_container property + # of the underlying variables, to make it easy to map back to the container. + for v in result.values: + # Hold a strong reference to avoid the container from being GC-ed. After + # v = v.assign(), the user code may no longer holds references to the + # original container, since v.assign() returns a new DistributedVariable. + v._distributed_container = result # pylint: disable=protected-access + + # Add the wrapped variable to the requested collections. + # The handling of eager mode and the global step matches + # ResourceVariable._init_from_args(). + if not context.executing_eagerly(): + g = ops.get_default_graph() + # If "trainable" is True, next_creator() will add the member variables + # to the TRAINABLE_VARIABLES collection, so we manually remove + # them and replace with the MirroredVariable. We can't set + # "trainable" to False for next_creator() since that causes functions + # like implicit_gradients to skip those variables. + if kwargs.get("trainable", True): + var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) + l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) + for value in value_list: + for i, trainable_variable in enumerate(l): + if value is trainable_variable: + del l[i] + break + + g.add_to_collections(var_collections, result) + elif ops.GraphKeys.GLOBAL_STEP in var_collections: + ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result) + + return result diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py index 153d4c2603b..85e2dac1c1d 100644 --- a/tensorflow/python/distribute/input_lib.py +++ b/tensorflow/python/distribute/input_lib.py @@ -29,6 +29,7 @@ from tensorflow.python.data.experimental.ops import distribute from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import multi_device_iterator_ops from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import distribute_utils from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import input_ops from tensorflow.python.distribute import reduce_util @@ -316,7 +317,7 @@ class DistributedIteratorBase(distribute_types.Iterator): # Make `replicas` a flat list of values across all replicas. replicas.extend( self._iterators[i].get_next_as_list_static_shapes(new_name)) - return values.regroup(replicas) + return distribute_utils.regroup(replicas) out_of_range_replicas = [] def out_of_range_fn(worker_index, device): @@ -349,7 +350,7 @@ class DistributedIteratorBase(distribute_types.Iterator): results.append(result) replicas = results - return values.regroup(replicas) + return distribute_utils.regroup(replicas) class DistributedIteratorV1(DistributedIteratorBase): @@ -577,7 +578,7 @@ class _IterableInput(distribute_types.Iterable): else: raise ValueError("Dataset iteration within a tf.function is" " not supported for multiple workers.") - state = reduce_fn(state, values.regroup(data)) + state = reduce_fn(state, distribute_utils.regroup(data)) has_data, data = _get_next_as_optional(iterator, self._strategy) return has_data, data, state diff --git a/tensorflow/python/distribute/input_lib_test.py b/tensorflow/python/distribute/input_lib_test.py index 2114c4e6bda..ff4436c4c8c 100644 --- a/tensorflow/python/distribute/input_lib_test.py +++ b/tensorflow/python/distribute/input_lib_test.py @@ -34,13 +34,13 @@ from tensorflow.python.distribute import combinations from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import distribute_utils from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import mirrored_strategy from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import parameter_server_strategy from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import strategy_combinations -from tensorflow.python.distribute import values from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import test @@ -188,8 +188,8 @@ class DistributedIteratorTestBase(test.TestCase): for expected_value in expected_values: next_element = iterator.get_next() computed_value = evaluate( - [values.select_replica(r, - next_element) for r in range(len(devices))]) + [distribute_utils.select_replica(r, next_element) + for r in range(len(devices))]) self.assertEqual(len(expected_value), len(computed_value)) for i in range(len(expected_value)): self.assertAllEqual(expected_value[i], computed_value[i]) @@ -197,8 +197,8 @@ class DistributedIteratorTestBase(test.TestCase): with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() evaluate( - [values.select_replica(r, - next_element) for r in range(len(devices))]) + [distribute_utils.select_replica(r, next_element) + for r in range(len(devices))]) # After re-initializing the iterator, should be able to iterate again. if not ops.executing_eagerly_outside_functions(): @@ -212,8 +212,8 @@ class DistributedIteratorTestBase(test.TestCase): for expected_value in expected_values: next_element = iterator.get_next() computed_value = evaluate( - [values.select_replica(r, - next_element) for r in range(len(devices))]) + [distribute_utils.select_replica(r, next_element) + for r in range(len(devices))]) self.assertEqual(len(expected_value), len(computed_value)) for i in range(len(expected_value)): self.assertAllEqual(expected_value[i], computed_value[i]) @@ -222,7 +222,8 @@ class DistributedIteratorTestBase(test.TestCase): actual_values = [] for x in dataset: computed_value = self.evaluate( - [values.select_replica(r, x) for r in range(len(devices))]) + [distribute_utils.select_replica(r, x) + for r in range(len(devices))]) actual_values.append(computed_value) for i, expected_value in enumerate(expected_values): self.assertEqual(len(expected_value), len(actual_values[i])) @@ -699,24 +700,29 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase, # Assert that the tensors are rebatched and sparsity is preserved. per_replica_batch = defun(lambda x: next(iter(x)))(dataset) self.assertAllEqual( - values.select_replica(0, per_replica_batch["dense"]), + distribute_utils.select_replica(0, per_replica_batch["dense"]), [[0., 0., 0.], [1., 0., 0.], [2., 2., 0.], [3., 3., 3.]]) self.assertAllEqual( - values.select_replica(1, per_replica_batch["dense"]), + distribute_utils.select_replica(1, per_replica_batch["dense"]), [[0., 0., 0.], [5., 0., 0.], [6., 6., 0.], [7., 7., 7.]]) # Transitively check the ragged and sparse tensors by densification. for i in range(2): self.assertLen( - values.select_replica(i, per_replica_batch["ragged"]).values, 6) + distribute_utils.select_replica(i, + per_replica_batch["ragged"]).values, + 6) self.assertAllEqual( - values.select_replica(i, per_replica_batch["ragged"]).to_tensor(), - values.select_replica(i, per_replica_batch["dense"])) + distribute_utils.select_replica( + i, per_replica_batch["ragged"]).to_tensor(), + distribute_utils.select_replica(i, per_replica_batch["dense"])) self.assertLen( - values.select_replica(i, per_replica_batch["sparse"]).indices, 6) + distribute_utils.select_replica(i, + per_replica_batch["sparse"]).indices, + 6) self.assertAllEqual( sparse_ops.sparse_tensor_to_dense( - values.select_replica(i, per_replica_batch["sparse"])), - values.select_replica(i, per_replica_batch["dense"])) + distribute_utils.select_replica(i, per_replica_batch["sparse"])), + distribute_utils.select_replica(i, per_replica_batch["dense"])) # Iterate through all the batches and sum them up. def sum_batch(per_replica_features): """Sums the `PerReplica` values in the `per_replica_features` map.""" diff --git a/tensorflow/python/distribute/mirrored_function_strategy.py b/tensorflow/python/distribute/mirrored_function_strategy.py index aa9ecfa1fc4..bbe52984d1e 100644 --- a/tensorflow/python/distribute/mirrored_function_strategy.py +++ b/tensorflow/python/distribute/mirrored_function_strategy.py @@ -22,6 +22,7 @@ import threading from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import distribute_utils from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import mirrored_strategy from tensorflow.python.distribute import values @@ -123,16 +124,18 @@ class MirroredFunctionExtended(distribute_lib.StrategyExtendedV1): # use a collective op. This is a particular concern with eager # execution. with context.execution_mode(context.ASYNC): - return_values.append(fn(*values.select_replica(index, args), - **values.select_replica(index, kwargs))) + return_values.append( + fn(*distribute_utils.select_replica(index, args), + **distribute_utils.select_replica(index, kwargs))) else: - return_values.append(fn(*values.select_replica(index, args), - **values.select_replica(index, kwargs))) + return_values.append( + fn(*distribute_utils.select_replica(index, args), + **distribute_utils.select_replica(index, kwargs))) finally: _replica_index.graph_outside_run = None _replica_index.current = None - return values.regroup(return_values) + return distribute_utils.regroup(return_values) def _local_results(self, val): if isinstance(val, values.DistributedValues): diff --git a/tensorflow/python/distribute/mirrored_run.py b/tensorflow/python/distribute/mirrored_run.py index aed7b363b81..ed338b05a4c 100644 --- a/tensorflow/python/distribute/mirrored_run.py +++ b/tensorflow/python/distribute/mirrored_run.py @@ -27,8 +27,8 @@ from tensorflow.python import pywrap_tfe from tensorflow.python.autograph.core import ag_ctx as autograph_ctx from tensorflow.python.autograph.impl import api as autograph from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import distribute_utils from tensorflow.python.distribute import shared_variable_creator -from tensorflow.python.distribute import values from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op @@ -160,8 +160,8 @@ def _call_for_each_replica(distribution, fn, args, kwargs): shared_variable_store, index) t = _MirroredReplicaThread( distribution, coord, index, devices, variable_creator_fn, fn, - values.select_replica(index, args), - values.select_replica(index, kwargs)) + distribute_utils.select_replica(index, args), + distribute_utils.select_replica(index, kwargs)) threads.append(t) for t in threads: @@ -209,8 +209,10 @@ def _call_for_each_replica(distribution, fn, args, kwargs): raise RuntimeError("Some replicas made a different number of " "replica_context().merge_call() calls.") # get_replica_context().merge_call() case - merge_args = values.regroup(tuple(t.merge_args for t in threads)) - merge_kwargs = values.regroup(tuple(t.merge_kwargs for t in threads)) + merge_args = distribute_utils.regroup( + tuple(t.merge_args for t in threads)) + merge_kwargs = distribute_utils.regroup( + tuple(t.merge_kwargs for t in threads)) # We capture the name_scope of the MRT when we call merge_fn # to ensure that if we have opened a name scope in the MRT, # it will be respected when executing the merge function. We only @@ -228,13 +230,13 @@ def _call_for_each_replica(distribution, fn, args, kwargs): merge_result = threads[0].merge_fn(distribution, *merge_args, **merge_kwargs) for r, t in enumerate(threads): - t.merge_result = values.select_replica(r, merge_result) + t.merge_result = distribute_utils.select_replica(r, merge_result) finally: for t in threads: t.should_run.set() coord.join(threads) - return values.regroup(tuple(t.main_result for t in threads)) + return distribute_utils.regroup(tuple(t.main_result for t in threads)) class _MirroredReplicaThread(threading.Thread): diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py index de66128ce37..fe565261f16 100644 --- a/tensorflow/python/distribute/mirrored_strategy.py +++ b/tensorflow/python/distribute/mirrored_strategy.py @@ -23,6 +23,7 @@ import copy from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import distribute_utils from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import mirrored_run from tensorflow.python.distribute import multi_worker_util @@ -445,13 +446,13 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): value_list.append(v) return value_list - return values.create_mirrored_variable(self._container_strategy(), - _real_mirrored_creator, - values.MirroredVariable, - values.SyncOnReadVariable, **kwargs) + return distribute_utils.create_mirrored_variable( + self._container_strategy(), _real_mirrored_creator, + values.MirroredVariable, values.SyncOnReadVariable, **kwargs) def _validate_colocate_with_variable(self, colocate_with_variable): - values.validate_colocate_distributed_variable(colocate_with_variable, self) + distribute_utils.validate_colocate_distributed_variable( + colocate_with_variable, self) def _make_dataset_iterator(self, dataset): return input_lib.DatasetIterator( @@ -507,7 +508,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): per_replica_values.append(value_fn( distribute_lib.ValueContext(replica_id, self._num_replicas_in_sync))) - return values.regroup(per_replica_values, always_wrap=True) + return distribute_utils.regroup(per_replica_values, always_wrap=True) # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, @@ -557,7 +558,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): # For outputs that have already been reduced, wrap them in a Mirrored # container, else in a PerReplica container. if reduce_op is None: - last_step_tensor_outputs_dict[name] = values.regroup(output) + last_step_tensor_outputs_dict[name] = distribute_utils.regroup(output) else: assert len(output) == 1 last_step_tensor_outputs_dict[name] = output[0] @@ -580,8 +581,8 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): return self._get_cross_device_ops().broadcast(tensor, destinations) def _call_for_each_replica(self, fn, args, kwargs): - return mirrored_run.call_for_each_replica(self._container_strategy(), fn, - args, kwargs) + return mirrored_run.call_for_each_replica( + self._container_strategy(), fn, args, kwargs) def _configure(self, session_config=None, @@ -643,10 +644,10 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): distribute_lib.UpdateContext(i), \ ops.name_scope(name): # If args and kwargs are not mirrored, the value is returned as is. - updates.append(fn(v, - *values.select_replica_mirrored(i, args), - **values.select_replica_mirrored(i, kwargs))) - return values.update_regroup(self, updates, group) + updates.append( + fn(v, *distribute_utils.select_replica_mirrored(i, args), + **distribute_utils.select_replica_mirrored(i, kwargs))) + return distribute_utils.update_regroup(self, updates, group) def _update_non_slot(self, colocate_with, fn, args, kwargs, group): assert isinstance(colocate_with, tuple) @@ -655,9 +656,10 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): for i, d in enumerate(colocate_with): name = "update_%d" % i with ops.device(d), distribute_lib.UpdateContext(i), ops.name_scope(name): - updates.append(fn(*values.select_replica_mirrored(i, args), - **values.select_replica_mirrored(i, kwargs))) - return values.update_regroup(self, updates, group) + updates.append( + fn(*distribute_utils.select_replica_mirrored(i, args), + **distribute_utils.select_replica_mirrored(i, kwargs))) + return distribute_utils.update_regroup(self, updates, group) def read_var(self, replica_local_var): """Read the aggregate value of a replica-local variable.""" @@ -672,7 +674,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): return (val,) def value_container(self, val): - return values.value_container(val) + return distribute_utils.value_container(val) @property def _num_replicas_in_sync(self): diff --git a/tensorflow/python/distribute/mirrored_strategy_test.py b/tensorflow/python/distribute/mirrored_strategy_test.py index 42f808589db..6009eece14e 100644 --- a/tensorflow/python/distribute/mirrored_strategy_test.py +++ b/tensorflow/python/distribute/mirrored_strategy_test.py @@ -31,6 +31,7 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import combinations from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import distribute_utils from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import mirrored_strategy from tensorflow.python.distribute import multi_worker_test_base @@ -1026,8 +1027,9 @@ class MirroredStrategyDefunTest(test.TestCase): result = distribution.extended.call_for_each_replica( model_fn, args=[mock_model] + inputs) for r in range(len(devices)): - device_result = values.select_replica(r, result) - device_expected_result = values.select_replica(r, expected_result) + device_result = distribute_utils.select_replica(r, result) + device_expected_result = distribute_utils.select_replica( + r, expected_result) self.assertAllClose(device_expected_result, self.evaluate(device_result)) diff --git a/tensorflow/python/distribute/one_device_strategy.py b/tensorflow/python/distribute/one_device_strategy.py index a1c4ada6c52..9a74832cd9d 100644 --- a/tensorflow/python/distribute/one_device_strategy.py +++ b/tensorflow/python/distribute/one_device_strategy.py @@ -20,9 +20,9 @@ from __future__ import print_function from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import distribute_utils from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import numpy_dataset -from tensorflow.python.distribute import values from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -233,7 +233,7 @@ class OneDeviceStrategy(distribute_lib.Strategy): return super(OneDeviceStrategy, self).scope() -@tf_export(v1=["distribute.OneDeviceStrategy"]) # pylint: disable=missing-docstring +@tf_export(v1=["distribute.OneDeviceStrategy"]) # pylint: disable=empty-docstring class OneDeviceStrategyV1(distribute_lib.StrategyV1): __doc__ = OneDeviceStrategy.__doc__.replace( @@ -272,7 +272,7 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1): return next_creator(**kwargs) def _validate_colocate_with_variable(self, colocate_with_variable): - values.validate_colocate(colocate_with_variable, self) + distribute_utils.validate_colocate(colocate_with_variable, self) def _make_dataset_iterator(self, dataset): """Make iterator from dataset without splitting the batch.""" diff --git a/tensorflow/python/distribute/parameter_server_strategy.py b/tensorflow/python/distribute/parameter_server_strategy.py index 142684bb3e9..42fc327351c 100644 --- a/tensorflow/python/distribute/parameter_server_strategy.py +++ b/tensorflow/python/distribute/parameter_server_strategy.py @@ -24,6 +24,7 @@ import copy from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import distribute_utils from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import mirrored_run from tensorflow.python.distribute import multi_worker_util @@ -334,7 +335,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1): compute_devices, self._variable_device) def _validate_colocate_with_variable(self, colocate_with_variable): - values.validate_colocate(colocate_with_variable, self) + distribute_utils.validate_colocate(colocate_with_variable, self) def _experimental_distribute_dataset(self, dataset): return input_lib.get_distributed_dataset( diff --git a/tensorflow/python/distribute/parameter_server_strategy_test.py b/tensorflow/python/distribute/parameter_server_strategy_test.py index d67ed72a576..24dbd091079 100644 --- a/tensorflow/python/distribute/parameter_server_strategy_test.py +++ b/tensorflow/python/distribute/parameter_server_strategy_test.py @@ -27,6 +27,7 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import central_storage_strategy from tensorflow.python.distribute import combinations from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import distribute_utils from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import multi_worker_util @@ -34,7 +35,6 @@ from tensorflow.python.distribute import parameter_server_strategy from tensorflow.python.distribute import ps_values from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import strategy_test_lib -from tensorflow.python.distribute import values from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver from tensorflow.python.eager import backprop from tensorflow.python.eager import context @@ -534,8 +534,8 @@ class ParameterServerStrategyTestBase( for expected_value in expected_values: next_element = iterator.get_next() - computed_value = sess.run([values.select_replica(r, next_element) - for r in range(len(devices))]) + computed_value = sess.run([distribute_utils.select_replica( + r, next_element) for r in range(len(devices))]) if ignore_order: self.assertCountEqual(expected_value, computed_value) else: @@ -543,7 +543,7 @@ class ParameterServerStrategyTestBase( with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() - sess.run([values.select_replica(r, next_element) + sess.run([distribute_utils.select_replica(r, next_element) for r in range(len(devices))]) # After re-initializing the iterator, should be able to iterate again. @@ -552,8 +552,8 @@ class ParameterServerStrategyTestBase( for expected_value in expected_values: next_element = iterator.get_next() - computed_value = sess.run([values.select_replica(r, next_element) - for r in range(len(devices))]) + computed_value = sess.run([distribute_utils.select_replica( + r, next_element) for r in range(len(devices))]) if ignore_order: self.assertCountEqual(expected_value, computed_value) else: diff --git a/tensorflow/python/distribute/strategy_test_lib.py b/tensorflow/python/distribute/strategy_test_lib.py index b3ececcdcba..0845391ce3b 100644 --- a/tensorflow/python/distribute/strategy_test_lib.py +++ b/tensorflow/python/distribute/strategy_test_lib.py @@ -27,9 +27,9 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.core.util import event_pb2 from tensorflow.python.client import session as session_lib from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import distribute_utils from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import reduce_util -from tensorflow.python.distribute import values from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import def_function @@ -348,7 +348,8 @@ class DistributionTestBase(test.TestCase): for expected_value in expected_values: next_element = iterator.get_next() computed_value = evaluate( - [values.select_replica(r, next_element) for r in range(len(devices))]) + [distribute_utils.select_replica(r, next_element) for r in + range(len(devices))]) if ignore_order: self.assertCountEqual(expected_value, computed_value) else: @@ -357,7 +358,8 @@ class DistributionTestBase(test.TestCase): with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() evaluate( - [values.select_replica(r, next_element) for r in range(len(devices))]) + [distribute_utils.select_replica(r, next_element) for r in + range(len(devices))]) # After re-initializing the iterator, should be able to iterate again. if test_reinitialize: @@ -366,7 +368,8 @@ class DistributionTestBase(test.TestCase): for expected_value in expected_values: next_element = iterator.get_next() computed_value = evaluate([ - values.select_replica(r, next_element) for r in range(len(devices)) + distribute_utils.select_replica(r, next_element) for r in + range(len(devices)) ]) if ignore_order: self.assertCountEqual(expected_value, computed_value) diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py index b1fea714dd1..c605abd9eae 100644 --- a/tensorflow/python/distribute/tpu_strategy.py +++ b/tensorflow/python/distribute/tpu_strategy.py @@ -32,6 +32,7 @@ from tensorflow.python.autograph.impl import api as autograph from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import distribute_utils from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import numpy_dataset from tensorflow.python.distribute import reduce_util @@ -363,7 +364,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): return self._input_workers_obj def _validate_colocate_with_variable(self, colocate_with_variable): - values.validate_colocate(colocate_with_variable, self) + distribute_utils. validate_colocate(colocate_with_variable, self) def _make_dataset_iterator(self, dataset): """Make iterators for each of the TPU hosts.""" @@ -423,7 +424,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): per_replica_values.append( value_fn(distribute_lib.ValueContext(replica_id, self._num_replicas_in_sync))) - return values.regroup(per_replica_values, always_wrap=True) + return distribute_utils.regroup(per_replica_values, always_wrap=True) # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have @@ -462,7 +463,8 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): per_replica_inputs = multi_worker_iterator.get_next() replicate_inputs = [] for replica_id in range(self._num_replicas_in_sync): - select_replica = lambda x: values.select_replica(replica_id, x) # pylint: disable=cell-var-from-loop + select_replica = lambda x: distribute_utils.select_replica( # pylint: disable=g-long-lambda + replica_id, x) # pylint: disable=cell-var-from-loop replicate_inputs.append((nest.map_structure( select_replica, per_replica_inputs),)) @@ -648,11 +650,10 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): value_list.append(v) return value_list - return values.create_mirrored_variable(self._container_strategy(), - _real_mirrored_creator, - tpu_values.TPUMirroredVariable, - tpu_values.TPUSyncOnReadVariable, - **kwargs) + return distribute_utils.create_mirrored_variable( + self._container_strategy(), _real_mirrored_creator, + tpu_values.TPUMirroredVariable, tpu_values.TPUSyncOnReadVariable, + **kwargs) def _reduce_to(self, reduce_op, value, destinations, experimental_hints): if (isinstance(value, values.DistributedValues) or @@ -722,10 +723,10 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): distribute_lib.UpdateContext(i), \ ops.name_scope(name): # If args and kwargs are not mirrored, the value is returned as is. - updates.append(fn(v, - *values.select_replica_mirrored(i, args), - **values.select_replica_mirrored(i, kwargs))) - return values.update_regroup(self, updates, group) + updates.append( + fn(v, *distribute_utils.select_replica_mirrored(i, args), + **distribute_utils.select_replica_mirrored(i, kwargs))) + return distribute_utils.update_regroup(self, updates, group) def read_var(self, var): assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance( @@ -893,8 +894,8 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): for i in range(strategy.num_replicas_in_sync): replicate_inputs.append( [constant_op.constant(i, dtype=dtypes.int32), - values.select_replica(i, args), - values.select_replica(i, kwargs)]) + distribute_utils.select_replica(i, args), + distribute_utils.select_replica(i, kwargs)]) # Construct and pass `maximum_shapes` so that we could support dynamic # shapes using dynamic padder. @@ -941,7 +942,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): nest.pack_sequence_as(result[0], nest.flatten(replica_output)) for replica_output in replicate_outputs ] - return values.regroup(replicate_outputs) + return distribute_utils.regroup(replicate_outputs) if context.executing_eagerly(): tpu_function = def_function.function(tpu_function) diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py index 90210e9041e..5c038c01999 100644 --- a/tensorflow/python/distribute/values.py +++ b/tensorflow/python/distribute/values.py @@ -26,10 +26,8 @@ from tensorflow.python.distribute import packed_distributed_variable as packed from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import values_util from tensorflow.python.eager import context -from tensorflow.python.eager import tape from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util from tensorflow.python.framework import type_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -40,7 +38,6 @@ from tensorflow.python.training.saving import saveable_object from tensorflow.python.training.saving import saveable_object_util from tensorflow.python.training.tracking import base as trackable from tensorflow.python.types import core -from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export @@ -1023,88 +1020,6 @@ class SyncOnReadVariable(DistributedVariable): self._get(), dtype=dtype, name=name, as_ref=as_ref) -# Variable creation function for sync strategies. -def create_mirrored_variable( # pylint: disable=missing-docstring - strategy, real_mirrored_creator, mirrored_cls, sync_on_read_cls, **kwargs): - # Figure out what collections this variable should be added to. - # We'll add the MirroredVariable to those collections instead. - var_collections = kwargs.pop("collections", None) - if var_collections is None: - var_collections = [ops.GraphKeys.GLOBAL_VARIABLES] - kwargs["collections"] = [] - - synchronization = kwargs.get("synchronization", - vs.VariableSynchronization.ON_WRITE) - - if synchronization == vs.VariableSynchronization.NONE: - raise ValueError( - "`NONE` variable synchronization mode is not supported with `Mirrored` " - "distribution strategy. Please change the `synchronization` for " - "variable: " + str(kwargs["name"])) - elif synchronization == vs.VariableSynchronization.ON_READ: - is_sync_on_read = True - elif synchronization in (vs.VariableSynchronization.ON_WRITE, - vs.VariableSynchronization.AUTO): - # `AUTO` synchronization defaults to `ON_WRITE`. - is_sync_on_read = False - else: - raise ValueError( - "Invalid variable synchronization mode: %s for variable: %s" % - (synchronization, kwargs["name"])) - - aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE) - - if aggregation not in (vs.VariableAggregation.NONE, - vs.VariableAggregation.SUM, - vs.VariableAggregation.MEAN, - vs.VariableAggregation.ONLY_FIRST_REPLICA): - raise ValueError("Invalid variable aggregation mode: %s for variable: %s" % - (aggregation, kwargs["name"])) - - # Ignore user-specified caching device, not needed for mirrored variables. - kwargs.pop("caching_device", None) - - # TODO(josh11b,apassos): It would be better if variable initialization - # was never recorded on the tape instead of having to do this manually - # here. - with tape.stop_recording(): - value_list = real_mirrored_creator(**kwargs) - var_cls = sync_on_read_cls if is_sync_on_read else mirrored_cls - result = var_cls(strategy, value_list, aggregation) - # Install the created DistributedVariable as _distributed_container property - # of the underlying variables, to make it easy to map back to the container. - for v in result.values: - # Hold a strong reference to avoid the container from being GC-ed. After - # v = v.assign(), the user code may no longer holds references to the - # original container, since v.assign() returns a new DistributedVariable. - v._distributed_container = result # pylint: disable=protected-access - - # Add the wrapped variable to the requested collections. - # The handling of eager mode and the global step matches - # ResourceVariable._init_from_args(). - if not context.executing_eagerly(): - g = ops.get_default_graph() - # If "trainable" is True, next_creator() will add the member variables - # to the TRAINABLE_VARIABLES collection, so we manually remove - # them and replace with the MirroredVariable. We can't set - # "trainable" to False for next_creator() since that causes functions - # like implicit_gradients to skip those variables. - if kwargs.get("trainable", True): - var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) - l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) - for value in value_list: - for i, trainable_variable in enumerate(l): - if value is trainable_variable: - del l[i] - break - - g.add_to_collections(var_collections, result) - elif ops.GraphKeys.GLOBAL_STEP in var_collections: - ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result) - - return result - - # Register a conversion functions which reads the value of the variable, # allowing instances of the class to be used as tensors. # MirroredVariables @@ -1133,214 +1048,3 @@ def _tensor_conversion_sync_on_read(var, dtype=None, name=None, as_ref=False): ops.register_tensor_conversion_function(SyncOnReadVariable, _tensor_conversion_sync_on_read) - - -def regroup(values, wrap_class=PerReplica, always_wrap=False): - """Makes a nest per-replica into a nest of PerReplica/Mirrored values. - - Args: - values: Values to regroup - wrap_class: Class that `values` be wrapped in. - always_wrap: Always wrap the `values` in `wrap_class` even if the values - are the same except for DistributeVariable. - Returns: - Wrapped `values`. - """ - v0 = values[0] - - if isinstance(v0, list): - for v in values[1:]: - assert isinstance(v, list) - assert len(v) == len(v0), ("len(v) == %d, len(v0) == %d, v: %s, v0: %s" % - (len(v), len(v0), v, v0)) - return [ - regroup(tuple(v[i] for v in values), wrap_class) - for i in range(len(v0)) - ] - - if isinstance(v0, tuple): - for v in values[1:]: - assert isinstance(v, tuple) - assert len(v) == len(v0) - regrouped_tuple = tuple( - regroup(tuple(v[i] for v in values), wrap_class) - for i in range(len(v0))) - if hasattr(v0, "_fields"): - # This tuple is in fact a namedtuple! Create a new namedtuple instance - # and initialize it with the regrouped values: - assert hasattr(type(v0), "_make") - return type(v0)._make(regrouped_tuple) - else: - return regrouped_tuple - - if isinstance(v0, dict): - v0keys = v0.keys() - for v in values[1:]: - assert isinstance(v, dict), ("v[0]: %r v[i]: %r" % (v0, v)) - assert set(v.keys()) == set(v0keys), ("v[0].keys: %s v[i].keys: %s" % - (set(v0keys), set(v.keys()))) - # Use the actual type in case it is a class inherited from a dict. - return type(v0)({ - key: regroup(tuple(v[key] for v in values), wrap_class) - for key in v0keys - }) - - # If exactly the same object across all devices, return it unwrapped. - same_id = True - for v in values[1:]: - if v is not v0: - same_id = False - break - # Consider three cases where same_id is true: - # * If v0 is a DistributedVariable (a MirroredVariable or - # SyncOnReadVariable, and same_id means it is the same across all - # devices), we want to return it. We check DistributedVariable - # specifically since it can look like it has a - # _distributed_container member since its members do. - if same_id and isinstance(v0, DistributedVariable): - return v0 - # * If v0 is a member of a distributed variable, in which case - # hasattr(v0, "_distributed_container") is true, we want to - # return the DistributedVariable that contains it using the - # _distributed_container logic below. This case can trigger - # same_id when there is only one device. - # * In any other situation, same_id means we return v0 unless `always_wrap` is - # true. - if same_id and not always_wrap and not hasattr(v0, "_distributed_container"): - return v0 - - # Detect the case where each device has a parallel component of the - # same MirroredVariable (or SyncOnReadVariable). In this case we - # want to return the containing MirroredVariable, after a bunch of - # sanity checking. In particular, each component should have the - # same container, and the devices of the variables should match the - # keys of the per-replica dictionary. - if hasattr(v0, "_distributed_container"): - # pylint: disable=protected-access - assert not isinstance(v0, MirroredVariable), ( - "ids = %s, values = %s" % ([id(v) for v in values], values)) - distributed_container = v0._distributed_container - assert distributed_container is not None - for v in values[1:]: - assert distributed_container is v._distributed_container - return distributed_container - # pylint: enable=protected-access - - return wrap_class(values) - - -def select_replica(replica_id, structured): - """Specialize a nest of regular & per-replica values for one replica.""" - - def _get(x): - # `DistributedValues` would be sliced according to replica unless it is a - # `DistributedVariable` because `DistributedVariable` can be handled - # directly in the replica context. - if (isinstance(x, DistributedVariable) or - not isinstance(x, DistributedValues)): - return x - else: - return x.values[replica_id] - - return nest.map_structure(_get, structured) - - -def select_replica_mirrored(replica_id, structured): - """Specialize a nest of regular & mirrored values for one replica.""" - - def _get_mirrored(x): - if isinstance(x, DistributedValues): - if not isinstance(x, Mirrored): - raise TypeError( - "Expected value to be mirrored across replicas: %s in %s." % - (x, structured)) - return x.values[replica_id] - else: - return x - - return nest.map_structure(_get_mirrored, structured) - - -def update_regroup(extended, updates, group): - """Regroup for an update, with dependencies to ensure all updates execute.""" - if not group: - regrouped = regroup(updates, Mirrored) - return nest.map_structure(extended._local_results, regrouped) # pylint: disable=protected-access - - def _make_grouped_mirrored(values): - """Convert per-replica list `values` into Mirrored type with grouping.""" - if len(values) == 1: - return Mirrored(values) - - # Make sure we run all updates. Without this, something like - # session.run(extended.update(...)) may only update one replica. - g = control_flow_ops.group(values) - - # If values is just ops, the grouping is enough. Everything in values - # should have the same type, since we expect every replica to be performing - # the same computation. - if not all(tensor_util.is_tensor(v) for v in values): - return g - - # Otherwise we need tensors with the same values as `values`, but - # that have a dependency on `g`. - with_dep = [] - for v in values: - with ops.device(v.device), ops.control_dependencies([g]): - with_dep.append(array_ops.identity(v)) - - return Mirrored(with_dep) - - return regroup(updates, _make_grouped_mirrored) - - -def value_container(val): - """Returns the container that this per-replica `value` belongs to. - - Args: - val: A value returned by `call_for_each_replica()` or a variable created in - `scope()`. - - Returns: - A container that `value` belongs to. - If value does not belong to any container (including the case of - container having been destroyed), returns the value itself. - """ - if (hasattr(val, "_distributed_container") and - # DistributedVariable has _distributed_container defined - # but we don't want to return it. - not isinstance(val, DistributedVariable)): - container = val._distributed_container # pylint: disable=protected-access - if container is not None: - return container - return val - - -def is_distributed_variable(v): - """Determine if a variable is ds variable or TPU mirrored variable.""" - return isinstance(v, DistributedVariable) - - -def _validate_colocate_extended(v, extended): - variable_strategy = v._distribute_strategy # pylint: disable=protected-access - if variable_strategy.extended is not extended: - raise ValueError( - "`colocate_vars_with` must only be passed a variable created in this " - "tf.distribute.Strategy.scope(), not %s created in scope: %s" % - (v, variable_strategy)) - - -def validate_colocate_distributed_variable(v, extended): - if not isinstance(v, DistributedVariable): - raise ValueError( - "`colocate_vars_with` must only be passed a variable created in this " - "tf.distribute.Strategy.scope(), not: %r" % (v,)) - _validate_colocate_extended(v, extended) - - -def validate_colocate(v, extended): - if not hasattr(v, "_distribute_strategy"): - raise ValueError( - "`colocate_vars_with` must only be passed a variable created in this " - "tf.distribute.Strategy.scope(), not: %r" % (v,)) - _validate_colocate_extended(v, extended) diff --git a/tensorflow/python/distribute/values_test.py b/tensorflow/python/distribute/values_test.py index 716bbbf4279..583e6020683 100644 --- a/tensorflow/python/distribute/values_test.py +++ b/tensorflow/python/distribute/values_test.py @@ -30,6 +30,7 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python import tf2 from tensorflow.python.distribute import combinations from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import distribute_utils from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import packed_distributed_variable as packed from tensorflow.python.distribute import strategy_combinations @@ -157,7 +158,7 @@ class DistributedValuesTest(test.TestCase, parameterized.TestCase): distribution.experimental_distribute_values_from_function(value_fn)) for i in range(distribution.num_replicas_in_sync): self.assertAllEqual( - values.select_replica(i, distributed_values), + distribute_utils.select_replica(i, distributed_values), (1. * i, 2. * i, 3. * i)) @combinations.generate( @@ -391,7 +392,7 @@ class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase): self.assertEqual(exp, result.values[i]) def testNested(self): - result = values.regroup((_nested_value("1"), _nested_value("2"))) + result = distribute_utils.regroup((_nested_value("1"), _nested_value("2"))) self.assertIsInstance(result, tuple) self.assertLen(result, 3) self._is_per_replica(result[0], ["a1", "a2"]) @@ -409,20 +410,20 @@ class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase): # Also test that we can undo the merge using select_replica() self.assertEqual(_nested_value("1"), - values.select_replica(0, result)) + distribute_utils.select_replica(0, result)) self.assertEqual(_nested_value("2"), - values.select_replica(1, result)) + distribute_utils.select_replica(1, result)) # select_device_mirrored() should fail due to non-mirrored values with self.assertRaises(TypeError): - values.select_replica_mirrored(0, result) + distribute_utils.select_replica_mirrored(0, result) with self.assertRaises(TypeError): - values.select_replica_mirrored(1, result) + distribute_utils.select_replica_mirrored(1, result) def testRegroupKeepsDictBasedClass(self): class DictBasedClass(dict): """Dummy class inherited from a dict.""" - result = values.regroup( + result = distribute_utils.regroup( (DictBasedClass(a="a1", b="b1"), DictBasedClass(a="a2", b="b2"))) self.assertIsInstance(result, DictBasedClass) self._is_per_replica(result["a"], ["a1", "a2"]) @@ -431,8 +432,8 @@ class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase): def testWrapClass(self): # Normally a mirrored value would be the same across devices, but # for a test it is convenient to be able to tell the values apart. - result = values.regroup((_nested_value("1"), _nested_value("2")), - values.Mirrored) + result = distribute_utils.regroup((_nested_value("1"), _nested_value("2")), + values.Mirrored) self.assertIsInstance(result, tuple) self.assertLen(result, 3) self._is_per_replica(result[0], ["a1", "a2"], values.Mirrored) @@ -450,17 +451,17 @@ class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase): # Also test that we can undo the merge using select_replica() self.assertEqual(_nested_value("1"), - values.select_replica(0, result)) + distribute_utils.select_replica(0, result)) self.assertEqual(_nested_value("2"), - values.select_replica(1, result)) + distribute_utils.select_replica(1, result)) # Values are marked as mirrored, so select_device_mirrored() is allowed. self.assertEqual(_nested_value("1"), - values.select_replica_mirrored(0, result)) + distribute_utils.select_replica_mirrored(0, result)) self.assertEqual(_nested_value("2"), - values.select_replica_mirrored(1, result)) + distribute_utils.select_replica_mirrored(1, result)) def testWrapAListOfTwoTuples(self): - result = values.regroup([("1", "2"), ("3", "4")]) + result = distribute_utils.regroup([("1", "2"), ("3", "4")]) self.assertIsInstance(result, tuple) self.assertLen(result, 2) self._is_per_replica(result[0], ("1", "3"), values.PerReplica) @@ -478,34 +479,36 @@ class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase): with distribution.scope(): v = variable_scope.variable( 1., aggregation=variable_scope.VariableAggregation.SUM) - self.assertTrue(values.is_distributed_variable(v)) - self.assertTrue(values.is_distributed_variable(values.regroup(v.values))) + self.assertTrue(distribute_utils.is_distributed_variable(v)) + self.assertTrue(distribute_utils.is_distributed_variable( + distribute_utils.regroup(v.values))) def testSameId(self): foo = object() - result = values.regroup((("a", foo), ("b", foo))) + result = distribute_utils.regroup((("a", foo), ("b", foo))) self.assertIsInstance(result, tuple) self.assertLen(result, 2) self._is_per_replica(result[0], ["a", "b"]) self.assertIs(foo, result[1]) # Test select_replica(), should undo the merge done by regroup(). - result_0 = values.select_replica(0, result) + result_0 = distribute_utils.select_replica(0, result) self.assertIsInstance(result_0, tuple) self.assertLen(result_0, 2) self.assertEqual("a", result_0[0]) self.assertIs(foo, result_0[1]) - result_1 = values.select_replica(1, result) + result_1 = distribute_utils.select_replica(1, result) self.assertIsInstance(result_1, tuple) self.assertLen(result_1, 2) self.assertEqual("b", result_1[0]) self.assertIs(foo, result_1[1]) def testOneDevice(self): - result = values.regroup((_nested_value("1"),)) + result = distribute_utils.regroup((_nested_value("1"),)) # On one device regroup() and select_replica() are basically identity. self.assertEqual(_nested_value("1"), result) - self.assertEqual(_nested_value("1"), values.select_replica(0, result)) + self.assertEqual(_nested_value("1"), + distribute_utils.select_replica(0, result)) def testNamedTuple(self): @@ -533,7 +536,7 @@ class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase): train_op=array_ops.identity(constant_op.constant(device_id))) created_estimator_specs.append(spec) - merged_estimator_spec = values.regroup(created_estimator_specs) + merged_estimator_spec = distribute_utils.regroup(created_estimator_specs) self.assertIsInstance(merged_estimator_spec, EstimatorSpec) self.assertEqual(mode_keys.EstimatorModeKeys.TRAIN, @@ -550,8 +553,8 @@ class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase): Scaffold) # Also test that we can undo the merge using select_replica() self.assertEqual(created_estimator_specs[device_id], - values.select_replica(device_id, - merged_estimator_spec)) + distribute_utils.select_replica( + device_id, merged_estimator_spec)) @combinations.generate( @@ -630,7 +633,7 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase): with distribution.scope(): v = variables_lib.Variable( 1., synchronization=synchronization, aggregation=aggregation) - self.assertIs(v, values.select_replica(0, v)) + self.assertIs(v, distribute_utils.select_replica(0, v)) def testIsTensorLike(self, distribution, synchronization, aggregation): if isinstance(distribution.extended, @@ -973,8 +976,8 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase): return v.assign(ctx.replica_id_in_sync_group) # disallow assign() with distributed value in replica context. - with self.assertRaisesRegexp(ValueError, - "Cannot update non-float variables"): + with self.assertRaisesRegex(ValueError, + "Cannot update non-float variables"): self.evaluate( distribution.experimental_local_results( distribution.run(assign))) @@ -2174,11 +2177,11 @@ class PerReplicaTest(test.TestCase, parameterized.TestCase): condition, lambda: per_replica_1, lambda: per_replica_2) -def _make_index_slices(values, indices, dense_shape=None): +def _make_index_slices(vals, indices, dense_shape=None): if dense_shape: dense_shape = array_ops.identity(dense_shape) return indexed_slices.IndexedSlices( - array_ops.identity(values), array_ops.identity(indices), dense_shape) + array_ops.identity(vals), array_ops.identity(indices), dense_shape) if __name__ == "__main__": diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index 5c30d320fb7..6e17b8af206 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -363,6 +363,7 @@ py_library( "//tensorflow/python:util", "//tensorflow/python:variables", "//tensorflow/python/distribute:distribute_lib", + "//tensorflow/python/distribute:distribute_utils", "//tensorflow/python/distribute:values", "//tensorflow/python/training/tracking", "//tensorflow/python/training/tracking:base", diff --git a/tensorflow/python/saved_model/load.py b/tensorflow/python/saved_model/load.py index e0fbb7db270..fe2919c88dc 100644 --- a/tensorflow/python/saved_model/load.py +++ b/tensorflow/python/saved_model/load.py @@ -22,8 +22,8 @@ import functools import os from tensorflow.core.protobuf import graph_debug_info_pb2 +from tensorflow.python.distribute import distribute_utils from tensorflow.python.distribute import distribution_strategy_context as ds_context -from tensorflow.python.distribute import values as ds_values from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.framework import constant_op @@ -87,10 +87,11 @@ class _WrapperFunction(function.ConcreteFunction): def _call_flat(self, args, captured_inputs, cancellation_manager=None): def get_in_replica_handle(x): - return x.handle if ds_values.is_distributed_variable(x) else x + return x.handle if distribute_utils.is_distributed_variable(x) else x def get_cross_replica_handle(x): - return _unused_handle() if ds_values.is_distributed_variable(x) else x + return _unused_handle() if distribute_utils.is_distributed_variable(x) \ + else x if ds_context.get_replica_context() is not None: # in-replica context captured_inputs = list(map(get_in_replica_handle, captured_inputs)) @@ -201,7 +202,7 @@ class Loader(object): if bound_inputs: for bound_input, internal_capture in zip( bound_inputs, concrete_function.inputs[-len(bound_inputs):]): - if ds_values.is_distributed_variable(bound_input): + if distribute_utils.is_distributed_variable(bound_input): concrete_function.graph.capture_distributed_variable( bound_input, internal_capture) else: @@ -227,7 +228,7 @@ class Loader(object): """Resolves a node id into a tensor to be captured for a function.""" with ops.init_scope(): obj = self._nodes[node_id] - if ds_values.is_distributed_variable(obj): + if distribute_utils.is_distributed_variable(obj): return obj elif resource_variable_ops.is_resource_variable(obj): return obj.handle diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py index 9553fb5b196..e22b0129dda 100644 --- a/tensorflow/python/saved_model/save.py +++ b/tensorflow/python/saved_model/save.py @@ -26,7 +26,7 @@ from tensorflow.core.framework import versions_pb2 from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import saved_model_pb2 from tensorflow.core.protobuf import saved_object_graph_pb2 -from tensorflow.python.distribute import values as ds_values +from tensorflow.python.distribute import distribute_utils as ds_utils from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import function as defun @@ -273,13 +273,13 @@ class _SaveableView(object): # pylint: enable=protected-access resource_map[obj.resource_handle] = new_resource self.captured_tensor_node_ids[obj.resource_handle] = node_id - elif (ds_values.is_distributed_variable(obj) or + elif (ds_utils.is_distributed_variable(obj) or resource_variable_ops.is_resource_variable(obj)): - obj_to_copy = obj._primary if ds_values.is_distributed_variable( # pylint: disable=protected-access + obj_to_copy = obj._primary if ds_utils.is_distributed_variable( # pylint: disable=protected-access obj) else obj new_variable = resource_variable_ops.copy_to_graph_uninitialized( obj_to_copy) - if ds_values.is_distributed_variable(obj): + if ds_utils.is_distributed_variable(obj): self.captured_tensor_node_ids[obj] = node_id for v in obj.values: object_map[v] = new_variable diff --git a/tensorflow/python/tpu/tpu_embedding_v2.py b/tensorflow/python/tpu/tpu_embedding_v2.py index d82dcc616c6..90b43c1ebf4 100644 --- a/tensorflow/python/tpu/tpu_embedding_v2.py +++ b/tensorflow/python/tpu/tpu_embedding_v2.py @@ -25,10 +25,10 @@ from absl import logging from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import distribute_utils from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import sharded_variable from tensorflow.python.distribute import tpu_strategy -from tensorflow.python.distribute import values as tf_values from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -1131,8 +1131,10 @@ class TPUEmbedding(tracking.AutoTrackable): # in the same (standard) order as self._strategy.extended.worker_devices. enqueue_ops = [] for replica_id in range(self._strategy.num_replicas_in_sync): - replica_inputs = tf_values.select_replica(replica_id, flat_inputs) - replica_weights = tf_values.select_replica(replica_id, flat_weights) + replica_inputs = distribute_utils.select_replica(replica_id, + flat_inputs) + replica_weights = distribute_utils.select_replica(replica_id, + flat_weights) tpu_device = self._strategy.extended.worker_devices[replica_id] # TPU devices string are like /job:worker/replica:0/task:0/device:TPU:0 # the device ordinal is the last number