Enable assign, assign_add and assign_sub to be called on Mirrored Variables in cross tower and tower context.

PiperOrigin-RevId: 202162272
This commit is contained in:
Anjali Sridhar 2018-06-26 11:25:21 -07:00 committed by TensorFlower Gardener
parent d10213099d
commit bfda539bef
5 changed files with 313 additions and 40 deletions

View File

@ -32,7 +32,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import device_util from tensorflow.python.training import device_util
def _validate_destinations(destinations): def validate_destinations(destinations):
if not isinstance(destinations, if not isinstance(destinations,
(value_lib.DistributedValues, six.string_types, list)): (value_lib.DistributedValues, six.string_types, list)):
raise ValueError("destinations must be one of a `DistributedValues` object," raise ValueError("destinations must be one of a `DistributedValues` object,"
@ -55,7 +55,7 @@ def _validate_value_destination_pairs(value_destination_pairs):
# TODO(yuefengz): consider calling this function in the caller of CrossTowerOps. # TODO(yuefengz): consider calling this function in the caller of CrossTowerOps.
def _get_devices_from(destinations): def get_devices_from(destinations):
if isinstance(destinations, value_lib.DistributedValues): if isinstance(destinations, value_lib.DistributedValues):
return list(destinations.devices) return list(destinations.devices)
elif isinstance(destinations, six.string_types): elif isinstance(destinations, six.string_types):
@ -65,7 +65,7 @@ def _get_devices_from(destinations):
def _devices_match(left, right): def _devices_match(left, right):
return set(_get_devices_from(left)) == set(_get_devices_from(right)) return set(get_devices_from(left)) == set(get_devices_from(right))
def _all_devices_match(value_destination_pairs): def _all_devices_match(value_destination_pairs):
@ -80,7 +80,7 @@ def _all_devices_match(value_destination_pairs):
def _simple_broadcast(value, destinations): def _simple_broadcast(value, destinations):
index = {} index = {}
devices = _get_devices_from(destinations) devices = get_devices_from(destinations)
for d in devices: for d in devices:
index[d] = cross_tower_utils.copy_tensor_or_indexed_slices_to_device( index[d] = cross_tower_utils.copy_tensor_or_indexed_slices_to_device(
value, d) value, d)
@ -146,7 +146,7 @@ class CrossTowerOps(object):
if not isinstance(per_device_value, value_lib.PerDevice): if not isinstance(per_device_value, value_lib.PerDevice):
raise ValueError("`per_device_value` must be a `PerDevice` object.") raise ValueError("`per_device_value` must be a `PerDevice` object.")
if destinations is not None: if destinations is not None:
_validate_destinations(destinations) validate_destinations(destinations)
return self._reduce(method_string, per_device_value, destinations) return self._reduce(method_string, per_device_value, destinations)
def batch_reduce(self, method_string, value_destination_pairs): def batch_reduce(self, method_string, value_destination_pairs):
@ -173,7 +173,7 @@ class CrossTowerOps(object):
"tuples of PerDevice objects and destinations") "tuples of PerDevice objects and destinations")
for _, d in value_destination_pairs: for _, d in value_destination_pairs:
if d is not None: if d is not None:
_validate_destinations(d) validate_destinations(d)
return self._batch_reduce(method_string, value_destination_pairs) return self._batch_reduce(method_string, value_destination_pairs)
@ -187,7 +187,7 @@ class CrossTowerOps(object):
Returns: Returns:
a Mirrored object. a Mirrored object.
""" """
_validate_destinations(destinations) validate_destinations(destinations)
return self._broadcast(tensor, destinations) return self._broadcast(tensor, destinations)
def _reduce(self, method_string, per_device_value, destinations): def _reduce(self, method_string, per_device_value, destinations):
@ -221,7 +221,7 @@ class ReductionToOneDeviceCrossTowerOps(CrossTowerOps):
super(ReductionToOneDeviceCrossTowerOps, self).__init__() super(ReductionToOneDeviceCrossTowerOps, self).__init__()
def _reduce(self, method_string, per_device_value, destinations): def _reduce(self, method_string, per_device_value, destinations):
devices = _get_devices_from(destinations or per_device_value) devices = get_devices_from(destinations or per_device_value)
reduce_to_device = self.reduce_to_device or devices[0] reduce_to_device = self.reduce_to_device or devices[0]
reduced = _simple_reduce(per_device_value, reduce_to_device, reduced = _simple_reduce(per_device_value, reduce_to_device,
self.accumulation_fn, method_string) self.accumulation_fn, method_string)
@ -501,7 +501,7 @@ class AllReduceCrossTowerOps(CrossTowerOps):
logging.WARN, logging.WARN,
"Efficient allreduce is not supported for IndexedSlices.", 10) "Efficient allreduce is not supported for IndexedSlices.", 10)
devices = _get_devices_from(destinations or per_device_value) devices = get_devices_from(destinations or per_device_value)
reduce_to_device = devices[0] reduce_to_device = devices[0]
reduced = _simple_reduce(per_device_value, reduce_to_device, reduced = _simple_reduce(per_device_value, reduce_to_device,
math_ops.add_n, method_string) math_ops.add_n, method_string)

