Use an incremental counter for collective keys of variables;
Make MultiWorkerMirroredStrategy reuse more code from MirroredStrategy in the variable creator; Moved mirrored variable creator test to a new file and add coverage for MultiWorkerMirroredStrategy; Fixed a bug in cross_device_ops that sometimes output of allreduce ops are not all run in graph mode. PiperOrigin-RevId: 251974203
This commit is contained in:
parent
f1ffa0225a
commit
f5003f6315
@ -115,7 +115,7 @@ class CollectiveAllReduceStrategyTestBase(
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
# We use a different key_base for each test so that collective keys won't be
|
# We use a different key_base for each test so that collective keys won't be
|
||||||
# reused.
|
# reused.
|
||||||
# TODO(yuefengz, tucker): enable it to reuse collective keys in different
|
# TODO(yuefengz, ayushd): enable it to reuse collective keys in different
|
||||||
# tests.
|
# tests.
|
||||||
CollectiveAllReduceStrategyTestBase.collective_key_base += 100000
|
CollectiveAllReduceStrategyTestBase.collective_key_base += 100000
|
||||||
super(CollectiveAllReduceStrategyTestBase, self).setUp()
|
super(CollectiveAllReduceStrategyTestBase, self).setUp()
|
||||||
@ -133,11 +133,11 @@ class CollectiveAllReduceStrategyTestBase(
|
|||||||
use_core_strategy=use_core_strategy)
|
use_core_strategy=use_core_strategy)
|
||||||
|
|
||||||
collective_keys = cross_device_utils.CollectiveKeys(
|
collective_keys = cross_device_utils.CollectiveKeys(
|
||||||
group_key_start=10 * num_gpus +
|
group_key_start=10 +
|
||||||
CollectiveAllReduceStrategyTestBase.collective_key_base,
|
CollectiveAllReduceStrategyTestBase.collective_key_base,
|
||||||
instance_key_start=num_gpus * 100 +
|
op_instance_key_start=100 +
|
||||||
CollectiveAllReduceStrategyTestBase.collective_key_base,
|
CollectiveAllReduceStrategyTestBase.collective_key_base,
|
||||||
instance_key_with_id_start=num_gpus * 10000 +
|
variable_instance_key_start=10000 +
|
||||||
CollectiveAllReduceStrategyTestBase.collective_key_base)
|
CollectiveAllReduceStrategyTestBase.collective_key_base)
|
||||||
strategy.extended._collective_keys = collective_keys
|
strategy.extended._collective_keys = collective_keys
|
||||||
strategy.extended._cross_device_ops._collective_keys = (collective_keys)
|
strategy.extended._cross_device_ops._collective_keys = (collective_keys)
|
||||||
|
@ -966,6 +966,32 @@ cuda_py_test(
|
|||||||
xla_enable_strict_auto_jit = True,
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cuda_py_test(
|
||||||
|
name = "mirrored_variable_test",
|
||||||
|
srcs = ["mirrored_variable_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
":collective_all_reduce_strategy",
|
||||||
|
":combinations",
|
||||||
|
":strategy_combinations",
|
||||||
|
"//tensorflow/core:protos_all_py",
|
||||||
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:constant_op",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"//tensorflow/python:layers",
|
||||||
|
"//tensorflow/python:state_ops",
|
||||||
|
"//tensorflow/python:variable_scope",
|
||||||
|
"//tensorflow/python/distribute:distribute_lib",
|
||||||
|
"//tensorflow/python/distribute:values",
|
||||||
|
"//tensorflow/python/eager:context",
|
||||||
|
"//tensorflow/python/eager:test",
|
||||||
|
],
|
||||||
|
tags = [
|
||||||
|
"guitar",
|
||||||
|
"multi_and_single_gpu",
|
||||||
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
|
)
|
||||||
|
|
||||||
distribute_py_test(
|
distribute_py_test(
|
||||||
name = "metrics_v1_test",
|
name = "metrics_v1_test",
|
||||||
srcs = ["metrics_v1_test.py"],
|
srcs = ["metrics_v1_test.py"],
|
||||||
|
@ -36,7 +36,6 @@ from tensorflow.python.distribute import values
|
|||||||
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
||||||
from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
|
from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import tape
|
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import collective_ops
|
from tensorflow.python.ops import collective_ops
|
||||||
@ -79,6 +78,13 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy):
|
|||||||
self,
|
self,
|
||||||
communication=communication))
|
communication=communication))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _from_local_devices(cls, devices):
|
||||||
|
"""A convenience method to create an obejct with a list of devices."""
|
||||||
|
obj = cls()
|
||||||
|
obj.extended._initialize_local(TFConfigClusterResolver(), devices=devices) # pylint: disable=protected-access
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
@tf_export(v1=["distribute.experimental.MultiWorkerMirroredStrategy"])
|
@tf_export(v1=["distribute.experimental.MultiWorkerMirroredStrategy"])
|
||||||
class CollectiveAllReduceStrategyV1(distribute_lib.StrategyV1):
|
class CollectiveAllReduceStrategyV1(distribute_lib.StrategyV1):
|
||||||
@ -117,7 +123,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
|||||||
else:
|
else:
|
||||||
self._initialize_local(cluster_resolver)
|
self._initialize_local(cluster_resolver)
|
||||||
|
|
||||||
def _initialize_local(self, cluster_resolver):
|
def _initialize_local(self, cluster_resolver, devices=None):
|
||||||
"""Initializes the object for local training."""
|
"""Initializes the object for local training."""
|
||||||
self._is_chief = True
|
self._is_chief = True
|
||||||
self._num_workers = 1
|
self._num_workers = 1
|
||||||
@ -140,10 +146,13 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
|||||||
else:
|
else:
|
||||||
num_gpus = cluster_resolver.num_accelerators().get("GPU", 0)
|
num_gpus = cluster_resolver.num_accelerators().get("GPU", 0)
|
||||||
|
|
||||||
if num_gpus:
|
if devices:
|
||||||
local_devices = tuple("/device:GPU:%d" % i for i in range(num_gpus))
|
local_devices = devices
|
||||||
else:
|
else:
|
||||||
local_devices = ("/device:CPU:0",)
|
if num_gpus:
|
||||||
|
local_devices = tuple("/device:GPU:%d" % i for i in range(num_gpus))
|
||||||
|
else:
|
||||||
|
local_devices = ("/device:CPU:0",)
|
||||||
self._worker_device = device_util.canonicalize("/device:CPU:0")
|
self._worker_device = device_util.canonicalize("/device:CPU:0")
|
||||||
self._host_input_device = numpy_dataset.SingleDevice(self._worker_device)
|
self._host_input_device = numpy_dataset.SingleDevice(self._worker_device)
|
||||||
|
|
||||||
@ -272,100 +281,52 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
|||||||
task_id, self._num_workers, local_devices,
|
task_id, self._num_workers, local_devices,
|
||||||
self._communication)
|
self._communication)
|
||||||
|
|
||||||
def _create_variable(self, next_creator, *args, **kwargs):
|
def _get_variable_creator_initial_value(self,
|
||||||
colocate_with = kwargs.pop("colocate_with", None)
|
replica_id=0,
|
||||||
if colocate_with is None:
|
device=None,
|
||||||
device_map = self._device_map
|
primary_var=None,
|
||||||
logical_device = 0 # TODO(josh11b): Get logical device from scope here.
|
**kwargs):
|
||||||
elif isinstance(colocate_with, numpy_dataset.SingleDevice):
|
if replica_id == 0: # First replica on each worker.
|
||||||
with ops.device(colocate_with.device):
|
assert device is not None
|
||||||
return next_creator(*args, **kwargs)
|
assert primary_var is None
|
||||||
|
|
||||||
|
def initial_value_fn(): # pylint: disable=g-missing-docstring
|
||||||
|
# Only the first device participates in the broadcast of initial values.
|
||||||
|
group_key = self._collective_keys.get_group_key([device])
|
||||||
|
group_size = self._num_workers
|
||||||
|
collective_instance_key = (
|
||||||
|
self._collective_keys.get_variable_instance_key())
|
||||||
|
|
||||||
|
with ops.device(device):
|
||||||
|
initial_value = kwargs["initial_value"]
|
||||||
|
if callable(initial_value):
|
||||||
|
initial_value = initial_value()
|
||||||
|
assert not callable(initial_value)
|
||||||
|
initial_value = ops.convert_to_tensor(
|
||||||
|
initial_value, dtype=kwargs.get("dtype", None))
|
||||||
|
|
||||||
|
if self._num_workers > 1:
|
||||||
|
if self._is_chief:
|
||||||
|
bcast_send = collective_ops.broadcast_send(
|
||||||
|
initial_value, initial_value.shape, initial_value.dtype,
|
||||||
|
group_size, group_key, collective_instance_key)
|
||||||
|
with ops.control_dependencies([bcast_send]):
|
||||||
|
return array_ops.identity(initial_value)
|
||||||
|
else:
|
||||||
|
return collective_ops.broadcast_recv(initial_value.shape,
|
||||||
|
initial_value.dtype,
|
||||||
|
group_size, group_key,
|
||||||
|
collective_instance_key)
|
||||||
|
return initial_value
|
||||||
|
|
||||||
|
return initial_value_fn
|
||||||
else:
|
else:
|
||||||
device_map = colocate_with.device_map
|
return super(CollectiveAllReduceExtended,
|
||||||
logical_device = colocate_with.logical_device
|
self)._get_variable_creator_initial_value(
|
||||||
|
replica_id=replica_id,
|
||||||
def _real_mirrored_creator(devices, *args, **kwargs):
|
device=device,
|
||||||
"""Creates one MirroredVariable on the current worker."""
|
primary_var=primary_var,
|
||||||
unique_var_name = ops.get_default_graph().unique_name(
|
**kwargs)
|
||||||
kwargs["name"], mark_as_used=False).rstrip("/")
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
collective_instance_key = self._collective_keys.get_instance_key(
|
|
||||||
key_id=unique_var_name)
|
|
||||||
# Only the first device participles in the broadcast of initial values.
|
|
||||||
group_key = self._collective_keys.get_group_key([devices[0]])
|
|
||||||
group_size = self._num_workers
|
|
||||||
if "initial_value" not in kwargs:
|
|
||||||
raise ValueError("Initial value must be specified.")
|
|
||||||
initial_value = kwargs["initial_value"]
|
|
||||||
if callable(initial_value):
|
|
||||||
initial_value_fn = initial_value
|
|
||||||
else:
|
|
||||||
initial_value_fn = lambda: initial_value
|
|
||||||
|
|
||||||
value_list = []
|
|
||||||
for i, d in enumerate(devices):
|
|
||||||
with ops.init_scope(), ops.device(d):
|
|
||||||
if i == 0:
|
|
||||||
# The initial value fn makes sure variables all initialized to
|
|
||||||
# same values. The first device of the chief worker will send their
|
|
||||||
# variable values to other workers.
|
|
||||||
def _overridden_initial_value_fn(device=d, index=i): # pylint: disable=g-missing-docstring
|
|
||||||
with ops.device(device):
|
|
||||||
initial_value = initial_value_fn()
|
|
||||||
assert not callable(initial_value)
|
|
||||||
initial_value = ops.convert_to_tensor(
|
|
||||||
initial_value, dtype=kwargs.get("dtype", None))
|
|
||||||
|
|
||||||
assert index == 0, index
|
|
||||||
if self._num_workers > 1:
|
|
||||||
if self._is_chief:
|
|
||||||
bcast_send = collective_ops.broadcast_send(
|
|
||||||
initial_value, initial_value.shape, initial_value.dtype,
|
|
||||||
group_size, group_key, collective_instance_key)
|
|
||||||
with ops.control_dependencies([bcast_send]):
|
|
||||||
return array_ops.identity(initial_value)
|
|
||||||
else:
|
|
||||||
return collective_ops.broadcast_recv(
|
|
||||||
initial_value.shape, initial_value.dtype, group_size,
|
|
||||||
group_key, collective_instance_key)
|
|
||||||
return initial_value
|
|
||||||
else:
|
|
||||||
# Give replicas meaningful distinct names:
|
|
||||||
var0name = value_list[0].name.split(":")[0]
|
|
||||||
# We append a / to variable names created on replicas with id > 0 to
|
|
||||||
# ensure that we ignore the name scope and instead use the given
|
|
||||||
# name as the absolute name of the variable.
|
|
||||||
kwargs["name"] = "%s/replica_%d/" % (var0name, i)
|
|
||||||
|
|
||||||
# Variables on non-first replica get initial values from the
|
|
||||||
# variables created on the first device of each worker.
|
|
||||||
def _overridden_initial_value_fn(device=d, index=i):
|
|
||||||
assert index > 0
|
|
||||||
with ops.device(device):
|
|
||||||
if context.executing_eagerly():
|
|
||||||
return array_ops.identity(value_list[0].value())
|
|
||||||
else:
|
|
||||||
return array_ops.identity(value_list[0].initial_value)
|
|
||||||
|
|
||||||
kwargs["initial_value"] = _overridden_initial_value_fn
|
|
||||||
with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
|
|
||||||
# Don't record operations (e.g. other variable reads) during
|
|
||||||
# variable creation.
|
|
||||||
with tape.stop_recording():
|
|
||||||
v = next_creator(*args, **kwargs)
|
|
||||||
|
|
||||||
if i == 0:
|
|
||||||
actual_var_name = v.name.split(":")[0]
|
|
||||||
assert unique_var_name == actual_var_name, "%r vs %r" % (
|
|
||||||
unique_var_name, actual_var_name)
|
|
||||||
assert not isinstance(v, values.DistributedVariable)
|
|
||||||
value_list.append(v)
|
|
||||||
return value_list
|
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
return mirrored_strategy._create_mirrored_variable(
|
|
||||||
self._container_strategy(), device_map, logical_device,
|
|
||||||
_real_mirrored_creator, *args, **kwargs)
|
|
||||||
|
|
||||||
def _make_input_context(self):
|
def _make_input_context(self):
|
||||||
if self._cluster_spec is None:
|
if self._cluster_spec is None:
|
||||||
|
@ -998,14 +998,15 @@ class CollectiveAllReduce(CrossDeviceOps):
|
|||||||
return all_reduced
|
return all_reduced
|
||||||
devices = device_map.logical_to_actual_devices(logical_device)
|
devices = device_map.logical_to_actual_devices(logical_device)
|
||||||
index = []
|
index = []
|
||||||
for d in devices:
|
with ops.control_dependencies(all_reduced.values):
|
||||||
if d in all_reduced.devices:
|
for d in devices:
|
||||||
index.append(all_reduced.get(d))
|
with ops.device(d):
|
||||||
else:
|
if d in all_reduced.devices:
|
||||||
# TODO(josh11b): Once we add support for model parallelism, get the
|
index.append(array_ops.identity(all_reduced.get(d)))
|
||||||
# copy from the corresponding replica instead of the primary.
|
else:
|
||||||
with ops.control_dependencies(all_reduced.values), ops.device(d):
|
# TODO(josh11b): Once we add support for model parallelism, get the
|
||||||
index.append(array_ops.identity(all_reduced.primary))
|
# copy from the corresponding replica instead of the primary.
|
||||||
|
index.append(array_ops.identity(all_reduced.primary))
|
||||||
|
|
||||||
return value_lib.Mirrored(device_map, index, logical_device)
|
return value_lib.Mirrored(device_map, index, logical_device)
|
||||||
|
|
||||||
|
@ -446,11 +446,9 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
|
|||||||
use_strategy_object=False,
|
use_strategy_object=False,
|
||||||
local_mode=False):
|
local_mode=False):
|
||||||
collective_keys = cross_device_utils.CollectiveKeys(
|
collective_keys = cross_device_utils.CollectiveKeys(
|
||||||
group_key_start=10 * num_gpus +
|
group_key_start=10 + CollectiveAllReduceTest.collective_key_base,
|
||||||
CollectiveAllReduceTest.collective_key_base,
|
op_instance_key_start=100 + CollectiveAllReduceTest.collective_key_base,
|
||||||
instance_key_start=num_gpus * 100 +
|
variable_instance_key_start=10000 +
|
||||||
CollectiveAllReduceTest.collective_key_base,
|
|
||||||
instance_key_with_id_start=num_gpus * 10000 +
|
|
||||||
CollectiveAllReduceTest.collective_key_base)
|
CollectiveAllReduceTest.collective_key_base)
|
||||||
if local_mode:
|
if local_mode:
|
||||||
if num_gpus:
|
if num_gpus:
|
||||||
|
@ -253,31 +253,28 @@ class CollectiveKeys(object):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
group_key_start=1,
|
group_key_start=1,
|
||||||
instance_key_start=100,
|
op_instance_key_start=100,
|
||||||
instance_key_with_id_start=10000):
|
variable_instance_key_start=1000000):
|
||||||
"""Initializes the object.
|
"""Initializes the object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group_key_start: the starting integer of group key.
|
group_key_start: the starting integer of group key.
|
||||||
instance_key_start: the starting integer of instance key.
|
op_instance_key_start: the starting integer of instance key for ops.
|
||||||
instance_key_with_id_start: the starting integer of instance key that is
|
variable_instance_key_start: the starting integer of instance key for
|
||||||
recorded with an id.
|
variables.
|
||||||
"""
|
"""
|
||||||
self._group_key = group_key_start
|
self._group_key = group_key_start
|
||||||
self._group_key_table = {}
|
self._group_key_table = {}
|
||||||
|
|
||||||
# For instance keys with ids
|
assert op_instance_key_start != variable_instance_key_start
|
||||||
self._instance_key_id_to_key_table = {}
|
self._op_instance_key_start = op_instance_key_start
|
||||||
self._instance_key_with_id_counter = instance_key_with_id_start
|
self._variable_instance_key = variable_instance_key_start
|
||||||
|
|
||||||
# For instance keys without ids
|
|
||||||
self._instance_key_start = instance_key_start
|
|
||||||
|
|
||||||
def _get_thread_local_object(self):
|
def _get_thread_local_object(self):
|
||||||
# We make instance key without key ids thread local so that it will work
|
# We make instance key without key ids thread local so that it will work
|
||||||
# with MirroredStrategy and distribute coordinator.
|
# with MirroredStrategy and distribute coordinator.
|
||||||
if not hasattr(_thread_local, 'instance_key'):
|
if not hasattr(_thread_local, 'op_instance_key'):
|
||||||
_thread_local.instance_key = self._instance_key_start
|
_thread_local.op_instance_key = self._op_instance_key_start
|
||||||
return _thread_local
|
return _thread_local
|
||||||
|
|
||||||
def get_group_key(self, devices):
|
def get_group_key(self, devices):
|
||||||
@ -304,25 +301,17 @@ class CollectiveKeys(object):
|
|||||||
self._group_key_table[key_id] = new_key
|
self._group_key_table[key_id] = new_key
|
||||||
return self._group_key_table[key_id]
|
return self._group_key_table[key_id]
|
||||||
|
|
||||||
def get_instance_key(self, key_id=None):
|
def get_op_instance_key(self):
|
||||||
"""Returns a new instance key for use in defining a collective op.
|
"""Returns a new instance key for use in defining a collective op."""
|
||||||
|
v = self._get_thread_local_object().op_instance_key
|
||||||
|
self._get_thread_local_object().op_instance_key += 1
|
||||||
|
return v
|
||||||
|
|
||||||
Args:
|
def get_variable_instance_key(self):
|
||||||
key_id: optional string. If set, key will be recorded and the same key
|
"""Returns a new instance key for use in creating a Variable."""
|
||||||
will be returned when the same key_id is provided. If not, an increasing
|
v = self._variable_instance_key
|
||||||
instance key will be returned.
|
self._variable_instance_key += 1
|
||||||
"""
|
return v
|
||||||
if key_id:
|
|
||||||
with _lock:
|
|
||||||
if key_id not in self._instance_key_id_to_key_table:
|
|
||||||
self._instance_key_with_id_counter += 1
|
|
||||||
self._instance_key_id_to_key_table[key_id] = (
|
|
||||||
self._instance_key_with_id_counter)
|
|
||||||
return self._instance_key_id_to_key_table[key_id]
|
|
||||||
else:
|
|
||||||
v = self._get_thread_local_object().instance_key
|
|
||||||
self._get_thread_local_object().instance_key += 1
|
|
||||||
return v
|
|
||||||
|
|
||||||
|
|
||||||
def build_collective_reduce(input_tensors,
|
def build_collective_reduce(input_tensors,
|
||||||
@ -354,7 +343,7 @@ def build_collective_reduce(input_tensors,
|
|||||||
devices = [t.device for t in input_tensors]
|
devices = [t.device for t in input_tensors]
|
||||||
num_devices = len(devices)
|
num_devices = len(devices)
|
||||||
group_key = collective_keys.get_group_key(devices)
|
group_key = collective_keys.get_group_key(devices)
|
||||||
instance_key = collective_keys.get_instance_key()
|
instance_key = collective_keys.get_op_instance_key()
|
||||||
subdiv_offsets = [0] # TODO(tucker): maybe support non-default subdiv spec
|
subdiv_offsets = [0] # TODO(tucker): maybe support non-default subdiv spec
|
||||||
|
|
||||||
def collective_all_reduce():
|
def collective_all_reduce():
|
||||||
@ -399,7 +388,7 @@ def build_collective_gather(input_tensors, num_workers, collective_keys):
|
|||||||
devices = [t.device for t in input_tensors]
|
devices = [t.device for t in input_tensors]
|
||||||
num_devices = len(devices)
|
num_devices = len(devices)
|
||||||
group_key = collective_keys.get_group_key(devices)
|
group_key = collective_keys.get_group_key(devices)
|
||||||
instance_key = collective_keys.get_instance_key()
|
instance_key = collective_keys.get_op_instance_key()
|
||||||
|
|
||||||
def collective_all_gather():
|
def collective_all_gather():
|
||||||
"""Call collective allgather."""
|
"""Call collective allgather."""
|
||||||
|
@ -532,6 +532,30 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
# containing job names.
|
# containing job names.
|
||||||
self._inferred_cross_device_ops = cross_device_ops_lib.NcclAllReduce()
|
self._inferred_cross_device_ops = cross_device_ops_lib.NcclAllReduce()
|
||||||
|
|
||||||
|
def _get_variable_creator_initial_value(self,
|
||||||
|
replica_id=0,
|
||||||
|
device=None,
|
||||||
|
primary_var=None,
|
||||||
|
**kwargs):
|
||||||
|
"""Return the initial value for variables on a replica."""
|
||||||
|
if replica_id == 0:
|
||||||
|
return kwargs["initial_value"]
|
||||||
|
else:
|
||||||
|
assert primary_var is not None
|
||||||
|
assert device is not None
|
||||||
|
assert kwargs is not None
|
||||||
|
|
||||||
|
def initial_value_fn():
|
||||||
|
if context.executing_eagerly() or ops.inside_function():
|
||||||
|
init_value = primary_var.value()
|
||||||
|
return array_ops.identity(init_value)
|
||||||
|
else:
|
||||||
|
with ops.device(device):
|
||||||
|
init_value = primary_var.initial_value
|
||||||
|
return array_ops.identity(init_value)
|
||||||
|
|
||||||
|
return initial_value_fn
|
||||||
|
|
||||||
def _create_variable(self, next_creator, *args, **kwargs):
|
def _create_variable(self, next_creator, *args, **kwargs):
|
||||||
"""Create a mirrored variable. See `DistributionStrategy.scope`."""
|
"""Create a mirrored variable. See `DistributionStrategy.scope`."""
|
||||||
colocate_with = kwargs.pop("colocate_with", None)
|
colocate_with = kwargs.pop("colocate_with", None)
|
||||||
@ -549,6 +573,11 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
value_list = []
|
value_list = []
|
||||||
for i, d in enumerate(devices):
|
for i, d in enumerate(devices):
|
||||||
with ops.device(d):
|
with ops.device(d):
|
||||||
|
kwargs["initial_value"] = self._get_variable_creator_initial_value(
|
||||||
|
replica_id=i,
|
||||||
|
device=d,
|
||||||
|
primary_var=value_list[0] if value_list else None,
|
||||||
|
**kwargs)
|
||||||
if i > 0:
|
if i > 0:
|
||||||
# Give replicas meaningful distinct names:
|
# Give replicas meaningful distinct names:
|
||||||
var0name = value_list[0].name.split(":")[0]
|
var0name = value_list[0].name.split(":")[0]
|
||||||
@ -556,16 +585,6 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
# ensure that we ignore the name scope and instead use the given
|
# ensure that we ignore the name scope and instead use the given
|
||||||
# name as the absolute name of the variable.
|
# name as the absolute name of the variable.
|
||||||
kwargs["name"] = "%s/replica_%d/" % (var0name, i)
|
kwargs["name"] = "%s/replica_%d/" % (var0name, i)
|
||||||
# Initialize replicas with the same value:
|
|
||||||
def initial_value_fn(device=d):
|
|
||||||
if context.executing_eagerly() or ops.inside_function():
|
|
||||||
init_value = value_list[0].value()
|
|
||||||
return array_ops.identity(init_value)
|
|
||||||
else:
|
|
||||||
with ops.device(device):
|
|
||||||
init_value = value_list[0].initial_value
|
|
||||||
return array_ops.identity(init_value)
|
|
||||||
kwargs["initial_value"] = initial_value_fn
|
|
||||||
with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
|
with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
|
||||||
# Don't record operations (e.g. other variable reads) during
|
# Don't record operations (e.g. other variable reads) during
|
||||||
# variable creation.
|
# variable creation.
|
||||||
@ -749,8 +768,8 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
reduce_op, value, destinations=destinations)
|
reduce_op, value, destinations=destinations)
|
||||||
|
|
||||||
def _batch_reduce_to(self, reduce_op, value_destination_pairs):
|
def _batch_reduce_to(self, reduce_op, value_destination_pairs):
|
||||||
return self._get_cross_device_ops().batch_reduce(
|
return self._get_cross_device_ops().batch_reduce(reduce_op,
|
||||||
reduce_op, value_destination_pairs)
|
value_destination_pairs)
|
||||||
|
|
||||||
def _update(self, var, fn, args, kwargs, group):
|
def _update(self, var, fn, args, kwargs, group):
|
||||||
# TODO(josh11b): In eager mode, use one thread per device.
|
# TODO(josh11b): In eager mode, use one thread per device.
|
||||||
|
@ -48,12 +48,8 @@ from tensorflow.python.framework import tensor_shape
|
|||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.keras.engine import training as keras_training
|
from tensorflow.python.keras.engine import training as keras_training
|
||||||
from tensorflow.python.keras.layers import core as keras_core
|
from tensorflow.python.keras.layers import core as keras_core
|
||||||
from tensorflow.python.layers import core
|
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import rnn
|
|
||||||
from tensorflow.python.ops import rnn_cell_impl
|
|
||||||
from tensorflow.python.ops import state_ops
|
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.training import gradient_descent
|
from tensorflow.python.training import gradient_descent
|
||||||
@ -370,516 +366,6 @@ class MirroredStrategyCallForEachReplicaTest(test.TestCase):
|
|||||||
distribution.extended.call_for_each_replica(model_fn)
|
distribution.extended.call_for_each_replica(model_fn)
|
||||||
|
|
||||||
|
|
||||||
@combinations.generate(
|
|
||||||
combinations.combine(
|
|
||||||
distribution=[
|
|
||||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
|
||||||
],
|
|
||||||
mode=["graph", "eager"]))
|
|
||||||
class MirroredStrategyVariableCreationTest(test.TestCase):
|
|
||||||
|
|
||||||
# TODO(priyag): Modify more tests to use this helper and check more
|
|
||||||
# properties.
|
|
||||||
def _test_mv_properties(self, var, name, strategy):
|
|
||||||
self.assertIsInstance(var, values.MirroredVariable)
|
|
||||||
self.assertEqual(name, var.name)
|
|
||||||
self.assertIs(strategy, var.distribute_strategy)
|
|
||||||
for d in var.devices:
|
|
||||||
self.assertEqual(d, var.get(d).device)
|
|
||||||
self.assertIs(strategy, var.get(d)._distribute_strategy) # pylint: disable=protected-access
|
|
||||||
|
|
||||||
def testVariableInFuncGraph(self, distribution):
|
|
||||||
def model_fn():
|
|
||||||
v = variable_scope.variable(2.0, name="bar")
|
|
||||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
|
||||||
return v
|
|
||||||
|
|
||||||
with func_graph.FuncGraph("fg").as_default(), distribution.scope():
|
|
||||||
v1 = variable_scope.variable(1.0, name="foo")
|
|
||||||
v2 = distribution.extended.call_for_each_replica(model_fn)
|
|
||||||
|
|
||||||
self._test_mv_properties(v1, "foo:0", distribution)
|
|
||||||
self._test_mv_properties(v2, "bar:0", distribution)
|
|
||||||
|
|
||||||
def testVariableWithTensorInitialValueInFunction(self, distribution):
|
|
||||||
if not context.executing_eagerly():
|
|
||||||
self.skipTest("`tf.function` is an eager-only feature")
|
|
||||||
|
|
||||||
v = [None]
|
|
||||||
def model_fn():
|
|
||||||
if v[0] is None:
|
|
||||||
init_val = array_ops.zeros([])
|
|
||||||
v[0] = variables.Variable(init_val)
|
|
||||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
|
||||||
return v[0]
|
|
||||||
|
|
||||||
@def_function.function(autograph=False)
|
|
||||||
def make_v1():
|
|
||||||
return distribution.experimental_local_results(
|
|
||||||
distribution.extended.call_for_each_replica(model_fn))
|
|
||||||
|
|
||||||
self.assertAllEqual([0, 0], make_v1())
|
|
||||||
|
|
||||||
def testSingleVariable(self, distribution):
|
|
||||||
def model_fn():
|
|
||||||
# This variable should be created only once across the threads because of
|
|
||||||
# special variable_creator functions used by
|
|
||||||
# `distribution.extended.call_for_each_replica`.
|
|
||||||
v = variable_scope.variable(1.0, name="foo")
|
|
||||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
|
||||||
return v
|
|
||||||
|
|
||||||
with distribution.scope():
|
|
||||||
result = distribution.extended.call_for_each_replica(model_fn)
|
|
||||||
self._test_mv_properties(result, "foo:0", distribution)
|
|
||||||
|
|
||||||
def testUnnamedVariable(self, distribution):
|
|
||||||
def model_fn():
|
|
||||||
v = variable_scope.variable(1.0)
|
|
||||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
|
||||||
return v
|
|
||||||
|
|
||||||
with distribution.scope():
|
|
||||||
result = distribution.extended.call_for_each_replica(model_fn)
|
|
||||||
self._test_mv_properties(result, "Variable:0", distribution)
|
|
||||||
|
|
||||||
def testMultipleVariables(self, distribution):
|
|
||||||
def model_fn():
|
|
||||||
vs = []
|
|
||||||
for i in range(5):
|
|
||||||
vs.append(variable_scope.variable(1.0, name="foo" + str(i)))
|
|
||||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
|
||||||
return vs
|
|
||||||
|
|
||||||
with distribution.scope():
|
|
||||||
result = distribution.extended.call_for_each_replica(model_fn)
|
|
||||||
for i, v in enumerate(result):
|
|
||||||
self._test_mv_properties(v, "foo" + str(i) + ":0", distribution)
|
|
||||||
|
|
||||||
def testMultipleVariablesWithSameCanonicalName(self, distribution):
|
|
||||||
def model_fn():
|
|
||||||
vs = []
|
|
||||||
vs.append(variable_scope.variable(1.0, name="foo/bar"))
|
|
||||||
vs.append(variable_scope.variable(1.0, name="foo_1/bar"))
|
|
||||||
vs.append(variable_scope.variable(1.0, name="foo_1/bar_1"))
|
|
||||||
vs.append(variable_scope.variable(1.0, name="foo/bar_1"))
|
|
||||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
|
||||||
return vs
|
|
||||||
|
|
||||||
with distribution.scope():
|
|
||||||
result = distribution.extended.call_for_each_replica(model_fn)
|
|
||||||
for v in result:
|
|
||||||
self.assertIsInstance(v, values.MirroredVariable)
|
|
||||||
self.assertEqual(4, len(result))
|
|
||||||
self.assertEqual("foo/bar:0", result[0].name)
|
|
||||||
self.assertEqual("foo_1/bar:0", result[1].name)
|
|
||||||
self.assertEqual("foo_1/bar_1:0", result[2].name)
|
|
||||||
self.assertEqual("foo/bar_1:0", result[3].name)
|
|
||||||
|
|
||||||
def testVariableWithSameCanonicalNameAcrossThreads(self, distribution):
|
|
||||||
def model_fn():
|
|
||||||
replica_id = self.evaluate(_replica_id())
|
|
||||||
v = variable_scope.variable(1.0, name="foo_" + str(replica_id))
|
|
||||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
|
||||||
return v
|
|
||||||
|
|
||||||
with distribution.scope():
|
|
||||||
result = distribution.extended.call_for_each_replica(model_fn)
|
|
||||||
self.assertIsInstance(result, values.MirroredVariable)
|
|
||||||
# The resulting mirrored variable will use the name from the first device.
|
|
||||||
self.assertEqual("foo_0:0", result.name)
|
|
||||||
|
|
||||||
def testWithLayers(self, distribution):
|
|
||||||
def model_fn(features):
|
|
||||||
with variable_scope.variable_scope("common"):
|
|
||||||
layer1 = core.Dense(1)
|
|
||||||
layer1(features)
|
|
||||||
layer2 = core.Dense(1)
|
|
||||||
layer2(features)
|
|
||||||
# This will pause the current thread, and execute the other thread.
|
|
||||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
|
||||||
layer3 = core.Dense(1)
|
|
||||||
layer3(features)
|
|
||||||
return [(layer1.kernel, layer1.bias),
|
|
||||||
(layer2.kernel, layer2.bias),
|
|
||||||
(layer3.kernel, layer3.bias)]
|
|
||||||
|
|
||||||
iterator = distribution.make_input_fn_iterator(
|
|
||||||
lambda _: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10))
|
|
||||||
self.evaluate(iterator.initialize())
|
|
||||||
features = iterator.get_next()
|
|
||||||
|
|
||||||
with distribution.scope():
|
|
||||||
result = distribution.extended.call_for_each_replica(
|
|
||||||
model_fn, args=(features,))
|
|
||||||
suffixes = ["", "_1", "_2"]
|
|
||||||
for (kernel, bias), suffix in zip(result, suffixes):
|
|
||||||
self.assertIsInstance(kernel, values.MirroredVariable)
|
|
||||||
self.assertEqual("common/dense" + suffix + "/kernel:0", kernel.name)
|
|
||||||
self.assertIsInstance(bias, values.MirroredVariable)
|
|
||||||
self.assertEqual("common/dense" + suffix + "/bias:0", bias.name)
|
|
||||||
|
|
||||||
def testWithVariableAndVariableScope(self, distribution):
|
|
||||||
def model_fn():
|
|
||||||
v0 = variable_scope.variable(1.0, name="var0", aggregation=None)
|
|
||||||
with variable_scope.variable_scope("common"):
|
|
||||||
v1 = variable_scope.variable(1.0, name="var1")
|
|
||||||
# This will pause the current thread, and execute the other thread.
|
|
||||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
|
||||||
v2 = variable_scope.variable(
|
|
||||||
1.0,
|
|
||||||
name="var2",
|
|
||||||
synchronization=variable_scope.VariableSynchronization.ON_READ,
|
|
||||||
aggregation=variable_scope.VariableAggregation.SUM)
|
|
||||||
v3 = variable_scope.variable(
|
|
||||||
1.0,
|
|
||||||
name="var3",
|
|
||||||
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
|
|
||||||
aggregation=variable_scope.VariableAggregation.MEAN)
|
|
||||||
|
|
||||||
return v0, v1, v2, v3
|
|
||||||
|
|
||||||
with distribution.scope():
|
|
||||||
v = variable_scope.variable(1.0, name="var-main0")
|
|
||||||
self.assertEqual("var-main0:0", v.name)
|
|
||||||
|
|
||||||
result = distribution.extended.call_for_each_replica(model_fn)
|
|
||||||
self.assertEqual(4, len(result))
|
|
||||||
v0, v1, v2, v3 = result
|
|
||||||
self.assertIsInstance(v0, values.MirroredVariable)
|
|
||||||
self.assertEqual("var0:0", v0.name)
|
|
||||||
self.assertIsInstance(v1, values.MirroredVariable)
|
|
||||||
self.assertEqual("common/var1:0", v1.name)
|
|
||||||
self.assertIsInstance(v2, values.SyncOnReadVariable)
|
|
||||||
self.assertEqual("common/var2:0", v2.name)
|
|
||||||
self.assertEqual(variable_scope.VariableAggregation.SUM, v2.aggregation)
|
|
||||||
self.assertIsInstance(v3, values.MirroredVariable)
|
|
||||||
self.assertEqual("common/var3:0", v3.name)
|
|
||||||
self.assertEqual(variable_scope.VariableAggregation.MEAN, v3.aggregation)
|
|
||||||
|
|
||||||
def testWithGetVariableAndVariableScope(self, distribution):
|
|
||||||
def model_fn():
|
|
||||||
v0 = variable_scope.get_variable("var0", [1])
|
|
||||||
with variable_scope.variable_scope("common"):
|
|
||||||
v1 = variable_scope.get_variable("var1", [1])
|
|
||||||
# This will pause the current thread, and execute the other thread.
|
|
||||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
|
||||||
v2 = variable_scope.get_variable(
|
|
||||||
"var2", [1],
|
|
||||||
synchronization=variable_scope.VariableSynchronization.ON_READ,
|
|
||||||
aggregation=variable_scope.VariableAggregation.SUM)
|
|
||||||
v3 = variable_scope.get_variable(
|
|
||||||
"var3", [1],
|
|
||||||
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
|
|
||||||
aggregation=variable_scope.VariableAggregation.MEAN)
|
|
||||||
|
|
||||||
return v0, v1, v2, v3
|
|
||||||
|
|
||||||
with distribution.scope():
|
|
||||||
with variable_scope.variable_scope("main"):
|
|
||||||
v = variable_scope.get_variable("var-main0", [1])
|
|
||||||
self.assertEqual("main/var-main0:0", v.name)
|
|
||||||
|
|
||||||
result = distribution.extended.call_for_each_replica(model_fn)
|
|
||||||
self.assertEqual(4, len(result))
|
|
||||||
v0, v1, v2, v3 = result
|
|
||||||
self.assertIsInstance(v0, values.MirroredVariable)
|
|
||||||
self.assertEqual("main/var0:0", v0.name)
|
|
||||||
self.assertIsInstance(v1, values.MirroredVariable)
|
|
||||||
self.assertEqual("main/common/var1:0", v1.name)
|
|
||||||
self.assertIsInstance(v2, values.SyncOnReadVariable)
|
|
||||||
self.assertEqual("main/common/var2:0", v2.name)
|
|
||||||
self.assertEqual(variable_scope.VariableAggregation.SUM,
|
|
||||||
v2.aggregation)
|
|
||||||
self.assertIsInstance(v3, values.MirroredVariable)
|
|
||||||
self.assertEqual("main/common/var3:0", v3.name)
|
|
||||||
self.assertEqual(variable_scope.VariableAggregation.MEAN,
|
|
||||||
v3.aggregation)
|
|
||||||
|
|
||||||
def testOnlyFirstReplicaUpdatesVariables(self, distribution):
|
|
||||||
def create_fn():
|
|
||||||
aggregation = variable_scope.VariableAggregation.ONLY_FIRST_REPLICA
|
|
||||||
v0 = variable_scope.variable(
|
|
||||||
2.0,
|
|
||||||
name="on_read",
|
|
||||||
synchronization=variable_scope.VariableSynchronization.ON_READ,
|
|
||||||
aggregation=aggregation)
|
|
||||||
v1 = variable_scope.variable(
|
|
||||||
3.0,
|
|
||||||
name="on_write",
|
|
||||||
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
|
|
||||||
aggregation=aggregation)
|
|
||||||
return v0, v1
|
|
||||||
|
|
||||||
devices = ["/device:GPU:0", "/device:CPU:0"]
|
|
||||||
with distribution.scope():
|
|
||||||
v0, v1 = distribution.extended.call_for_each_replica(create_fn)
|
|
||||||
self.evaluate(v0.initializer)
|
|
||||||
self.assertEqual(2.0, self.evaluate(v0.get(devices[0])))
|
|
||||||
self.assertEqual(2.0, self.evaluate(v0.get(devices[1])))
|
|
||||||
self.assertEqual(2.0, self.evaluate(distribution.extended.read_var(v0)))
|
|
||||||
self.evaluate(v1.initializer)
|
|
||||||
self.assertEqual(3.0, self.evaluate(v1.get(devices[0])))
|
|
||||||
self.assertEqual(3.0, self.evaluate(v1.get(devices[1])))
|
|
||||||
self.assertEqual(3.0, self.evaluate(distribution.extended.read_var(v1)))
|
|
||||||
|
|
||||||
def replica_id_plus_one():
|
|
||||||
return math_ops.cast(_replica_id() + 1, dtype=dtypes.float32)
|
|
||||||
|
|
||||||
# Update using the assign_add member function.
|
|
||||||
def update_member_fn():
|
|
||||||
update0 = v0.assign_add(5.0 * replica_id_plus_one())
|
|
||||||
update1 = v1.assign_add(7.0 * replica_id_plus_one())
|
|
||||||
return update0, update1
|
|
||||||
|
|
||||||
update0a, update1a = distribution.extended.call_for_each_replica(
|
|
||||||
update_member_fn)
|
|
||||||
|
|
||||||
# Update "sync on read" variable.
|
|
||||||
self.evaluate(distribution.group(update0a))
|
|
||||||
self.assertEqual(2.0 + 5.0, self.evaluate(v0.get(devices[0])))
|
|
||||||
# Writes are not synchronized for "sync on read" variables,
|
|
||||||
# so device[1] can end up with a different value.
|
|
||||||
self.assertEqual(2.0 + 2*5.0, self.evaluate(v0.get(devices[1])))
|
|
||||||
# Always reads from device 0.
|
|
||||||
self.assertEqual(2.0 + 5.0, self.evaluate(
|
|
||||||
distribution.extended.read_var(v0)))
|
|
||||||
|
|
||||||
# Update "sync on write" variable.
|
|
||||||
self.evaluate(distribution.group(update1a))
|
|
||||||
self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[0])))
|
|
||||||
# Writes are synchronized for v1, only the argument to assign_add on
|
|
||||||
# device[0] is used.
|
|
||||||
self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[1])))
|
|
||||||
self.assertEqual(3.0 + 7.0, self.evaluate(
|
|
||||||
distribution.extended.read_var(v1)))
|
|
||||||
|
|
||||||
# Update using state_ops.assign_add global function.
|
|
||||||
def update_state_ops_fn():
|
|
||||||
update0 = state_ops.assign_add(v0, 11.0 * replica_id_plus_one())
|
|
||||||
update1 = state_ops.assign_add(v1, 13.0 * replica_id_plus_one())
|
|
||||||
return update0, update1
|
|
||||||
|
|
||||||
update0b, update1b = distribution.extended.call_for_each_replica(
|
|
||||||
update_state_ops_fn)
|
|
||||||
self.evaluate(distribution.group(update0b))
|
|
||||||
|
|
||||||
# Update "sync on read" variable.
|
|
||||||
self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate(v0.get(devices[0])))
|
|
||||||
self.assertEqual(2.0 + 2*5.0 + 2*11.0, self.evaluate(v0.get(devices[1])))
|
|
||||||
self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate(
|
|
||||||
distribution.extended.read_var(v0)))
|
|
||||||
|
|
||||||
# Update "sync on write" variable.
|
|
||||||
self.evaluate(distribution.group(update1b))
|
|
||||||
self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[0])))
|
|
||||||
self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[1])))
|
|
||||||
self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(
|
|
||||||
distribution.extended.read_var(v1)))
|
|
||||||
|
|
||||||
def testNoneSynchronizationWithGetVariable(self, distribution):
|
|
||||||
with distribution.scope():
|
|
||||||
with self.assertRaisesRegexp(
|
|
||||||
ValueError, "`NONE` variable synchronization mode is not "
|
|
||||||
"supported with `Mirrored` distribution strategy. Please change "
|
|
||||||
"the `synchronization` for variable: v"):
|
|
||||||
variable_scope.get_variable(
|
|
||||||
"v", [1],
|
|
||||||
synchronization=variable_scope.VariableSynchronization.NONE)
|
|
||||||
|
|
||||||
def testNoneSynchronizationWithVariable(self, distribution):
|
|
||||||
with distribution.scope():
|
|
||||||
with self.assertRaisesRegexp(
|
|
||||||
ValueError, "`NONE` variable synchronization mode is not "
|
|
||||||
"supported with `Mirrored` distribution strategy. Please change "
|
|
||||||
"the `synchronization` for variable: v"):
|
|
||||||
variable_scope.variable(
|
|
||||||
1.0,
|
|
||||||
name="v",
|
|
||||||
synchronization=variable_scope.VariableSynchronization.NONE)
|
|
||||||
|
|
||||||
def testInvalidSynchronizationWithVariable(self, distribution):
|
|
||||||
with distribution.scope():
|
|
||||||
with self.assertRaisesRegexp(
|
|
||||||
ValueError, "Invalid variable synchronization mode: Invalid for "
|
|
||||||
"variable: v"):
|
|
||||||
variable_scope.variable(1.0, name="v", synchronization="Invalid")
|
|
||||||
|
|
||||||
def testInvalidAggregationWithGetVariable(self, distribution):
|
|
||||||
with distribution.scope():
|
|
||||||
with self.assertRaisesRegexp(
|
|
||||||
ValueError, "Invalid variable aggregation mode: invalid for "
|
|
||||||
"variable: v"):
|
|
||||||
variable_scope.get_variable(
|
|
||||||
"v", [1],
|
|
||||||
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
|
|
||||||
aggregation="invalid")
|
|
||||||
|
|
||||||
def testInvalidAggregationWithVariable(self, distribution):
|
|
||||||
with distribution.scope():
|
|
||||||
with self.assertRaisesRegexp(
|
|
||||||
ValueError, "Invalid variable aggregation mode: invalid for "
|
|
||||||
"variable: v"):
|
|
||||||
variable_scope.variable(
|
|
||||||
1.0,
|
|
||||||
name="v",
|
|
||||||
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
|
|
||||||
aggregation="invalid")
|
|
||||||
|
|
||||||
def testNonMatchingVariableCreation(self, distribution):
|
|
||||||
self.skipTest("b/123075960")
|
|
||||||
def model_fn(name):
|
|
||||||
v = variable_scope.variable(1.0, name=name)
|
|
||||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
|
||||||
return v
|
|
||||||
|
|
||||||
with distribution.scope():
|
|
||||||
device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0"))
|
|
||||||
names = values.DistributedValues(device_map, ("foo", "bar"))
|
|
||||||
with self.assertRaises(RuntimeError):
|
|
||||||
_ = distribution.extended.call_for_each_replica(model_fn, args=(names,))
|
|
||||||
|
|
||||||
def testSyncOnReadVariable(self, distribution):
|
|
||||||
all_v_sum = {}
|
|
||||||
all_v_mean = {}
|
|
||||||
components_sum = {}
|
|
||||||
components_mean = {}
|
|
||||||
|
|
||||||
def model_fn():
|
|
||||||
replica_id = self.evaluate(_replica_id())
|
|
||||||
v_sum = variable_scope.variable(
|
|
||||||
1.0,
|
|
||||||
synchronization=variable_scope.VariableSynchronization.ON_READ,
|
|
||||||
aggregation=variable_scope.VariableAggregation.SUM)
|
|
||||||
v_mean = variable_scope.variable(
|
|
||||||
4.0,
|
|
||||||
synchronization=variable_scope.VariableSynchronization.ON_READ,
|
|
||||||
aggregation=variable_scope.VariableAggregation.MEAN)
|
|
||||||
self.assertIsInstance(v_sum, values.SyncOnReadVariable)
|
|
||||||
self.assertIsInstance(v_mean, values.SyncOnReadVariable)
|
|
||||||
updates = [v_sum.assign_add(2.0 + replica_id),
|
|
||||||
v_mean.assign(6.0 * replica_id)]
|
|
||||||
all_v_sum[replica_id] = v_sum
|
|
||||||
all_v_mean[replica_id] = v_mean
|
|
||||||
c_sum = v_sum.get()
|
|
||||||
c_mean = v_mean.get()
|
|
||||||
components_sum[replica_id] = c_sum
|
|
||||||
components_mean[replica_id] = c_mean
|
|
||||||
self.assertIsNot(v_sum, c_sum)
|
|
||||||
self.assertIsNot(v_mean, c_mean)
|
|
||||||
return updates, v_sum, v_mean, c_sum, c_mean
|
|
||||||
|
|
||||||
with distribution.scope():
|
|
||||||
# Create "sum" and "mean" versions of SyncOnReadVariables.
|
|
||||||
ret_ops, ret_v_sum, ret_v_mean, regrouped_sum, regrouped_mean = (
|
|
||||||
distribution.extended.call_for_each_replica(model_fn))
|
|
||||||
# Should see the same wrapping instance in all replicas.
|
|
||||||
self.assertIs(all_v_sum[0], ret_v_sum)
|
|
||||||
self.assertIs(all_v_mean[0], ret_v_mean)
|
|
||||||
self.assertIs(all_v_sum[0], all_v_sum[1])
|
|
||||||
self.assertIs(all_v_mean[0], all_v_mean[1])
|
|
||||||
|
|
||||||
# Regroup should recover the same wrapper.
|
|
||||||
self.assertIs(ret_v_sum, regrouped_sum)
|
|
||||||
self.assertIs(ret_v_mean, regrouped_mean)
|
|
||||||
self.assertIsNot(components_sum[0], components_sum[1])
|
|
||||||
self.assertIsNot(components_mean[0], components_mean[1])
|
|
||||||
|
|
||||||
# Apply updates
|
|
||||||
self.evaluate(variables.global_variables_initializer())
|
|
||||||
self.evaluate([y for x in ret_ops # pylint: disable=g-complex-comprehension
|
|
||||||
for y in distribution.experimental_local_results(x)])
|
|
||||||
expected_sum = 0.0
|
|
||||||
expected_mean = 0.0
|
|
||||||
for i, d in enumerate(distribution.extended.worker_devices):
|
|
||||||
# Should see different values on different devices.
|
|
||||||
v_sum_value = self.evaluate(ret_v_sum.get(d).read_value())
|
|
||||||
v_mean_value = self.evaluate(ret_v_mean.get(d).read_value())
|
|
||||||
expected = i + 3.0
|
|
||||||
self.assertEqual(expected, v_sum_value)
|
|
||||||
expected_sum += expected
|
|
||||||
expected = i * 6.0
|
|
||||||
self.assertEqual(expected, v_mean_value)
|
|
||||||
expected_mean += expected
|
|
||||||
expected_mean /= len(distribution.extended.worker_devices)
|
|
||||||
|
|
||||||
# Without get(device), should return the value you get by
|
|
||||||
# applying the reduction across all replicas (whether you use
|
|
||||||
# read_var(), get(), or nothing).
|
|
||||||
self.assertEqual(expected_sum, self.evaluate(
|
|
||||||
distribution.extended.read_var(ret_v_sum)))
|
|
||||||
self.assertEqual(expected_mean, self.evaluate(
|
|
||||||
distribution.extended.read_var(ret_v_mean)))
|
|
||||||
self.assertEqual(expected_sum, self.evaluate(ret_v_sum.get()))
|
|
||||||
self.assertEqual(expected_mean, self.evaluate(ret_v_mean.get()))
|
|
||||||
self.assertEqual(expected_sum, self.evaluate(ret_v_sum))
|
|
||||||
self.assertEqual(expected_mean, self.evaluate(ret_v_mean))
|
|
||||||
|
|
||||||
# TODO(priyag): Update this test to work in eager mode as well.
|
|
||||||
def testDynamicRnnVariables(self, distribution):
|
|
||||||
def model_fn():
|
|
||||||
inputs = constant_op.constant(2 * [2 * [[0.0, 1.0, 2.0, 3.0, 4.0]]])
|
|
||||||
cell_fw = rnn_cell_impl.LSTMCell(300)
|
|
||||||
cell_bw = rnn_cell_impl.LSTMCell(300)
|
|
||||||
(outputs, _) = rnn.bidirectional_dynamic_rnn(
|
|
||||||
cell_fw,
|
|
||||||
cell_bw,
|
|
||||||
inputs,
|
|
||||||
dtype=dtypes.float32)
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
with context.graph_mode(), distribution.scope():
|
|
||||||
result = distribution.extended.call_for_each_replica(model_fn)
|
|
||||||
# Two variables are created by the RNN layer.
|
|
||||||
self.assertEqual(2, len(result))
|
|
||||||
for v in result:
|
|
||||||
self.assertIsInstance(v, values.DistributedValues)
|
|
||||||
_, v1 = distribution.experimental_local_results(v)
|
|
||||||
self.assertStartsWith(v1._op.name, "replica_1/")
|
|
||||||
|
|
||||||
def testSyncOnReadVariableUpdate(self, distribution):
|
|
||||||
def model_fn():
|
|
||||||
v_sum = variable_scope.variable(
|
|
||||||
1.0,
|
|
||||||
synchronization=variable_scope.VariableSynchronization.ON_READ,
|
|
||||||
aggregation=variable_scope.VariableAggregation.SUM)
|
|
||||||
self.assertIsInstance(v_sum, values.SyncOnReadVariable)
|
|
||||||
return v_sum
|
|
||||||
|
|
||||||
def update(var, value):
|
|
||||||
return var.assign(value)
|
|
||||||
|
|
||||||
with distribution.scope():
|
|
||||||
ret_v_sum = distribution.extended.call_for_each_replica(model_fn)
|
|
||||||
|
|
||||||
# Initialize variables.
|
|
||||||
self.evaluate(variables.global_variables_initializer())
|
|
||||||
# Assert that the aggregated value of the sync on read var is the sum
|
|
||||||
# of the individual values before running the update ops.
|
|
||||||
self.assertEqual(1.0, self.evaluate(ret_v_sum.get(
|
|
||||||
distribution.extended.worker_devices[0]).read_value()))
|
|
||||||
self.assertEqual(2.0, self.evaluate(ret_v_sum))
|
|
||||||
|
|
||||||
# Apply updates.
|
|
||||||
update_ops = distribution.extended.update(
|
|
||||||
ret_v_sum, update, args=(5.0,), group=False)
|
|
||||||
self.evaluate(update_ops)
|
|
||||||
# Assert that the aggregated value of the sync on read vars is the sum
|
|
||||||
# of the individual values after running the update ops.
|
|
||||||
self.assertEqual(5.0, self.evaluate(ret_v_sum.get(
|
|
||||||
distribution.extended.worker_devices[0]).read_value()))
|
|
||||||
self.assertEqual(10.0, self.evaluate(ret_v_sum))
|
|
||||||
|
|
||||||
def testVarDistributeStrategy(self, distribution):
|
|
||||||
with distribution.scope():
|
|
||||||
mirrored = variable_scope.variable(1.0)
|
|
||||||
sync_on_read = variable_scope.variable(
|
|
||||||
1.0,
|
|
||||||
synchronization=variable_scope.VariableSynchronization.ON_READ)
|
|
||||||
self.assertIs(distribution, mirrored.distribute_strategy)
|
|
||||||
self.assertIs(distribution, sync_on_read.distribute_strategy)
|
|
||||||
|
|
||||||
|
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
combinations.combine(
|
combinations.combine(
|
||||||
distribution=[
|
distribution=[
|
||||||
|
606
tensorflow/python/distribute/mirrored_variable_test.py
Normal file
606
tensorflow/python/distribute/mirrored_variable_test.py
Normal file
@ -0,0 +1,606 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Test MirroredVariable in MirroredStrategy and MultiWorkerMirroredStrategy."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
|
from tensorflow.python.distribute import collective_all_reduce_strategy
|
||||||
|
from tensorflow.python.distribute import combinations
|
||||||
|
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||||
|
from tensorflow.python.distribute import strategy_combinations
|
||||||
|
from tensorflow.python.distribute import values
|
||||||
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
|
from tensorflow.python.eager import test
|
||||||
|
from tensorflow.python.framework import config
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import func_graph
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.layers import core
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import rnn
|
||||||
|
from tensorflow.python.ops import rnn_cell_impl
|
||||||
|
from tensorflow.python.ops import state_ops
|
||||||
|
from tensorflow.python.ops import variable_scope
|
||||||
|
from tensorflow.python.ops import variables
|
||||||
|
|
||||||
|
|
||||||
|
def _replica_id():
|
||||||
|
replica_id = ds_context.get_replica_context().replica_id_in_sync_group
|
||||||
|
if not isinstance(replica_id, ops.Tensor):
|
||||||
|
replica_id = constant_op.constant(replica_id)
|
||||||
|
return replica_id
|
||||||
|
|
||||||
|
|
||||||
|
def _mimic_two_cpus():
|
||||||
|
cpus = config.list_physical_devices("CPU")
|
||||||
|
|
||||||
|
config.set_virtual_device_configuration(cpus[0], [
|
||||||
|
context.VirtualDeviceConfiguration(),
|
||||||
|
context.VirtualDeviceConfiguration(),
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
distribution=[
|
||||||
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||||
|
combinations.NamedDistribution(
|
||||||
|
"Collective2CPUs",
|
||||||
|
# pylint: disable=g-long-lambda
|
||||||
|
lambda: collective_all_reduce_strategy.
|
||||||
|
CollectiveAllReduceStrategy._from_local_devices((
|
||||||
|
"/device:CPU:0", "/device:CPU:1")),
|
||||||
|
required_gpus=0)
|
||||||
|
],
|
||||||
|
mode=["graph", "eager"]))
|
||||||
|
class MirroredVariableCreationTest(test.TestCase):
|
||||||
|
"""Base class that tests mirrored variable creator.
|
||||||
|
|
||||||
|
Currently it assumes all strategy objects have two replicas.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
_mimic_two_cpus()
|
||||||
|
|
||||||
|
# TODO(priyag): Modify more tests to use this helper and check more
|
||||||
|
# properties.
|
||||||
|
def _test_mv_properties(self, var, name, strategy):
|
||||||
|
self.assertIsInstance(var, values.MirroredVariable)
|
||||||
|
self.assertEqual(name, var.name)
|
||||||
|
self.assertIs(strategy, var.distribute_strategy)
|
||||||
|
for d in var.devices:
|
||||||
|
self.assertEqual(d, var.get(d).device)
|
||||||
|
self.assertIs(strategy, var.get(d)._distribute_strategy) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
def testVariableInFuncGraph(self, distribution):
|
||||||
|
|
||||||
|
def model_fn():
|
||||||
|
v = variable_scope.variable(2.0, name="bar")
|
||||||
|
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||||
|
return v
|
||||||
|
|
||||||
|
with func_graph.FuncGraph("fg").as_default(), distribution.scope():
|
||||||
|
v1 = variable_scope.variable(1.0, name="foo")
|
||||||
|
v2 = distribution.extended.call_for_each_replica(model_fn)
|
||||||
|
|
||||||
|
self._test_mv_properties(v1, "foo:0", distribution)
|
||||||
|
self._test_mv_properties(v2, "bar:0", distribution)
|
||||||
|
|
||||||
|
def testVariableWithTensorInitialValueInFunction(self, distribution):
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
self.skipTest("`tf.function` is an eager-only feature")
|
||||||
|
|
||||||
|
v = [None]
|
||||||
|
|
||||||
|
def model_fn():
|
||||||
|
if v[0] is None:
|
||||||
|
init_val = array_ops.zeros([])
|
||||||
|
v[0] = variables.Variable(init_val)
|
||||||
|
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||||
|
return v[0]
|
||||||
|
|
||||||
|
@def_function.function(autograph=False)
|
||||||
|
def make_v1():
|
||||||
|
return distribution.experimental_local_results(
|
||||||
|
distribution.extended.call_for_each_replica(model_fn))
|
||||||
|
|
||||||
|
self.assertAllEqual([0, 0], make_v1())
|
||||||
|
|
||||||
|
def testSingleVariable(self, distribution):
|
||||||
|
|
||||||
|
def model_fn():
|
||||||
|
# This variable should be created only once across the threads because of
|
||||||
|
# special variable_creator functions used by
|
||||||
|
# `distribution.extended.call_for_each_replica`.
|
||||||
|
v = variable_scope.variable(1.0, name="foo")
|
||||||
|
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||||
|
return v
|
||||||
|
|
||||||
|
with distribution.scope():
|
||||||
|
result = distribution.extended.call_for_each_replica(model_fn)
|
||||||
|
self._test_mv_properties(result, "foo:0", distribution)
|
||||||
|
|
||||||
|
def testUnnamedVariable(self, distribution):
|
||||||
|
|
||||||
|
def model_fn():
|
||||||
|
v = variable_scope.variable(1.0)
|
||||||
|
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||||
|
return v
|
||||||
|
|
||||||
|
with distribution.scope():
|
||||||
|
result = distribution.extended.call_for_each_replica(model_fn)
|
||||||
|
self._test_mv_properties(result, "Variable:0", distribution)
|
||||||
|
|
||||||
|
def testMultipleVariables(self, distribution):
|
||||||
|
|
||||||
|
def model_fn():
|
||||||
|
vs = []
|
||||||
|
for i in range(5):
|
||||||
|
vs.append(variable_scope.variable(1.0, name="foo" + str(i)))
|
||||||
|
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||||
|
return vs
|
||||||
|
|
||||||
|
with distribution.scope():
|
||||||
|
result = distribution.extended.call_for_each_replica(model_fn)
|
||||||
|
for i, v in enumerate(result):
|
||||||
|
self._test_mv_properties(v, "foo" + str(i) + ":0", distribution)
|
||||||
|
|
||||||
|
def testMultipleVariablesWithSameCanonicalName(self, distribution):
|
||||||
|
|
||||||
|
def model_fn():
|
||||||
|
vs = []
|
||||||
|
vs.append(variable_scope.variable(1.0, name="foo/bar"))
|
||||||
|
vs.append(variable_scope.variable(1.0, name="foo_1/bar"))
|
||||||
|
vs.append(variable_scope.variable(1.0, name="foo_1/bar_1"))
|
||||||
|
vs.append(variable_scope.variable(1.0, name="foo/bar_1"))
|
||||||
|
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||||
|
return vs
|
||||||
|
|
||||||
|
with distribution.scope():
|
||||||
|
result = distribution.extended.call_for_each_replica(model_fn)
|
||||||
|
for v in result:
|
||||||
|
self.assertIsInstance(v, values.MirroredVariable)
|
||||||
|
self.assertEqual(4, len(result))
|
||||||
|
self.assertEqual("foo/bar:0", result[0].name)
|
||||||
|
self.assertEqual("foo_1/bar:0", result[1].name)
|
||||||
|
self.assertEqual("foo_1/bar_1:0", result[2].name)
|
||||||
|
self.assertEqual("foo/bar_1:0", result[3].name)
|
||||||
|
|
||||||
|
def testVariableWithSameCanonicalNameAcrossThreads(self, distribution):
|
||||||
|
|
||||||
|
def model_fn():
|
||||||
|
replica_id = self.evaluate(_replica_id())
|
||||||
|
v = variable_scope.variable(1.0, name="foo_" + str(replica_id))
|
||||||
|
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||||
|
return v
|
||||||
|
|
||||||
|
with distribution.scope():
|
||||||
|
result = distribution.extended.call_for_each_replica(model_fn)
|
||||||
|
self.assertIsInstance(result, values.MirroredVariable)
|
||||||
|
# The resulting mirrored variable will use the name from the first device.
|
||||||
|
self.assertEqual("foo_0:0", result.name)
|
||||||
|
|
||||||
|
def testWithLayers(self, distribution):
|
||||||
|
|
||||||
|
def model_fn(features):
|
||||||
|
with variable_scope.variable_scope("common"):
|
||||||
|
layer1 = core.Dense(1)
|
||||||
|
layer1(features)
|
||||||
|
layer2 = core.Dense(1)
|
||||||
|
layer2(features)
|
||||||
|
# This will pause the current thread, and execute the other thread.
|
||||||
|
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||||
|
layer3 = core.Dense(1)
|
||||||
|
layer3(features)
|
||||||
|
return [(layer1.kernel, layer1.bias), (layer2.kernel, layer2.bias),
|
||||||
|
(layer3.kernel, layer3.bias)]
|
||||||
|
|
||||||
|
iterator = distribution.make_input_fn_iterator(
|
||||||
|
lambda _: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10))
|
||||||
|
self.evaluate(iterator.initialize())
|
||||||
|
features = iterator.get_next()
|
||||||
|
|
||||||
|
with distribution.scope():
|
||||||
|
result = distribution.extended.call_for_each_replica(
|
||||||
|
model_fn, args=(features,))
|
||||||
|
suffixes = ["", "_1", "_2"]
|
||||||
|
for (kernel, bias), suffix in zip(result, suffixes):
|
||||||
|
self.assertIsInstance(kernel, values.MirroredVariable)
|
||||||
|
self.assertEqual("common/dense" + suffix + "/kernel:0", kernel.name)
|
||||||
|
self.assertIsInstance(bias, values.MirroredVariable)
|
||||||
|
self.assertEqual("common/dense" + suffix + "/bias:0", bias.name)
|
||||||
|
|
||||||
|
def testWithVariableAndVariableScope(self, distribution):
|
||||||
|
|
||||||
|
def model_fn():
|
||||||
|
v0 = variable_scope.variable(1.0, name="var0", aggregation=None)
|
||||||
|
with variable_scope.variable_scope("common"):
|
||||||
|
v1 = variable_scope.variable(1.0, name="var1")
|
||||||
|
# This will pause the current thread, and execute the other thread.
|
||||||
|
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||||
|
v2 = variable_scope.variable(
|
||||||
|
1.0,
|
||||||
|
name="var2",
|
||||||
|
synchronization=variable_scope.VariableSynchronization.ON_READ,
|
||||||
|
aggregation=variable_scope.VariableAggregation.SUM)
|
||||||
|
v3 = variable_scope.variable(
|
||||||
|
1.0,
|
||||||
|
name="var3",
|
||||||
|
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
|
||||||
|
aggregation=variable_scope.VariableAggregation.MEAN)
|
||||||
|
|
||||||
|
return v0, v1, v2, v3
|
||||||
|
|
||||||
|
with distribution.scope():
|
||||||
|
v = variable_scope.variable(1.0, name="var-main0")
|
||||||
|
self.assertEqual("var-main0:0", v.name)
|
||||||
|
|
||||||
|
result = distribution.extended.call_for_each_replica(model_fn)
|
||||||
|
self.assertEqual(4, len(result))
|
||||||
|
v0, v1, v2, v3 = result
|
||||||
|
self.assertIsInstance(v0, values.MirroredVariable)
|
||||||
|
self.assertEqual("var0:0", v0.name)
|
||||||
|
self.assertIsInstance(v1, values.MirroredVariable)
|
||||||
|
self.assertEqual("common/var1:0", v1.name)
|
||||||
|
self.assertIsInstance(v2, values.SyncOnReadVariable)
|
||||||
|
self.assertEqual("common/var2:0", v2.name)
|
||||||
|
self.assertEqual(variable_scope.VariableAggregation.SUM, v2.aggregation)
|
||||||
|
self.assertIsInstance(v3, values.MirroredVariable)
|
||||||
|
self.assertEqual("common/var3:0", v3.name)
|
||||||
|
self.assertEqual(variable_scope.VariableAggregation.MEAN, v3.aggregation)
|
||||||
|
|
||||||
|
def testWithGetVariableAndVariableScope(self, distribution):
|
||||||
|
|
||||||
|
def model_fn():
|
||||||
|
v0 = variable_scope.get_variable("var0", [1])
|
||||||
|
with variable_scope.variable_scope("common"):
|
||||||
|
v1 = variable_scope.get_variable("var1", [1])
|
||||||
|
# This will pause the current thread, and execute the other thread.
|
||||||
|
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||||
|
v2 = variable_scope.get_variable(
|
||||||
|
"var2", [1],
|
||||||
|
synchronization=variable_scope.VariableSynchronization.ON_READ,
|
||||||
|
aggregation=variable_scope.VariableAggregation.SUM)
|
||||||
|
v3 = variable_scope.get_variable(
|
||||||
|
"var3", [1],
|
||||||
|
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
|
||||||
|
aggregation=variable_scope.VariableAggregation.MEAN)
|
||||||
|
|
||||||
|
return v0, v1, v2, v3
|
||||||
|
|
||||||
|
with distribution.scope():
|
||||||
|
with variable_scope.variable_scope("main"):
|
||||||
|
v = variable_scope.get_variable("var-main0", [1])
|
||||||
|
self.assertEqual("main/var-main0:0", v.name)
|
||||||
|
|
||||||
|
result = distribution.extended.call_for_each_replica(model_fn)
|
||||||
|
self.assertEqual(4, len(result))
|
||||||
|
v0, v1, v2, v3 = result
|
||||||
|
self.assertIsInstance(v0, values.MirroredVariable)
|
||||||
|
self.assertEqual("main/var0:0", v0.name)
|
||||||
|
self.assertIsInstance(v1, values.MirroredVariable)
|
||||||
|
self.assertEqual("main/common/var1:0", v1.name)
|
||||||
|
self.assertIsInstance(v2, values.SyncOnReadVariable)
|
||||||
|
self.assertEqual("main/common/var2:0", v2.name)
|
||||||
|
self.assertEqual(variable_scope.VariableAggregation.SUM, v2.aggregation)
|
||||||
|
self.assertIsInstance(v3, values.MirroredVariable)
|
||||||
|
self.assertEqual("main/common/var3:0", v3.name)
|
||||||
|
self.assertEqual(variable_scope.VariableAggregation.MEAN,
|
||||||
|
v3.aggregation)
|
||||||
|
|
||||||
|
def testOnlyFirstReplicaUpdatesVariables(self, distribution):
|
||||||
|
|
||||||
|
def create_fn():
|
||||||
|
aggregation = variable_scope.VariableAggregation.ONLY_FIRST_REPLICA
|
||||||
|
v0 = variable_scope.variable(
|
||||||
|
2.0,
|
||||||
|
name="on_read",
|
||||||
|
synchronization=variable_scope.VariableSynchronization.ON_READ,
|
||||||
|
aggregation=aggregation)
|
||||||
|
v1 = variable_scope.variable(
|
||||||
|
3.0,
|
||||||
|
name="on_write",
|
||||||
|
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
|
||||||
|
aggregation=aggregation)
|
||||||
|
return v0, v1
|
||||||
|
|
||||||
|
devices = distribution.extended.worker_devices
|
||||||
|
with distribution.scope():
|
||||||
|
v0, v1 = distribution.extended.call_for_each_replica(create_fn)
|
||||||
|
self.evaluate(v0.initializer)
|
||||||
|
self.assertEqual(2.0, self.evaluate(v0.get(devices[0])))
|
||||||
|
self.assertEqual(2.0, self.evaluate(v0.get(devices[1])))
|
||||||
|
self.assertEqual(2.0, self.evaluate(distribution.extended.read_var(v0)))
|
||||||
|
self.evaluate(v1.initializer)
|
||||||
|
self.assertEqual(3.0, self.evaluate(v1.get(devices[0])))
|
||||||
|
self.assertEqual(3.0, self.evaluate(v1.get(devices[1])))
|
||||||
|
self.assertEqual(3.0, self.evaluate(distribution.extended.read_var(v1)))
|
||||||
|
|
||||||
|
def replica_id_plus_one():
|
||||||
|
return math_ops.cast(_replica_id() + 1, dtype=dtypes.float32)
|
||||||
|
|
||||||
|
# Update using the assign_add member function.
|
||||||
|
def update_member_fn():
|
||||||
|
update0 = v0.assign_add(5.0 * replica_id_plus_one())
|
||||||
|
update1 = v1.assign_add(7.0 * replica_id_plus_one())
|
||||||
|
return update0, update1
|
||||||
|
|
||||||
|
update0a, update1a = distribution.extended.call_for_each_replica(
|
||||||
|
update_member_fn)
|
||||||
|
|
||||||
|
# Update "sync on read" variable.
|
||||||
|
self.evaluate(distribution.group(update0a))
|
||||||
|
self.assertEqual(2.0 + 5.0, self.evaluate(v0.get(devices[0])))
|
||||||
|
# Writes are not synchronized for "sync on read" variables,
|
||||||
|
# so device[1] can end up with a different value.
|
||||||
|
self.assertEqual(2.0 + 2 * 5.0, self.evaluate(v0.get(devices[1])))
|
||||||
|
# Always reads from device 0.
|
||||||
|
self.assertEqual(2.0 + 5.0,
|
||||||
|
self.evaluate(distribution.extended.read_var(v0)))
|
||||||
|
|
||||||
|
# Update "sync on write" variable.
|
||||||
|
self.evaluate(distribution.group(update1a))
|
||||||
|
self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[0])))
|
||||||
|
# Writes are synchronized for v1, only the argument to assign_add on
|
||||||
|
# device[0] is used.
|
||||||
|
self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[1])))
|
||||||
|
self.assertEqual(3.0 + 7.0,
|
||||||
|
self.evaluate(distribution.extended.read_var(v1)))
|
||||||
|
|
||||||
|
# Update using state_ops.assign_add global function.
|
||||||
|
def update_state_ops_fn():
|
||||||
|
update0 = state_ops.assign_add(v0, 11.0 * replica_id_plus_one())
|
||||||
|
update1 = state_ops.assign_add(v1, 13.0 * replica_id_plus_one())
|
||||||
|
return update0, update1
|
||||||
|
|
||||||
|
update0b, update1b = distribution.extended.call_for_each_replica(
|
||||||
|
update_state_ops_fn)
|
||||||
|
self.evaluate(distribution.group(update0b))
|
||||||
|
|
||||||
|
# Update "sync on read" variable.
|
||||||
|
self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate(v0.get(devices[0])))
|
||||||
|
self.assertEqual(2.0 + 2 * 5.0 + 2 * 11.0,
|
||||||
|
self.evaluate(v0.get(devices[1])))
|
||||||
|
self.assertEqual(2.0 + 5.0 + 11.0,
|
||||||
|
self.evaluate(distribution.extended.read_var(v0)))
|
||||||
|
|
||||||
|
# Update "sync on write" variable.
|
||||||
|
self.evaluate(distribution.group(update1b))
|
||||||
|
self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[0])))
|
||||||
|
self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[1])))
|
||||||
|
self.assertEqual(3.0 + 7.0 + 13.0,
|
||||||
|
self.evaluate(distribution.extended.read_var(v1)))
|
||||||
|
|
||||||
|
def testNoneSynchronizationWithGetVariable(self, distribution):
|
||||||
|
with distribution.scope():
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, "`NONE` variable synchronization mode is not "
|
||||||
|
"supported with `Mirrored` distribution strategy. Please change "
|
||||||
|
"the `synchronization` for variable: v"):
|
||||||
|
variable_scope.get_variable(
|
||||||
|
"v", [1],
|
||||||
|
synchronization=variable_scope.VariableSynchronization.NONE)
|
||||||
|
|
||||||
|
def testNoneSynchronizationWithVariable(self, distribution):
|
||||||
|
with distribution.scope():
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, "`NONE` variable synchronization mode is not "
|
||||||
|
"supported with `Mirrored` distribution strategy. Please change "
|
||||||
|
"the `synchronization` for variable: v"):
|
||||||
|
variable_scope.variable(
|
||||||
|
1.0,
|
||||||
|
name="v",
|
||||||
|
synchronization=variable_scope.VariableSynchronization.NONE)
|
||||||
|
|
||||||
|
def testInvalidSynchronizationWithVariable(self, distribution):
|
||||||
|
with distribution.scope():
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, "Invalid variable synchronization mode: Invalid for "
|
||||||
|
"variable: v"):
|
||||||
|
variable_scope.variable(1.0, name="v", synchronization="Invalid")
|
||||||
|
|
||||||
|
def testInvalidAggregationWithGetVariable(self, distribution):
|
||||||
|
with distribution.scope():
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, "Invalid variable aggregation mode: invalid for "
|
||||||
|
"variable: v"):
|
||||||
|
variable_scope.get_variable(
|
||||||
|
"v", [1],
|
||||||
|
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
|
||||||
|
aggregation="invalid")
|
||||||
|
|
||||||
|
def testInvalidAggregationWithVariable(self, distribution):
|
||||||
|
with distribution.scope():
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, "Invalid variable aggregation mode: invalid for "
|
||||||
|
"variable: v"):
|
||||||
|
variable_scope.variable(
|
||||||
|
1.0,
|
||||||
|
name="v",
|
||||||
|
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
|
||||||
|
aggregation="invalid")
|
||||||
|
|
||||||
|
def testNonMatchingVariableCreation(self, distribution):
|
||||||
|
self.skipTest("b/123075960")
|
||||||
|
|
||||||
|
def model_fn(name):
|
||||||
|
v = variable_scope.variable(1.0, name=name)
|
||||||
|
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||||
|
return v
|
||||||
|
|
||||||
|
with distribution.scope():
|
||||||
|
device_map = values.ReplicaDeviceMap(distribution.extended.worker_devices)
|
||||||
|
names = values.DistributedValues(device_map, ("foo", "bar"))
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
_ = distribution.extended.call_for_each_replica(model_fn, args=(names,))
|
||||||
|
|
||||||
|
def testSyncOnReadVariable(self, distribution):
|
||||||
|
all_v_sum = {}
|
||||||
|
all_v_mean = {}
|
||||||
|
components_sum = {}
|
||||||
|
components_mean = {}
|
||||||
|
|
||||||
|
def model_fn():
|
||||||
|
replica_id = self.evaluate(_replica_id())
|
||||||
|
v_sum = variable_scope.variable(
|
||||||
|
1.0,
|
||||||
|
synchronization=variable_scope.VariableSynchronization.ON_READ,
|
||||||
|
aggregation=variable_scope.VariableAggregation.SUM)
|
||||||
|
v_mean = variable_scope.variable(
|
||||||
|
4.0,
|
||||||
|
synchronization=variable_scope.VariableSynchronization.ON_READ,
|
||||||
|
aggregation=variable_scope.VariableAggregation.MEAN)
|
||||||
|
self.assertIsInstance(v_sum, values.SyncOnReadVariable)
|
||||||
|
self.assertIsInstance(v_mean, values.SyncOnReadVariable)
|
||||||
|
updates = [
|
||||||
|
v_sum.assign_add(2.0 + replica_id),
|
||||||
|
v_mean.assign(6.0 * replica_id)
|
||||||
|
]
|
||||||
|
all_v_sum[replica_id] = v_sum
|
||||||
|
all_v_mean[replica_id] = v_mean
|
||||||
|
c_sum = v_sum.get()
|
||||||
|
c_mean = v_mean.get()
|
||||||
|
components_sum[replica_id] = c_sum
|
||||||
|
components_mean[replica_id] = c_mean
|
||||||
|
self.assertIsNot(v_sum, c_sum)
|
||||||
|
self.assertIsNot(v_mean, c_mean)
|
||||||
|
return updates, v_sum, v_mean, c_sum, c_mean
|
||||||
|
|
||||||
|
with distribution.scope():
|
||||||
|
# Create "sum" and "mean" versions of SyncOnReadVariables.
|
||||||
|
ret_ops, ret_v_sum, ret_v_mean, regrouped_sum, regrouped_mean = (
|
||||||
|
distribution.extended.call_for_each_replica(model_fn))
|
||||||
|
# Should see the same wrapping instance in all replicas.
|
||||||
|
self.assertIs(all_v_sum[0], ret_v_sum)
|
||||||
|
self.assertIs(all_v_mean[0], ret_v_mean)
|
||||||
|
self.assertIs(all_v_sum[0], all_v_sum[1])
|
||||||
|
self.assertIs(all_v_mean[0], all_v_mean[1])
|
||||||
|
|
||||||
|
# Regroup should recover the same wrapper.
|
||||||
|
self.assertIs(ret_v_sum, regrouped_sum)
|
||||||
|
self.assertIs(ret_v_mean, regrouped_mean)
|
||||||
|
self.assertIsNot(components_sum[0], components_sum[1])
|
||||||
|
self.assertIsNot(components_mean[0], components_mean[1])
|
||||||
|
|
||||||
|
# Apply updates
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
self.evaluate([
|
||||||
|
y for x in ret_ops # pylint: disable=g-complex-comprehension
|
||||||
|
for y in distribution.experimental_local_results(x)
|
||||||
|
])
|
||||||
|
expected_sum = 0.0
|
||||||
|
expected_mean = 0.0
|
||||||
|
for i, d in enumerate(distribution.extended.worker_devices):
|
||||||
|
# Should see different values on different devices.
|
||||||
|
v_sum_value = self.evaluate(ret_v_sum.get(d).read_value())
|
||||||
|
v_mean_value = self.evaluate(ret_v_mean.get(d).read_value())
|
||||||
|
expected = i + 3.0
|
||||||
|
self.assertEqual(expected, v_sum_value)
|
||||||
|
expected_sum += expected
|
||||||
|
expected = i * 6.0
|
||||||
|
self.assertEqual(expected, v_mean_value)
|
||||||
|
expected_mean += expected
|
||||||
|
expected_mean /= len(distribution.extended.worker_devices)
|
||||||
|
|
||||||
|
# Without get(device), should return the value you get by
|
||||||
|
# applying the reduction across all replicas (whether you use
|
||||||
|
# read_var(), get(), or nothing).
|
||||||
|
self.assertEqual(expected_sum, self.evaluate(
|
||||||
|
distribution.extended.read_var(ret_v_sum)))
|
||||||
|
self.assertEqual(expected_mean, self.evaluate(
|
||||||
|
distribution.extended.read_var(ret_v_mean)))
|
||||||
|
self.assertEqual(expected_sum, self.evaluate(ret_v_sum.get()))
|
||||||
|
self.assertEqual(expected_mean, self.evaluate(ret_v_mean.get()))
|
||||||
|
self.assertEqual(expected_sum, self.evaluate(ret_v_sum))
|
||||||
|
self.assertEqual(expected_mean, self.evaluate(ret_v_mean))
|
||||||
|
|
||||||
|
# TODO(priyag): Update this test to work in eager mode as well.
|
||||||
|
def testDynamicRnnVariables(self, distribution):
|
||||||
|
|
||||||
|
def model_fn():
|
||||||
|
inputs = constant_op.constant(2 * [2 * [[0.0, 1.0, 2.0, 3.0, 4.0]]])
|
||||||
|
cell_fw = rnn_cell_impl.LSTMCell(300)
|
||||||
|
cell_bw = rnn_cell_impl.LSTMCell(300)
|
||||||
|
(outputs, _) = rnn.bidirectional_dynamic_rnn(
|
||||||
|
cell_fw, cell_bw, inputs, dtype=dtypes.float32)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
with context.graph_mode(), distribution.scope():
|
||||||
|
result = distribution.extended.call_for_each_replica(model_fn)
|
||||||
|
# Two variables are created by the RNN layer.
|
||||||
|
self.assertEqual(2, len(result))
|
||||||
|
for v in result:
|
||||||
|
self.assertIsInstance(v, values.DistributedValues)
|
||||||
|
_, v1 = distribution.experimental_local_results(v)
|
||||||
|
self.assertStartsWith(v1._op.name, "replica_1/")
|
||||||
|
|
||||||
|
def testSyncOnReadVariableUpdate(self, distribution):
|
||||||
|
|
||||||
|
def model_fn():
|
||||||
|
v_sum = variable_scope.variable(
|
||||||
|
1.0,
|
||||||
|
synchronization=variable_scope.VariableSynchronization.ON_READ,
|
||||||
|
aggregation=variable_scope.VariableAggregation.SUM)
|
||||||
|
self.assertIsInstance(v_sum, values.SyncOnReadVariable)
|
||||||
|
return v_sum
|
||||||
|
|
||||||
|
def update(var, value):
|
||||||
|
return var.assign(value)
|
||||||
|
|
||||||
|
with distribution.scope():
|
||||||
|
ret_v_sum = distribution.extended.call_for_each_replica(model_fn)
|
||||||
|
|
||||||
|
# Initialize variables.
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
# Assert that the aggregated value of the sync on read var is the sum
|
||||||
|
# of the individual values before running the update ops.
|
||||||
|
self.assertEqual(
|
||||||
|
1.0,
|
||||||
|
self.evaluate(
|
||||||
|
ret_v_sum.get(
|
||||||
|
distribution.extended.worker_devices[0]).read_value()))
|
||||||
|
self.assertEqual(2.0, self.evaluate(ret_v_sum))
|
||||||
|
|
||||||
|
# Apply updates.
|
||||||
|
update_ops = distribution.extended.update(
|
||||||
|
ret_v_sum, update, args=(5.0,), group=False)
|
||||||
|
self.evaluate(update_ops)
|
||||||
|
# Assert that the aggregated value of the sync on read vars is the sum
|
||||||
|
# of the individual values after running the update ops.
|
||||||
|
self.assertEqual(
|
||||||
|
5.0,
|
||||||
|
self.evaluate(
|
||||||
|
ret_v_sum.get(
|
||||||
|
distribution.extended.worker_devices[0]).read_value()))
|
||||||
|
self.assertEqual(10.0, self.evaluate(ret_v_sum))
|
||||||
|
|
||||||
|
def testVarDistributeStrategy(self, distribution):
|
||||||
|
with distribution.scope():
|
||||||
|
mirrored = variable_scope.variable(1.0)
|
||||||
|
sync_on_read = variable_scope.variable(
|
||||||
|
1.0, synchronization=variable_scope.VariableSynchronization.ON_READ)
|
||||||
|
self.assertIs(distribution, mirrored.distribute_strategy)
|
||||||
|
self.assertIs(distribution, sync_on_read.distribute_strategy)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
Loading…
Reference in New Issue
Block a user