diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py index 1009c3c0124..0261ce43fa8 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py @@ -32,7 +32,7 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import device_util -def _validate_destinations(destinations): +def validate_destinations(destinations): if not isinstance(destinations, (value_lib.DistributedValues, six.string_types, list)): raise ValueError("destinations must be one of a `DistributedValues` object," @@ -55,7 +55,7 @@ def _validate_value_destination_pairs(value_destination_pairs): # TODO(yuefengz): consider calling this function in the caller of CrossTowerOps. -def _get_devices_from(destinations): +def get_devices_from(destinations): if isinstance(destinations, value_lib.DistributedValues): return list(destinations.devices) elif isinstance(destinations, six.string_types): @@ -65,7 +65,7 @@ def _get_devices_from(destinations): def _devices_match(left, right): - return set(_get_devices_from(left)) == set(_get_devices_from(right)) + return set(get_devices_from(left)) == set(get_devices_from(right)) def _all_devices_match(value_destination_pairs): @@ -80,7 +80,7 @@ def _all_devices_match(value_destination_pairs): def _simple_broadcast(value, destinations): index = {} - devices = _get_devices_from(destinations) + devices = get_devices_from(destinations) for d in devices: index[d] = cross_tower_utils.copy_tensor_or_indexed_slices_to_device( value, d) @@ -146,7 +146,7 @@ class CrossTowerOps(object): if not isinstance(per_device_value, value_lib.PerDevice): raise ValueError("`per_device_value` must be a `PerDevice` object.") if destinations is not None: - _validate_destinations(destinations) + validate_destinations(destinations) return self._reduce(method_string, per_device_value, destinations) def batch_reduce(self, method_string, value_destination_pairs): @@ -173,7 +173,7 @@ class CrossTowerOps(object): "tuples of PerDevice objects and destinations") for _, d in value_destination_pairs: if d is not None: - _validate_destinations(d) + validate_destinations(d) return self._batch_reduce(method_string, value_destination_pairs) @@ -187,7 +187,7 @@ class CrossTowerOps(object): Returns: a Mirrored object. """ - _validate_destinations(destinations) + validate_destinations(destinations) return self._broadcast(tensor, destinations) def _reduce(self, method_string, per_device_value, destinations): @@ -221,7 +221,7 @@ class ReductionToOneDeviceCrossTowerOps(CrossTowerOps): super(ReductionToOneDeviceCrossTowerOps, self).__init__() def _reduce(self, method_string, per_device_value, destinations): - devices = _get_devices_from(destinations or per_device_value) + devices = get_devices_from(destinations or per_device_value) reduce_to_device = self.reduce_to_device or devices[0] reduced = _simple_reduce(per_device_value, reduce_to_device, self.accumulation_fn, method_string) @@ -501,7 +501,7 @@ class AllReduceCrossTowerOps(CrossTowerOps): logging.WARN, "Efficient allreduce is not supported for IndexedSlices.", 10) - devices = _get_devices_from(destinations or per_device_value) + devices = get_devices_from(destinations or per_device_value) reduce_to_device = devices[0] reduced = _simple_reduce(per_device_value, reduce_to_device, math_ops.add_n, method_string) diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py index fed5505d92e..b3cfa3c5a5d 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py @@ -36,7 +36,7 @@ from tensorflow.python.training import device_util def _make_per_device(values, devices): - devices = cross_tower_ops_lib._get_devices_from(devices) + devices = cross_tower_ops_lib.get_devices_from(devices) assert len(values) == len(devices) index = {} for d, v in zip(devices, values): @@ -53,7 +53,7 @@ def _fake_mirrored(value, devices): All components of the returned Mirrored have the same objects, which is not true in reality. """ - devices = cross_tower_ops_lib._get_devices_from(devices) + devices = cross_tower_ops_lib.get_devices_from(devices) return value_lib.Mirrored( {d: v for d, v in zip(devices, [value] * len(devices))}) diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 98fea76b3d5..d269bed1e57 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -309,9 +309,29 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): return self._cross_tower_ops def _reduce(self, method_string, value, destinations): - if len(self._devices) == 1 and not isinstance(value, values.PerDevice): - value = values.PerDevice({self._devices[0]: value}) - assert isinstance(value, values.PerDevice) + assert not isinstance(value, values.Mirrored) + if not isinstance(value, values.PerDevice): + if value == 0: + return 0 + if method_string == "mean": + return self._broadcast(value, destinations) + + cross_tower_ops_lib.validate_destinations(destinations) + if len(self._devices) == 1: + if destinations: + # TODO(anjalisridhar): Moves these methods to a device utility file? + devices = cross_tower_ops_lib.get_devices_from(destinations) + if len(devices) == 1: + with ops.device(devices[0]): + return array_ops.identity(value) + else: + value_updates = {} + for d in devices: + with ops.device(d): + value_updates[d] = array_ops.identity(value) + return values.Mirrored(value_updates) + raise ValueError("A non PerDevice value cannot be reduced with the given " + "method_string.") return self._get_cross_tower_ops().reduce( method_string, value, destinations=destinations) diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index 647cf953d73..8d474124b7e 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -32,12 +32,14 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.layers import core +from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import distribute as distribute_lib + GPU_TEST = "test_gpu" in sys.argv[0] @@ -118,6 +120,24 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): expected = sum(range(len(dist.worker_devices))) self.assertEqual(expected, self.evaluate(unwrapped[0])) + @test_util.run_in_graph_and_eager_modes() + def testReduceToMultipleDestinations(self): + if not GPU_TEST: + self.skipTest("Not GPU test") + + devices = ["/device:GPU:0"] + if GPU_TEST: + self.assertGreater(context.num_gpus(), 0) + print(self.id().split(".")[-1], "devices:", ", ".join(devices)) + + dist = mirrored_strategy.MirroredStrategy(devices) + with dist.scope(): + reduced = dist.reduce("sum", 1.0, destinations=["/device:CPU:0", + "/device:GPU:0"]) + unwrapped = dist.unwrap(reduced) + self.assertEqual(2, len(unwrapped)) + self.assertEqual(1.0, self.evaluate(unwrapped[0])) + class MirroredStrategyVariableCreationTest(test.TestCase): @@ -581,5 +601,201 @@ class MirroredStrategyVariableCreationTest(test.TestCase): self.assertEquals(10.0, self.evaluate(ret_v_sum)) +class MirroredVariableUpdateTest(test.TestCase): + # The following tests check assign, assign_add and assign_sub on Mirrored + # variables in tower and cross tower context. + config = config_pb2.ConfigProto() + config.allow_soft_placement = True + + def _skip_eager_if_gpus_less_than(self, num_gpus): + if context.num_gpus() < num_gpus and context.executing_eagerly(): + self.skipTest("Enough GPUs not available for this test in eager mode.") + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignMirroredVarTowerContextWithoutAggregationType(self): + # Test that we always have an aggregation type set on the mirrored variable + # if we assign to it in tower mode. + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + v = variable_scope.variable(1.0, name="foo") + return v + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + + def model_fn(): + return mirrored_var.assign(5.0) + + with self.assertRaisesRegexp( + ValueError, "You must specify an aggregation method to update a " + "MirroredVariable in Tower Context."): + self.evaluate(dist.unwrap(dist.call_for_each_tower(model_fn))) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignMirroredVarTowerContextWithSum(self): + # Test that we don't reduce a non-per-device value with the "sum" + # aggregation type. + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + v = variable_scope.variable(1.0, name="foo") + return v + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + # TODO(anjalisridhar): Use API introduced in cr/201463945 to set the + # aggregation method. + mirrored_var._aggregation_method = "sum" + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + + def model_fn(): + return mirrored_var.assign(5.0) + + with self.assertRaisesRegexp( + ValueError, "A non PerDevice value cannot be reduced with the given " + "method_string."): + self.evaluate(dist.unwrap(dist.call_for_each_tower(model_fn))) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignMirroredVarCrossTowerContext(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable(1.0, name="foo") + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(1.0, self.evaluate(mirrored_var)) + mirrored_var_result = self.evaluate(mirrored_var.assign(6.0)) + self.assertEquals(6.0, mirrored_var_result) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignMirroredVarTowerContext(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable(1.0, name="foo") + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + # TODO(anjalisridhar): Use API introduced in cr/201463945 to set the + # aggregation method. + mirrored_var._aggregation_method = "mean" + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(1.0, self.evaluate(mirrored_var)) + + def model_fn(): + value = math_ops.cast(distribute_lib.get_tower_context().tower_id, + mirrored_var.dtype) + return mirrored_var.assign(value) + + self.evaluate(dist.unwrap(dist.call_for_each_tower( + model_fn, run_concurrently=False))) + self.assertEquals(0.5, self.evaluate(mirrored_var)) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignAddMirroredVarCrossTowerContext(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable(1.0, name="foo") + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(1.0, self.evaluate(mirrored_var)) + mirrored_var_result = self.evaluate(mirrored_var.assign_add(6.0)) + self.assertEquals(7.0, mirrored_var_result) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignAddMirroredVarTowerContext(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable(1.0, name="foo") + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + # TODO(anjalisridhar): Use API introduced in cr/201463945 to set the + # aggregation method. + mirrored_var._aggregation_method = "mean" + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(1.0, self.evaluate(mirrored_var)) + + def model_fn(): + value = math_ops.cast(distribute_lib.get_tower_context().tower_id, + mirrored_var.dtype) + return mirrored_var.assign_add(value) + + self.evaluate(dist.unwrap(dist.call_for_each_tower( + model_fn, run_concurrently=False))) + self.assertEquals(1.5, self.evaluate(mirrored_var)) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignSubMirroredVarCrossTowerContext(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable(5.0, name="foo") + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(5.0, self.evaluate(mirrored_var)) + mirrored_var_result = self.evaluate(mirrored_var.assign_sub(2.0)) + self.assertEquals(3.0, mirrored_var_result) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignSubMirroredVarTowerContext(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable(5.0, name="foo") + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + # TODO(anjalisridhar): Use API introduced in cr/201463945 to set the + # aggregation method. + mirrored_var._aggregation_method = "mean" + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(5.0, self.evaluate(mirrored_var)) + + def model_fn(): + value = math_ops.cast(distribute_lib.get_tower_context().tower_id, + mirrored_var.dtype) + return mirrored_var.assign_sub(value) + + self.evaluate(dist.unwrap(dist.call_for_each_tower( + model_fn, run_concurrently=False))) + self.assertEquals(4.5, self.evaluate(mirrored_var)) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 9a48928a953..ce95b718f67 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -23,7 +23,6 @@ from __future__ import print_function import collections import weakref - import six from tensorflow.contrib.distribute.python import input_ops @@ -34,6 +33,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops from tensorflow.python.training import device_util from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import saver @@ -251,21 +251,6 @@ class DistributedVariable(DistributedDelegate): ops.register_dense_tensor_like_type(DistributedVariable) -class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable): - """Class for defining how to restore a MirroredVariable.""" - - def __init__(self, mirrored_variable, primary_variable, name): - self._mirrored_variable = mirrored_variable - super(_MirroredSaveable, self).__init__(primary_variable, "", name) - - def restore(self, restored_tensors, restored_shapes): - """Restore the same value into all variables.""" - tensor, = restored_tensors - return control_flow_ops.group([ - _assign_on_device(d, v, tensor) - for d, v in six.iteritems(self._mirrored_variable._index)]) # pylint: disable=protected-access - - def _get_update_device(): """Validate we are in update/update_non_slot() and return current device. @@ -286,30 +271,82 @@ def _get_update_device(): return device +class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable): + """Class for defining how to restore a MirroredVariable.""" + + def __init__(self, mirrored_variable, primary_variable, name): + self._mirrored_variable = mirrored_variable + super(_MirroredSaveable, self).__init__(primary_variable, "", name) + + def restore(self, restored_tensors, restored_shapes): + """Restore the same value into all variables.""" + tensor, = restored_tensors + return control_flow_ops.group([ + _assign_on_device(d, v, tensor) + for d, v in six.iteritems(self._mirrored_variable._index)]) # pylint: disable=protected-access + + class MirroredVariable(DistributedVariable, Mirrored, checkpointable.CheckpointableBase): """Holds a map from device to variables whose values are kept in sync.""" - def __init__(self, index, primary_var): + def __init__(self, index, primary_var, aggregation_method=None): + # Use a weakref to make it easy to map from the contained values + # to the container without introducing a reference cycle. + for v in six.itervalues(index): + v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access self._primary_var = primary_var + self._aggregation_method = aggregation_method super(MirroredVariable, self).__init__(index) - # We use _get_update_device() for the assign* methods to enforce - # that we are in an update() function. The arguments to update() are - # automatically unwrapped so the update() function would normally - # see regular variables, not MirroredVariables. However, the update - # function can still operate on wrapped MirroredVariables through - # object members, captured arguments, etc. This is more likely in an + # The arguments to update() are automatically unwrapped so the update() + # function would normally see regular variables, not MirroredVariables. + # However, the update function can still operate on wrapped MirroredVariables + # through object members, captured arguments, etc. This is more likely in an # update_non_slot() function (like OptimizerV2._finish), which can # update several non-slot variables in one call. + def _assign_func(self, *args, **kwargs): + f = kwargs.pop("f") + if distribute_lib.get_cross_tower_context(): + update_device = distribute_lib.get_update_device() + # We are calling update on the mirrored variable in cross tower context. + if update_device is not None: + # We are calling an assign function on the mirrored variable in cross + # tower context. + v = self.get(device=update_device) + return f(v, *args, **kwargs) + + return distribute_lib.get_distribution_strategy().update( + self, f, *args, **kwargs) + else: + # We are calling an assign function on the mirrored variable in tower + # context. + # We reduce the value we want to assign/add/sub. More details about how we + # handle the different use cases can be found in the _reduce method. + # We call the function on each of the mirrored variables with the reduced + # value. + if not self._aggregation_method: + raise ValueError("You must specify an aggregation method to update a " + "MirroredVariable in Tower Context.") + + def merge_fn(strategy, value): + return strategy.update(self, + f, + strategy.reduce( + method_string=self._aggregation_method, + value=value, + destinations=self)) + return distribute_lib.get_tower_context().merge_call(merge_fn, *args, + **kwargs) + def assign_sub(self, *args, **kwargs): - return self.get(device=_get_update_device()).assign_sub(*args, **kwargs) + return self._assign_func(f=state_ops.assign_sub, *args, **kwargs) def assign_add(self, *args, **kwargs): - return self.get(device=_get_update_device()).assign_add(*args, **kwargs) + return self._assign_func(f=state_ops.assign_add, *args, **kwargs) def assign(self, *args, **kwargs): - return self.get(device=_get_update_device()).assign(*args, **kwargs) + return self._assign_func(f=state_ops.assign, *args, **kwargs) def _get_cross_tower(self): device = device_util.canonicalize(device_util.current())