diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py index 588fa47c6ae..c3a5b953b38 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py @@ -115,7 +115,7 @@ class CollectiveAllReduceStrategyTestBase( def setUp(self): # We use a different key_base for each test so that collective keys won't be # reused. - # TODO(yuefengz, tucker): enable it to reuse collective keys in different + # TODO(yuefengz, ayushd): enable it to reuse collective keys in different # tests. CollectiveAllReduceStrategyTestBase.collective_key_base += 100000 super(CollectiveAllReduceStrategyTestBase, self).setUp() @@ -133,11 +133,11 @@ class CollectiveAllReduceStrategyTestBase( use_core_strategy=use_core_strategy) collective_keys = cross_device_utils.CollectiveKeys( - group_key_start=10 * num_gpus + + group_key_start=10 + CollectiveAllReduceStrategyTestBase.collective_key_base, - instance_key_start=num_gpus * 100 + + op_instance_key_start=100 + CollectiveAllReduceStrategyTestBase.collective_key_base, - instance_key_with_id_start=num_gpus * 10000 + + variable_instance_key_start=10000 + CollectiveAllReduceStrategyTestBase.collective_key_base) strategy.extended._collective_keys = collective_keys strategy.extended._cross_device_ops._collective_keys = (collective_keys) diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 5f4f38bebab..84b57caf652 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -966,6 +966,32 @@ cuda_py_test( xla_enable_strict_auto_jit = True, ) +cuda_py_test( + name = "mirrored_variable_test", + srcs = ["mirrored_variable_test.py"], + additional_deps = [ + ":collective_all_reduce_strategy", + ":combinations", + ":strategy_combinations", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:layers", + "//tensorflow/python:state_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python/distribute:distribute_lib", + "//tensorflow/python/distribute:values", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + ], + tags = [ + "guitar", + "multi_and_single_gpu", + ], + xla_enable_strict_auto_jit = True, +) + distribute_py_test( name = "metrics_v1_test", srcs = ["metrics_v1_test.py"], diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy.py b/tensorflow/python/distribute/collective_all_reduce_strategy.py index 2e750b0bfc4..f0e86f71c6a 100644 --- a/tensorflow/python/distribute/collective_all_reduce_strategy.py +++ b/tensorflow/python/distribute/collective_all_reduce_strategy.py @@ -36,7 +36,6 @@ from tensorflow.python.distribute import values from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver from tensorflow.python.eager import context -from tensorflow.python.eager import tape from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import collective_ops @@ -79,6 +78,13 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy): self, communication=communication)) + @classmethod + def _from_local_devices(cls, devices): + """A convenience method to create an obejct with a list of devices.""" + obj = cls() + obj.extended._initialize_local(TFConfigClusterResolver(), devices=devices) # pylint: disable=protected-access + return obj + @tf_export(v1=["distribute.experimental.MultiWorkerMirroredStrategy"]) class CollectiveAllReduceStrategyV1(distribute_lib.StrategyV1): @@ -117,7 +123,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): else: self._initialize_local(cluster_resolver) - def _initialize_local(self, cluster_resolver): + def _initialize_local(self, cluster_resolver, devices=None): """Initializes the object for local training.""" self._is_chief = True self._num_workers = 1 @@ -140,10 +146,13 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): else: num_gpus = cluster_resolver.num_accelerators().get("GPU", 0) - if num_gpus: - local_devices = tuple("/device:GPU:%d" % i for i in range(num_gpus)) + if devices: + local_devices = devices else: - local_devices = ("/device:CPU:0",) + if num_gpus: + local_devices = tuple("/device:GPU:%d" % i for i in range(num_gpus)) + else: + local_devices = ("/device:CPU:0",) self._worker_device = device_util.canonicalize("/device:CPU:0") self._host_input_device = numpy_dataset.SingleDevice(self._worker_device) @@ -272,100 +281,52 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): task_id, self._num_workers, local_devices, self._communication) - def _create_variable(self, next_creator, *args, **kwargs): - colocate_with = kwargs.pop("colocate_with", None) - if colocate_with is None: - device_map = self._device_map - logical_device = 0 # TODO(josh11b): Get logical device from scope here. - elif isinstance(colocate_with, numpy_dataset.SingleDevice): - with ops.device(colocate_with.device): - return next_creator(*args, **kwargs) + def _get_variable_creator_initial_value(self, + replica_id=0, + device=None, + primary_var=None, + **kwargs): + if replica_id == 0: # First replica on each worker. + assert device is not None + assert primary_var is None + + def initial_value_fn(): # pylint: disable=g-missing-docstring + # Only the first device participates in the broadcast of initial values. + group_key = self._collective_keys.get_group_key([device]) + group_size = self._num_workers + collective_instance_key = ( + self._collective_keys.get_variable_instance_key()) + + with ops.device(device): + initial_value = kwargs["initial_value"] + if callable(initial_value): + initial_value = initial_value() + assert not callable(initial_value) + initial_value = ops.convert_to_tensor( + initial_value, dtype=kwargs.get("dtype", None)) + + if self._num_workers > 1: + if self._is_chief: + bcast_send = collective_ops.broadcast_send( + initial_value, initial_value.shape, initial_value.dtype, + group_size, group_key, collective_instance_key) + with ops.control_dependencies([bcast_send]): + return array_ops.identity(initial_value) + else: + return collective_ops.broadcast_recv(initial_value.shape, + initial_value.dtype, + group_size, group_key, + collective_instance_key) + return initial_value + + return initial_value_fn else: - device_map = colocate_with.device_map - logical_device = colocate_with.logical_device - - def _real_mirrored_creator(devices, *args, **kwargs): - """Creates one MirroredVariable on the current worker.""" - unique_var_name = ops.get_default_graph().unique_name( - kwargs["name"], mark_as_used=False).rstrip("/") - # pylint: disable=protected-access - collective_instance_key = self._collective_keys.get_instance_key( - key_id=unique_var_name) - # Only the first device participles in the broadcast of initial values. - group_key = self._collective_keys.get_group_key([devices[0]]) - group_size = self._num_workers - if "initial_value" not in kwargs: - raise ValueError("Initial value must be specified.") - initial_value = kwargs["initial_value"] - if callable(initial_value): - initial_value_fn = initial_value - else: - initial_value_fn = lambda: initial_value - - value_list = [] - for i, d in enumerate(devices): - with ops.init_scope(), ops.device(d): - if i == 0: - # The initial value fn makes sure variables all initialized to - # same values. The first device of the chief worker will send their - # variable values to other workers. - def _overridden_initial_value_fn(device=d, index=i): # pylint: disable=g-missing-docstring - with ops.device(device): - initial_value = initial_value_fn() - assert not callable(initial_value) - initial_value = ops.convert_to_tensor( - initial_value, dtype=kwargs.get("dtype", None)) - - assert index == 0, index - if self._num_workers > 1: - if self._is_chief: - bcast_send = collective_ops.broadcast_send( - initial_value, initial_value.shape, initial_value.dtype, - group_size, group_key, collective_instance_key) - with ops.control_dependencies([bcast_send]): - return array_ops.identity(initial_value) - else: - return collective_ops.broadcast_recv( - initial_value.shape, initial_value.dtype, group_size, - group_key, collective_instance_key) - return initial_value - else: - # Give replicas meaningful distinct names: - var0name = value_list[0].name.split(":")[0] - # We append a / to variable names created on replicas with id > 0 to - # ensure that we ignore the name scope and instead use the given - # name as the absolute name of the variable. - kwargs["name"] = "%s/replica_%d/" % (var0name, i) - - # Variables on non-first replica get initial values from the - # variables created on the first device of each worker. - def _overridden_initial_value_fn(device=d, index=i): - assert index > 0 - with ops.device(device): - if context.executing_eagerly(): - return array_ops.identity(value_list[0].value()) - else: - return array_ops.identity(value_list[0].initial_value) - - kwargs["initial_value"] = _overridden_initial_value_fn - with context.device_policy(context.DEVICE_PLACEMENT_SILENT): - # Don't record operations (e.g. other variable reads) during - # variable creation. - with tape.stop_recording(): - v = next_creator(*args, **kwargs) - - if i == 0: - actual_var_name = v.name.split(":")[0] - assert unique_var_name == actual_var_name, "%r vs %r" % ( - unique_var_name, actual_var_name) - assert not isinstance(v, values.DistributedVariable) - value_list.append(v) - return value_list - - # pylint: disable=protected-access - return mirrored_strategy._create_mirrored_variable( - self._container_strategy(), device_map, logical_device, - _real_mirrored_creator, *args, **kwargs) + return super(CollectiveAllReduceExtended, + self)._get_variable_creator_initial_value( + replica_id=replica_id, + device=device, + primary_var=primary_var, + **kwargs) def _make_input_context(self): if self._cluster_spec is None: diff --git a/tensorflow/python/distribute/cross_device_ops.py b/tensorflow/python/distribute/cross_device_ops.py index 0ef3382fe1a..e3b8fc0cada 100644 --- a/tensorflow/python/distribute/cross_device_ops.py +++ b/tensorflow/python/distribute/cross_device_ops.py @@ -998,14 +998,15 @@ class CollectiveAllReduce(CrossDeviceOps): return all_reduced devices = device_map.logical_to_actual_devices(logical_device) index = [] - for d in devices: - if d in all_reduced.devices: - index.append(all_reduced.get(d)) - else: - # TODO(josh11b): Once we add support for model parallelism, get the - # copy from the corresponding replica instead of the primary. - with ops.control_dependencies(all_reduced.values), ops.device(d): - index.append(array_ops.identity(all_reduced.primary)) + with ops.control_dependencies(all_reduced.values): + for d in devices: + with ops.device(d): + if d in all_reduced.devices: + index.append(array_ops.identity(all_reduced.get(d))) + else: + # TODO(josh11b): Once we add support for model parallelism, get the + # copy from the corresponding replica instead of the primary. + index.append(array_ops.identity(all_reduced.primary)) return value_lib.Mirrored(device_map, index, logical_device) diff --git a/tensorflow/python/distribute/cross_device_ops_test.py b/tensorflow/python/distribute/cross_device_ops_test.py index e269ef4da3c..b19d4ff84ec 100644 --- a/tensorflow/python/distribute/cross_device_ops_test.py +++ b/tensorflow/python/distribute/cross_device_ops_test.py @@ -446,11 +446,9 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase, use_strategy_object=False, local_mode=False): collective_keys = cross_device_utils.CollectiveKeys( - group_key_start=10 * num_gpus + - CollectiveAllReduceTest.collective_key_base, - instance_key_start=num_gpus * 100 + - CollectiveAllReduceTest.collective_key_base, - instance_key_with_id_start=num_gpus * 10000 + + group_key_start=10 + CollectiveAllReduceTest.collective_key_base, + op_instance_key_start=100 + CollectiveAllReduceTest.collective_key_base, + variable_instance_key_start=10000 + CollectiveAllReduceTest.collective_key_base) if local_mode: if num_gpus: diff --git a/tensorflow/python/distribute/cross_device_utils.py b/tensorflow/python/distribute/cross_device_utils.py index 17091ec31f5..6058db356e2 100644 --- a/tensorflow/python/distribute/cross_device_utils.py +++ b/tensorflow/python/distribute/cross_device_utils.py @@ -253,31 +253,28 @@ class CollectiveKeys(object): def __init__(self, group_key_start=1, - instance_key_start=100, - instance_key_with_id_start=10000): + op_instance_key_start=100, + variable_instance_key_start=1000000): """Initializes the object. Args: group_key_start: the starting integer of group key. - instance_key_start: the starting integer of instance key. - instance_key_with_id_start: the starting integer of instance key that is - recorded with an id. + op_instance_key_start: the starting integer of instance key for ops. + variable_instance_key_start: the starting integer of instance key for + variables. """ self._group_key = group_key_start self._group_key_table = {} - # For instance keys with ids - self._instance_key_id_to_key_table = {} - self._instance_key_with_id_counter = instance_key_with_id_start - - # For instance keys without ids - self._instance_key_start = instance_key_start + assert op_instance_key_start != variable_instance_key_start + self._op_instance_key_start = op_instance_key_start + self._variable_instance_key = variable_instance_key_start def _get_thread_local_object(self): # We make instance key without key ids thread local so that it will work # with MirroredStrategy and distribute coordinator. - if not hasattr(_thread_local, 'instance_key'): - _thread_local.instance_key = self._instance_key_start + if not hasattr(_thread_local, 'op_instance_key'): + _thread_local.op_instance_key = self._op_instance_key_start return _thread_local def get_group_key(self, devices): @@ -304,25 +301,17 @@ class CollectiveKeys(object): self._group_key_table[key_id] = new_key return self._group_key_table[key_id] - def get_instance_key(self, key_id=None): - """Returns a new instance key for use in defining a collective op. + def get_op_instance_key(self): + """Returns a new instance key for use in defining a collective op.""" + v = self._get_thread_local_object().op_instance_key + self._get_thread_local_object().op_instance_key += 1 + return v - Args: - key_id: optional string. If set, key will be recorded and the same key - will be returned when the same key_id is provided. If not, an increasing - instance key will be returned. - """ - if key_id: - with _lock: - if key_id not in self._instance_key_id_to_key_table: - self._instance_key_with_id_counter += 1 - self._instance_key_id_to_key_table[key_id] = ( - self._instance_key_with_id_counter) - return self._instance_key_id_to_key_table[key_id] - else: - v = self._get_thread_local_object().instance_key - self._get_thread_local_object().instance_key += 1 - return v + def get_variable_instance_key(self): + """Returns a new instance key for use in creating a Variable.""" + v = self._variable_instance_key + self._variable_instance_key += 1 + return v def build_collective_reduce(input_tensors, @@ -354,7 +343,7 @@ def build_collective_reduce(input_tensors, devices = [t.device for t in input_tensors] num_devices = len(devices) group_key = collective_keys.get_group_key(devices) - instance_key = collective_keys.get_instance_key() + instance_key = collective_keys.get_op_instance_key() subdiv_offsets = [0] # TODO(tucker): maybe support non-default subdiv spec def collective_all_reduce(): @@ -399,7 +388,7 @@ def build_collective_gather(input_tensors, num_workers, collective_keys): devices = [t.device for t in input_tensors] num_devices = len(devices) group_key = collective_keys.get_group_key(devices) - instance_key = collective_keys.get_instance_key() + instance_key = collective_keys.get_op_instance_key() def collective_all_gather(): """Call collective allgather.""" diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py index f8edb8aea1d..2ead8f8c4ad 100644 --- a/tensorflow/python/distribute/mirrored_strategy.py +++ b/tensorflow/python/distribute/mirrored_strategy.py @@ -532,6 +532,30 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): # containing job names. self._inferred_cross_device_ops = cross_device_ops_lib.NcclAllReduce() + def _get_variable_creator_initial_value(self, + replica_id=0, + device=None, + primary_var=None, + **kwargs): + """Return the initial value for variables on a replica.""" + if replica_id == 0: + return kwargs["initial_value"] + else: + assert primary_var is not None + assert device is not None + assert kwargs is not None + + def initial_value_fn(): + if context.executing_eagerly() or ops.inside_function(): + init_value = primary_var.value() + return array_ops.identity(init_value) + else: + with ops.device(device): + init_value = primary_var.initial_value + return array_ops.identity(init_value) + + return initial_value_fn + def _create_variable(self, next_creator, *args, **kwargs): """Create a mirrored variable. See `DistributionStrategy.scope`.""" colocate_with = kwargs.pop("colocate_with", None) @@ -549,6 +573,11 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): value_list = [] for i, d in enumerate(devices): with ops.device(d): + kwargs["initial_value"] = self._get_variable_creator_initial_value( + replica_id=i, + device=d, + primary_var=value_list[0] if value_list else None, + **kwargs) if i > 0: # Give replicas meaningful distinct names: var0name = value_list[0].name.split(":")[0] @@ -556,16 +585,6 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): # ensure that we ignore the name scope and instead use the given # name as the absolute name of the variable. kwargs["name"] = "%s/replica_%d/" % (var0name, i) - # Initialize replicas with the same value: - def initial_value_fn(device=d): - if context.executing_eagerly() or ops.inside_function(): - init_value = value_list[0].value() - return array_ops.identity(init_value) - else: - with ops.device(device): - init_value = value_list[0].initial_value - return array_ops.identity(init_value) - kwargs["initial_value"] = initial_value_fn with context.device_policy(context.DEVICE_PLACEMENT_SILENT): # Don't record operations (e.g. other variable reads) during # variable creation. @@ -749,8 +768,8 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): reduce_op, value, destinations=destinations) def _batch_reduce_to(self, reduce_op, value_destination_pairs): - return self._get_cross_device_ops().batch_reduce( - reduce_op, value_destination_pairs) + return self._get_cross_device_ops().batch_reduce(reduce_op, + value_destination_pairs) def _update(self, var, fn, args, kwargs, group): # TODO(josh11b): In eager mode, use one thread per device. diff --git a/tensorflow/python/distribute/mirrored_strategy_test.py b/tensorflow/python/distribute/mirrored_strategy_test.py index 2b815dfdaf2..fdbce82ca44 100644 --- a/tensorflow/python/distribute/mirrored_strategy_test.py +++ b/tensorflow/python/distribute/mirrored_strategy_test.py @@ -48,12 +48,8 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.keras.engine import training as keras_training from tensorflow.python.keras.layers import core as keras_core -from tensorflow.python.layers import core from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import rnn -from tensorflow.python.ops import rnn_cell_impl -from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import gradient_descent @@ -370,516 +366,6 @@ class MirroredStrategyCallForEachReplicaTest(test.TestCase): distribution.extended.call_for_each_replica(model_fn) -@combinations.generate( - combinations.combine( - distribution=[ - strategy_combinations.mirrored_strategy_with_gpu_and_cpu, - ], - mode=["graph", "eager"])) -class MirroredStrategyVariableCreationTest(test.TestCase): - - # TODO(priyag): Modify more tests to use this helper and check more - # properties. - def _test_mv_properties(self, var, name, strategy): - self.assertIsInstance(var, values.MirroredVariable) - self.assertEqual(name, var.name) - self.assertIs(strategy, var.distribute_strategy) - for d in var.devices: - self.assertEqual(d, var.get(d).device) - self.assertIs(strategy, var.get(d)._distribute_strategy) # pylint: disable=protected-access - - def testVariableInFuncGraph(self, distribution): - def model_fn(): - v = variable_scope.variable(2.0, name="bar") - ds_context.get_replica_context().merge_call(lambda _: _) - return v - - with func_graph.FuncGraph("fg").as_default(), distribution.scope(): - v1 = variable_scope.variable(1.0, name="foo") - v2 = distribution.extended.call_for_each_replica(model_fn) - - self._test_mv_properties(v1, "foo:0", distribution) - self._test_mv_properties(v2, "bar:0", distribution) - - def testVariableWithTensorInitialValueInFunction(self, distribution): - if not context.executing_eagerly(): - self.skipTest("`tf.function` is an eager-only feature") - - v = [None] - def model_fn(): - if v[0] is None: - init_val = array_ops.zeros([]) - v[0] = variables.Variable(init_val) - ds_context.get_replica_context().merge_call(lambda _: _) - return v[0] - - @def_function.function(autograph=False) - def make_v1(): - return distribution.experimental_local_results( - distribution.extended.call_for_each_replica(model_fn)) - - self.assertAllEqual([0, 0], make_v1()) - - def testSingleVariable(self, distribution): - def model_fn(): - # This variable should be created only once across the threads because of - # special variable_creator functions used by - # `distribution.extended.call_for_each_replica`. - v = variable_scope.variable(1.0, name="foo") - ds_context.get_replica_context().merge_call(lambda _: _) - return v - - with distribution.scope(): - result = distribution.extended.call_for_each_replica(model_fn) - self._test_mv_properties(result, "foo:0", distribution) - - def testUnnamedVariable(self, distribution): - def model_fn(): - v = variable_scope.variable(1.0) - ds_context.get_replica_context().merge_call(lambda _: _) - return v - - with distribution.scope(): - result = distribution.extended.call_for_each_replica(model_fn) - self._test_mv_properties(result, "Variable:0", distribution) - - def testMultipleVariables(self, distribution): - def model_fn(): - vs = [] - for i in range(5): - vs.append(variable_scope.variable(1.0, name="foo" + str(i))) - ds_context.get_replica_context().merge_call(lambda _: _) - return vs - - with distribution.scope(): - result = distribution.extended.call_for_each_replica(model_fn) - for i, v in enumerate(result): - self._test_mv_properties(v, "foo" + str(i) + ":0", distribution) - - def testMultipleVariablesWithSameCanonicalName(self, distribution): - def model_fn(): - vs = [] - vs.append(variable_scope.variable(1.0, name="foo/bar")) - vs.append(variable_scope.variable(1.0, name="foo_1/bar")) - vs.append(variable_scope.variable(1.0, name="foo_1/bar_1")) - vs.append(variable_scope.variable(1.0, name="foo/bar_1")) - ds_context.get_replica_context().merge_call(lambda _: _) - return vs - - with distribution.scope(): - result = distribution.extended.call_for_each_replica(model_fn) - for v in result: - self.assertIsInstance(v, values.MirroredVariable) - self.assertEqual(4, len(result)) - self.assertEqual("foo/bar:0", result[0].name) - self.assertEqual("foo_1/bar:0", result[1].name) - self.assertEqual("foo_1/bar_1:0", result[2].name) - self.assertEqual("foo/bar_1:0", result[3].name) - - def testVariableWithSameCanonicalNameAcrossThreads(self, distribution): - def model_fn(): - replica_id = self.evaluate(_replica_id()) - v = variable_scope.variable(1.0, name="foo_" + str(replica_id)) - ds_context.get_replica_context().merge_call(lambda _: _) - return v - - with distribution.scope(): - result = distribution.extended.call_for_each_replica(model_fn) - self.assertIsInstance(result, values.MirroredVariable) - # The resulting mirrored variable will use the name from the first device. - self.assertEqual("foo_0:0", result.name) - - def testWithLayers(self, distribution): - def model_fn(features): - with variable_scope.variable_scope("common"): - layer1 = core.Dense(1) - layer1(features) - layer2 = core.Dense(1) - layer2(features) - # This will pause the current thread, and execute the other thread. - ds_context.get_replica_context().merge_call(lambda _: _) - layer3 = core.Dense(1) - layer3(features) - return [(layer1.kernel, layer1.bias), - (layer2.kernel, layer2.bias), - (layer3.kernel, layer3.bias)] - - iterator = distribution.make_input_fn_iterator( - lambda _: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)) - self.evaluate(iterator.initialize()) - features = iterator.get_next() - - with distribution.scope(): - result = distribution.extended.call_for_each_replica( - model_fn, args=(features,)) - suffixes = ["", "_1", "_2"] - for (kernel, bias), suffix in zip(result, suffixes): - self.assertIsInstance(kernel, values.MirroredVariable) - self.assertEqual("common/dense" + suffix + "/kernel:0", kernel.name) - self.assertIsInstance(bias, values.MirroredVariable) - self.assertEqual("common/dense" + suffix + "/bias:0", bias.name) - - def testWithVariableAndVariableScope(self, distribution): - def model_fn(): - v0 = variable_scope.variable(1.0, name="var0", aggregation=None) - with variable_scope.variable_scope("common"): - v1 = variable_scope.variable(1.0, name="var1") - # This will pause the current thread, and execute the other thread. - ds_context.get_replica_context().merge_call(lambda _: _) - v2 = variable_scope.variable( - 1.0, - name="var2", - synchronization=variable_scope.VariableSynchronization.ON_READ, - aggregation=variable_scope.VariableAggregation.SUM) - v3 = variable_scope.variable( - 1.0, - name="var3", - synchronization=variable_scope.VariableSynchronization.ON_WRITE, - aggregation=variable_scope.VariableAggregation.MEAN) - - return v0, v1, v2, v3 - - with distribution.scope(): - v = variable_scope.variable(1.0, name="var-main0") - self.assertEqual("var-main0:0", v.name) - - result = distribution.extended.call_for_each_replica(model_fn) - self.assertEqual(4, len(result)) - v0, v1, v2, v3 = result - self.assertIsInstance(v0, values.MirroredVariable) - self.assertEqual("var0:0", v0.name) - self.assertIsInstance(v1, values.MirroredVariable) - self.assertEqual("common/var1:0", v1.name) - self.assertIsInstance(v2, values.SyncOnReadVariable) - self.assertEqual("common/var2:0", v2.name) - self.assertEqual(variable_scope.VariableAggregation.SUM, v2.aggregation) - self.assertIsInstance(v3, values.MirroredVariable) - self.assertEqual("common/var3:0", v3.name) - self.assertEqual(variable_scope.VariableAggregation.MEAN, v3.aggregation) - - def testWithGetVariableAndVariableScope(self, distribution): - def model_fn(): - v0 = variable_scope.get_variable("var0", [1]) - with variable_scope.variable_scope("common"): - v1 = variable_scope.get_variable("var1", [1]) - # This will pause the current thread, and execute the other thread. - ds_context.get_replica_context().merge_call(lambda _: _) - v2 = variable_scope.get_variable( - "var2", [1], - synchronization=variable_scope.VariableSynchronization.ON_READ, - aggregation=variable_scope.VariableAggregation.SUM) - v3 = variable_scope.get_variable( - "var3", [1], - synchronization=variable_scope.VariableSynchronization.ON_WRITE, - aggregation=variable_scope.VariableAggregation.MEAN) - - return v0, v1, v2, v3 - - with distribution.scope(): - with variable_scope.variable_scope("main"): - v = variable_scope.get_variable("var-main0", [1]) - self.assertEqual("main/var-main0:0", v.name) - - result = distribution.extended.call_for_each_replica(model_fn) - self.assertEqual(4, len(result)) - v0, v1, v2, v3 = result - self.assertIsInstance(v0, values.MirroredVariable) - self.assertEqual("main/var0:0", v0.name) - self.assertIsInstance(v1, values.MirroredVariable) - self.assertEqual("main/common/var1:0", v1.name) - self.assertIsInstance(v2, values.SyncOnReadVariable) - self.assertEqual("main/common/var2:0", v2.name) - self.assertEqual(variable_scope.VariableAggregation.SUM, - v2.aggregation) - self.assertIsInstance(v3, values.MirroredVariable) - self.assertEqual("main/common/var3:0", v3.name) - self.assertEqual(variable_scope.VariableAggregation.MEAN, - v3.aggregation) - - def testOnlyFirstReplicaUpdatesVariables(self, distribution): - def create_fn(): - aggregation = variable_scope.VariableAggregation.ONLY_FIRST_REPLICA - v0 = variable_scope.variable( - 2.0, - name="on_read", - synchronization=variable_scope.VariableSynchronization.ON_READ, - aggregation=aggregation) - v1 = variable_scope.variable( - 3.0, - name="on_write", - synchronization=variable_scope.VariableSynchronization.ON_WRITE, - aggregation=aggregation) - return v0, v1 - - devices = ["/device:GPU:0", "/device:CPU:0"] - with distribution.scope(): - v0, v1 = distribution.extended.call_for_each_replica(create_fn) - self.evaluate(v0.initializer) - self.assertEqual(2.0, self.evaluate(v0.get(devices[0]))) - self.assertEqual(2.0, self.evaluate(v0.get(devices[1]))) - self.assertEqual(2.0, self.evaluate(distribution.extended.read_var(v0))) - self.evaluate(v1.initializer) - self.assertEqual(3.0, self.evaluate(v1.get(devices[0]))) - self.assertEqual(3.0, self.evaluate(v1.get(devices[1]))) - self.assertEqual(3.0, self.evaluate(distribution.extended.read_var(v1))) - - def replica_id_plus_one(): - return math_ops.cast(_replica_id() + 1, dtype=dtypes.float32) - - # Update using the assign_add member function. - def update_member_fn(): - update0 = v0.assign_add(5.0 * replica_id_plus_one()) - update1 = v1.assign_add(7.0 * replica_id_plus_one()) - return update0, update1 - - update0a, update1a = distribution.extended.call_for_each_replica( - update_member_fn) - - # Update "sync on read" variable. - self.evaluate(distribution.group(update0a)) - self.assertEqual(2.0 + 5.0, self.evaluate(v0.get(devices[0]))) - # Writes are not synchronized for "sync on read" variables, - # so device[1] can end up with a different value. - self.assertEqual(2.0 + 2*5.0, self.evaluate(v0.get(devices[1]))) - # Always reads from device 0. - self.assertEqual(2.0 + 5.0, self.evaluate( - distribution.extended.read_var(v0))) - - # Update "sync on write" variable. - self.evaluate(distribution.group(update1a)) - self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[0]))) - # Writes are synchronized for v1, only the argument to assign_add on - # device[0] is used. - self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[1]))) - self.assertEqual(3.0 + 7.0, self.evaluate( - distribution.extended.read_var(v1))) - - # Update using state_ops.assign_add global function. - def update_state_ops_fn(): - update0 = state_ops.assign_add(v0, 11.0 * replica_id_plus_one()) - update1 = state_ops.assign_add(v1, 13.0 * replica_id_plus_one()) - return update0, update1 - - update0b, update1b = distribution.extended.call_for_each_replica( - update_state_ops_fn) - self.evaluate(distribution.group(update0b)) - - # Update "sync on read" variable. - self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate(v0.get(devices[0]))) - self.assertEqual(2.0 + 2*5.0 + 2*11.0, self.evaluate(v0.get(devices[1]))) - self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate( - distribution.extended.read_var(v0))) - - # Update "sync on write" variable. - self.evaluate(distribution.group(update1b)) - self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[0]))) - self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[1]))) - self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate( - distribution.extended.read_var(v1))) - - def testNoneSynchronizationWithGetVariable(self, distribution): - with distribution.scope(): - with self.assertRaisesRegexp( - ValueError, "`NONE` variable synchronization mode is not " - "supported with `Mirrored` distribution strategy. Please change " - "the `synchronization` for variable: v"): - variable_scope.get_variable( - "v", [1], - synchronization=variable_scope.VariableSynchronization.NONE) - - def testNoneSynchronizationWithVariable(self, distribution): - with distribution.scope(): - with self.assertRaisesRegexp( - ValueError, "`NONE` variable synchronization mode is not " - "supported with `Mirrored` distribution strategy. Please change " - "the `synchronization` for variable: v"): - variable_scope.variable( - 1.0, - name="v", - synchronization=variable_scope.VariableSynchronization.NONE) - - def testInvalidSynchronizationWithVariable(self, distribution): - with distribution.scope(): - with self.assertRaisesRegexp( - ValueError, "Invalid variable synchronization mode: Invalid for " - "variable: v"): - variable_scope.variable(1.0, name="v", synchronization="Invalid") - - def testInvalidAggregationWithGetVariable(self, distribution): - with distribution.scope(): - with self.assertRaisesRegexp( - ValueError, "Invalid variable aggregation mode: invalid for " - "variable: v"): - variable_scope.get_variable( - "v", [1], - synchronization=variable_scope.VariableSynchronization.ON_WRITE, - aggregation="invalid") - - def testInvalidAggregationWithVariable(self, distribution): - with distribution.scope(): - with self.assertRaisesRegexp( - ValueError, "Invalid variable aggregation mode: invalid for " - "variable: v"): - variable_scope.variable( - 1.0, - name="v", - synchronization=variable_scope.VariableSynchronization.ON_WRITE, - aggregation="invalid") - - def testNonMatchingVariableCreation(self, distribution): - self.skipTest("b/123075960") - def model_fn(name): - v = variable_scope.variable(1.0, name=name) - ds_context.get_replica_context().merge_call(lambda _: _) - return v - - with distribution.scope(): - device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0")) - names = values.DistributedValues(device_map, ("foo", "bar")) - with self.assertRaises(RuntimeError): - _ = distribution.extended.call_for_each_replica(model_fn, args=(names,)) - - def testSyncOnReadVariable(self, distribution): - all_v_sum = {} - all_v_mean = {} - components_sum = {} - components_mean = {} - - def model_fn(): - replica_id = self.evaluate(_replica_id()) - v_sum = variable_scope.variable( - 1.0, - synchronization=variable_scope.VariableSynchronization.ON_READ, - aggregation=variable_scope.VariableAggregation.SUM) - v_mean = variable_scope.variable( - 4.0, - synchronization=variable_scope.VariableSynchronization.ON_READ, - aggregation=variable_scope.VariableAggregation.MEAN) - self.assertIsInstance(v_sum, values.SyncOnReadVariable) - self.assertIsInstance(v_mean, values.SyncOnReadVariable) - updates = [v_sum.assign_add(2.0 + replica_id), - v_mean.assign(6.0 * replica_id)] - all_v_sum[replica_id] = v_sum - all_v_mean[replica_id] = v_mean - c_sum = v_sum.get() - c_mean = v_mean.get() - components_sum[replica_id] = c_sum - components_mean[replica_id] = c_mean - self.assertIsNot(v_sum, c_sum) - self.assertIsNot(v_mean, c_mean) - return updates, v_sum, v_mean, c_sum, c_mean - - with distribution.scope(): - # Create "sum" and "mean" versions of SyncOnReadVariables. - ret_ops, ret_v_sum, ret_v_mean, regrouped_sum, regrouped_mean = ( - distribution.extended.call_for_each_replica(model_fn)) - # Should see the same wrapping instance in all replicas. - self.assertIs(all_v_sum[0], ret_v_sum) - self.assertIs(all_v_mean[0], ret_v_mean) - self.assertIs(all_v_sum[0], all_v_sum[1]) - self.assertIs(all_v_mean[0], all_v_mean[1]) - - # Regroup should recover the same wrapper. - self.assertIs(ret_v_sum, regrouped_sum) - self.assertIs(ret_v_mean, regrouped_mean) - self.assertIsNot(components_sum[0], components_sum[1]) - self.assertIsNot(components_mean[0], components_mean[1]) - - # Apply updates - self.evaluate(variables.global_variables_initializer()) - self.evaluate([y for x in ret_ops # pylint: disable=g-complex-comprehension - for y in distribution.experimental_local_results(x)]) - expected_sum = 0.0 - expected_mean = 0.0 - for i, d in enumerate(distribution.extended.worker_devices): - # Should see different values on different devices. - v_sum_value = self.evaluate(ret_v_sum.get(d).read_value()) - v_mean_value = self.evaluate(ret_v_mean.get(d).read_value()) - expected = i + 3.0 - self.assertEqual(expected, v_sum_value) - expected_sum += expected - expected = i * 6.0 - self.assertEqual(expected, v_mean_value) - expected_mean += expected - expected_mean /= len(distribution.extended.worker_devices) - - # Without get(device), should return the value you get by - # applying the reduction across all replicas (whether you use - # read_var(), get(), or nothing). - self.assertEqual(expected_sum, self.evaluate( - distribution.extended.read_var(ret_v_sum))) - self.assertEqual(expected_mean, self.evaluate( - distribution.extended.read_var(ret_v_mean))) - self.assertEqual(expected_sum, self.evaluate(ret_v_sum.get())) - self.assertEqual(expected_mean, self.evaluate(ret_v_mean.get())) - self.assertEqual(expected_sum, self.evaluate(ret_v_sum)) - self.assertEqual(expected_mean, self.evaluate(ret_v_mean)) - - # TODO(priyag): Update this test to work in eager mode as well. - def testDynamicRnnVariables(self, distribution): - def model_fn(): - inputs = constant_op.constant(2 * [2 * [[0.0, 1.0, 2.0, 3.0, 4.0]]]) - cell_fw = rnn_cell_impl.LSTMCell(300) - cell_bw = rnn_cell_impl.LSTMCell(300) - (outputs, _) = rnn.bidirectional_dynamic_rnn( - cell_fw, - cell_bw, - inputs, - dtype=dtypes.float32) - return outputs - - with context.graph_mode(), distribution.scope(): - result = distribution.extended.call_for_each_replica(model_fn) - # Two variables are created by the RNN layer. - self.assertEqual(2, len(result)) - for v in result: - self.assertIsInstance(v, values.DistributedValues) - _, v1 = distribution.experimental_local_results(v) - self.assertStartsWith(v1._op.name, "replica_1/") - - def testSyncOnReadVariableUpdate(self, distribution): - def model_fn(): - v_sum = variable_scope.variable( - 1.0, - synchronization=variable_scope.VariableSynchronization.ON_READ, - aggregation=variable_scope.VariableAggregation.SUM) - self.assertIsInstance(v_sum, values.SyncOnReadVariable) - return v_sum - - def update(var, value): - return var.assign(value) - - with distribution.scope(): - ret_v_sum = distribution.extended.call_for_each_replica(model_fn) - - # Initialize variables. - self.evaluate(variables.global_variables_initializer()) - # Assert that the aggregated value of the sync on read var is the sum - # of the individual values before running the update ops. - self.assertEqual(1.0, self.evaluate(ret_v_sum.get( - distribution.extended.worker_devices[0]).read_value())) - self.assertEqual(2.0, self.evaluate(ret_v_sum)) - - # Apply updates. - update_ops = distribution.extended.update( - ret_v_sum, update, args=(5.0,), group=False) - self.evaluate(update_ops) - # Assert that the aggregated value of the sync on read vars is the sum - # of the individual values after running the update ops. - self.assertEqual(5.0, self.evaluate(ret_v_sum.get( - distribution.extended.worker_devices[0]).read_value())) - self.assertEqual(10.0, self.evaluate(ret_v_sum)) - - def testVarDistributeStrategy(self, distribution): - with distribution.scope(): - mirrored = variable_scope.variable(1.0) - sync_on_read = variable_scope.variable( - 1.0, - synchronization=variable_scope.VariableSynchronization.ON_READ) - self.assertIs(distribution, mirrored.distribute_strategy) - self.assertIs(distribution, sync_on_read.distribute_strategy) - - @combinations.generate( combinations.combine( distribution=[ diff --git a/tensorflow/python/distribute/mirrored_variable_test.py b/tensorflow/python/distribute/mirrored_variable_test.py new file mode 100644 index 00000000000..1bf995b881a --- /dev/null +++ b/tensorflow/python/distribute/mirrored_variable_test.py @@ -0,0 +1,606 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test MirroredVariable in MirroredStrategy and MultiWorkerMirroredStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import collective_all_reduce_strategy +from tensorflow.python.distribute import combinations +from tensorflow.python.distribute import distribution_strategy_context as ds_context +from tensorflow.python.distribute import strategy_combinations +from tensorflow.python.distribute import values +from tensorflow.python.eager import context +from tensorflow.python.eager import def_function +from tensorflow.python.eager import test +from tensorflow.python.framework import config +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import func_graph +from tensorflow.python.framework import ops +from tensorflow.python.layers import core +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import rnn +from tensorflow.python.ops import rnn_cell_impl +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables + + +def _replica_id(): + replica_id = ds_context.get_replica_context().replica_id_in_sync_group + if not isinstance(replica_id, ops.Tensor): + replica_id = constant_op.constant(replica_id) + return replica_id + + +def _mimic_two_cpus(): + cpus = config.list_physical_devices("CPU") + + config.set_virtual_device_configuration(cpus[0], [ + context.VirtualDeviceConfiguration(), + context.VirtualDeviceConfiguration(), + ]) + + +@combinations.generate( + combinations.combine( + distribution=[ + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.NamedDistribution( + "Collective2CPUs", + # pylint: disable=g-long-lambda + lambda: collective_all_reduce_strategy. + CollectiveAllReduceStrategy._from_local_devices(( + "/device:CPU:0", "/device:CPU:1")), + required_gpus=0) + ], + mode=["graph", "eager"])) +class MirroredVariableCreationTest(test.TestCase): + """Base class that tests mirrored variable creator. + + Currently it assumes all strategy objects have two replicas. + """ + + @classmethod + def setUpClass(cls): + _mimic_two_cpus() + + # TODO(priyag): Modify more tests to use this helper and check more + # properties. + def _test_mv_properties(self, var, name, strategy): + self.assertIsInstance(var, values.MirroredVariable) + self.assertEqual(name, var.name) + self.assertIs(strategy, var.distribute_strategy) + for d in var.devices: + self.assertEqual(d, var.get(d).device) + self.assertIs(strategy, var.get(d)._distribute_strategy) # pylint: disable=protected-access + + def testVariableInFuncGraph(self, distribution): + + def model_fn(): + v = variable_scope.variable(2.0, name="bar") + ds_context.get_replica_context().merge_call(lambda _: _) + return v + + with func_graph.FuncGraph("fg").as_default(), distribution.scope(): + v1 = variable_scope.variable(1.0, name="foo") + v2 = distribution.extended.call_for_each_replica(model_fn) + + self._test_mv_properties(v1, "foo:0", distribution) + self._test_mv_properties(v2, "bar:0", distribution) + + def testVariableWithTensorInitialValueInFunction(self, distribution): + if not context.executing_eagerly(): + self.skipTest("`tf.function` is an eager-only feature") + + v = [None] + + def model_fn(): + if v[0] is None: + init_val = array_ops.zeros([]) + v[0] = variables.Variable(init_val) + ds_context.get_replica_context().merge_call(lambda _: _) + return v[0] + + @def_function.function(autograph=False) + def make_v1(): + return distribution.experimental_local_results( + distribution.extended.call_for_each_replica(model_fn)) + + self.assertAllEqual([0, 0], make_v1()) + + def testSingleVariable(self, distribution): + + def model_fn(): + # This variable should be created only once across the threads because of + # special variable_creator functions used by + # `distribution.extended.call_for_each_replica`. + v = variable_scope.variable(1.0, name="foo") + ds_context.get_replica_context().merge_call(lambda _: _) + return v + + with distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) + self._test_mv_properties(result, "foo:0", distribution) + + def testUnnamedVariable(self, distribution): + + def model_fn(): + v = variable_scope.variable(1.0) + ds_context.get_replica_context().merge_call(lambda _: _) + return v + + with distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) + self._test_mv_properties(result, "Variable:0", distribution) + + def testMultipleVariables(self, distribution): + + def model_fn(): + vs = [] + for i in range(5): + vs.append(variable_scope.variable(1.0, name="foo" + str(i))) + ds_context.get_replica_context().merge_call(lambda _: _) + return vs + + with distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) + for i, v in enumerate(result): + self._test_mv_properties(v, "foo" + str(i) + ":0", distribution) + + def testMultipleVariablesWithSameCanonicalName(self, distribution): + + def model_fn(): + vs = [] + vs.append(variable_scope.variable(1.0, name="foo/bar")) + vs.append(variable_scope.variable(1.0, name="foo_1/bar")) + vs.append(variable_scope.variable(1.0, name="foo_1/bar_1")) + vs.append(variable_scope.variable(1.0, name="foo/bar_1")) + ds_context.get_replica_context().merge_call(lambda _: _) + return vs + + with distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) + for v in result: + self.assertIsInstance(v, values.MirroredVariable) + self.assertEqual(4, len(result)) + self.assertEqual("foo/bar:0", result[0].name) + self.assertEqual("foo_1/bar:0", result[1].name) + self.assertEqual("foo_1/bar_1:0", result[2].name) + self.assertEqual("foo/bar_1:0", result[3].name) + + def testVariableWithSameCanonicalNameAcrossThreads(self, distribution): + + def model_fn(): + replica_id = self.evaluate(_replica_id()) + v = variable_scope.variable(1.0, name="foo_" + str(replica_id)) + ds_context.get_replica_context().merge_call(lambda _: _) + return v + + with distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) + self.assertIsInstance(result, values.MirroredVariable) + # The resulting mirrored variable will use the name from the first device. + self.assertEqual("foo_0:0", result.name) + + def testWithLayers(self, distribution): + + def model_fn(features): + with variable_scope.variable_scope("common"): + layer1 = core.Dense(1) + layer1(features) + layer2 = core.Dense(1) + layer2(features) + # This will pause the current thread, and execute the other thread. + ds_context.get_replica_context().merge_call(lambda _: _) + layer3 = core.Dense(1) + layer3(features) + return [(layer1.kernel, layer1.bias), (layer2.kernel, layer2.bias), + (layer3.kernel, layer3.bias)] + + iterator = distribution.make_input_fn_iterator( + lambda _: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)) + self.evaluate(iterator.initialize()) + features = iterator.get_next() + + with distribution.scope(): + result = distribution.extended.call_for_each_replica( + model_fn, args=(features,)) + suffixes = ["", "_1", "_2"] + for (kernel, bias), suffix in zip(result, suffixes): + self.assertIsInstance(kernel, values.MirroredVariable) + self.assertEqual("common/dense" + suffix + "/kernel:0", kernel.name) + self.assertIsInstance(bias, values.MirroredVariable) + self.assertEqual("common/dense" + suffix + "/bias:0", bias.name) + + def testWithVariableAndVariableScope(self, distribution): + + def model_fn(): + v0 = variable_scope.variable(1.0, name="var0", aggregation=None) + with variable_scope.variable_scope("common"): + v1 = variable_scope.variable(1.0, name="var1") + # This will pause the current thread, and execute the other thread. + ds_context.get_replica_context().merge_call(lambda _: _) + v2 = variable_scope.variable( + 1.0, + name="var2", + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.SUM) + v3 = variable_scope.variable( + 1.0, + name="var3", + synchronization=variable_scope.VariableSynchronization.ON_WRITE, + aggregation=variable_scope.VariableAggregation.MEAN) + + return v0, v1, v2, v3 + + with distribution.scope(): + v = variable_scope.variable(1.0, name="var-main0") + self.assertEqual("var-main0:0", v.name) + + result = distribution.extended.call_for_each_replica(model_fn) + self.assertEqual(4, len(result)) + v0, v1, v2, v3 = result + self.assertIsInstance(v0, values.MirroredVariable) + self.assertEqual("var0:0", v0.name) + self.assertIsInstance(v1, values.MirroredVariable) + self.assertEqual("common/var1:0", v1.name) + self.assertIsInstance(v2, values.SyncOnReadVariable) + self.assertEqual("common/var2:0", v2.name) + self.assertEqual(variable_scope.VariableAggregation.SUM, v2.aggregation) + self.assertIsInstance(v3, values.MirroredVariable) + self.assertEqual("common/var3:0", v3.name) + self.assertEqual(variable_scope.VariableAggregation.MEAN, v3.aggregation) + + def testWithGetVariableAndVariableScope(self, distribution): + + def model_fn(): + v0 = variable_scope.get_variable("var0", [1]) + with variable_scope.variable_scope("common"): + v1 = variable_scope.get_variable("var1", [1]) + # This will pause the current thread, and execute the other thread. + ds_context.get_replica_context().merge_call(lambda _: _) + v2 = variable_scope.get_variable( + "var2", [1], + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.SUM) + v3 = variable_scope.get_variable( + "var3", [1], + synchronization=variable_scope.VariableSynchronization.ON_WRITE, + aggregation=variable_scope.VariableAggregation.MEAN) + + return v0, v1, v2, v3 + + with distribution.scope(): + with variable_scope.variable_scope("main"): + v = variable_scope.get_variable("var-main0", [1]) + self.assertEqual("main/var-main0:0", v.name) + + result = distribution.extended.call_for_each_replica(model_fn) + self.assertEqual(4, len(result)) + v0, v1, v2, v3 = result + self.assertIsInstance(v0, values.MirroredVariable) + self.assertEqual("main/var0:0", v0.name) + self.assertIsInstance(v1, values.MirroredVariable) + self.assertEqual("main/common/var1:0", v1.name) + self.assertIsInstance(v2, values.SyncOnReadVariable) + self.assertEqual("main/common/var2:0", v2.name) + self.assertEqual(variable_scope.VariableAggregation.SUM, v2.aggregation) + self.assertIsInstance(v3, values.MirroredVariable) + self.assertEqual("main/common/var3:0", v3.name) + self.assertEqual(variable_scope.VariableAggregation.MEAN, + v3.aggregation) + + def testOnlyFirstReplicaUpdatesVariables(self, distribution): + + def create_fn(): + aggregation = variable_scope.VariableAggregation.ONLY_FIRST_REPLICA + v0 = variable_scope.variable( + 2.0, + name="on_read", + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=aggregation) + v1 = variable_scope.variable( + 3.0, + name="on_write", + synchronization=variable_scope.VariableSynchronization.ON_WRITE, + aggregation=aggregation) + return v0, v1 + + devices = distribution.extended.worker_devices + with distribution.scope(): + v0, v1 = distribution.extended.call_for_each_replica(create_fn) + self.evaluate(v0.initializer) + self.assertEqual(2.0, self.evaluate(v0.get(devices[0]))) + self.assertEqual(2.0, self.evaluate(v0.get(devices[1]))) + self.assertEqual(2.0, self.evaluate(distribution.extended.read_var(v0))) + self.evaluate(v1.initializer) + self.assertEqual(3.0, self.evaluate(v1.get(devices[0]))) + self.assertEqual(3.0, self.evaluate(v1.get(devices[1]))) + self.assertEqual(3.0, self.evaluate(distribution.extended.read_var(v1))) + + def replica_id_plus_one(): + return math_ops.cast(_replica_id() + 1, dtype=dtypes.float32) + + # Update using the assign_add member function. + def update_member_fn(): + update0 = v0.assign_add(5.0 * replica_id_plus_one()) + update1 = v1.assign_add(7.0 * replica_id_plus_one()) + return update0, update1 + + update0a, update1a = distribution.extended.call_for_each_replica( + update_member_fn) + + # Update "sync on read" variable. + self.evaluate(distribution.group(update0a)) + self.assertEqual(2.0 + 5.0, self.evaluate(v0.get(devices[0]))) + # Writes are not synchronized for "sync on read" variables, + # so device[1] can end up with a different value. + self.assertEqual(2.0 + 2 * 5.0, self.evaluate(v0.get(devices[1]))) + # Always reads from device 0. + self.assertEqual(2.0 + 5.0, + self.evaluate(distribution.extended.read_var(v0))) + + # Update "sync on write" variable. + self.evaluate(distribution.group(update1a)) + self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[0]))) + # Writes are synchronized for v1, only the argument to assign_add on + # device[0] is used. + self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[1]))) + self.assertEqual(3.0 + 7.0, + self.evaluate(distribution.extended.read_var(v1))) + + # Update using state_ops.assign_add global function. + def update_state_ops_fn(): + update0 = state_ops.assign_add(v0, 11.0 * replica_id_plus_one()) + update1 = state_ops.assign_add(v1, 13.0 * replica_id_plus_one()) + return update0, update1 + + update0b, update1b = distribution.extended.call_for_each_replica( + update_state_ops_fn) + self.evaluate(distribution.group(update0b)) + + # Update "sync on read" variable. + self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate(v0.get(devices[0]))) + self.assertEqual(2.0 + 2 * 5.0 + 2 * 11.0, + self.evaluate(v0.get(devices[1]))) + self.assertEqual(2.0 + 5.0 + 11.0, + self.evaluate(distribution.extended.read_var(v0))) + + # Update "sync on write" variable. + self.evaluate(distribution.group(update1b)) + self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[0]))) + self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[1]))) + self.assertEqual(3.0 + 7.0 + 13.0, + self.evaluate(distribution.extended.read_var(v1))) + + def testNoneSynchronizationWithGetVariable(self, distribution): + with distribution.scope(): + with self.assertRaisesRegexp( + ValueError, "`NONE` variable synchronization mode is not " + "supported with `Mirrored` distribution strategy. Please change " + "the `synchronization` for variable: v"): + variable_scope.get_variable( + "v", [1], + synchronization=variable_scope.VariableSynchronization.NONE) + + def testNoneSynchronizationWithVariable(self, distribution): + with distribution.scope(): + with self.assertRaisesRegexp( + ValueError, "`NONE` variable synchronization mode is not " + "supported with `Mirrored` distribution strategy. Please change " + "the `synchronization` for variable: v"): + variable_scope.variable( + 1.0, + name="v", + synchronization=variable_scope.VariableSynchronization.NONE) + + def testInvalidSynchronizationWithVariable(self, distribution): + with distribution.scope(): + with self.assertRaisesRegexp( + ValueError, "Invalid variable synchronization mode: Invalid for " + "variable: v"): + variable_scope.variable(1.0, name="v", synchronization="Invalid") + + def testInvalidAggregationWithGetVariable(self, distribution): + with distribution.scope(): + with self.assertRaisesRegexp( + ValueError, "Invalid variable aggregation mode: invalid for " + "variable: v"): + variable_scope.get_variable( + "v", [1], + synchronization=variable_scope.VariableSynchronization.ON_WRITE, + aggregation="invalid") + + def testInvalidAggregationWithVariable(self, distribution): + with distribution.scope(): + with self.assertRaisesRegexp( + ValueError, "Invalid variable aggregation mode: invalid for " + "variable: v"): + variable_scope.variable( + 1.0, + name="v", + synchronization=variable_scope.VariableSynchronization.ON_WRITE, + aggregation="invalid") + + def testNonMatchingVariableCreation(self, distribution): + self.skipTest("b/123075960") + + def model_fn(name): + v = variable_scope.variable(1.0, name=name) + ds_context.get_replica_context().merge_call(lambda _: _) + return v + + with distribution.scope(): + device_map = values.ReplicaDeviceMap(distribution.extended.worker_devices) + names = values.DistributedValues(device_map, ("foo", "bar")) + with self.assertRaises(RuntimeError): + _ = distribution.extended.call_for_each_replica(model_fn, args=(names,)) + + def testSyncOnReadVariable(self, distribution): + all_v_sum = {} + all_v_mean = {} + components_sum = {} + components_mean = {} + + def model_fn(): + replica_id = self.evaluate(_replica_id()) + v_sum = variable_scope.variable( + 1.0, + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.SUM) + v_mean = variable_scope.variable( + 4.0, + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.MEAN) + self.assertIsInstance(v_sum, values.SyncOnReadVariable) + self.assertIsInstance(v_mean, values.SyncOnReadVariable) + updates = [ + v_sum.assign_add(2.0 + replica_id), + v_mean.assign(6.0 * replica_id) + ] + all_v_sum[replica_id] = v_sum + all_v_mean[replica_id] = v_mean + c_sum = v_sum.get() + c_mean = v_mean.get() + components_sum[replica_id] = c_sum + components_mean[replica_id] = c_mean + self.assertIsNot(v_sum, c_sum) + self.assertIsNot(v_mean, c_mean) + return updates, v_sum, v_mean, c_sum, c_mean + + with distribution.scope(): + # Create "sum" and "mean" versions of SyncOnReadVariables. + ret_ops, ret_v_sum, ret_v_mean, regrouped_sum, regrouped_mean = ( + distribution.extended.call_for_each_replica(model_fn)) + # Should see the same wrapping instance in all replicas. + self.assertIs(all_v_sum[0], ret_v_sum) + self.assertIs(all_v_mean[0], ret_v_mean) + self.assertIs(all_v_sum[0], all_v_sum[1]) + self.assertIs(all_v_mean[0], all_v_mean[1]) + + # Regroup should recover the same wrapper. + self.assertIs(ret_v_sum, regrouped_sum) + self.assertIs(ret_v_mean, regrouped_mean) + self.assertIsNot(components_sum[0], components_sum[1]) + self.assertIsNot(components_mean[0], components_mean[1]) + + # Apply updates + self.evaluate(variables.global_variables_initializer()) + self.evaluate([ + y for x in ret_ops # pylint: disable=g-complex-comprehension + for y in distribution.experimental_local_results(x) + ]) + expected_sum = 0.0 + expected_mean = 0.0 + for i, d in enumerate(distribution.extended.worker_devices): + # Should see different values on different devices. + v_sum_value = self.evaluate(ret_v_sum.get(d).read_value()) + v_mean_value = self.evaluate(ret_v_mean.get(d).read_value()) + expected = i + 3.0 + self.assertEqual(expected, v_sum_value) + expected_sum += expected + expected = i * 6.0 + self.assertEqual(expected, v_mean_value) + expected_mean += expected + expected_mean /= len(distribution.extended.worker_devices) + + # Without get(device), should return the value you get by + # applying the reduction across all replicas (whether you use + # read_var(), get(), or nothing). + self.assertEqual(expected_sum, self.evaluate( + distribution.extended.read_var(ret_v_sum))) + self.assertEqual(expected_mean, self.evaluate( + distribution.extended.read_var(ret_v_mean))) + self.assertEqual(expected_sum, self.evaluate(ret_v_sum.get())) + self.assertEqual(expected_mean, self.evaluate(ret_v_mean.get())) + self.assertEqual(expected_sum, self.evaluate(ret_v_sum)) + self.assertEqual(expected_mean, self.evaluate(ret_v_mean)) + + # TODO(priyag): Update this test to work in eager mode as well. + def testDynamicRnnVariables(self, distribution): + + def model_fn(): + inputs = constant_op.constant(2 * [2 * [[0.0, 1.0, 2.0, 3.0, 4.0]]]) + cell_fw = rnn_cell_impl.LSTMCell(300) + cell_bw = rnn_cell_impl.LSTMCell(300) + (outputs, _) = rnn.bidirectional_dynamic_rnn( + cell_fw, cell_bw, inputs, dtype=dtypes.float32) + return outputs + + with context.graph_mode(), distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) + # Two variables are created by the RNN layer. + self.assertEqual(2, len(result)) + for v in result: + self.assertIsInstance(v, values.DistributedValues) + _, v1 = distribution.experimental_local_results(v) + self.assertStartsWith(v1._op.name, "replica_1/") + + def testSyncOnReadVariableUpdate(self, distribution): + + def model_fn(): + v_sum = variable_scope.variable( + 1.0, + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.SUM) + self.assertIsInstance(v_sum, values.SyncOnReadVariable) + return v_sum + + def update(var, value): + return var.assign(value) + + with distribution.scope(): + ret_v_sum = distribution.extended.call_for_each_replica(model_fn) + + # Initialize variables. + self.evaluate(variables.global_variables_initializer()) + # Assert that the aggregated value of the sync on read var is the sum + # of the individual values before running the update ops. + self.assertEqual( + 1.0, + self.evaluate( + ret_v_sum.get( + distribution.extended.worker_devices[0]).read_value())) + self.assertEqual(2.0, self.evaluate(ret_v_sum)) + + # Apply updates. + update_ops = distribution.extended.update( + ret_v_sum, update, args=(5.0,), group=False) + self.evaluate(update_ops) + # Assert that the aggregated value of the sync on read vars is the sum + # of the individual values after running the update ops. + self.assertEqual( + 5.0, + self.evaluate( + ret_v_sum.get( + distribution.extended.worker_devices[0]).read_value())) + self.assertEqual(10.0, self.evaluate(ret_v_sum)) + + def testVarDistributeStrategy(self, distribution): + with distribution.scope(): + mirrored = variable_scope.variable(1.0) + sync_on_read = variable_scope.variable( + 1.0, synchronization=variable_scope.VariableSynchronization.ON_READ) + self.assertIs(distribution, mirrored.distribute_strategy) + self.assertIs(distribution, sync_on_read.distribute_strategy) + + +if __name__ == "__main__": + test.main()