View File

@ -36,7 +36,7 @@ from tensorflow.python.training import device_util
def _make_per_device(values, devices): def _make_per_device(values, devices):
devices = cross_tower_ops_lib._get_devices_from(devices) devices = cross_tower_ops_lib.get_devices_from(devices)
assert len(values) == len(devices) assert len(values) == len(devices)
index = {} index = {}
for d, v in zip(devices, values): for d, v in zip(devices, values):
@ -53,7 +53,7 @@ def _fake_mirrored(value, devices):
All components of the returned Mirrored have the same objects, which is not All components of the returned Mirrored have the same objects, which is not
true in reality. true in reality.
""" """
devices = cross_tower_ops_lib._get_devices_from(devices) devices = cross_tower_ops_lib.get_devices_from(devices)
return value_lib.Mirrored( return value_lib.Mirrored(
{d: v for d, v in zip(devices, [value] * len(devices))}) {d: v for d, v in zip(devices, [value] * len(devices))})

View File

@ -309,9 +309,29 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
return self._cross_tower_ops return self._cross_tower_ops
def _reduce(self, method_string, value, destinations): def _reduce(self, method_string, value, destinations):
if len(self._devices) == 1 and not isinstance(value, values.PerDevice): assert not isinstance(value, values.Mirrored)
value = values.PerDevice({self._devices[0]: value}) if not isinstance(value, values.PerDevice):
assert isinstance(value, values.PerDevice) if value == 0:
return 0
if method_string == "mean":
return self._broadcast(value, destinations)
cross_tower_ops_lib.validate_destinations(destinations)
if len(self._devices) == 1:
if destinations:
# TODO(anjalisridhar): Moves these methods to a device utility file?
devices = cross_tower_ops_lib.get_devices_from(destinations)
if len(devices) == 1:
with ops.device(devices[0]):
return array_ops.identity(value)
else:
value_updates = {}
for d in devices:
with ops.device(d):
value_updates[d] = array_ops.identity(value)
return values.Mirrored(value_updates)
raise ValueError("A non PerDevice value cannot be reduced with the given "
"method_string.")
return self._get_cross_tower_ops().reduce( return self._get_cross_tower_ops().reduce(
method_string, value, destinations=destinations) method_string, value, destinations=destinations)

View File

