Enforce the contract of and improve the documentation for
`tf.distribute.Strategy.extended.colocate_vars_with()`. PiperOrigin-RevId: 226626715
This commit is contained in:
parent
cd8c6c995e
commit
b27d50f234
@ -18,8 +18,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import values
|
||||
@ -61,16 +59,12 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended):
|
||||
if colocate_with is None:
|
||||
with ops.device(self._device):
|
||||
return next_creator(*args, **kwargs)
|
||||
if isinstance(colocate_with, six.string_types):
|
||||
with ops.device(colocate_with):
|
||||
return next_creator(*args, **kwargs)
|
||||
if (isinstance(colocate_with, (list, tuple)) and len(colocate_with) == 1 and
|
||||
isinstance(colocate_with[0], six.string_types)):
|
||||
with ops.device(colocate_with[0]):
|
||||
return next_creator(*args, **kwargs)
|
||||
with ops.colocate_with(colocate_with):
|
||||
return next_creator(*args, **kwargs)
|
||||
|
||||
def _validate_colocate_with_variable(self, colocate_with_variable):
|
||||
values.validate_colocate(colocate_with_variable, self)
|
||||
|
||||
def _make_dataset_iterator(self, dataset):
|
||||
"""Make iterator from dataset without splitting the batch."""
|
||||
return values.DatasetIterator(dataset, self._input_workers)
|
||||
|
@ -232,6 +232,9 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended):
|
||||
"ParameterServerStrategy with compute_devices = %r, "
|
||||
"variable_device = %r", compute_devices, self._variable_device)
|
||||
|
||||
def _validate_colocate_with_variable(self, colocate_with_variable):
|
||||
values.validate_colocate(colocate_with_variable, self)
|
||||
|
||||
def _distribute_dataset(self, dataset_fn):
|
||||
"""Distributes the dataset to each local GPU."""
|
||||
return values.PerReplicaDataset(
|
||||
|
@ -225,39 +225,18 @@ class ParameterServerStrategyTestBase(
|
||||
self.assertEqual(var.device, '/job:ps/task:%d' % part_id)
|
||||
self.assertEqual(var.device, x_add[part_id].device)
|
||||
|
||||
# The colocate_vars_with can override the distribution's device.
|
||||
with d.colocate_vars_with(x_add[0]):
|
||||
y = variable_scope.get_variable(
|
||||
'y',
|
||||
initializer=constant_op.constant([20.0, 10.0]),
|
||||
aggregation=variable_scope.VariableAggregation.SUM,
|
||||
partitioner=partitioner)
|
||||
y_add = y.assign_add(
|
||||
[array_ops.identity(x_add[0]),
|
||||
array_ops.identity(x_add[1])])
|
||||
return x_add
|
||||
|
||||
for part_id, var in enumerate(y):
|
||||
self.assertEqual(var.device, '/job:ps/task:0')
|
||||
self.assertEqual(y_add[part_id].device, var.device)
|
||||
self.assertEqual(var.device, x_add[0].device)
|
||||
|
||||
return x_add, y_add
|
||||
|
||||
x, y = d.call_for_each_replica(model_fn)
|
||||
x = d.call_for_each_replica(model_fn)
|
||||
|
||||
if context.num_gpus() >= 1:
|
||||
variables.global_variables_initializer().run()
|
||||
x_val, y_val = sess.run([x, y])
|
||||
x_val = sess.run(x)
|
||||
if num_gpus < 1:
|
||||
self.assertEqual(x_val, [13.0, 25.0])
|
||||
self.assertEqual(y_val, [33.0, 35.0])
|
||||
else:
|
||||
x_expect = [10.0 + 3 * num_gpus, 20.0 + 5 * num_gpus]
|
||||
y_expect = [
|
||||
20.0 + x_expect[0] * num_gpus, 10.0 + x_expect[1] * num_gpus
|
||||
]
|
||||
self.assertEqual(x_val, x_expect)
|
||||
self.assertEqual(y_val, y_expect)
|
||||
|
||||
def _test_device_assignment_local(self,
|
||||
d,
|
||||
|
@ -224,6 +224,9 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
|
||||
# Update Strategy state to make sure we can track device initialization.
|
||||
TPUExtended._initialized_devices.append(master)
|
||||
|
||||
def _validate_colocate_with_variable(self, colocate_with_variable):
|
||||
values.validate_colocate_tpu_variable(colocate_with_variable, self)
|
||||
|
||||
def _get_enqueue_op_per_host(self, host_id, multi_worker_iterator,
|
||||
input_shapes, iterations):
|
||||
"""Create an enqueue op for a single host identified using host_id.
|
||||
|
@ -1026,7 +1026,7 @@ class DistributionStrategyExtended(object):
|
||||
```
|
||||
with strategy.scope():
|
||||
var1 = tf.get_variable(...)
|
||||
with strategy.extended.colocate_vars_with(v1):
|
||||
with strategy.extended.colocate_vars_with(var1):
|
||||
# var2 and var3 will be created on the same device(s) as var1
|
||||
var2 = tf.get_variable(...)
|
||||
var3 = tf.get_variable(...)
|
||||
@ -1034,8 +1034,9 @@ class DistributionStrategyExtended(object):
|
||||
def fn(v1, v2, v3):
|
||||
# operates on v1 from var1, v2 from var2, and v3 from var3
|
||||
|
||||
# `fn` runs on every device `v1` is on, `v2` and `v3` will be there too.
|
||||
strategy.extended.update(v1, fn, args=(v2, v3))
|
||||
# `fn` runs on every device `var1` is on, `var2` and `var3` will be there
|
||||
# too.
|
||||
strategy.extended.update(var1, fn, args=(var2, var3))
|
||||
```
|
||||
|
||||
Args:
|
||||
@ -1053,8 +1054,13 @@ class DistributionStrategyExtended(object):
|
||||
return next_creator(*args, **kwargs)
|
||||
|
||||
_require_strategy_scope_extended(self)
|
||||
self._validate_colocate_with_variable(colocate_with_variable)
|
||||
return variable_scope.variable_creator_scope(create_colocated_variable)
|
||||
|
||||
def _validate_colocate_with_variable(self, colocate_with_variable):
|
||||
"""Validate `colocate_with_variable` argument to `colocate_vars_with`."""
|
||||
pass
|
||||
|
||||
def _call_dataset_fn(self, dataset_fn):
|
||||
"""Call the `dataset_fn` with `input_context` as argument."""
|
||||
result = dataset_fn()
|
||||
|
@ -537,6 +537,9 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended):
|
||||
self._container_strategy(), device_map, logical_device,
|
||||
_real_mirrored_creator, *args, **kwargs)
|
||||
|
||||
def _validate_colocate_with_variable(self, colocate_with_variable):
|
||||
values.validate_colocate_distributed_variable(colocate_with_variable, self)
|
||||
|
||||
def _distribute_dataset(self, dataset_fn):
|
||||
if self._local_mode:
|
||||
worker_index = 0
|
||||
|
@ -575,6 +575,38 @@ class DistributedVariable(DistributedDelegate):
|
||||
ops.register_dense_tensor_like_type(DistributedVariable)
|
||||
|
||||
|
||||
def _validate_colocate_extended(v, extended):
|
||||
if v.distribute_strategy.extended is not extended:
|
||||
raise ValueError(
|
||||
"`colocate_vars_with` must only be passed a variable created in this "
|
||||
"tf.distribute.Strategy.scope(), not %s created in scope: %s" %
|
||||
(v, v.distribute_strategy,))
|
||||
|
||||
|
||||
def validate_colocate_distributed_variable(v, extended):
|
||||
if not isinstance(v, DistributedVariable):
|
||||
raise ValueError(
|
||||
"`colocate_vars_with` must only be passed a variable created in this "
|
||||
"tf.distribute.Strategy.scope(), not: %r" % (v,))
|
||||
_validate_colocate_extended(v, extended)
|
||||
|
||||
|
||||
def validate_colocate_tpu_variable(v, extended):
|
||||
if not isinstance(v, TPUMirroredVariable):
|
||||
raise ValueError(
|
||||
"`colocate_vars_with` must only be passed a variable created in this "
|
||||
"tf.distribute.Strategy.scope(), not: %r" % (v,))
|
||||
_validate_colocate_extended(v, extended)
|
||||
|
||||
|
||||
def validate_colocate(v, extended):
|
||||
if not hasattr(v, "distribute_strategy"):
|
||||
raise ValueError(
|
||||
"`colocate_vars_with` must only be passed a variable created in this "
|
||||
"tf.distribute.Strategy.scope(), not: %r" % (v,))
|
||||
_validate_colocate_extended(v, extended)
|
||||
|
||||
|
||||
def _apply_aggregation(strategy, value, aggregation, destinations):
|
||||
if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
|
||||
return strategy.broadcast(strategy.unwrap(value)[0],
|
||||
|
Loading…
Reference in New Issue
Block a user