Use an incremental counter for collective keys of variables;

Make MultiWorkerMirroredStrategy reuse more code from MirroredStrategy in the variable creator;
Moved mirrored variable creator test to a new file and add coverage for MultiWorkerMirroredStrategy;
Fixed a bug in cross_device_ops that sometimes output of allreduce ops are not all run in graph mode.

PiperOrigin-RevId: 251974203
This commit is contained in:
Yuefeng Zhou 2019-06-06 19:08:23 -07:00 committed by TensorFlower Gardener
parent f1ffa0225a
commit f5003f6315
9 changed files with 760 additions and 674 deletions

View File

@ -115,7 +115,7 @@ class CollectiveAllReduceStrategyTestBase(
def setUp(self):
# We use a different key_base for each test so that collective keys won't be
# reused.
# TODO(yuefengz, tucker): enable it to reuse collective keys in different
# TODO(yuefengz, ayushd): enable it to reuse collective keys in different
# tests.
CollectiveAllReduceStrategyTestBase.collective_key_base += 100000
super(CollectiveAllReduceStrategyTestBase, self).setUp()
@ -133,11 +133,11 @@ class CollectiveAllReduceStrategyTestBase(
use_core_strategy=use_core_strategy)
collective_keys = cross_device_utils.CollectiveKeys(
group_key_start=10 * num_gpus +
group_key_start=10 +
CollectiveAllReduceStrategyTestBase.collective_key_base,
instance_key_start=num_gpus * 100 +
op_instance_key_start=100 +
CollectiveAllReduceStrategyTestBase.collective_key_base,
instance_key_with_id_start=num_gpus * 10000 +
variable_instance_key_start=10000 +
CollectiveAllReduceStrategyTestBase.collective_key_base)
strategy.extended._collective_keys = collective_keys
strategy.extended._cross_device_ops._collective_keys = (collective_keys)

View File

@ -966,6 +966,32 @@ cuda_py_test(
xla_enable_strict_auto_jit = True,
)
cuda_py_test(
name = "mirrored_variable_test",
srcs = ["mirrored_variable_test.py"],
additional_deps = [
":collective_all_reduce_strategy",
":combinations",
":strategy_combinations",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:layers",
"//tensorflow/python:state_ops",
"//tensorflow/python:variable_scope",
"//tensorflow/python/distribute:distribute_lib",
"//tensorflow/python/distribute:values",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:test",
],
tags = [
"guitar",
"multi_and_single_gpu",
],
xla_enable_strict_auto_jit = True,
)
distribute_py_test(
name = "metrics_v1_test",
srcs = ["metrics_v1_test.py"],

View File

@ -36,7 +36,6 @@ from tensorflow.python.distribute import values
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
from tensorflow.python.eager import context
from tensorflow.python.eager import tape
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import collective_ops
@ -79,6 +78,13 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy):
self,
communication=communication))
@classmethod
def _from_local_devices(cls, devices):
"""A convenience method to create an obejct with a list of devices."""
obj = cls()
obj.extended._initialize_local(TFConfigClusterResolver(), devices=devices) # pylint: disable=protected-access
return obj
@tf_export(v1=["distribute.experimental.MultiWorkerMirroredStrategy"])
class CollectiveAllReduceStrategyV1(distribute_lib.StrategyV1):
@ -117,7 +123,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
else:
self._initialize_local(cluster_resolver)
def _initialize_local(self, cluster_resolver):
def _initialize_local(self, cluster_resolver, devices=None):
"""Initializes the object for local training."""
self._is_chief = True
self._num_workers = 1
@ -140,6 +146,9 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
else:
num_gpus = cluster_resolver.num_accelerators().get("GPU", 0)
if devices:
local_devices = devices
else:
if num_gpus:
local_devices = tuple("/device:GPU:%d" % i for i in range(num_gpus))
else:
@ -272,51 +281,30 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
task_id, self._num_workers, local_devices,
self._communication)
def _create_variable(self, next_creator, *args, **kwargs):
colocate_with = kwargs.pop("colocate_with", None)
if colocate_with is None:
device_map = self._device_map
logical_device = 0 # TODO(josh11b): Get logical device from scope here.
elif isinstance(colocate_with, numpy_dataset.SingleDevice):
with ops.device(colocate_with.device):
return next_creator(*args, **kwargs)
else:
device_map = colocate_with.device_map
logical_device = colocate_with.logical_device
def _get_variable_creator_initial_value(self,
replica_id=0,
device=None,
primary_var=None,
**kwargs):
if replica_id == 0: # First replica on each worker.
assert device is not None
assert primary_var is None
def _real_mirrored_creator(devices, *args, **kwargs):
"""Creates one MirroredVariable on the current worker."""
unique_var_name = ops.get_default_graph().unique_name(
kwargs["name"], mark_as_used=False).rstrip("/")
# pylint: disable=protected-access
collective_instance_key = self._collective_keys.get_instance_key(
key_id=unique_var_name)
# Only the first device participles in the broadcast of initial values.
group_key = self._collective_keys.get_group_key([devices[0]])
def initial_value_fn(): # pylint: disable=g-missing-docstring
# Only the first device participates in the broadcast of initial values.
group_key = self._collective_keys.get_group_key([device])
group_size = self._num_workers
if "initial_value" not in kwargs:
raise ValueError("Initial value must be specified.")
collective_instance_key = (
self._collective_keys.get_variable_instance_key())
with ops.device(device):
initial_value = kwargs["initial_value"]
if callable(initial_value):
initial_value_fn = initial_value
else:
initial_value_fn = lambda: initial_value
value_list = []
for i, d in enumerate(devices):
with ops.init_scope(), ops.device(d):
if i == 0:
# The initial value fn makes sure variables all initialized to
# same values. The first device of the chief worker will send their
# variable values to other workers.
def _overridden_initial_value_fn(device=d, index=i): # pylint: disable=g-missing-docstring
with ops.device(device):
initial_value = initial_value_fn()
initial_value = initial_value()
assert not callable(initial_value)
initial_value = ops.convert_to_tensor(
initial_value, dtype=kwargs.get("dtype", None))
assert index == 0, index
if self._num_workers > 1:
if self._is_chief:
bcast_send = collective_ops.broadcast_send(
@ -325,47 +313,20 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
with ops.control_dependencies([bcast_send]):
return array_ops.identity(initial_value)
else:
return collective_ops.broadcast_recv(
initial_value.shape, initial_value.dtype, group_size,
group_key, collective_instance_key)
return collective_ops.broadcast_recv(initial_value.shape,
initial_value.dtype,
group_size, group_key,
collective_instance_key)
return initial_value
return initial_value_fn
else:
# Give replicas meaningful distinct names:
var0name = value_list[0].name.split(":")[0]
# We append a / to variable names created on replicas with id > 0 to
# ensure that we ignore the name scope and instead use the given
# name as the absolute name of the variable.
kwargs["name"] = "%s/replica_%d/" % (var0name, i)
# Variables on non-first replica get initial values from the
# variables created on the first device of each worker.
def _overridden_initial_value_fn(device=d, index=i):
assert index > 0
with ops.device(device):
if context.executing_eagerly():
return array_ops.identity(value_list[0].value())
else:
return array_ops.identity(value_list[0].initial_value)
kwargs["initial_value"] = _overridden_initial_value_fn
with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
# Don't record operations (e.g. other variable reads) during
# variable creation.
with tape.stop_recording():
v = next_creator(*args, **kwargs)
if i == 0:
actual_var_name = v.name.split(":")[0]
assert unique_var_name == actual_var_name, "%r vs %r" % (
unique_var_name, actual_var_name)
assert not isinstance(v, values.DistributedVariable)
value_list.append(v)
return value_list
# pylint: disable=protected-access
return mirrored_strategy._create_mirrored_variable(
self._container_strategy(), device_map, logical_device,
_real_mirrored_creator, *args, **kwargs)
return super(CollectiveAllReduceExtended,
self)._get_variable_creator_initial_value(
replica_id=replica_id,
device=device,
primary_var=primary_var,
**kwargs)
def _make_input_context(self):
if self._cluster_spec is None:

View File

@ -998,13 +998,14 @@ class CollectiveAllReduce(CrossDeviceOps):
return all_reduced
devices = device_map.logical_to_actual_devices(logical_device)
index = []
with ops.control_dependencies(all_reduced.values):
for d in devices:
with ops.device(d):
if d in all_reduced.devices:
index.append(all_reduced.get(d))
index.append(array_ops.identity(all_reduced.get(d)))
else:
# TODO(josh11b): Once we add support for model parallelism, get the
# copy from the corresponding replica instead of the primary.
with ops.control_dependencies(all_reduced.values), ops.device(d):
index.append(array_ops.identity(all_reduced.primary))
return value_lib.Mirrored(device_map, index, logical_device)

View File

@ -446,11 +446,9 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
use_strategy_object=False,
local_mode=False):
collective_keys = cross_device_utils.CollectiveKeys(
group_key_start=10 * num_gpus +
CollectiveAllReduceTest.collective_key_base,
instance_key_start=num_gpus * 100 +
CollectiveAllReduceTest.collective_key_base,
instance_key_with_id_start=num_gpus * 10000 +
group_key_start=10 + CollectiveAllReduceTest.collective_key_base,
op_instance_key_start=100 + CollectiveAllReduceTest.collective_key_base,
variable_instance_key_start=10000 +
CollectiveAllReduceTest.collective_key_base)
if local_mode:
if num_gpus:

View File

@ -253,31 +253,28 @@ class CollectiveKeys(object):
def __init__(self,
group_key_start=1,
instance_key_start=100,
instance_key_with_id_start=10000):
op_instance_key_start=100,
variable_instance_key_start=1000000):
"""Initializes the object.
Args:
group_key_start: the starting integer of group key.
instance_key_start: the starting integer of instance key.
instance_key_with_id_start: the starting integer of instance key that is
recorded with an id.
op_instance_key_start: the starting integer of instance key for ops.
variable_instance_key_start: the starting integer of instance key for
variables.
"""
self._group_key = group_key_start
self._group_key_table = {}
# For instance keys with ids
self._instance_key_id_to_key_table = {}
self._instance_key_with_id_counter = instance_key_with_id_start
# For instance keys without ids
self._instance_key_start = instance_key_start
assert op_instance_key_start != variable_instance_key_start
self._op_instance_key_start = op_instance_key_start
self._variable_instance_key = variable_instance_key_start
def _get_thread_local_object(self):
# We make instance key without key ids thread local so that it will work
# with MirroredStrategy and distribute coordinator.
if not hasattr(_thread_local, 'instance_key'):
_thread_local.instance_key = self._instance_key_start
if not hasattr(_thread_local, 'op_instance_key'):
_thread_local.op_instance_key = self._op_instance_key_start
return _thread_local
def get_group_key(self, devices):
@ -304,24 +301,16 @@ class CollectiveKeys(object):
self._group_key_table[key_id] = new_key
return self._group_key_table[key_id]
def get_instance_key(self, key_id=None):
"""Returns a new instance key for use in defining a collective op.
def get_op_instance_key(self):
"""Returns a new instance key for use in defining a collective op."""
v = self._get_thread_local_object().op_instance_key
self._get_thread_local_object().op_instance_key += 1
return v
Args:
key_id: optional string. If set, key will be recorded and the same key
will be returned when the same key_id is provided. If not, an increasing
instance key will be returned.
"""
if key_id:
with _lock:
if key_id not in self._instance_key_id_to_key_table:
self._instance_key_with_id_counter += 1
self._instance_key_id_to_key_table[key_id] = (
self._instance_key_with_id_counter)
return self._instance_key_id_to_key_table[key_id]
else:
v = self._get_thread_local_object().instance_key
self._get_thread_local_object().instance_key += 1
def get_variable_instance_key(self):
"""Returns a new instance key for use in creating a Variable."""
v = self._variable_instance_key
self._variable_instance_key += 1
return v
@ -354,7 +343,7 @@ def build_collective_reduce(input_tensors,
devices = [t.device for t in input_tensors]
num_devices = len(devices)
group_key = collective_keys.get_group_key(devices)
instance_key = collective_keys.get_instance_key()
instance_key = collective_keys.get_op_instance_key()
subdiv_offsets = [0] # TODO(tucker): maybe support non-default subdiv spec
def collective_all_reduce():
@ -399,7 +388,7 @@ def build_collective_gather(input_tensors, num_workers, collective_keys):
devices = [t.device for t in input_tensors]
num_devices = len(devices)
group_key = collective_keys.get_group_key(devices)
instance_key = collective_keys.get_instance_key()
instance_key = collective_keys.get_op_instance_key()
def collective_all_gather():
"""Call collective allgather."""

View File

@ -532,6 +532,30 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
# containing job names.
self._inferred_cross_device_ops = cross_device_ops_lib.NcclAllReduce()
def _get_variable_creator_initial_value(self,
replica_id=0,
device=None,
primary_var=None,
**kwargs):
"""Return the initial value for variables on a replica."""
if replica_id == 0:
return kwargs["initial_value"]
else:
assert primary_var is not None
assert device is not None
assert kwargs is not None
def initial_value_fn():
if context.executing_eagerly() or ops.inside_function():
init_value = primary_var.value()
return array_ops.identity(init_value)
else:
with ops.device(device):
init_value = primary_var.initial_value
return array_ops.identity(init_value)
return initial_value_fn
def _create_variable(self, next_creator, *args, **kwargs):
"""Create a mirrored variable. See `DistributionStrategy.scope`."""
colocate_with = kwargs.pop("colocate_with", None)
@ -549,6 +573,11 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
value_list = []
for i, d in enumerate(devices):
with ops.device(d):
kwargs["initial_value"] = self._get_variable_creator_initial_value(
replica_id=i,
device=d,
primary_var=value_list[0] if value_list else None,
**kwargs)
if i > 0:
# Give replicas meaningful distinct names:
var0name = value_list[0].name.split(":")[0]
@ -556,16 +585,6 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
# ensure that we ignore the name scope and instead use the given
# name as the absolute name of the variable.
kwargs["name"] = "%s/replica_%d/" % (var0name, i)
# Initialize replicas with the same value:
def initial_value_fn(device=d):
if context.executing_eagerly() or ops.inside_function():
init_value = value_list[0].value()
return array_ops.identity(init_value)
else:
with ops.device(device):
init_value = value_list[0].initial_value
return array_ops.identity(init_value)
kwargs["initial_value"] = initial_value_fn
with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
# Don't record operations (e.g. other variable reads) during
# variable creation.
@ -749,8 +768,8 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
reduce_op, value, destinations=destinations)
def _batch_reduce_to(self, reduce_op, value_destination_pairs):
return self._get_cross_device_ops().batch_reduce(
reduce_op, value_destination_pairs)
return self._get_cross_device_ops().batch_reduce(reduce_op,
value_destination_pairs)
def _update(self, var, fn, args, kwargs, group):
# TODO(josh11b): In eager mode, use one thread per device.