@ -32,12 +32,14 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.layers import core from tensorflow.python.layers import core
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import distribute as distribute_lib
GPU_TEST = "test_gpu" in sys.argv[0] GPU_TEST = "test_gpu" in sys.argv[0]
@ -118,6 +120,24 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase):
expected = sum(range(len(dist.worker_devices))) expected = sum(range(len(dist.worker_devices)))
self.assertEqual(expected, self.evaluate(unwrapped[0])) self.assertEqual(expected, self.evaluate(unwrapped[0]))
@test_util.run_in_graph_and_eager_modes()
def testReduceToMultipleDestinations(self):
if not GPU_TEST:
self.skipTest("Not GPU test")
devices = ["/device:GPU:0"]
if GPU_TEST:
self.assertGreater(context.num_gpus(), 0)
print(self.id().split(".")[-1], "devices:", ", ".join(devices))
dist = mirrored_strategy.MirroredStrategy(devices)
with dist.scope():
reduced = dist.reduce("sum", 1.0, destinations=["/device:CPU:0",
"/device:GPU:0"])
unwrapped = dist.unwrap(reduced)
self.assertEqual(2, len(unwrapped))
self.assertEqual(1.0, self.evaluate(unwrapped[0]))
class MirroredStrategyVariableCreationTest(test.TestCase): class MirroredStrategyVariableCreationTest(test.TestCase):
@ -581,5 +601,201 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
self.assertEquals(10.0, self.evaluate(ret_v_sum)) self.assertEquals(10.0, self.evaluate(ret_v_sum))
class MirroredVariableUpdateTest(test.TestCase):
# The following tests check assign, assign_add and assign_sub on Mirrored
# variables in tower and cross tower context.
config = config_pb2.ConfigProto()
config.allow_soft_placement = True
def _skip_eager_if_gpus_less_than(self, num_gpus):
if context.num_gpus() < num_gpus and context.executing_eagerly():
self.skipTest("Enough GPUs not available for this test in eager mode.")
@test_util.run_in_graph_and_eager_modes(config=config)
def testAssignMirroredVarTowerContextWithoutAggregationType(self):
# Test that we always have an aggregation type set on the mirrored variable
# if we assign to it in tower mode.
self._skip_eager_if_gpus_less_than(1)
def var_fn():
v = variable_scope.variable(1.0, name="foo")
return v
dist = mirrored_strategy.MirroredStrategy(
["/device:GPU:0", "/device:CPU:0"])
with dist.scope():
mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
def model_fn():
return mirrored_var.assign(5.0)
with self.assertRaisesRegexp(
ValueError, "You must specify an aggregation method to update a "
"MirroredVariable in Tower Context."):
self.evaluate(dist.unwrap(dist.call_for_each_tower(model_fn)))
@test_util.run_in_graph_and_eager_modes(config=config)
def testAssignMirroredVarTowerContextWithSum(self):
# Test that we don't reduce a non-per-device value with the "sum"
# aggregation type.
self._skip_eager_if_gpus_less_than(1)
def var_fn():
v = variable_scope.variable(1.0, name="foo")
return v
dist = mirrored_strategy.MirroredStrategy(
["/device:GPU:0", "/device:CPU:0"])
with dist.scope():
mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
# TODO(anjalisridhar): Use API introduced in cr/201463945 to set the
# aggregation method.
mirrored_var._aggregation_method = "sum"
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
def model_fn():
return mirrored_var.assign(5.0)
with self.assertRaisesRegexp(
ValueError, "A non PerDevice value cannot be reduced with the given "
"method_string."):
self.evaluate(dist.unwrap(dist.call_for_each_tower(model_fn)))
@test_util.run_in_graph_and_eager_modes(config=config)
def testAssignMirroredVarCrossTowerContext(self):
self._skip_eager_if_gpus_less_than(1)
def var_fn():
return variable_scope.variable(1.0, name="foo")
dist = mirrored_strategy.MirroredStrategy(
["/device:GPU:0", "/device:CPU:0"])
with dist.scope():
mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEquals(1.0, self.evaluate(mirrored_var))
mirrored_var_result = self.evaluate(mirrored_var.assign(6.0))
self.assertEquals(6.0, mirrored_var_result)
@test_util.run_in_graph_and_eager_modes(config=config)
def testAssignMirroredVarTowerContext(self):
self._skip_eager_if_gpus_less_than(1)
def var_fn():
return variable_scope.variable(1.0, name="foo")
dist = mirrored_strategy.MirroredStrategy(
["/device:GPU:0", "/device:CPU:0"])
with dist.scope():
mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
# TODO(anjalisridhar): Use API introduced in cr/201463945 to set the
# aggregation method.
mirrored_var._aggregation_method = "mean"
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEquals(1.0, self.evaluate(mirrored_var))
def model_fn():
value = math_ops.cast(distribute_lib.get_tower_context().tower_id,
mirrored_var.dtype)
return mirrored_var.assign(value)
self.evaluate(dist.unwrap(dist.call_for_each_tower(
model_fn, run_concurrently=False)))
self.assertEquals(0.5, self.evaluate(mirrored_var))
@test_util.run_in_graph_and_eager_modes(config=config)
def testAssignAddMirroredVarCrossTowerContext(self):
self._skip_eager_if_gpus_less_than(1)
def var_fn():
return variable_scope.variable(1.0, name="foo")
dist = mirrored_strategy.MirroredStrategy(
["/device:GPU:0", "/device:CPU:0"])
with dist.scope():
mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEquals(1.0, self.evaluate(mirrored_var))
mirrored_var_result = self.evaluate(mirrored_var.assign_add(6.0))
self.assertEquals(7.0, mirrored_var_result)
@test_util.run_in_graph_and_eager_modes(config=config)
def testAssignAddMirroredVarTowerContext(self):
self._skip_eager_if_gpus_less_than(1)
def var_fn():
return variable_scope.variable(1.0, name="foo")
dist = mirrored_strategy.MirroredStrategy(
["/device:GPU:0", "/device:CPU:0"])
with dist.scope():
mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
# TODO(anjalisridhar): Use API introduced in cr/201463945 to set the
# aggregation method.
mirrored_var._aggregation_method = "mean"
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEquals(1.0, self.evaluate(mirrored_var))
def model_fn():
value = math_ops.cast(distribute_lib.get_tower_context().tower_id,
mirrored_var.dtype)
return mirrored_var.assign_add(value)
self.evaluate(dist.unwrap(dist.call_for_each_tower(
model_fn, run_concurrently=False)))
self.assertEquals(1.5, self.evaluate(mirrored_var))
@test_util.run_in_graph_and_eager_modes(config=config)
def testAssignSubMirroredVarCrossTowerContext(self):
self._skip_eager_if_gpus_less_than(1)
def var_fn():
return variable_scope.variable(5.0, name="foo")
dist = mirrored_strategy.MirroredStrategy(
["/device:GPU:0", "/device:CPU:0"])
with dist.scope():
mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEquals(5.0, self.evaluate(mirrored_var))
mirrored_var_result = self.evaluate(mirrored_var.assign_sub(2.0))
self.assertEquals(3.0, mirrored_var_result)
@test_util.run_in_graph_and_eager_modes(config=config)
def testAssignSubMirroredVarTowerContext(self):
self._skip_eager_if_gpus_less_than(1)
def var_fn():
return variable_scope.variable(5.0, name="foo")
dist = mirrored_strategy.MirroredStrategy(
["/device:GPU:0", "/device:CPU:0"])
with dist.scope():
mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
# TODO(anjalisridhar): Use API introduced in cr/201463945 to set the
# aggregation method.
mirrored_var._aggregation_method = "mean"
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEquals(5.0, self.evaluate(mirrored_var))
def model_fn():
value = math_ops.cast(distribute_lib.get_tower_context().tower_id,
mirrored_var.dtype)
return mirrored_var.assign_sub(value)
self.evaluate(dist.unwrap(dist.call_for_each_tower(
model_fn, run_concurrently=False)))
self.assertEquals(4.5, self.evaluate(mirrored_var))
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()

