Use an incremental counter for collective keys of variables;
Make MultiWorkerMirroredStrategy reuse more code from MirroredStrategy in the variable creator; Moved mirrored variable creator test to a new file and add coverage for MultiWorkerMirroredStrategy; Fixed a bug in cross_device_ops that sometimes output of allreduce ops are not all run in graph mode. PiperOrigin-RevId: 251974203
This commit is contained in:
parent
f1ffa0225a
commit
f5003f6315
@ -115,7 +115,7 @@ class CollectiveAllReduceStrategyTestBase(
|
||||
def setUp(self):
|
||||
# We use a different key_base for each test so that collective keys won't be
|
||||
# reused.
|
||||
# TODO(yuefengz, tucker): enable it to reuse collective keys in different
|
||||
# TODO(yuefengz, ayushd): enable it to reuse collective keys in different
|
||||
# tests.
|
||||
CollectiveAllReduceStrategyTestBase.collective_key_base += 100000
|
||||
super(CollectiveAllReduceStrategyTestBase, self).setUp()
|
||||
@ -133,11 +133,11 @@ class CollectiveAllReduceStrategyTestBase(
|
||||
use_core_strategy=use_core_strategy)
|
||||
|
||||
collective_keys = cross_device_utils.CollectiveKeys(
|
||||
group_key_start=10 * num_gpus +
|
||||
group_key_start=10 +
|
||||
CollectiveAllReduceStrategyTestBase.collective_key_base,
|
||||
instance_key_start=num_gpus * 100 +
|
||||
op_instance_key_start=100 +
|
||||
CollectiveAllReduceStrategyTestBase.collective_key_base,
|
||||
instance_key_with_id_start=num_gpus * 10000 +
|
||||
variable_instance_key_start=10000 +
|
||||
CollectiveAllReduceStrategyTestBase.collective_key_base)
|
||||
strategy.extended._collective_keys = collective_keys
|
||||
strategy.extended._cross_device_ops._collective_keys = (collective_keys)
|
||||
|
@ -966,6 +966,32 @@ cuda_py_test(
|
||||
xla_enable_strict_auto_jit = True,
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "mirrored_variable_test",
|
||||
srcs = ["mirrored_variable_test.py"],
|
||||
additional_deps = [
|
||||
":collective_all_reduce_strategy",
|
||||
":combinations",
|
||||
":strategy_combinations",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:layers",
|
||||
"//tensorflow/python:state_ops",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/distribute:distribute_lib",
|
||||
"//tensorflow/python/distribute:values",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:test",
|
||||
],
|
||||
tags = [
|
||||
"guitar",
|
||||
"multi_and_single_gpu",
|
||||
],
|
||||
xla_enable_strict_auto_jit = True,
|
||||
)
|
||||
|
||||
distribute_py_test(
|
||||
name = "metrics_v1_test",
|
||||
srcs = ["metrics_v1_test.py"],
|
||||
|
@ -36,7 +36,6 @@ from tensorflow.python.distribute import values
|
||||
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import tape
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import collective_ops
|
||||
@ -79,6 +78,13 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy):
|
||||
self,
|
||||
communication=communication))
|
||||
|
||||
@classmethod
|
||||
def _from_local_devices(cls, devices):
|
||||
"""A convenience method to create an obejct with a list of devices."""
|
||||
obj = cls()
|
||||
obj.extended._initialize_local(TFConfigClusterResolver(), devices=devices) # pylint: disable=protected-access
|
||||
return obj
|
||||
|
||||
|
||||
@tf_export(v1=["distribute.experimental.MultiWorkerMirroredStrategy"])
|
||||
class CollectiveAllReduceStrategyV1(distribute_lib.StrategyV1):
|
||||
@ -117,7 +123,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
else:
|
||||
self._initialize_local(cluster_resolver)
|
||||
|
||||
def _initialize_local(self, cluster_resolver):
|
||||
def _initialize_local(self, cluster_resolver, devices=None):
|
||||
"""Initializes the object for local training."""
|
||||
self._is_chief = True
|
||||
self._num_workers = 1
|
||||
@ -140,10 +146,13 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
else:
|
||||
num_gpus = cluster_resolver.num_accelerators().get("GPU", 0)
|
||||
|
||||
if num_gpus:
|
||||
local_devices = tuple("/device:GPU:%d" % i for i in range(num_gpus))
|
||||
if devices:
|
||||
local_devices = devices
|
||||
else:
|
||||
local_devices = ("/device:CPU:0",)
|
||||
if num_gpus:
|
||||
local_devices = tuple("/device:GPU:%d" % i for i in range(num_gpus))
|
||||
else:
|
||||
local_devices = ("/device:CPU:0",)
|
||||
self._worker_device = device_util.canonicalize("/device:CPU:0")
|
||||
self._host_input_device = numpy_dataset.SingleDevice(self._worker_device)
|
||||
|
||||
@ -272,100 +281,52 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
task_id, self._num_workers, local_devices,
|
||||
self._communication)
|
||||
|
||||
def _create_variable(self, next_creator, *args, **kwargs):
|
||||
colocate_with = kwargs.pop("colocate_with", None)
|
||||
if colocate_with is None:
|
||||
device_map = self._device_map
|
||||
logical_device = 0 # TODO(josh11b): Get logical device from scope here.
|
||||
elif isinstance(colocate_with, numpy_dataset.SingleDevice):
|
||||
with ops.device(colocate_with.device):
|
||||
return next_creator(*args, **kwargs)
|
||||
def _get_variable_creator_initial_value(self,
|
||||
replica_id=0,
|
||||
device=None,
|
||||
primary_var=None,
|
||||
**kwargs):
|
||||
if replica_id == 0: # First replica on each worker.
|
||||
assert device is not None
|
||||
assert primary_var is None
|
||||
|
||||
def initial_value_fn(): # pylint: disable=g-missing-docstring
|
||||
# Only the first device participates in the broadcast of initial values.
|
||||
group_key = self._collective_keys.get_group_key([device])
|
||||
group_size = self._num_workers
|
||||
collective_instance_key = (
|
||||
self._collective_keys.get_variable_instance_key())
|
||||
|
||||
with ops.device(device):
|
||||
initial_value = kwargs["initial_value"]
|
||||
if callable(initial_value):
|
||||
initial_value = initial_value()
|
||||
assert not callable(initial_value)
|
||||
initial_value = ops.convert_to_tensor(
|
||||
initial_value, dtype=kwargs.get("dtype", None))
|
||||
|
||||
if self._num_workers > 1:
|
||||
if self._is_chief:
|
||||
bcast_send = collective_ops.broadcast_send(
|
||||
initial_value, initial_value.shape, initial_value.dtype,
|
||||
group_size, group_key, collective_instance_key)
|
||||
with ops.control_dependencies([bcast_send]):
|
||||
return array_ops.identity(initial_value)
|
||||
else:
|
||||
return collective_ops.broadcast_recv(initial_value.shape,
|
||||
initial_value.dtype,
|
||||
group_size, group_key,
|
||||
collective_instance_key)
|
||||
return initial_value
|
||||
|
||||
return initial_value_fn
|
||||
else:
|
||||
device_map = colocate_with.device_map
|
||||
logical_device = colocate_with.logical_device
|
||||
|
||||
def _real_mirrored_creator(devices, *args, **kwargs):
|
||||
"""Creates one MirroredVariable on the current worker."""
|
||||
unique_var_name = ops.get_default_graph().unique_name(
|
||||
kwargs["name"], mark_as_used=False).rstrip("/")
|
||||
# pylint: disable=protected-access
|
||||
collective_instance_key = self._collective_keys.get_instance_key(
|
||||
key_id=unique_var_name)
|
||||
# Only the first device participles in the broadcast of initial values.
|
||||
group_key = self._collective_keys.get_group_key([devices[0]])
|
||||
group_size = self._num_workers
|
||||
if "initial_value" not in kwargs:
|
||||
raise ValueError("Initial value must be specified.")
|
||||
initial_value = kwargs["initial_value"]
|
||||
if callable(initial_value):
|
||||
initial_value_fn = initial_value
|
||||
else:
|
||||
initial_value_fn = lambda: initial_value
|
||||
|
||||
value_list = []
|
||||
for i, d in enumerate(devices):
|
||||
with ops.init_scope(), ops.device(d):
|
||||
if i == 0:
|
||||
# The initial value fn makes sure variables all initialized to
|
||||
# same values. The first device of the chief worker will send their
|
||||
# variable values to other workers.
|
||||
def _overridden_initial_value_fn(device=d, index=i): # pylint: disable=g-missing-docstring
|
||||
with ops.device(device):
|
||||
initial_value = initial_value_fn()
|
||||
assert not callable(initial_value)
|
||||
initial_value = ops.convert_to_tensor(
|
||||
initial_value, dtype=kwargs.get("dtype", None))
|
||||
|
||||
assert index == 0, index
|
||||
if self._num_workers > 1:
|
||||
if self._is_chief:
|
||||
bcast_send = collective_ops.broadcast_send(
|
||||
initial_value, initial_value.shape, initial_value.dtype,
|
||||
group_size, group_key, collective_instance_key)
|
||||
with ops.control_dependencies([bcast_send]):
|
||||
return array_ops.identity(initial_value)
|
||||
else:
|
||||
return collective_ops.broadcast_recv(
|
||||
initial_value.shape, initial_value.dtype, group_size,
|
||||
group_key, collective_instance_key)
|
||||
return initial_value
|
||||
else:
|
||||
# Give replicas meaningful distinct names:
|
||||
var0name = value_list[0].name.split(":")[0]
|
||||
# We append a / to variable names created on replicas with id > 0 to
|
||||
# ensure that we ignore the name scope and instead use the given
|
||||
# name as the absolute name of the variable.
|
||||
kwargs["name"] = "%s/replica_%d/" % (var0name, i)
|
||||
|
||||
# Variables on non-first replica get initial values from the
|
||||
# variables created on the first device of each worker.
|
||||
def _overridden_initial_value_fn(device=d, index=i):
|
||||
assert index > 0
|
||||
with ops.device(device):
|
||||
if context.executing_eagerly():
|
||||
return array_ops.identity(value_list[0].value())
|
||||
else:
|
||||
return array_ops.identity(value_list[0].initial_value)
|
||||
|
||||
kwargs["initial_value"] = _overridden_initial_value_fn
|
||||
with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
|
||||
# Don't record operations (e.g. other variable reads) during
|
||||
# variable creation.
|
||||
with tape.stop_recording():
|
||||
v = next_creator(*args, **kwargs)
|
||||
|
||||
if i == 0:
|
||||
actual_var_name = v.name.split(":")[0]
|
||||
assert unique_var_name == actual_var_name, "%r vs %r" % (
|
||||
unique_var_name, actual_var_name)
|
||||
assert not isinstance(v, values.DistributedVariable)
|
||||
value_list.append(v)
|
||||
return value_list
|
||||
|
||||
# pylint: disable=protected-access
|
||||
return mirrored_strategy._create_mirrored_variable(
|
||||
self._container_strategy(), device_map, logical_device,
|
||||
_real_mirrored_creator, *args, **kwargs)
|
||||
return super(CollectiveAllReduceExtended,
|
||||
self)._get_variable_creator_initial_value(
|
||||
replica_id=replica_id,
|
||||
device=device,
|
||||
primary_var=primary_var,
|
||||
**kwargs)
|
||||
|
||||
def _make_input_context(self):
|
||||
if self._cluster_spec is None:
|
||||
|
@ -998,14 +998,15 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
return all_reduced
|
||||
devices = device_map.logical_to_actual_devices(logical_device)
|
||||
index = []
|
||||
for d in devices:
|
||||
if d in all_reduced.devices:
|
||||
index.append(all_reduced.get(d))
|
||||
else:
|
||||
# TODO(josh11b): Once we add support for model parallelism, get the
|
||||
# copy from the corresponding replica instead of the primary.
|
||||
with ops.control_dependencies(all_reduced.values), ops.device(d):
|
||||
index.append(array_ops.identity(all_reduced.primary))
|
||||
with ops.control_dependencies(all_reduced.values):
|
||||
for d in devices:
|
||||
with ops.device(d):
|
||||
if d in all_reduced.devices:
|
||||
index.append(array_ops.identity(all_reduced.get(d)))
|
||||
else:
|
||||
# TODO(josh11b): Once we add support for model parallelism, get the
|
||||
# copy from the corresponding replica instead of the primary.
|
||||
index.append(array_ops.identity(all_reduced.primary))
|
||||
|
||||
return value_lib.Mirrored(device_map, index, logical_device)
|
||||
|
||||
|
@ -446,11 +446,9 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
|
||||
use_strategy_object=False,
|
||||
local_mode=False):
|
||||
collective_keys = cross_device_utils.CollectiveKeys(
|
||||
group_key_start=10 * num_gpus +
|
||||
CollectiveAllReduceTest.collective_key_base,
|
||||
instance_key_start=num_gpus * 100 +
|
||||
CollectiveAllReduceTest.collective_key_base,
|
||||
instance_key_with_id_start=num_gpus * 10000 +
|
||||
group_key_start=10 + CollectiveAllReduceTest.collective_key_base,
|
||||
op_instance_key_start=100 + CollectiveAllReduceTest.collective_key_base,
|
||||
variable_instance_key_start=10000 +
|
||||
CollectiveAllReduceTest.collective_key_base)
|
||||
if local_mode:
|
||||
if num_gpus:
|
||||
|
@ -253,31 +253,28 @@ class CollectiveKeys(object):
|
||||
|
||||
def __init__(self,
|
||||
group_key_start=1,
|
||||
instance_key_start=100,
|
||||
instance_key_with_id_start=10000):
|
||||
op_instance_key_start=100,
|
||||
variable_instance_key_start=1000000):
|
||||
"""Initializes the object.
|
||||
|
||||
Args:
|
||||
group_key_start: the starting integer of group key.
|
||||
instance_key_start: the starting integer of instance key.
|
||||
instance_key_with_id_start: the starting integer of instance key that is
|
||||
recorded with an id.
|
||||
op_instance_key_start: the starting integer of instance key for ops.
|
||||
variable_instance_key_start: the starting integer of instance key for
|
||||
variables.
|
||||
"""
|
||||
self._group_key = group_key_start
|
||||
self._group_key_table = {}
|
||||
|
||||
# For instance keys with ids
|
||||
self._instance_key_id_to_key_table = {}
|
||||
self._instance_key_with_id_counter = instance_key_with_id_start
|
||||
|
||||
# For instance keys without ids
|
||||
self._instance_key_start = instance_key_start
|
||||
assert op_instance_key_start != variable_instance_key_start
|
||||
self._op_instance_key_start = op_instance_key_start
|
||||
self._variable_instance_key = variable_instance_key_start
|
||||
|
||||
def _get_thread_local_object(self):
|
||||
# We make instance key without key ids thread local so that it will work
|
||||
# with MirroredStrategy and distribute coordinator.
|
||||
if not hasattr(_thread_local, 'instance_key'):
|
||||
_thread_local.instance_key = self._instance_key_start
|
||||
if not hasattr(_thread_local, 'op_instance_key'):
|
||||
_thread_local.op_instance_key = self._op_instance_key_start
|
||||
return _thread_local
|
||||
|
||||
def get_group_key(self, devices):
|
||||
@ -304,25 +301,17 @@ class CollectiveKeys(object):
|
||||
self._group_key_table[key_id] = new_key
|
||||
return self._group_key_table[key_id]
|
||||
|
||||
def get_instance_key(self, key_id=None):
|
||||
"""Returns a new instance key for use in defining a collective op.
|
||||
def get_op_instance_key(self):
|
||||
"""Returns a new instance key for use in defining a collective op."""
|
||||
v = self._get_thread_local_object().op_instance_key
|
||||
self._get_thread_local_object().op_instance_key += 1
|
||||
return v
|
||||
|
||||
Args:
|
||||
key_id: optional string. If set, key will be recorded and the same key
|
||||
will be returned when the same key_id is provided. If not, an increasing
|
||||
instance key will be returned.
|
||||
"""
|
||||
if key_id:
|
||||
with _lock:
|
||||
if key_id not in self._instance_key_id_to_key_table:
|
||||
self._instance_key_with_id_counter += 1
|
||||
self._instance_key_id_to_key_table[key_id] = (
|
||||
self._instance_key_with_id_counter)
|
||||
return self._instance_key_id_to_key_table[key_id]
|
||||
else:
|
||||
v = self._get_thread_local_object().instance_key
|
||||
self._get_thread_local_object().instance_key += 1
|
||||
return v
|
||||
def get_variable_instance_key(self):
|
||||
"""Returns a new instance key for use in creating a Variable."""
|
||||
v = self._variable_instance_key
|
||||
self._variable_instance_key += 1
|
||||
return v
|
||||
|
||||
|
||||
def build_collective_reduce(input_tensors,
|
||||
@ -354,7 +343,7 @@ def build_collective_reduce(input_tensors,
|
||||
devices = [t.device for t in input_tensors]
|
||||
num_devices = len(devices)
|
||||
group_key = collective_keys.get_group_key(devices)
|
||||
instance_key = collective_keys.get_instance_key()
|
||||
instance_key = collective_keys.get_op_instance_key()
|
||||
subdiv_offsets = [0] # TODO(tucker): maybe support non-default subdiv spec
|
||||
|
||||
def collective_all_reduce():
|
||||
@ -399,7 +388,7 @@ def build_collective_gather(input_tensors, num_workers, collective_keys):
|
||||
devices = [t.device for t in input_tensors]
|
||||
num_devices = len(devices)
|
||||
group_key = collective_keys.get_group_key(devices)
|
||||
instance_key = collective_keys.get_instance_key()
|
||||
instance_key = collective_keys.get_op_instance_key()
|
||||
|
||||
def collective_all_gather():
|
||||
"""Call collective allgather."""
|
||||
|
@ -532,6 +532,30 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
# containing job names.
|
||||
self._inferred_cross_device_ops = cross_device_ops_lib.NcclAllReduce()
|
||||
|
||||
def _get_variable_creator_initial_value(self,
|
||||
replica_id=0,
|
||||
device=None,
|
||||
primary_var=None,
|
||||
**kwargs):
|
||||
"""Return the initial value for variables on a replica."""
|
||||
if replica_id == 0:
|
||||
return kwargs["initial_value"]
|
||||
else:
|
||||
assert primary_var is not None
|
||||
assert device is not None
|
||||
assert kwargs is not None
|
||||
|
||||
def initial_value_fn():
|
||||
if context.executing_eagerly() or ops.inside_function():
|
||||
init_value = primary_var.value()
|
||||
return array_ops.identity(init_value)
|
||||
else:
|
||||
with ops.device(device):
|
||||
init_value = primary_var.initial_value
|
||||
return array_ops.identity(init_value)
|
||||
|
||||
return initial_value_fn
|
||||
|
||||
def _create_variable(self, next_creator, *args, **kwargs):
|
||||
"""Create a mirrored variable. See `DistributionStrategy.scope`."""
|
||||
colocate_with = kwargs.pop("colocate_with", None)
|
||||
@ -549,6 +573,11 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
value_list = []
|
||||
for i, d in enumerate(devices):
|
||||
with ops.device(d):
|
||||
kwargs["initial_value"] = self._get_variable_creator_initial_value(
|
||||
replica_id=i,
|
||||
device=d,
|
||||
primary_var=value_list[0] if value_list else None,
|
||||
**kwargs)
|
||||
if i > 0:
|
||||
# Give replicas meaningful distinct names:
|
||||
var0name = value_list[0].name.split(":")[0]
|
||||
@ -556,16 +585,6 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
# ensure that we ignore the name scope and instead use the given
|
||||
# name as the absolute name of the variable.
|
||||
kwargs["name"] = "%s/replica_%d/" % (var0name, i)
|
||||
# Initialize replicas with the same value:
|
||||
def initial_value_fn(device=d):
|
||||
if context.executing_eagerly() or ops.inside_function():
|
||||
init_value = value_list[0].value()
|
||||
return array_ops.identity(init_value)
|
||||
else:
|
||||
with ops.device(device):
|
||||
init_value = value_list[0].initial_value
|
||||
return array_ops.identity(init_value)
|
||||
kwargs["initial_value"] = initial_value_fn
|
||||
with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
|
||||
# Don't record operations (e.g. other variable reads) during
|
||||
# variable creation.
|
||||
@ -749,8 +768,8 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
reduce_op, value, destinations=destinations)
|
||||
|
||||
def _batch_reduce_to(self, reduce_op, value_destination_pairs):
|
||||
return self._get_cross_device_ops().batch_reduce(
|
||||
reduce_op, value_destination_pairs)
|
||||
return self._get_cross_device_ops().batch_reduce(reduce_op,
|
||||
value_destination_pairs)
|
||||
|
||||
def _update(self, var, fn, args, kwargs, group):
|
||||
# TODO(josh11b): In eager mode, use one thread per device.
|
||||
|
@ -48,12 +48,8 @@ from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.keras.engine import training as keras_training
|
||||
from tensorflow.python.keras.layers import core as keras_core
|
||||
from tensorflow.python.layers import core
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import rnn
|
||||
from tensorflow.python.ops import rnn_cell_impl
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training import gradient_descent
|
||||
@ -370,516 +366,6 @@ class MirroredStrategyCallForEachReplicaTest(test.TestCase):
|
||||
distribution.extended.call_for_each_replica(model_fn)
|
||||
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
],
|
||||
mode=["graph", "eager"]))
|
||||
class MirroredStrategyVariableCreationTest(test.TestCase):
|
||||
|
||||
# TODO(priyag): Modify more tests to use this helper and check more
|
||||
# properties.
|
||||
def _test_mv_properties(self, var, name, strategy):
|
||||
self.assertIsInstance(var, values.MirroredVariable)
|
||||
self.assertEqual(name, var.name)
|
||||
self.assertIs(strategy, var.distribute_strategy)
|
||||
for d in var.devices:
|
||||
self.assertEqual(d, var.get(d).device)
|
||||
self.assertIs(strategy, var.get(d)._distribute_strategy) # pylint: disable=protected-access
|
||||
|
||||
def testVariableInFuncGraph(self, distribution):
|
||||
def model_fn():
|
||||
v = variable_scope.variable(2.0, name="bar")
|
||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||
return v
|
||||
|
||||
with func_graph.FuncGraph("fg").as_default(), distribution.scope():
|
||||
v1 = variable_scope.variable(1.0, name="foo")
|
||||
v2 = distribution.extended.call_for_each_replica(model_fn)
|
||||
|
||||
self._test_mv_properties(v1, "foo:0", distribution)
|
||||
self._test_mv_properties(v2, "bar:0", distribution)
|
||||
|
||||
def testVariableWithTensorInitialValueInFunction(self, distribution):
|
||||
if not context.executing_eagerly():
|
||||
self.skipTest("`tf.function` is an eager-only feature")
|
||||
|
||||
v = [None]
|
||||
def model_fn():
|
||||
if v[0] is None:
|
||||
init_val = array_ops.zeros([])
|
||||
v[0] = variables.Variable(init_val)
|
||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||
return v[0]
|
||||
|
||||
@def_function.function(autograph=False)
|
||||
def make_v1():
|
||||
return distribution.experimental_local_results(
|
||||
distribution.extended.call_for_each_replica(model_fn))
|
||||
|
||||
self.assertAllEqual([0, 0], make_v1())
|
||||
|
||||
def testSingleVariable(self, distribution):
|
||||
def model_fn():
|
||||
# This variable should be created only once across the threads because of
|
||||
# special variable_creator functions used by
|
||||
# `distribution.extended.call_for_each_replica`.
|
||||
v = variable_scope.variable(1.0, name="foo")
|
||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||
return v
|
||||
|
||||
with distribution.scope():
|
||||
result = distribution.extended.call_for_each_replica(model_fn)
|
||||
self._test_mv_properties(result, "foo:0", distribution)
|
||||
|
||||
def testUnnamedVariable(self, distribution):
|
||||
def model_fn():
|
||||
v = variable_scope.variable(1.0)
|
||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||
return v
|
||||
|
||||
with distribution.scope():
|
||||
result = distribution.extended.call_for_each_replica(model_fn)
|
||||
self._test_mv_properties(result, "Variable:0", distribution)
|
||||
|
||||
def testMultipleVariables(self, distribution):
|
||||
def model_fn():
|
||||
vs = []
|
||||
for i in range(5):
|
||||
vs.append(variable_scope.variable(1.0, name="foo" + str(i)))
|
||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||
return vs
|
||||
|
||||
with distribution.scope():
|
||||
result = distribution.extended.call_for_each_replica(model_fn)
|
||||
for i, v in enumerate(result):
|
||||
self._test_mv_properties(v, "foo" + str(i) + ":0", distribution)
|
||||
|
||||
def testMultipleVariablesWithSameCanonicalName(self, distribution):
|
||||
def model_fn():
|
||||
vs = []
|
||||
vs.append(variable_scope.variable(1.0, name="foo/bar"))
|
||||
vs.append(variable_scope.variable(1.0, name="foo_1/bar"))
|
||||
vs.append(variable_scope.variable(1.0, name="foo_1/bar_1"))
|
||||
vs.append(variable_scope.variable(1.0, name="foo/bar_1"))
|
||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||
return vs
|
||||
|
||||
with distribution.scope():
|
||||
result = distribution.extended.call_for_each_replica(model_fn)
|
||||
for v in result:
|
||||
self.assertIsInstance(v, values.MirroredVariable)
|
||||
self.assertEqual(4, len(result))
|
||||
self.assertEqual("foo/bar:0", result[0].name)
|
||||
self.assertEqual("foo_1/bar:0", result[1].name)
|
||||
self.assertEqual("foo_1/bar_1:0", result[2].name)
|
||||
self.assertEqual("foo/bar_1:0", result[3].name)
|
||||
|
||||
def testVariableWithSameCanonicalNameAcrossThreads(self, distribution):
|
||||
def model_fn():
|
||||
replica_id = self.evaluate(_replica_id())
|
||||
v = variable_scope.variable(1.0, name="foo_" + str(replica_id))
|
||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||
return v
|
||||
|
||||
with distribution.scope():
|
||||
result = distribution.extended.call_for_each_replica(model_fn)
|
||||
self.assertIsInstance(result, values.MirroredVariable)
|
||||
# The resulting mirrored variable will use the name from the first device.
|
||||
self.assertEqual("foo_0:0", result.name)
|
||||
|
||||
def testWithLayers(self, distribution):
|
||||
def model_fn(features):
|
||||
with variable_scope.variable_scope("common"):
|
||||
layer1 = core.Dense(1)
|
||||
layer1(features)
|
||||
layer2 = core.Dense(1)
|
||||
layer2(features)
|
||||
# This will pause the current thread, and execute the other thread.
|
||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||
layer3 = core.Dense(1)
|
||||
layer3(features)
|
||||
return [(layer1.kernel, layer1.bias),
|
||||
(layer2.kernel, layer2.bias),
|
||||
(layer3.kernel, layer3.bias)]
|
||||
|
||||
iterator = distribution.make_input_fn_iterator(
|
||||
lambda _: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10))
|
||||
self.evaluate(iterator.initialize())
|
||||
features = iterator.get_next()
|
||||
|
||||
with distribution.scope():
|
||||
result = distribution.extended.call_for_each_replica(
|
||||
model_fn, args=(features,))
|
||||
suffixes = ["", "_1", "_2"]
|
||||
for (kernel, bias), suffix in zip(result, suffixes):
|
||||
self.assertIsInstance(kernel, values.MirroredVariable)
|
||||
self.assertEqual("common/dense" + suffix + "/kernel:0", kernel.name)
|
||||
self.assertIsInstance(bias, values.MirroredVariable)
|
||||
self.assertEqual("common/dense" + suffix + "/bias:0", bias.name)
|
||||
|
||||
def testWithVariableAndVariableScope(self, distribution):
|
||||
def model_fn():
|
||||
v0 = variable_scope.variable(1.0, name="var0", aggregation=None)
|
||||
with variable_scope.variable_scope("common"):
|
||||
v1 = variable_scope.variable(1.0, name="var1")
|
||||
# This will pause the current thread, and execute the other thread.
|
||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||
v2 = variable_scope.variable(
|
||||
1.0,
|
||||
name="var2",
|
||||
synchronization=variable_scope.VariableSynchronization.ON_READ,
|
||||
aggregation=variable_scope.VariableAggregation.SUM)
|
||||
v3 = variable_scope.variable(
|
||||
1.0,
|
||||
name="var3",
|
||||
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
|
||||
aggregation=variable_scope.VariableAggregation.MEAN)
|
||||
|
||||
return v0, v1, v2, v3
|
||||
|
||||
with distribution.scope():
|
||||
v = variable_scope.variable(1.0, name="var-main0")
|
||||
self.assertEqual("var-main0:0", v.name)
|
||||
|
||||
result = distribution.extended.call_for_each_replica(model_fn)
|
||||
self.assertEqual(4, len(result))
|
||||
v0, v1, v2, v3 = result
|
||||
self.assertIsInstance(v0, values.MirroredVariable)
|
||||
self.assertEqual("var0:0", v0.name)
|
||||
self.assertIsInstance(v1, values.MirroredVariable)
|
||||
self.assertEqual("common/var1:0", v1.name)
|
||||
self.assertIsInstance(v2, values.SyncOnReadVariable)
|
||||
self.assertEqual("common/var2:0", v2.name)
|
||||
self.assertEqual(variable_scope.VariableAggregation.SUM, v2.aggregation)
|
||||
self.assertIsInstance(v3, values.MirroredVariable)
|
||||
self.assertEqual("common/var3:0", v3.name)
|
||||
self.assertEqual(variable_scope.VariableAggregation.MEAN, v3.aggregation)
|
||||
|
||||
def testWithGetVariableAndVariableScope(self, distribution):
|
||||
def model_fn():
|
||||
v0 = variable_scope.get_variable("var0", [1])
|
||||
with variable_scope.variable_scope("common"):
|
||||
v1 = variable_scope.get_variable("var1", [1])
|
||||
# This will pause the current thread, and execute the other thread.
|
||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||
v2 = variable_scope.get_variable(
|
||||
"var2", [1],
|
||||
synchronization=variable_scope.VariableSynchronization.ON_READ,
|
||||
aggregation=variable_scope.VariableAggregation.SUM)
|
||||
v3 = variable_scope.get_variable(
|
||||
"var3", [1],
|
||||
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
|
||||
aggregation=variable_scope.VariableAggregation.MEAN)
|
||||
|
||||
return v0, v1, v2, v3
|
||||
|
||||
with distribution.scope():
|
||||
with variable_scope.variable_scope("main"):
|
||||
v = variable_scope.get_variable("var-main0", [1])
|
||||
self.assertEqual("main/var-main0:0", v.name)
|
||||
|
||||
result = distribution.extended.call_for_each_replica(model_fn)
|
||||
self.assertEqual(4, len(result))
|
||||
v0, v1, v2, v3 = result
|
||||
self.assertIsInstance(v0, values.MirroredVariable)
|
||||
self.assertEqual("main/var0:0", v0.name)
|
||||
self.assertIsInstance(v1, values.MirroredVariable)
|
||||
self.assertEqual("main/common/var1:0", v1.name)
|
||||
self.assertIsInstance(v2, values.SyncOnReadVariable)
|
||||
self.assertEqual("main/common/var2:0", v2.name)
|
||||
self.assertEqual(variable_scope.VariableAggregation.SUM,
|
||||
v2.aggregation)
|
||||
self.assertIsInstance(v3, values.MirroredVariable)
|
||||
self.assertEqual("main/common/var3:0", v3.name)
|
||||
self.assertEqual(variable_scope.VariableAggregation.MEAN,
|
||||
v3.aggregation)
|
||||
|
||||
def testOnlyFirstReplicaUpdatesVariables(self, distribution):
|
||||
def create_fn():
|
||||
aggregation = variable_scope.VariableAggregation.ONLY_FIRST_REPLICA
|
||||
v0 = variable_scope.variable(
|
||||
2.0,
|
||||
name="on_read",
|
||||
synchronization=variable_scope.VariableSynchronization.ON_READ,
|
||||
aggregation=aggregation)
|
||||
v1 = variable_scope.variable(
|
||||
3.0,
|
||||
name="on_write",
|
||||
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
|
||||
aggregation=aggregation)
|
||||
return v0, v1
|
||||
|
||||
devices = ["/device:GPU:0", "/device:CPU:0"]
|
||||
with distribution.scope():
|
||||
v0, v1 = distribution.extended.call_for_each_replica(create_fn)
|
||||
self.evaluate(v0.initializer)
|
||||
self.assertEqual(2.0, self.evaluate(v0.get(devices[0])))
|
||||
self.assertEqual(2.0, self.evaluate(v0.get(devices[1])))
|
||||
self.assertEqual(2.0, self.evaluate(distribution.extended.read_var(v0)))
|
||||
self.evaluate(v1.initializer)
|
||||
self.assertEqual(3.0, self.evaluate(v1.get(devices[0])))
|
||||
self.assertEqual(3.0, self.evaluate(v1.get(devices[1])))
|
||||
self.assertEqual(3.0, self.evaluate(distribution.extended.read_var(v1)))
|
||||
|
||||
def replica_id_plus_one():
|
||||
return math_ops.cast(_replica_id() + 1, dtype=dtypes.float32)
|
||||
|
||||
# Update using the assign_add member function.
|
||||
def update_member_fn():
|
||||
update0 = v0.assign_add(5.0 * replica_id_plus_one())
|
||||
update1 = v1.assign_add(7.0 * replica_id_plus_one())
|
||||
return update0, update1
|
||||
|
||||
update0a, update1a = distribution.extended.call_for_each_replica(
|
||||
update_member_fn)
|
||||
|
||||
# Update "sync on read" variable.
|
||||
self.evaluate(distribution.group(update0a))
|
||||
self.assertEqual(2.0 + 5.0, self.evaluate(v0.get(devices[0])))
|
||||
# Writes are not synchronized for "sync on read" variables,
|
||||
# so device[1] can end up with a different value.
|
||||
self.assertEqual(2.0 + 2*5.0, self.evaluate(v0.get(devices[1])))
|
||||
# Always reads from device 0.
|
||||
self.assertEqual(2.0 + 5.0, self.evaluate(
|
||||
distribution.extended.read_var(v0)))
|
||||
|
||||
# Update "sync on write" variable.
|
||||
self.evaluate(distribution.group(update1a))
|
||||
self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[0])))
|
||||
# Writes are synchronized for v1, only the argument to assign_add on
|
||||
# device[0] is used.
|
||||
self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[1])))
|
||||
self.assertEqual(3.0 + 7.0, self.evaluate(
|
||||
distribution.extended.read_var(v1)))
|
||||
|
||||
# Update using state_ops.assign_add global function.
|
||||
def update_state_ops_fn():
|
||||
update0 = state_ops.assign_add(v0, 11.0 * replica_id_plus_one())
|
||||
update1 = state_ops.assign_add(v1, 13.0 * replica_id_plus_one())
|
||||
return update0, update1
|
||||
|
||||
update0b, update1b = distribution.extended.call_for_each_replica(
|
||||
update_state_ops_fn)
|
||||
self.evaluate(distribution.group(update0b))
|
||||
|
||||
# Update "sync on read" variable.
|
||||
self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate(v0.get(devices[0])))
|
||||
self.assertEqual(2.0 + 2*5.0 + 2*11.0, self.evaluate(v0.get(devices[1])))
|
||||
self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate(
|
||||
distribution.extended.read_var(v0)))
|
||||
|
||||
# Update "sync on write" variable.
|
||||
self.evaluate(distribution.group(update1b))
|
||||
self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[0])))
|
||||
self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[1])))
|
||||
self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(
|
||||
distribution.extended.read_var(v1)))
|
||||
|
||||
def testNoneSynchronizationWithGetVariable(self, distribution):
|
||||
with distribution.scope():
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "`NONE` variable synchronization mode is not "
|
||||
"supported with `Mirrored` distribution strategy. Please change "
|
||||
"the `synchronization` for variable: v"):
|
||||
variable_scope.get_variable(
|
||||
"v", [1],
|
||||
synchronization=variable_scope.VariableSynchronization.NONE)
|
||||
|
||||
def testNoneSynchronizationWithVariable(self, distribution):
|
||||
with distribution.scope():
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "`NONE` variable synchronization mode is not "
|
||||
"supported with `Mirrored` distribution strategy. Please change "
|
||||
"the `synchronization` for variable: v"):
|
||||
variable_scope.variable(
|
||||
1.0,
|
||||
name="v",
|
||||
synchronization=variable_scope.VariableSynchronization.NONE)
|
||||
|
||||
def testInvalidSynchronizationWithVariable(self, distribution):
|
||||
with distribution.scope():
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Invalid variable synchronization mode: Invalid for "
|
||||
"variable: v"):
|
||||
variable_scope.variable(1.0, name="v", synchronization="Invalid")
|
||||
|
||||
def testInvalidAggregationWithGetVariable(self, distribution):
|
||||
with distribution.scope():
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Invalid variable aggregation mode: invalid for "
|
||||
"variable: v"):
|
||||
variable_scope.get_variable(
|
||||
"v", [1],
|
||||
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
|
||||
aggregation="invalid")
|
||||
|
||||
def testInvalidAggregationWithVariable(self, distribution):
|
||||
with distribution.scope():
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Invalid variable aggregation mode: invalid for "
|
||||
"variable: v"):
|
||||
variable_scope.variable(
|
||||
1.0,
|
||||
name="v",
|
||||
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
|
||||
aggregation="invalid")
|
||||
|
||||
def testNonMatchingVariableCreation(self, distribution):
|
||||
self.skipTest("b/123075960")
|
||||
def model_fn(name):
|
||||
v = variable_scope.variable(1.0, name=name)
|
||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||
return v
|
||||
|
||||
with distribution.scope():
|
||||
device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0"))
|
||||
names = values.DistributedValues(device_map, ("foo", "bar"))
|
||||
with self.assertRaises(RuntimeError):
|
||||
_ = distribution.extended.call_for_each_replica(model_fn, args=(names,))
|
||||
|
||||
def testSyncOnReadVariable(self, distribution):
|
||||
all_v_sum = {}
|
||||
all_v_mean = {}
|
||||
components_sum = {}
|
||||
components_mean = {}
|
||||
|
||||
def model_fn():
|
||||
replica_id = self.evaluate(_replica_id())
|
||||
v_sum = variable_scope.variable(
|
||||
1.0,
|
||||
synchronization=variable_scope.VariableSynchronization.ON_READ,
|
||||
aggregation=variable_scope.VariableAggregation.SUM)
|
||||
v_mean = variable_scope.variable(
|
||||
4.0,
|
||||
synchronization=variable_scope.VariableSynchronization.ON_READ,
|
||||
aggregation=variable_scope.VariableAggregation.MEAN)
|
||||
self.assertIsInstance(v_sum, values.SyncOnReadVariable)
|
||||
self.assertIsInstance(v_mean, values.SyncOnReadVariable)
|
||||
updates = [v_sum.assign_add(2.0 + replica_id),
|
||||
v_mean.assign(6.0 * replica_id)]
|
||||
all_v_sum[replica_id] = v_sum
|
||||
all_v_mean[replica_id] = v_mean
|
||||
c_sum = v_sum.get()
|
||||
c_mean = v_mean.get()
|
||||
components_sum[replica_id] = c_sum
|
||||
components_mean[replica_id] = c_mean
|
||||
self.assertIsNot(v_sum, c_sum)
|
||||
self.assertIsNot(v_mean, c_mean)
|
||||
return updates, v_sum, v_mean, c_sum, c_mean
|
||||
|
||||
with distribution.scope():
|
||||
# Create "sum" and "mean" versions of SyncOnReadVariables.
|
||||
ret_ops, ret_v_sum, ret_v_mean, regrouped_sum, regrouped_mean = (
|
||||
distribution.extended.call_for_each_replica(model_fn))
|
||||
# Should see the same wrapping instance in all replicas.
|
||||
self.assertIs(all_v_sum[0], ret_v_sum)
|
||||
self.assertIs(all_v_mean[0], ret_v_mean)
|
||||
self.assertIs(all_v_sum[0], all_v_sum[1])
|
||||
self.assertIs(all_v_mean[0], all_v_mean[1])
|
||||
|
||||
# Regroup should recover the same wrapper.
|
||||
self.assertIs(ret_v_sum, regrouped_sum)
|
||||
self.assertIs(ret_v_mean, regrouped_mean)
|
||||
self.assertIsNot(components_sum[0], components_sum[1])
|
||||
self.assertIsNot(components_mean[0], components_mean[1])
|
||||
|
||||
# Apply updates
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate([y for x in ret_ops # pylint: disable=g-complex-comprehension
|
||||
for y in distribution.experimental_local_results(x)])
|
||||
expected_sum = 0.0
|
||||
expected_mean = 0.0
|
||||
for i, d in enumerate(distribution.extended.worker_devices):
|
||||
# Should see different values on different devices.
|
||||
v_sum_value = self.evaluate(ret_v_sum.get(d).read_value())
|
||||
v_mean_value = self.evaluate(ret_v_mean.get(d).read_value())
|
||||
expected = i + 3.0
|
||||
self.assertEqual(expected, v_sum_value)
|
||||
expected_sum += expected
|
||||
expected = i * 6.0
|
||||
self.assertEqual(expected, v_mean_value)
|
||||
expected_mean += expected
|
||||
expected_mean /= len(distribution.extended.worker_devices)
|
||||
|
||||
# Without get(device), should return the value you get by
|
||||
# applying the reduction across all replicas (whether you use
|
||||
# read_var(), get(), or nothing).
|
||||
self.assertEqual(expected_sum, self.evaluate(
|
||||
distribution.extended.read_var(ret_v_sum)))
|
||||
self.assertEqual(expected_mean, self.evaluate(
|
||||
distribution.extended.read_var(ret_v_mean)))
|
||||
self.assertEqual(expected_sum, self.evaluate(ret_v_sum.get()))
|
||||
self.assertEqual(expected_mean, self.evaluate(ret_v_mean.get()))
|
||||
self.assertEqual(expected_sum, self.evaluate(ret_v_sum))
|
||||
self.assertEqual(expected_mean, self.evaluate(ret_v_mean))
|
||||
|
||||
# TODO(priyag): Update this test to work in eager mode as well.
|
||||
def testDynamicRnnVariables(self, distribution):
|
||||
def model_fn():
|
||||
inputs = constant_op.constant(2 * [2 * [[0.0, 1.0, 2.0, 3.0, 4.0]]])
|
||||
cell_fw = rnn_cell_impl.LSTMCell(300)
|
||||
cell_bw = rnn_cell_impl.LSTMCell(300)
|
||||
(outputs, _) = rnn.bidirectional_dynamic_rnn(
|
||||
cell_fw,
|
||||
cell_bw,
|
||||
inputs,
|
||||
dtype=dtypes.float32)
|
||||
return outputs
|
||||
|
||||
with context.graph_mode(), distribution.scope():
|
||||
result = distribution.extended.call_for_each_replica(model_fn)
|
||||
# Two variables are created by the RNN layer.
|
||||
self.assertEqual(2, len(result))
|
||||
for v in result:
|
||||
self.assertIsInstance(v, values.DistributedValues)
|
||||
_, v1 = distribution.experimental_local_results(v)
|
||||
self.assertStartsWith(v1._op.name, "replica_1/")
|
||||
|
||||
def testSyncOnReadVariableUpdate(self, distribution):
|
||||
def model_fn():
|
||||
v_sum = variable_scope.variable(
|
||||
1.0,
|
||||
synchronization=variable_scope.VariableSynchronization.ON_READ,
|
||||
aggregation=variable_scope.VariableAggregation.SUM)
|
||||
self.assertIsInstance(v_sum, values.SyncOnReadVariable)
|
||||
return v_sum
|
||||
|
||||
def update(var, value):
|
||||
return var.assign(value)
|
||||
|
||||
with distribution.scope():
|
||||
ret_v_sum = distribution.extended.call_for_each_replica(model_fn)
|
||||
|
||||
# Initialize variables.
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
# Assert that the aggregated value of the sync on read var is the sum
|
||||
# of the individual values before running the update ops.
|
||||
self.assertEqual(1.0, self.evaluate(ret_v_sum.get(
|
||||
distribution.extended.worker_devices[0]).read_value()))
|
||||
self.assertEqual(2.0, self.evaluate(ret_v_sum))
|
||||
|
||||
# Apply updates.
|
||||
update_ops = distribution.extended.update(
|
||||
ret_v_sum, update, args=(5.0,), group=False)
|
||||
self.evaluate(update_ops)
|
||||
# Assert that the aggregated value of the sync on read vars is the sum
|
||||
# of the individual values after running the update ops.
|
||||
self.assertEqual(5.0, self.evaluate(ret_v_sum.get(
|
||||
distribution.extended.worker_devices[0]).read_value()))
|
||||
self.assertEqual(10.0, self.evaluate(ret_v_sum))
|
||||
|
||||
def testVarDistributeStrategy(self, distribution):
|
||||
with distribution.scope():
|
||||
mirrored = variable_scope.variable(1.0)
|
||||
sync_on_read = variable_scope.variable(
|
||||
1.0,
|
||||
synchronization=variable_scope.VariableSynchronization.ON_READ)
|
||||
self.assertIs(distribution, mirrored.distribute_strategy)
|
||||
self.assertIs(distribution, sync_on_read.distribute_strategy)
|
||||
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=[
|
||||
|
606
tensorflow/python/distribute/mirrored_variable_test.py
Normal file
606
tensorflow/python/distribute/mirrored_variable_test.py
Normal file
@ -0,0 +1,606 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Test MirroredVariable in MirroredStrategy and MultiWorkerMirroredStrategy."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import collective_all_reduce_strategy
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.distribute import values
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import func_graph
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.layers import core
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import rnn
|
||||
from tensorflow.python.ops import rnn_cell_impl
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
|
||||
|
||||
def _replica_id():
|
||||
replica_id = ds_context.get_replica_context().replica_id_in_sync_group
|
||||
if not isinstance(replica_id, ops.Tensor):
|
||||
replica_id = constant_op.constant(replica_id)
|
||||
return replica_id
|
||||
|
||||
|
||||
def _mimic_two_cpus():
|
||||
cpus = config.list_physical_devices("CPU")
|
||||
|
||||
config.set_virtual_device_configuration(cpus[0], [
|
||||
context.VirtualDeviceConfiguration(),
|
||||
context.VirtualDeviceConfiguration(),
|
||||
])
|
||||
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
combinations.NamedDistribution(
|
||||
"Collective2CPUs",
|
||||
# pylint: disable=g-long-lambda
|
||||
lambda: collective_all_reduce_strategy.
|
||||
CollectiveAllReduceStrategy._from_local_devices((
|
||||
"/device:CPU:0", "/device:CPU:1")),
|
||||
required_gpus=0)
|
||||
],
|
||||
mode=["graph", "eager"]))
|
||||
class MirroredVariableCreationTest(test.TestCase):
|
||||
"""Base class that tests mirrored variable creator.
|
||||
|
||||
Currently it assumes all strategy objects have two replicas.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
_mimic_two_cpus()
|
||||
|
||||
# TODO(priyag): Modify more tests to use this helper and check more
|
||||
# properties.
|
||||
def _test_mv_properties(self, var, name, strategy):
|
||||
self.assertIsInstance(var, values.MirroredVariable)
|
||||
self.assertEqual(name, var.name)
|
||||
self.assertIs(strategy, var.distribute_strategy)
|
||||
for d in var.devices:
|
||||
self.assertEqual(d, var.get(d).device)
|
||||
self.assertIs(strategy, var.get(d)._distribute_strategy) # pylint: disable=protected-access
|
||||
|
||||
def testVariableInFuncGraph(self, distribution):
|
||||
|
||||
def model_fn():
|
||||
v = variable_scope.variable(2.0, name="bar")
|
||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||
return v
|
||||
|
||||
with func_graph.FuncGraph("fg").as_default(), distribution.scope():
|
||||
v1 = variable_scope.variable(1.0, name="foo")
|
||||
v2 = distribution.extended.call_for_each_replica(model_fn)
|
||||
|
||||
self._test_mv_properties(v1, "foo:0", distribution)
|
||||
self._test_mv_properties(v2, "bar:0", distribution)
|
||||
|
||||
def testVariableWithTensorInitialValueInFunction(self, distribution):
|
||||
if not context.executing_eagerly():
|
||||
self.skipTest("`tf.function` is an eager-only feature")
|
||||
|
||||
v = [None]
|
||||
|
||||
def model_fn():
|
||||
if v[0] is None:
|
||||
init_val = array_ops.zeros([])
|
||||
v[0] = variables.Variable(init_val)
|
||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||
return v[0]
|
||||
|
||||
@def_function.function(autograph=False)
|
||||
def make_v1():
|
||||
return distribution.experimental_local_results(
|
||||
distribution.extended.call_for_each_replica(model_fn))
|
||||
|
||||
self.assertAllEqual([0, 0], make_v1())
|
||||
|
||||
def testSingleVariable(self, distribution):
|
||||
|
||||
def model_fn():
|
||||
# This variable should be created only once across the threads because of
|
||||
# special variable_creator functions used by
|
||||
# `distribution.extended.call_for_each_replica`.
|
||||
v = variable_scope.variable(1.0, name="foo")
|
||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||
return v
|
||||
|
||||
with distribution.scope():
|
||||
result = distribution.extended.call_for_each_replica(model_fn)
|
||||
self._test_mv_properties(result, "foo:0", distribution)
|
||||
|
||||
def testUnnamedVariable(self, distribution):
|
||||
|
||||
def model_fn():
|
||||
v = variable_scope.variable(1.0)
|
||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||
return v
|
||||
|
||||
with distribution.scope():
|
||||
result = distribution.extended.call_for_each_replica(model_fn)
|
||||
self._test_mv_properties(result, "Variable:0", distribution)
|
||||
|
||||
def testMultipleVariables(self, distribution):
|
||||
|
||||
def model_fn():
|
||||
vs = []
|
||||
for i in range(5):
|
||||
vs.append(variable_scope.variable(1.0, name="foo" + str(i)))
|
||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||
return vs
|
||||
|
||||
with distribution.scope():
|
||||
result = distribution.extended.call_for_each_replica(model_fn)
|
||||
for i, v in enumerate(result):
|
||||
self._test_mv_properties(v, "foo" + str(i) + ":0", distribution)
|
||||
|
||||
def testMultipleVariablesWithSameCanonicalName(self, distribution):
|
||||
|
||||
def model_fn():
|
||||
vs = []
|
||||
vs.append(variable_scope.variable(1.0, name="foo/bar"))
|
||||
vs.append(variable_scope.variable(1.0, name="foo_1/bar"))
|
||||
vs.append(variable_scope.variable(1.0, name="foo_1/bar_1"))
|
||||
vs.append(variable_scope.variable(1.0, name="foo/bar_1"))
|
||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||
return vs
|
||||
|
||||
with distribution.scope():
|
||||
result = distribution.extended.call_for_each_replica(model_fn)
|
||||
for v in result:
|
||||
self.assertIsInstance(v, values.MirroredVariable)
|
||||
self.assertEqual(4, len(result))
|
||||
self.assertEqual("foo/bar:0", result[0].name)
|
||||
self.assertEqual("foo_1/bar:0", result[1].name)
|
||||
self.assertEqual("foo_1/bar_1:0", result[2].name)
|
||||
self.assertEqual("foo/bar_1:0", result[3].name)
|
||||
|
||||
def testVariableWithSameCanonicalNameAcrossThreads(self, distribution):
|
||||
|
||||
def model_fn():
|
||||
replica_id = self.evaluate(_replica_id())
|
||||
v = variable_scope.variable(1.0, name="foo_" + str(replica_id))
|
||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||
return v
|
||||
|
||||
with distribution.scope():
|
||||
result = distribution.extended.call_for_each_replica(model_fn)
|
||||
self.assertIsInstance(result, values.MirroredVariable)
|
||||
# The resulting mirrored variable will use the name from the first device.
|
||||
self.assertEqual("foo_0:0", result.name)
|
||||
|
||||
def testWithLayers(self, distribution):
|
||||
|
||||
def model_fn(features):
|
||||
with variable_scope.variable_scope("common"):
|
||||
layer1 = core.Dense(1)
|
||||
layer1(features)
|
||||
layer2 = core.Dense(1)
|
||||
layer2(features)
|
||||
# This will pause the current thread, and execute the other thread.
|
||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||
layer3 = core.Dense(1)
|
||||
layer3(features)
|
||||
return [(layer1.kernel, layer1.bias), (layer2.kernel, layer2.bias),
|
||||
(layer3.kernel, layer3.bias)]
|
||||
|
||||
iterator = distribution.make_input_fn_iterator(
|
||||
lambda _: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10))
|
||||
self.evaluate(iterator.initialize())
|
||||
features = iterator.get_next()
|
||||
|
||||
with distribution.scope():
|
||||
result = distribution.extended.call_for_each_replica(
|
||||
model_fn, args=(features,))
|
||||
suffixes = ["", "_1", "_2"]
|
||||
for (kernel, bias), suffix in zip(result, suffixes):
|
||||
self.assertIsInstance(kernel, values.MirroredVariable)
|
||||
self.assertEqual("common/dense" + suffix + "/kernel:0", kernel.name)
|
||||
self.assertIsInstance(bias, values.MirroredVariable)
|
||||
self.assertEqual("common/dense" + suffix + "/bias:0", bias.name)
|
||||
|
||||
def testWithVariableAndVariableScope(self, distribution):
|
||||
|
||||
def model_fn():
|
||||
v0 = variable_scope.variable(1.0, name="var0", aggregation=None)
|
||||
with variable_scope.variable_scope("common"):
|
||||
v1 = variable_scope.variable(1.0, name="var1")
|
||||
# This will pause the current thread, and execute the other thread.
|
||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||
v2 = variable_scope.variable(
|
||||
1.0,
|
||||
name="var2",
|
||||
synchronization=variable_scope.VariableSynchronization.ON_READ,
|
||||
aggregation=variable_scope.VariableAggregation.SUM)
|
||||
v3 = variable_scope.variable(
|
||||
1.0,
|
||||
name="var3",
|
||||
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
|
||||
aggregation=variable_scope.VariableAggregation.MEAN)
|
||||
|
||||
return v0, v1, v2, v3
|
||||
|
||||
with distribution.scope():
|
||||
v = variable_scope.variable(1.0, name="var-main0")
|
||||
self.assertEqual("var-main0:0", v.name)
|
||||
|
||||
result = distribution.extended.call_for_each_replica(model_fn)
|
||||
self.assertEqual(4, len(result))
|
||||
v0, v1, v2, v3 = result
|
||||
self.assertIsInstance(v0, values.MirroredVariable)
|
||||
self.assertEqual("var0:0", v0.name)
|
||||
self.assertIsInstance(v1, values.MirroredVariable)
|
||||
self.assertEqual("common/var1:0", v1.name)
|
||||
self.assertIsInstance(v2, values.SyncOnReadVariable)
|
||||
self.assertEqual("common/var2:0", v2.name)
|
||||
self.assertEqual(variable_scope.VariableAggregation.SUM, v2.aggregation)
|
||||
self.assertIsInstance(v3, values.MirroredVariable)
|
||||
self.assertEqual("common/var3:0", v3.name)
|
||||
self.assertEqual(variable_scope.VariableAggregation.MEAN, v3.aggregation)
|
||||
|
||||
def testWithGetVariableAndVariableScope(self, distribution):
|
||||
|
||||
def model_fn():
|
||||
v0 = variable_scope.get_variable("var0", [1])
|
||||
with variable_scope.variable_scope("common"):
|
||||
v1 = variable_scope.get_variable("var1", [1])
|
||||
# This will pause the current thread, and execute the other thread.
|
||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||
v2 = variable_scope.get_variable(
|
||||
"var2", [1],
|
||||
synchronization=variable_scope.VariableSynchronization.ON_READ,
|
||||
aggregation=variable_scope.VariableAggregation.SUM)
|
||||
v3 = variable_scope.get_variable(
|
||||
"var3", [1],
|
||||
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
|
||||
aggregation=variable_scope.VariableAggregation.MEAN)
|
||||
|
||||
return v0, v1, v2, v3
|
||||
|
||||
with distribution.scope():
|
||||
with variable_scope.variable_scope("main"):
|
||||
v = variable_scope.get_variable("var-main0", [1])
|
||||
self.assertEqual("main/var-main0:0", v.name)
|
||||
|
||||
result = distribution.extended.call_for_each_replica(model_fn)
|
||||
self.assertEqual(4, len(result))
|
||||
v0, v1, v2, v3 = result
|
||||
self.assertIsInstance(v0, values.MirroredVariable)
|
||||
self.assertEqual("main/var0:0", v0.name)
|
||||
self.assertIsInstance(v1, values.MirroredVariable)
|
||||
self.assertEqual("main/common/var1:0", v1.name)
|
||||
self.assertIsInstance(v2, values.SyncOnReadVariable)
|
||||
self.assertEqual("main/common/var2:0", v2.name)
|
||||
self.assertEqual(variable_scope.VariableAggregation.SUM, v2.aggregation)
|
||||
self.assertIsInstance(v3, values.MirroredVariable)
|
||||
self.assertEqual("main/common/var3:0", v3.name)
|
||||
self.assertEqual(variable_scope.VariableAggregation.MEAN,
|
||||
v3.aggregation)
|
||||
|
||||
def testOnlyFirstReplicaUpdatesVariables(self, distribution):
|
||||
|
||||
def create_fn():
|
||||
aggregation = variable_scope.VariableAggregation.ONLY_FIRST_REPLICA
|
||||
v0 = variable_scope.variable(
|
||||
2.0,
|
||||
name="on_read",
|
||||
synchronization=variable_scope.VariableSynchronization.ON_READ,
|
||||
aggregation=aggregation)
|
||||
v1 = variable_scope.variable(
|
||||
3.0,
|
||||
name="on_write",
|
||||
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
|
||||
aggregation=aggregation)
|
||||
return v0, v1
|
||||
|
||||
devices = distribution.extended.worker_devices
|
||||
with distribution.scope():
|
||||
v0, v1 = distribution.extended.call_for_each_replica(create_fn)
|
||||
self.evaluate(v0.initializer)
|
||||
self.assertEqual(2.0, self.evaluate(v0.get(devices[0])))
|
||||
self.assertEqual(2.0, self.evaluate(v0.get(devices[1])))
|
||||
self.assertEqual(2.0, self.evaluate(distribution.extended.read_var(v0)))
|
||||
self.evaluate(v1.initializer)
|
||||
self.assertEqual(3.0, self.evaluate(v1.get(devices[0])))
|
||||
self.assertEqual(3.0, self.evaluate(v1.get(devices[1])))
|
||||
self.assertEqual(3.0, self.evaluate(distribution.extended.read_var(v1)))
|
||||
|
||||
def replica_id_plus_one():
|
||||
return math_ops.cast(_replica_id() + 1, dtype=dtypes.float32)
|
||||
|
||||
# Update using the assign_add member function.
|
||||
def update_member_fn():
|
||||
update0 = v0.assign_add(5.0 * replica_id_plus_one())
|
||||
update1 = v1.assign_add(7.0 * replica_id_plus_one())
|
||||
return update0, update1
|
||||
|
||||
update0a, update1a = distribution.extended.call_for_each_replica(
|
||||
update_member_fn)
|
||||
|
||||
# Update "sync on read" variable.
|
||||
self.evaluate(distribution.group(update0a))
|
||||
self.assertEqual(2.0 + 5.0, self.evaluate(v0.get(devices[0])))
|
||||
# Writes are not synchronized for "sync on read" variables,
|
||||
# so device[1] can end up with a different value.
|
||||
self.assertEqual(2.0 + 2 * 5.0, self.evaluate(v0.get(devices[1])))
|
||||
# Always reads from device 0.
|
||||
self.assertEqual(2.0 + 5.0,
|
||||
self.evaluate(distribution.extended.read_var(v0)))
|
||||
|
||||
# Update "sync on write" variable.
|
||||
self.evaluate(distribution.group(update1a))
|
||||
self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[0])))
|
||||
# Writes are synchronized for v1, only the argument to assign_add on
|
||||
# device[0] is used.
|
||||
self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[1])))
|
||||
self.assertEqual(3.0 + 7.0,
|
||||
self.evaluate(distribution.extended.read_var(v1)))
|
||||
|
||||
# Update using state_ops.assign_add global function.
|
||||
def update_state_ops_fn():
|
||||
update0 = state_ops.assign_add(v0, 11.0 * replica_id_plus_one())
|
||||
update1 = state_ops.assign_add(v1, 13.0 * replica_id_plus_one())
|
||||
return update0, update1
|
||||
|
||||
update0b, update1b = distribution.extended.call_for_each_replica(
|
||||
update_state_ops_fn)
|
||||
self.evaluate(distribution.group(update0b))
|
||||
|
||||
# Update "sync on read" variable.
|
||||
self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate(v0.get(devices[0])))
|
||||
self.assertEqual(2.0 + 2 * 5.0 + 2 * 11.0,
|
||||
self.evaluate(v0.get(devices[1])))
|
||||
self.assertEqual(2.0 + 5.0 + 11.0,
|
||||
self.evaluate(distribution.extended.read_var(v0)))
|
||||
|
||||
# Update "sync on write" variable.
|
||||
self.evaluate(distribution.group(update1b))
|
||||
self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[0])))
|
||||
self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[1])))
|
||||
self.assertEqual(3.0 + 7.0 + 13.0,
|
||||
self.evaluate(distribution.extended.read_var(v1)))
|
||||
|
||||
def testNoneSynchronizationWithGetVariable(self, distribution):
|
||||
with distribution.scope():
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "`NONE` variable synchronization mode is not "
|
||||
"supported with `Mirrored` distribution strategy. Please change "
|
||||
"the `synchronization` for variable: v"):
|
||||
variable_scope.get_variable(
|
||||
"v", [1],
|
||||
synchronization=variable_scope.VariableSynchronization.NONE)
|
||||
|
||||
def testNoneSynchronizationWithVariable(self, distribution):
|
||||
with distribution.scope():
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "`NONE` variable synchronization mode is not "
|
||||
"supported with `Mirrored` distribution strategy. Please change "
|
||||
"the `synchronization` for variable: v"):
|
||||
variable_scope.variable(
|
||||
1.0,
|
||||
name="v",
|
||||
synchronization=variable_scope.VariableSynchronization.NONE)
|
||||
|
||||
def testInvalidSynchronizationWithVariable(self, distribution):
|
||||
with distribution.scope():
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Invalid variable synchronization mode: Invalid for "
|
||||
"variable: v"):
|
||||
variable_scope.variable(1.0, name="v", synchronization="Invalid")
|
||||
|
||||
def testInvalidAggregationWithGetVariable(self, distribution):
|
||||
with distribution.scope():
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Invalid variable aggregation mode: invalid for "
|
||||
"variable: v"):
|
||||
variable_scope.get_variable(
|
||||
"v", [1],
|
||||
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
|
||||
aggregation="invalid")
|
||||
|
||||
def testInvalidAggregationWithVariable(self, distribution):
|
||||
with distribution.scope():
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Invalid variable aggregation mode: invalid for "
|
||||
"variable: v"):
|
||||
variable_scope.variable(
|
||||
1.0,
|
||||
name="v",
|
||||
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
|
||||
aggregation="invalid")
|
||||
|
||||
def testNonMatchingVariableCreation(self, distribution):
|
||||
self.skipTest("b/123075960")
|
||||
|
||||
def model_fn(name):
|
||||
v = variable_scope.variable(1.0, name=name)
|
||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||
return v
|
||||
|
||||
with distribution.scope():
|
||||
device_map = values.ReplicaDeviceMap(distribution.extended.worker_devices)
|
||||
names = values.DistributedValues(device_map, ("foo", "bar"))
|
||||
with self.assertRaises(RuntimeError):
|
||||
_ = distribution.extended.call_for_each_replica(model_fn, args=(names,))
|
||||
|
||||
def testSyncOnReadVariable(self, distribution):
|
||||
all_v_sum = {}
|
||||
all_v_mean = {}
|
||||
components_sum = {}
|
||||
components_mean = {}
|
||||
|
||||
def model_fn():
|
||||
replica_id = self.evaluate(_replica_id())
|
||||
v_sum = variable_scope.variable(
|
||||
1.0,
|
||||
synchronization=variable_scope.VariableSynchronization.ON_READ,
|
||||
aggregation=variable_scope.VariableAggregation.SUM)
|
||||
v_mean = variable_scope.variable(
|
||||
4.0,
|
||||
synchronization=variable_scope.VariableSynchronization.ON_READ,
|
||||
aggregation=variable_scope.VariableAggregation.MEAN)
|
||||
self.assertIsInstance(v_sum, values.SyncOnReadVariable)
|
||||
self.assertIsInstance(v_mean, values.SyncOnReadVariable)
|
||||
updates = [
|
||||
v_sum.assign_add(2.0 + replica_id),
|
||||
v_mean.assign(6.0 * replica_id)
|
||||
]
|
||||
all_v_sum[replica_id] = v_sum
|
||||
all_v_mean[replica_id] = v_mean
|
||||
c_sum = v_sum.get()
|
||||
c_mean = v_mean.get()
|
||||
components_sum[replica_id] = c_sum
|
||||
components_mean[replica_id] = c_mean
|
||||
self.assertIsNot(v_sum, c_sum)
|
||||
self.assertIsNot(v_mean, c_mean)
|
||||
return updates, v_sum, v_mean, c_sum, c_mean
|
||||
|
||||
with distribution.scope():
|
||||
# Create "sum" and "mean" versions of SyncOnReadVariables.
|
||||
ret_ops, ret_v_sum, ret_v_mean, regrouped_sum, regrouped_mean = (
|
||||
distribution.extended.call_for_each_replica(model_fn))
|
||||
# Should see the same wrapping instance in all replicas.
|
||||
self.assertIs(all_v_sum[0], ret_v_sum)
|
||||
self.assertIs(all_v_mean[0], ret_v_mean)
|
||||
self.assertIs(all_v_sum[0], all_v_sum[1])
|
||||
self.assertIs(all_v_mean[0], all_v_mean[1])
|
||||
|
||||
# Regroup should recover the same wrapper.
|
||||
self.assertIs(ret_v_sum, regrouped_sum)
|
||||
self.assertIs(ret_v_mean, regrouped_mean)
|
||||
self.assertIsNot(components_sum[0], components_sum[1])
|
||||
self.assertIsNot(components_mean[0], components_mean[1])
|
||||
|
||||
# Apply updates
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate([
|
||||
y for x in ret_ops # pylint: disable=g-complex-comprehension
|
||||
for y in distribution.experimental_local_results(x)
|
||||
])
|
||||
expected_sum = 0.0
|
||||
expected_mean = 0.0
|
||||
for i, d in enumerate(distribution.extended.worker_devices):
|
||||
# Should see different values on different devices.
|
||||
v_sum_value = self.evaluate(ret_v_sum.get(d).read_value())
|
||||
v_mean_value = self.evaluate(ret_v_mean.get(d).read_value())
|
||||
expected = i + 3.0
|
||||
self.assertEqual(expected, v_sum_value)
|
||||
expected_sum += expected
|
||||
expected = i * 6.0
|
||||
self.assertEqual(expected, v_mean_value)
|
||||
expected_mean += expected
|
||||
expected_mean /= len(distribution.extended.worker_devices)
|
||||
|
||||
# Without get(device), should return the value you get by
|
||||
# applying the reduction across all replicas (whether you use
|
||||
# read_var(), get(), or nothing).
|
||||
self.assertEqual(expected_sum, self.evaluate(
|
||||
distribution.extended.read_var(ret_v_sum)))
|
||||
self.assertEqual(expected_mean, self.evaluate(
|
||||
distribution.extended.read_var(ret_v_mean)))
|
||||
self.assertEqual(expected_sum, self.evaluate(ret_v_sum.get()))
|
||||
self.assertEqual(expected_mean, self.evaluate(ret_v_mean.get()))
|
||||
self.assertEqual(expected_sum, self.evaluate(ret_v_sum))
|
||||
self.assertEqual(expected_mean, self.evaluate(ret_v_mean))
|
||||
|
||||
# TODO(priyag): Update this test to work in eager mode as well.
|
||||
def testDynamicRnnVariables(self, distribution):
|
||||
|
||||
def model_fn():
|
||||
inputs = constant_op.constant(2 * [2 * [[0.0, 1.0, 2.0, 3.0, 4.0]]])
|
||||
cell_fw = rnn_cell_impl.LSTMCell(300)
|
||||
cell_bw = rnn_cell_impl.LSTMCell(300)
|
||||
(outputs, _) = rnn.bidirectional_dynamic_rnn(
|
||||
cell_fw, cell_bw, inputs, dtype=dtypes.float32)
|
||||
return outputs
|
||||
|
||||
with context.graph_mode(), distribution.scope():
|
||||
result = distribution.extended.call_for_each_replica(model_fn)
|
||||
# Two variables are created by the RNN layer.
|
||||
self.assertEqual(2, len(result))
|
||||
for v in result:
|
||||
self.assertIsInstance(v, values.DistributedValues)
|
||||
_, v1 = distribution.experimental_local_results(v)
|
||||
self.assertStartsWith(v1._op.name, "replica_1/")
|
||||
|
||||
def testSyncOnReadVariableUpdate(self, distribution):
|
||||
|
||||
def model_fn():
|
||||
v_sum = variable_scope.variable(
|
||||
1.0,
|
||||
synchronization=variable_scope.VariableSynchronization.ON_READ,
|
||||
aggregation=variable_scope.VariableAggregation.SUM)
|
||||
self.assertIsInstance(v_sum, values.SyncOnReadVariable)
|
||||
return v_sum
|
||||
|
||||
def update(var, value):
|
||||
return var.assign(value)
|
||||
|
||||
with distribution.scope():
|
||||
ret_v_sum = distribution.extended.call_for_each_replica(model_fn)
|
||||
|
||||
# Initialize variables.
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
# Assert that the aggregated value of the sync on read var is the sum
|
||||
# of the individual values before running the update ops.
|
||||
self.assertEqual(
|
||||
1.0,
|
||||
self.evaluate(
|
||||
ret_v_sum.get(
|
||||
distribution.extended.worker_devices[0]).read_value()))
|
||||
self.assertEqual(2.0, self.evaluate(ret_v_sum))
|
||||
|
||||
# Apply updates.
|
||||
update_ops = distribution.extended.update(
|
||||
ret_v_sum, update, args=(5.0,), group=False)
|
||||
self.evaluate(update_ops)
|
||||
# Assert that the aggregated value of the sync on read vars is the sum
|
||||
# of the individual values after running the update ops.
|
||||
self.assertEqual(
|
||||
5.0,
|
||||
self.evaluate(
|
||||
ret_v_sum.get(
|
||||
distribution.extended.worker_devices[0]).read_value()))
|
||||
self.assertEqual(10.0, self.evaluate(ret_v_sum))
|
||||
|
||||
def testVarDistributeStrategy(self, distribution):
|
||||
with distribution.scope():
|
||||
mirrored = variable_scope.variable(1.0)
|
||||
sync_on_read = variable_scope.variable(
|
||||
1.0, synchronization=variable_scope.VariableSynchronization.ON_READ)
|
||||
self.assertIs(distribution, mirrored.distribute_strategy)
|
||||
self.assertIs(distribution, sync_on_read.distribute_strategy)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
Loading…
Reference in New Issue
Block a user