View File

@ -48,12 +48,8 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras.engine import training as keras_training
from tensorflow.python.keras.layers import core as keras_core
from tensorflow.python.layers import core
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import gradient_descent
@ -370,516 +366,6 @@ class MirroredStrategyCallForEachReplicaTest(test.TestCase):
distribution.extended.call_for_each_replica(model_fn)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
mode=["graph", "eager"]))
class MirroredStrategyVariableCreationTest(test.TestCase):
# TODO(priyag): Modify more tests to use this helper and check more
# properties.
def _test_mv_properties(self, var, name, strategy):
self.assertIsInstance(var, values.MirroredVariable)
self.assertEqual(name, var.name)
self.assertIs(strategy, var.distribute_strategy)
for d in var.devices:
self.assertEqual(d, var.get(d).device)
self.assertIs(strategy, var.get(d)._distribute_strategy) # pylint: disable=protected-access
def testVariableInFuncGraph(self, distribution):
def model_fn():
v = variable_scope.variable(2.0, name="bar")
ds_context.get_replica_context().merge_call(lambda _: _)
return v
with func_graph.FuncGraph("fg").as_default(), distribution.scope():
v1 = variable_scope.variable(1.0, name="foo")
v2 = distribution.extended.call_for_each_replica(model_fn)
self._test_mv_properties(v1, "foo:0", distribution)
self._test_mv_properties(v2, "bar:0", distribution)
def testVariableWithTensorInitialValueInFunction(self, distribution):
if not context.executing_eagerly():
self.skipTest("`tf.function` is an eager-only feature")
v = [None]
def model_fn():
if v[0] is None:
init_val = array_ops.zeros([])
v[0] = variables.Variable(init_val)
ds_context.get_replica_context().merge_call(lambda _: _)
return v[0]
@def_function.function(autograph=False)
def make_v1():
return distribution.experimental_local_results(
distribution.extended.call_for_each_replica(model_fn))
self.assertAllEqual([0, 0], make_v1())
def testSingleVariable(self, distribution):
def model_fn():
# This variable should be created only once across the threads because of
# special variable_creator functions used by
# `distribution.extended.call_for_each_replica`.
v = variable_scope.variable(1.0, name="foo")
ds_context.get_replica_context().merge_call(lambda _: _)
return v
with distribution.scope():
result = distribution.extended.call_for_each_replica(model_fn)
self._test_mv_properties(result, "foo:0", distribution)
def testUnnamedVariable(self, distribution):
def model_fn():
v = variable_scope.variable(1.0)
ds_context.get_replica_context().merge_call(lambda _: _)
return v
with distribution.scope():
result = distribution.extended.call_for_each_replica(model_fn)
self._test_mv_properties(result, "Variable:0", distribution)
def testMultipleVariables(self, distribution):
def model_fn():
vs = []
for i in range(5):
vs.append(variable_scope.variable(1.0, name="foo" + str(i)))
ds_context.get_replica_context().merge_call(lambda _: _)
return vs
with distribution.scope():
result = distribution.extended.call_for_each_replica(model_fn)
for i, v in enumerate(result):
self._test_mv_properties(v, "foo" + str(i) + ":0", distribution)
def testMultipleVariablesWithSameCanonicalName(self, distribution):
def model_fn():
vs = []
vs.append(variable_scope.variable(1.0, name="foo/bar"))
vs.append(variable_scope.variable(1.0, name="foo_1/bar"))
vs.append(variable_scope.variable(1.0, name="foo_1/bar_1"))
vs.append(variable_scope.variable(1.0, name="foo/bar_1"))
ds_context.get_replica_context().merge_call(lambda _: _)
return vs
with distribution.scope():
result = distribution.extended.call_for_each_replica(model_fn)
for v in result:
self.assertIsInstance(v, values.MirroredVariable)
self.assertEqual(4, len(result))
self.assertEqual("foo/bar:0", result[0].name)
self.assertEqual("foo_1/bar:0", result[1].name)
self.assertEqual("foo_1/bar_1:0", result[2].name)
self.assertEqual("foo/bar_1:0", result[3].name)
def testVariableWithSameCanonicalNameAcrossThreads(self, distribution):
def model_fn():
replica_id = self.evaluate(_replica_id())
v = variable_scope.variable(1.0, name="foo_" + str(replica_id))
ds_context.get_replica_context().merge_call(lambda _: _)
return v
with distribution.scope():
result = distribution.extended.call_for_each_replica(model_fn)
self.assertIsInstance(result, values.MirroredVariable)
# The resulting mirrored variable will use the name from the first device.
self.assertEqual("foo_0:0", result.name)
def testWithLayers(self, distribution):
def model_fn(features):
with variable_scope.variable_scope("common"):
layer1 = core.Dense(1)
layer1(features)
layer2 = core.Dense(1)
layer2(features)
# This will pause the current thread, and execute the other thread.
ds_context.get_replica_context().merge_call(lambda _: _)
layer3 = core.Dense(1)
layer3(features)
return [(layer1.kernel, layer1.bias),
(layer2.kernel, layer2.bias),
(layer3.kernel, layer3.bias)]
iterator = distribution.make_input_fn_iterator(
lambda _: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10))
self.evaluate(iterator.initialize())
features = iterator.get_next()
with distribution.scope():
result = distribution.extended.call_for_each_replica(
model_fn, args=(features,))
suffixes = ["", "_1", "_2"]
for (kernel, bias), suffix in zip(result, suffixes):
self.assertIsInstance(kernel, values.MirroredVariable)
self.assertEqual("common/dense" + suffix + "/kernel:0", kernel.name)
self.assertIsInstance(bias, values.MirroredVariable)
self.assertEqual("common/dense" + suffix + "/bias:0", bias.name)
def testWithVariableAndVariableScope(self, distribution):
def model_fn():
v0 = variable_scope.variable(1.0, name="var0", aggregation=None)
with variable_scope.variable_scope("common"):
v1 = variable_scope.variable(1.0, name="var1")
# This will pause the current thread, and execute the other thread.
ds_context.get_replica_context().merge_call(lambda _: _)
v2 = variable_scope.variable(
1.0,
name="var2",
synchronization=variable_scope.VariableSynchronization.ON_READ,
aggregation=variable_scope.VariableAggregation.SUM)
v3 = variable_scope.variable(
1.0,
name="var3",
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
aggregation=variable_scope.VariableAggregation.MEAN)
return v0, v1, v2, v3
with distribution.scope():
v = variable_scope.variable(1.0, name="var-main0")
self.assertEqual("var-main0:0", v.name)
result = distribution.extended.call_for_each_replica(model_fn)
self.assertEqual(4, len(result))
v0, v1, v2, v3 = result
self.assertIsInstance(v0, values.MirroredVariable)
self.assertEqual("var0:0", v0.name)
self.assertIsInstance(v1, values.MirroredVariable)
self.assertEqual("common/var1:0", v1.name)
self.assertIsInstance(v2, values.SyncOnReadVariable)
self.assertEqual("common/var2:0", v2.name)
self.assertEqual(variable_scope.VariableAggregation.SUM, v2.aggregation)
self.assertIsInstance(v3, values.MirroredVariable)
self.assertEqual("common/var3:0", v3.name)
self.assertEqual(variable_scope.VariableAggregation.MEAN, v3.aggregation)
def testWithGetVariableAndVariableScope(self, distribution):
def model_fn():
v0 = variable_scope.get_variable("var0", [1])
with variable_scope.variable_scope("common"):
v1 = variable_scope.get_variable("var1", [1])
# This will pause the current thread, and execute the other thread.
ds_context.get_replica_context().merge_call(lambda _: _)
v2 = variable_scope.get_variable(
"var2", [1],
synchronization=variable_scope.VariableSynchronization.ON_READ,
aggregation=variable_scope.VariableAggregation.SUM)
v3 = variable_scope.get_variable(
"var3", [1],
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
aggregation=variable_scope.VariableAggregation.MEAN)
return v0, v1, v2, v3
with distribution.scope():
with variable_scope.variable_scope("main"):
v = variable_scope.get_variable("var-main0", [1])
self.assertEqual("main/var-main0:0", v.name)
result = distribution.extended.call_for_each_replica(model_fn)
self.assertEqual(4, len(result))
v0, v1, v2, v3 = result
self.assertIsInstance(v0, values.MirroredVariable)
self.assertEqual("main/var0:0", v0.name)
self.assertIsInstance(v1, values.MirroredVariable)
self.assertEqual("main/common/var1:0", v1.name)
self.assertIsInstance(v2, values.SyncOnReadVariable)
self.assertEqual("main/common/var2:0", v2.name)
self.assertEqual(variable_scope.VariableAggregation.SUM,
v2.aggregation)
self.assertIsInstance(v3, values.MirroredVariable)
self.assertEqual("main/common/var3:0", v3.name)
self.assertEqual(variable_scope.VariableAggregation.MEAN,
v3.aggregation)
def testOnlyFirstReplicaUpdatesVariables(self, distribution):
def create_fn():
aggregation = variable_scope.VariableAggregation.ONLY_FIRST_REPLICA
v0 = variable_scope.variable(
2.0,
name="on_read",
synchronization=variable_scope.VariableSynchronization.ON_READ,
aggregation=aggregation)
v1 = variable_scope.variable(
3.0,
name="on_write",
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
aggregation=aggregation)
return v0, v1
devices = ["/device:GPU:0", "/device:CPU:0"]
with distribution.scope():
v0, v1 = distribution.extended.call_for_each_replica(create_fn)
self.evaluate(v0.initializer)
self.assertEqual(2.0, self.evaluate(v0.get(devices[0])))
self.assertEqual(2.0, self.evaluate(v0.get(devices[1])))
self.assertEqual(2.0, self.evaluate(distribution.extended.read_var(v0)))
self.evaluate(v1.initializer)
self.assertEqual(3.0, self.evaluate(v1.get(devices[0])))
self.assertEqual(3.0, self.evaluate(v1.get(devices[1])))
self.assertEqual(3.0, self.evaluate(distribution.extended.read_var(v1)))
def replica_id_plus_one():
return math_ops.cast(_replica_id() + 1, dtype=dtypes.float32)
# Update using the assign_add member function.
def update_member_fn():
update0 = v0.assign_add(5.0 * replica_id_plus_one())
update1 = v1.assign_add(7.0 * replica_id_plus_one())
return update0, update1
update0a, update1a = distribution.extended.call_for_each_replica(
update_member_fn)
# Update "sync on read" variable.
self.evaluate(distribution.group(update0a))
self.assertEqual(2.0 + 5.0, self.evaluate(v0.get(devices[0])))
# Writes are not synchronized for "sync on read" variables,
# so device[1] can end up with a different value.
self.assertEqual(2.0 + 2*5.0, self.evaluate(v0.get(devices[1])))
# Always reads from device 0.
self.assertEqual(2.0 + 5.0, self.evaluate(
distribution.extended.read_var(v0)))
# Update "sync on write" variable.
self.evaluate(distribution.group(update1a))
self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[0])))
# Writes are synchronized for v1, only the argument to assign_add on
# device[0] is used.
self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[1])))
self.assertEqual(3.0 + 7.0, self.evaluate(
distribution.extended.read_var(v1)))
# Update using state_ops.assign_add global function.
def update_state_ops_fn():
update0 = state_ops.assign_add(v0, 11.0 * replica_id_plus_one())
update1 = state_ops.assign_add(v1, 13.0 * replica_id_plus_one())
return update0, update1
update0b, update1b = distribution.extended.call_for_each_replica(
update_state_ops_fn)
self.evaluate(distribution.group(update0b))
# Update "sync on read" variable.
self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate(v0.get(devices[0])))
self.assertEqual(2.0 + 2*5.0 + 2*11.0, self.evaluate(v0.get(devices[1])))
self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate(
distribution.extended.read_var(v0)))
# Update "sync on write" variable.
self.evaluate(distribution.group(update1b))
self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[0])))
self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[1])))
self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(
distribution.extended.read_var(v1)))
def testNoneSynchronizationWithGetVariable(self, distribution):
with distribution.scope():
with self.assertRaisesRegexp(
ValueError, "`NONE` variable synchronization mode is not "
"supported with `Mirrored` distribution strategy. Please change "
"the `synchronization` for variable: v"):
variable_scope.get_variable(
"v", [1],
synchronization=variable_scope.VariableSynchronization.NONE)
def testNoneSynchronizationWithVariable(self, distribution):
with distribution.scope():
with self.assertRaisesRegexp(
ValueError, "`NONE` variable synchronization mode is not "
"supported with `Mirrored` distribution strategy. Please change "
"the `synchronization` for variable: v"):
variable_scope.variable(
1.0,
name="v",
synchronization=variable_scope.VariableSynchronization.NONE)
def testInvalidSynchronizationWithVariable(self, distribution):
with distribution.scope():
with self.assertRaisesRegexp(
ValueError, "Invalid variable synchronization mode: Invalid for "
"variable: v"):
variable_scope.variable(1.0, name="v", synchronization="Invalid")
def testInvalidAggregationWithGetVariable(self, distribution):
with distribution.scope():
with self.assertRaisesRegexp(
ValueError, "Invalid variable aggregation mode: invalid for "
"variable: v"):
variable_scope.get_variable(
"v", [1],
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
aggregation="invalid")
def testInvalidAggregationWithVariable(self, distribution):
with distribution.scope():
with self.assertRaisesRegexp(
ValueError, "Invalid variable aggregation mode: invalid for "
"variable: v"):
variable_scope.variable(
1.0,
name="v",
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
aggregation="invalid")
def testNonMatchingVariableCreation(self, distribution):
self.skipTest("b/123075960")
def model_fn(name):
v = variable_scope.variable(1.0, name=name)
ds_context.get_replica_context().merge_call(lambda _: _)
return v
with distribution.scope():
device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0"))
names = values.DistributedValues(device_map, ("foo", "bar"))
with self.assertRaises(RuntimeError):
_ = distribution.extended.call_for_each_replica(model_fn, args=(names,))
def testSyncOnReadVariable(self, distribution):
all_v_sum = {}
all_v_mean = {}
components_sum = {}
components_mean = {}
def model_fn():
replica_id = self.evaluate(_replica_id())
v_sum = variable_scope.variable(
1.0,
synchronization=variable_scope.VariableSynchronization.ON_READ,
aggregation=variable_scope.VariableAggregation.SUM)
v_mean = variable_scope.variable(
4.0,
synchronization=variable_scope.VariableSynchronization.ON_READ,
aggregation=variable_scope.VariableAggregation.MEAN)
self.assertIsInstance(v_sum, values.SyncOnReadVariable)
self.assertIsInstance(v_mean, values.SyncOnReadVariable)
updates = [v_sum.assign_add(2.0 + replica_id),
v_mean.assign(6.0 * replica_id)]
all_v_sum[replica_id] = v_sum
all_v_mean[replica_id] = v_mean
c_sum = v_sum.get()
c_mean = v_mean.get()
components_sum[replica_id] = c_sum
components_mean[replica_id] = c_mean
self.assertIsNot(v_sum, c_sum)
self.assertIsNot(v_mean, c_mean)
return updates, v_sum, v_mean, c_sum, c_mean
with distribution.scope():
# Create "sum" and "mean" versions of SyncOnReadVariables.
ret_ops, ret_v_sum, ret_v_mean, regrouped_sum, regrouped_mean = (
distribution.extended.call_for_each_replica(model_fn))
# Should see the same wrapping instance in all replicas.
self.assertIs(all_v_sum[0], ret_v_sum)
self.assertIs(all_v_mean[0], ret_v_mean)
self.assertIs(all_v_sum[0], all_v_sum[1])
self.assertIs(all_v_mean[0], all_v_mean[1])
# Regroup should recover the same wrapper.
self.assertIs(ret_v_sum, regrouped_sum)
self.assertIs(ret_v_mean, regrouped_mean)
self.assertIsNot(components_sum[0], components_sum[1])
self.assertIsNot(components_mean[0], components_mean[1])
# Apply updates
self.evaluate(variables.global_variables_initializer())
self.evaluate([y for x in ret_ops # pylint: disable=g-complex-comprehension
for y in distribution.experimental_local_results(x)])
expected_sum = 0.0
expected_mean = 0.0
for i, d in enumerate(distribution.extended.worker_devices):
# Should see different values on different devices.
v_sum_value = self.evaluate(ret_v_sum.get(d).read_value())
v_mean_value = self.evaluate(ret_v_mean.get(d).read_value())
expected = i + 3.0
self.assertEqual(expected, v_sum_value)
expected_sum += expected
expected = i * 6.0
self.assertEqual(expected, v_mean_value)
expected_mean += expected
expected_mean /= len(distribution.extended.worker_devices)
# Without get(device), should return the value you get by
# applying the reduction across all replicas (whether you use
# read_var(), get(), or nothing).
self.assertEqual(expected_sum, self.evaluate(
distribution.extended.read_var(ret_v_sum)))
self.assertEqual(expected_mean, self.evaluate(
distribution.extended.read_var(ret_v_mean)))
self.assertEqual(expected_sum, self.evaluate(ret_v_sum.get()))
self.assertEqual(expected_mean, self.evaluate(ret_v_mean.get()))
self.assertEqual(expected_sum, self.evaluate(ret_v_sum))
self.assertEqual(expected_mean, self.evaluate(ret_v_mean))
# TODO(priyag): Update this test to work in eager mode as well.
def testDynamicRnnVariables(self, distribution):
def model_fn():
inputs = constant_op.constant(2 * [2 * [[0.0, 1.0, 2.0, 3.0, 4.0]]])
cell_fw = rnn_cell_impl.LSTMCell(300)
cell_bw = rnn_cell_impl.LSTMCell(300)
(outputs, _) = rnn.bidirectional_dynamic_rnn(
cell_fw,
cell_bw,
inputs,
dtype=dtypes.float32)
return outputs
with context.graph_mode(), distribution.scope():
result = distribution.extended.call_for_each_replica(model_fn)
# Two variables are created by the RNN layer.
self.assertEqual(2, len(result))
for v in result:
self.assertIsInstance(v, values.DistributedValues)
_, v1 = distribution.experimental_local_results(v)
self.assertStartsWith(v1._op.name, "replica_1/")
def testSyncOnReadVariableUpdate(self, distribution):
def model_fn():
v_sum = variable_scope.variable(
1.0,
synchronization=variable_scope.VariableSynchronization.ON_READ,
aggregation=variable_scope.VariableAggregation.SUM)
self.assertIsInstance(v_sum, values.SyncOnReadVariable)
return v_sum
def update(var, value):
return var.assign(value)
with distribution.scope():
ret_v_sum = distribution.extended.call_for_each_replica(model_fn)
# Initialize variables.
self.evaluate(variables.global_variables_initializer())
# Assert that the aggregated value of the sync on read var is the sum
# of the individual values before running the update ops.
self.assertEqual(1.0, self.evaluate(ret_v_sum.get(
distribution.extended.worker_devices[0]).read_value()))
self.assertEqual(2.0, self.evaluate(ret_v_sum))
# Apply updates.
update_ops = distribution.extended.update(
ret_v_sum, update, args=(5.0,), group=False)
self.evaluate(update_ops)
# Assert that the aggregated value of the sync on read vars is the sum
# of the individual values after running the update ops.
self.assertEqual(5.0, self.evaluate(ret_v_sum.get(
distribution.extended.worker_devices[0]).read_value()))
self.assertEqual(10.0, self.evaluate(ret_v_sum))
def testVarDistributeStrategy(self, distribution):
with distribution.scope():
mirrored = variable_scope.variable(1.0)
sync_on_read = variable_scope.variable(
1.0,
synchronization=variable_scope.VariableSynchronization.ON_READ)
self.assertIs(distribution, mirrored.distribute_strategy)
self.assertIs(distribution, sync_on_read.distribute_strategy)
@combinations.generate(
combinations.combine(
distribution=[

View File

@ -0,0 +1,606 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test MirroredVariable in MirroredStrategy and MultiWorkerMirroredStrategy."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import values
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import ops
from tensorflow.python.layers import core
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
def _replica_id():
replica_id = ds_context.get_replica_context().replica_id_in_sync_group
if not isinstance(replica_id, ops.Tensor):
replica_id = constant_op.constant(replica_id)
return replica_id
def _mimic_two_cpus():
cpus = config.list_physical_devices("CPU")
config.set_virtual_device_configuration(cpus[0], [
context.VirtualDeviceConfiguration(),
context.VirtualDeviceConfiguration(),
])
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
combinations.NamedDistribution(
"Collective2CPUs",
# pylint: disable=g-long-lambda
lambda: collective_all_reduce_strategy.
CollectiveAllReduceStrategy._from_local_devices((
"/device:CPU:0", "/device:CPU:1")),
required_gpus=0)
],
mode=["graph", "eager"]))
class MirroredVariableCreationTest(test.TestCase):
"""Base class that tests mirrored variable creator.
Currently it assumes all strategy objects have two replicas.
"""
@classmethod
def setUpClass(cls):
_mimic_two_cpus()
# TODO(priyag): Modify more tests to use this helper and check more
# properties.
def _test_mv_properties(self, var, name, strategy):
self.assertIsInstance(var, values.MirroredVariable)
self.assertEqual(name, var.name)
self.assertIs(strategy, var.distribute_strategy)
for d in var.devices:
self.assertEqual(d, var.get(d).device)
self.assertIs(strategy, var.get(d)._distribute_strategy) # pylint: disable=protected-access
def testVariableInFuncGraph(self, distribution):
def model_fn():
v = variable_scope.variable(2.0, name="bar")
ds_context.get_replica_context().merge_call(lambda _: _)
return v
with func_graph.FuncGraph("fg").as_default(), distribution.scope():
v1 = variable_scope.variable(1.0, name="foo")
v2 = distribution.extended.call_for_each_replica(model_fn)
self._test_mv_properties(v1, "foo:0", distribution)
self._test_mv_properties(v2, "bar:0", distribution)
def testVariableWithTensorInitialValueInFunction(self, distribution):
if not context.executing_eagerly():
self.skipTest("`tf.function` is an eager-only feature")
v = [None]
def model_fn():
if v[0] is None:
init_val = array_ops.zeros([])
v[0] = variables.Variable(init_val)
ds_context.get_replica_context().merge_call(lambda _: _)
return v[0]
@def_function.function(autograph=False)
def make_v1():
return distribution.experimental_local_results(
distribution.extended.call_for_each_replica(model_fn))
self.assertAllEqual([0, 0], make_v1())
def testSingleVariable(self, distribution):
def model_fn():
# This variable should be created only once across the threads because of
# special variable_creator functions used by
# `distribution.extended.call_for_each_replica`.
v = variable_scope.variable(1.0, name="foo")
ds_context.get_replica_context().merge_call(lambda _: _)
return v
with distribution.scope():
result = distribution.extended.call_for_each_replica(model_fn)
self._test_mv_properties(result, "foo:0", distribution)
def testUnnamedVariable(self, distribution):
def model_fn():
v = variable_scope.variable(1.0)
ds_context.get_replica_context().merge_call(lambda _: _)
return v
with distribution.scope():
result = distribution.extended.call_for_each_replica(model_fn)
self._test_mv_properties(result, "Variable:0", distribution)
def testMultipleVariables(self, distribution):
def model_fn():
vs = []
for i in range(5):
vs.append(variable_scope.variable(1.0, name="foo" + str(i)))
ds_context.get_replica_context().merge_call(lambda _: _)
return vs
with distribution.scope():
result = distribution.extended.call_for_each_replica(model_fn)
for i, v in enumerate(result):
self._test_mv_properties(v, "foo" + str(i) + ":0", distribution)
def testMultipleVariablesWithSameCanonicalName(self, distribution):
def model_fn():
vs = []
vs.append(variable_scope.variable(1.0, name="foo/bar"))
vs.append(variable_scope.variable(1.0, name="foo_1/bar"))
vs.append(variable_scope.variable(1.0, name="foo_1/bar_1"))
vs.append(variable_scope.variable(1.0, name="foo/bar_1"))
ds_context.get_replica_context().merge_call(lambda _: _)
return vs
with distribution.scope():
result = distribution.extended.call_for_each_replica(model_fn)
for v in result:
self.assertIsInstance(v, values.MirroredVariable)
self.assertEqual(4, len(result))
self.assertEqual("foo/bar:0", result[0].name)
self.assertEqual("foo_1/bar:0", result[1].name)
self.assertEqual("foo_1/bar_1:0", result[2].name)
self.assertEqual("foo/bar_1:0", result[3].name)
def testVariableWithSameCanonicalNameAcrossThreads(self, distribution):
def model_fn():
replica_id = self.evaluate(_replica_id())
v = variable_scope.variable(1.0, name="foo_" + str(replica_id))
ds_context.get_replica_context().merge_call(lambda _: _)
return v
with distribution.scope():
result = distribution.extended.call_for_each_replica(model_fn)
self.assertIsInstance(result, values.MirroredVariable)
# The resulting mirrored variable will use the name from the first device.
self.assertEqual("foo_0:0", result.name)
def testWithLayers(self, distribution):
def model_fn(features):
with variable_scope.variable_scope("common"):
layer1 = core.Dense(1)
layer1(features)
layer2 = core.Dense(1)
layer2(features)
# This will pause the current thread, and execute the other thread.
ds_context.get_replica_context().merge_call(lambda _: _)
layer3 = core.Dense(1)
layer3(features)
return [(layer1.kernel, layer1.bias), (layer2.kernel, layer2.bias),
(layer3.kernel, layer3.bias)]
iterator = distribution.make_input_fn_iterator(
lambda _: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10))
self.evaluate(iterator.initialize())
features = iterator.get_next()
with distribution.scope():
result = distribution.extended.call_for_each_replica(
model_fn, args=(features,))
suffixes = ["", "_1", "_2"]
for (kernel, bias), suffix in zip(result, suffixes):
self.assertIsInstance(kernel, values.MirroredVariable)
self.assertEqual("common/dense" + suffix + "/kernel:0", kernel.name)
self.assertIsInstance(bias, values.MirroredVariable)
self.assertEqual("common/dense" + suffix + "/bias:0", bias.name)
def testWithVariableAndVariableScope(self, distribution):
def model_fn():
v0 = variable_scope.variable(1.0, name="var0", aggregation=None)
with variable_scope.variable_scope("common"):
v1 = variable_scope.variable(1.0, name="var1")
# This will pause the current thread, and execute the other thread.
ds_context.get_replica_context().merge_call(lambda _: _)
v2 = variable_scope.variable(
1.0,
name="var2",
synchronization=variable_scope.VariableSynchronization.ON_READ,
aggregation=variable_scope.VariableAggregation.SUM)
v3 = variable_scope.variable(
1.0,
name="var3",
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
aggregation=variable_scope.VariableAggregation.MEAN)
return v0, v1, v2, v3
with distribution.scope():
v = variable_scope.variable(1.0, name="var-main0")
self.assertEqual("var-main0:0", v.name)
result = distribution.extended.call_for_each_replica(model_fn)
self.assertEqual(4, len(result))
v0, v1, v2, v3 = result
self.assertIsInstance(v0, values.MirroredVariable)
self.assertEqual("var0:0", v0.name)
self.assertIsInstance(v1, values.MirroredVariable)
self.assertEqual("common/var1:0", v1.name)
self.assertIsInstance(v2, values.SyncOnReadVariable)
self.assertEqual("common/var2:0", v2.name)
self.assertEqual(variable_scope.VariableAggregation.SUM, v2.aggregation)
self.assertIsInstance(v3, values.MirroredVariable)
self.assertEqual("common/var3:0", v3.name)
self.assertEqual(variable_scope.VariableAggregation.MEAN, v3.aggregation)
def testWithGetVariableAndVariableScope(self, distribution):
def model_fn():
v0 = variable_scope.get_variable("var0", [1])
with variable_scope.variable_scope("common"):
v1 = variable_scope.get_variable("var1", [1])
# This will pause the current thread, and execute the other thread.
ds_context.get_replica_context().merge_call(lambda _: _)
v2 = variable_scope.get_variable(
"var2", [1],
synchronization=variable_scope.VariableSynchronization.ON_READ,
aggregation=variable_scope.VariableAggregation.SUM)
v3 = variable_scope.get_variable(
"var3", [1],
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
aggregation=variable_scope.VariableAggregation.MEAN)
return v0, v1, v2, v3
with distribution.scope():
with variable_scope.variable_scope("main"):
v = variable_scope.get_variable("var-main0", [1])
self.assertEqual("main/var-main0:0", v.name)
result = distribution.extended.call_for_each_replica(model_fn)
self.assertEqual(4, len(result))
v0, v1, v2, v3 = result
self.assertIsInstance(v0, values.MirroredVariable)
self.assertEqual("main/var0:0", v0.name)
self.assertIsInstance(v1, values.MirroredVariable)
self.assertEqual("main/common/var1:0", v1.name)
self.assertIsInstance(v2, values.SyncOnReadVariable)
self.assertEqual("main/common/var2:0", v2.name)
self.assertEqual(variable_scope.VariableAggregation.SUM, v2.aggregation)
self.assertIsInstance(v3, values.MirroredVariable)
self.assertEqual("main/common/var3:0", v3.name)
self.assertEqual(variable_scope.VariableAggregation.MEAN,
v3.aggregation)
def testOnlyFirstReplicaUpdatesVariables(self, distribution):
def create_fn():
aggregation = variable_scope.VariableAggregation.ONLY_FIRST_REPLICA
v0 = variable_scope.variable(
2.0,
name="on_read",
synchronization=variable_scope.VariableSynchronization.ON_READ,
aggregation=aggregation)
v1 = variable_scope.variable(
3.0,
name="on_write",
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
aggregation=aggregation)
return v0, v1
devices = distribution.extended.worker_devices
with distribution.scope():
v0, v1 = distribution.extended.call_for_each_replica(create_fn)
self.evaluate(v0.initializer)
self.assertEqual(2.0, self.evaluate(v0.get(devices[0])))
self.assertEqual(2.0, self.evaluate(v0.get(devices[1])))
self.assertEqual(2.0, self.evaluate(distribution.extended.read_var(v0)))
self.evaluate(v1.initializer)
self.assertEqual(3.0, self.evaluate(v1.get(devices[0])))
self.assertEqual(3.0, self.evaluate(v1.get(devices[1])))
self.assertEqual(3.0, self.evaluate(distribution.extended.read_var(v1)))
def replica_id_plus_one():
return math_ops.cast(_replica_id() + 1, dtype=dtypes.float32)
# Update using the assign_add member function.
def update_member_fn():
update0 = v0.assign_add(5.0 * replica_id_plus_one())
update1 = v1.assign_add(7.0 * replica_id_plus_one())
return update0, update1
update0a, update1a = distribution.extended.call_for_each_replica(
update_member_fn)
# Update "sync on read" variable.
self.evaluate(distribution.group(update0a))
self.assertEqual(2.0 + 5.0, self.evaluate(v0.get(devices[0])))
# Writes are not synchronized for "sync on read" variables,
# so device[1] can end up with a different value.
self.assertEqual(2.0 + 2 * 5.0, self.evaluate(v0.get(devices[1])))
# Always reads from device 0.
self.assertEqual(2.0 + 5.0,
self.evaluate(distribution.extended.read_var(v0)))
# Update "sync on write" variable.
self.evaluate(distribution.group(update1a))
self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[0])))
# Writes are synchronized for v1, only the argument to assign_add on
# device[0] is used.
self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[1])))
self.assertEqual(3.0 + 7.0,
self.evaluate(distribution.extended.read_var(v1)))
# Update using state_ops.assign_add global function.
def update_state_ops_fn():
update0 = state_ops.assign_add(v0, 11.0 * replica_id_plus_one())
update1 = state_ops.assign_add(v1, 13.0 * replica_id_plus_one())
return update0, update1
update0b, update1b = distribution.extended.call_for_each_replica(
update_state_ops_fn)
self.evaluate(distribution.group(update0b))
# Update "sync on read" variable.
self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate(v0.get(devices[0])))
self.assertEqual(2.0 + 2 * 5.0 + 2 * 11.0,
self.evaluate(v0.get(devices[1])))
self.assertEqual(2.0 + 5.0 + 11.0,
self.evaluate(distribution.extended.read_var(v0)))
# Update "sync on write" variable.
self.evaluate(distribution.group(update1b))
self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[0])))
self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[1])))
self.assertEqual(3.0 + 7.0 + 13.0,
self.evaluate(distribution.extended.read_var(v1)))
def testNoneSynchronizationWithGetVariable(self, distribution):
with distribution.scope():
with self.assertRaisesRegexp(
ValueError, "`NONE` variable synchronization mode is not "
"supported with `Mirrored` distribution strategy. Please change "
"the `synchronization` for variable: v"):
variable_scope.get_variable(
"v", [1],
synchronization=variable_scope.VariableSynchronization.NONE)
def testNoneSynchronizationWithVariable(self, distribution):
with distribution.scope():
with self.assertRaisesRegexp(
ValueError, "`NONE` variable synchronization mode is not "
"supported with `Mirrored` distribution strategy. Please change "
"the `synchronization` for variable: v"):
variable_scope.variable(
1.0,
name="v",
synchronization=variable_scope.VariableSynchronization.NONE)
def testInvalidSynchronizationWithVariable(self, distribution):
with distribution.scope():
with self.assertRaisesRegexp(
ValueError, "Invalid variable synchronization mode: Invalid for "
"variable: v"):
variable_scope.variable(1.0, name="v", synchronization="Invalid")
def testInvalidAggregationWithGetVariable(self, distribution):
with distribution.scope():
with self.assertRaisesRegexp(
ValueError, "Invalid variable aggregation mode: invalid for "
"variable: v"):
variable_scope.get_variable(
"v", [1],
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
aggregation="invalid")
def testInvalidAggregationWithVariable(self, distribution):
with distribution.scope():
with self.assertRaisesRegexp(
ValueError, "Invalid variable aggregation mode: invalid for "
"variable: v"):
variable_scope.variable(
1.0,
name="v",
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
aggregation="invalid")
def testNonMatchingVariableCreation(self, distribution):
self.skipTest("b/123075960")
def model_fn(name):
v = variable_scope.variable(1.0, name=name)
ds_context.get_replica_context().merge_call(lambda _: _)
return v
with distribution.scope():
device_map = values.ReplicaDeviceMap(distribution.extended.worker_devices)
names = values.DistributedValues(device_map, ("foo", "bar"))
with self.assertRaises(RuntimeError):
_ = distribution.extended.call_for_each_replica(model_fn, args=(names,))
def testSyncOnReadVariable(self, distribution):
all_v_sum = {}
all_v_mean = {}
components_sum = {}
components_mean = {}
def model_fn():
replica_id = self.evaluate(_replica_id())
v_sum = variable_scope.variable(
1.0,
synchronization=variable_scope.VariableSynchronization.ON_READ,
aggregation=variable_scope.VariableAggregation.SUM)
v_mean = variable_scope.variable(
4.0,
synchronization=variable_scope.VariableSynchronization.ON_READ,
aggregation=variable_scope.VariableAggregation.MEAN)
self.assertIsInstance(v_sum, values.SyncOnReadVariable)
self.assertIsInstance(v_mean, values.SyncOnReadVariable)
updates = [
v_sum.assign_add(2.0 + replica_id),
v_mean.assign(6.0 * replica_id)
]
all_v_sum[replica_id] = v_sum
all_v_mean[replica_id] = v_mean
c_sum = v_sum.get()
c_mean = v_mean.get()
components_sum[replica_id] = c_sum
components_mean[replica_id] = c_mean
self.assertIsNot(v_sum, c_sum)
self.assertIsNot(v_mean, c_mean)
return updates, v_sum, v_mean, c_sum, c_mean
with distribution.scope():
# Create "sum" and "mean" versions of SyncOnReadVariables.
ret_ops, ret_v_sum, ret_v_mean, regrouped_sum, regrouped_mean = (
distribution.extended.call_for_each_replica(model_fn))
# Should see the same wrapping instance in all replicas.
self.assertIs(all_v_sum[0], ret_v_sum)
self.assertIs(all_v_mean[0], ret_v_mean)
self.assertIs(all_v_sum[0], all_v_sum[1])
self.assertIs(all_v_mean[0], all_v_mean[1])
# Regroup should recover the same wrapper.
self.assertIs(ret_v_sum, regrouped_sum)
self.assertIs(ret_v_mean, regrouped_mean)
self.assertIsNot(components_sum[0], components_sum[1])
self.assertIsNot(components_mean[0], components_mean[1])
# Apply updates
self.evaluate(variables.global_variables_initializer())
self.evaluate([
y for x in ret_ops # pylint: disable=g-complex-comprehension
for y in distribution.experimental_local_results(x)
])
expected_sum = 0.0
expected_mean = 0.0
for i, d in enumerate(distribution.extended.worker_devices):
# Should see different values on different devices.
v_sum_value = self.evaluate(ret_v_sum.get(d).read_value())
v_mean_value = self.evaluate(ret_v_mean.get(d).read_value())
expected = i + 3.0
self.assertEqual(expected, v_sum_value)
expected_sum += expected
expected = i * 6.0
self.assertEqual(expected, v_mean_value)
expected_mean += expected
expected_mean /= len(distribution.extended.worker_devices)
# Without get(device), should return the value you get by
# applying the reduction across all replicas (whether you use
# read_var(), get(), or nothing).
self.assertEqual(expected_sum, self.evaluate(
distribution.extended.read_var(ret_v_sum)))
self.assertEqual(expected_mean, self.evaluate(
distribution.extended.read_var(ret_v_mean)))
self.assertEqual(expected_sum, self.evaluate(ret_v_sum.get()))
self.assertEqual(expected_mean, self.evaluate(ret_v_mean.get()))
self.assertEqual(expected_sum, self.evaluate(ret_v_sum))
self.assertEqual(expected_mean, self.evaluate(ret_v_mean))
# TODO(priyag): Update this test to work in eager mode as well.
def testDynamicRnnVariables(self, distribution):
def model_fn():
inputs = constant_op.constant(2 * [2 * [[0.0, 1.0, 2.0, 3.0, 4.0]]])
cell_fw = rnn_cell_impl.LSTMCell(300)
cell_bw = rnn_cell_impl.LSTMCell(300)
(outputs, _) = rnn.bidirectional_dynamic_rnn(
cell_fw, cell_bw, inputs, dtype=dtypes.float32)
return outputs
with context.graph_mode(), distribution.scope():
result = distribution.extended.call_for_each_replica(model_fn)
# Two variables are created by the RNN layer.
self.assertEqual(2, len(result))
for v in result:
self.assertIsInstance(v, values.DistributedValues)
_, v1 = distribution.experimental_local_results(v)
self.assertStartsWith(v1._op.name, "replica_1/")
def testSyncOnReadVariableUpdate(self, distribution):
def model_fn():
v_sum = variable_scope.variable(
1.0,
synchronization=variable_scope.VariableSynchronization.ON_READ,
aggregation=variable_scope.VariableAggregation.SUM)
self.assertIsInstance(v_sum, values.SyncOnReadVariable)
return v_sum
def update(var, value):
return var.assign(value)
with distribution.scope():
ret_v_sum = distribution.extended.call_for_each_replica(model_fn)
# Initialize variables.
self.evaluate(variables.global_variables_initializer())
# Assert that the aggregated value of the sync on read var is the sum
# of the individual values before running the update ops.
self.assertEqual(
1.0,
self.evaluate(
ret_v_sum.get(
distribution.extended.worker_devices[0]).read_value()))
self.assertEqual(2.0, self.evaluate(ret_v_sum))
# Apply updates.
update_ops = distribution.extended.update(
ret_v_sum, update, args=(5.0,), group=False)
self.evaluate(update_ops)
# Assert that the aggregated value of the sync on read vars is the sum
# of the individual values after running the update ops.
self.assertEqual(
5.0,
self.evaluate(
ret_v_sum.get(
distribution.extended.worker_devices[0]).read_value()))
self.assertEqual(10.0, self.evaluate(ret_v_sum))
def testVarDistributeStrategy(self, distribution):
with distribution.scope():
mirrored = variable_scope.variable(1.0)
sync_on_read = variable_scope.variable(
1.0, synchronization=variable_scope.VariableSynchronization.ON_READ)
self.assertIs(distribution, mirrored.distribute_strategy)
self.assertIs(distribution, sync_on_read.distribute_strategy)
if __name__ == "__main__":
test.main()