View File

@ -23,7 +23,6 @@ from __future__ import print_function
import collections import collections
import weakref import weakref
import six import six
from tensorflow.contrib.distribute.python import input_ops from tensorflow.contrib.distribute.python import input_ops
@ -34,6 +33,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.training import device_util from tensorflow.python.training import device_util
from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import saver from tensorflow.python.training import saver
@ -251,21 +251,6 @@ class DistributedVariable(DistributedDelegate):
ops.register_dense_tensor_like_type(DistributedVariable) ops.register_dense_tensor_like_type(DistributedVariable)
class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable):
"""Class for defining how to restore a MirroredVariable."""
def __init__(self, mirrored_variable, primary_variable, name):
self._mirrored_variable = mirrored_variable
super(_MirroredSaveable, self).__init__(primary_variable, "", name)
def restore(self, restored_tensors, restored_shapes):
"""Restore the same value into all variables."""
tensor, = restored_tensors
return control_flow_ops.group([
_assign_on_device(d, v, tensor)
for d, v in six.iteritems(self._mirrored_variable._index)]) # pylint: disable=protected-access
def _get_update_device(): def _get_update_device():
"""Validate we are in update/update_non_slot() and return current device. """Validate we are in update/update_non_slot() and return current device.
@ -286,30 +271,82 @@ def _get_update_device():
return device return device
class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable):
"""Class for defining how to restore a MirroredVariable."""
def __init__(self, mirrored_variable, primary_variable, name):
self._mirrored_variable = mirrored_variable
super(_MirroredSaveable, self).__init__(primary_variable, "", name)
def restore(self, restored_tensors, restored_shapes):
"""Restore the same value into all variables."""
tensor, = restored_tensors
return control_flow_ops.group([
_assign_on_device(d, v, tensor)
for d, v in six.iteritems(self._mirrored_variable._index)]) # pylint: disable=protected-access
class MirroredVariable(DistributedVariable, Mirrored, class MirroredVariable(DistributedVariable, Mirrored,
checkpointable.CheckpointableBase): checkpointable.CheckpointableBase):
"""Holds a map from device to variables whose values are kept in sync.""" """Holds a map from device to variables whose values are kept in sync."""
def __init__(self, index, primary_var): def __init__(self, index, primary_var, aggregation_method=None):
# Use a weakref to make it easy to map from the contained values
# to the container without introducing a reference cycle.
for v in six.itervalues(index):
v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access
self._primary_var = primary_var self._primary_var = primary_var
self._aggregation_method = aggregation_method
super(MirroredVariable, self).__init__(index) super(MirroredVariable, self).__init__(index)
# We use _get_update_device() for the assign* methods to enforce # The arguments to update() are automatically unwrapped so the update()
# that we are in an update() function. The arguments to update() are # function would normally see regular variables, not MirroredVariables.
# automatically unwrapped so the update() function would normally # However, the update function can still operate on wrapped MirroredVariables
# see regular variables, not MirroredVariables. However, the update # through object members, captured arguments, etc. This is more likely in an
# function can still operate on wrapped MirroredVariables through
# object members, captured arguments, etc. This is more likely in an
# update_non_slot() function (like OptimizerV2._finish), which can # update_non_slot() function (like OptimizerV2._finish), which can
# update several non-slot variables in one call. # update several non-slot variables in one call.
def _assign_func(self, *args, **kwargs):
f = kwargs.pop("f")
if distribute_lib.get_cross_tower_context():
update_device = distribute_lib.get_update_device()
# We are calling update on the mirrored variable in cross tower context.
if update_device is not None:
# We are calling an assign function on the mirrored variable in cross
# tower context.
v = self.get(device=update_device)
return f(v, *args, **kwargs)
return distribute_lib.get_distribution_strategy().update(
self, f, *args, **kwargs)
else:
# We are calling an assign function on the mirrored variable in tower
# context.
# We reduce the value we want to assign/add/sub. More details about how we
# handle the different use cases can be found in the _reduce method.
# We call the function on each of the mirrored variables with the reduced
# value.
if not self._aggregation_method:
raise ValueError("You must specify an aggregation method to update a "
"MirroredVariable in Tower Context.")
def merge_fn(strategy, value):
return strategy.update(self,
f,
strategy.reduce(
method_string=self._aggregation_method,
value=value,
destinations=self))
return distribute_lib.get_tower_context().merge_call(merge_fn, *args,
**kwargs)
def assign_sub(self, *args, **kwargs): def assign_sub(self, *args, **kwargs):
return self.get(device=_get_update_device()).assign_sub(*args, **kwargs) return self._assign_func(f=state_ops.assign_sub, *args, **kwargs)
def assign_add(self, *args, **kwargs): def assign_add(self, *args, **kwargs):
return self.get(device=_get_update_device()).assign_add(*args, **kwargs) return self._assign_func(f=state_ops.assign_add, *args, **kwargs)
def assign(self, *args, **kwargs): def assign(self, *args, **kwargs):
return self.get(device=_get_update_device()).assign(*args, **kwargs) return self._assign_func(f=state_ops.assign, *args, **kwargs)
def _get_cross_tower(self): def _get_cross_tower(self):
device = device_util.canonicalize(device_util.current()) device = device_util.canonicalize(device_util.current())