Enforce the contract of and improve the documentation for

`tf.distribute.Strategy.extended.colocate_vars_with()`.

PiperOrigin-RevId: 226626715
This commit is contained in:
A. Unique TensorFlower 2018-12-22 13:01:32 -08:00 committed by TensorFlower Gardener
parent cd8c6c995e
commit b27d50f234
7 changed files with 56 additions and 36 deletions

View File

@ -18,8 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import values
@ -61,16 +59,12 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended):
if colocate_with is None:
with ops.device(self._device):
return next_creator(*args, **kwargs)
if isinstance(colocate_with, six.string_types):
with ops.device(colocate_with):
return next_creator(*args, **kwargs)
if (isinstance(colocate_with, (list, tuple)) and len(colocate_with) == 1 and
isinstance(colocate_with[0], six.string_types)):
with ops.device(colocate_with[0]):
return next_creator(*args, **kwargs)
with ops.colocate_with(colocate_with):
return next_creator(*args, **kwargs)
def _validate_colocate_with_variable(self, colocate_with_variable):
values.validate_colocate(colocate_with_variable, self)
def _make_dataset_iterator(self, dataset):
"""Make iterator from dataset without splitting the batch."""
return values.DatasetIterator(dataset, self._input_workers)

View File

@ -232,6 +232,9 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended):
"ParameterServerStrategy with compute_devices = %r, "
"variable_device = %r", compute_devices, self._variable_device)
def _validate_colocate_with_variable(self, colocate_with_variable):
values.validate_colocate(colocate_with_variable, self)
def _distribute_dataset(self, dataset_fn):
"""Distributes the dataset to each local GPU."""
return values.PerReplicaDataset(

View File

@ -225,39 +225,18 @@ class ParameterServerStrategyTestBase(
self.assertEqual(var.device, '/job:ps/task:%d' % part_id)
self.assertEqual(var.device, x_add[part_id].device)
# The colocate_vars_with can override the distribution's device.
with d.colocate_vars_with(x_add[0]):
y = variable_scope.get_variable(
'y',
initializer=constant_op.constant([20.0, 10.0]),
aggregation=variable_scope.VariableAggregation.SUM,
partitioner=partitioner)
y_add = y.assign_add(
[array_ops.identity(x_add[0]),
array_ops.identity(x_add[1])])
return x_add
for part_id, var in enumerate(y):
self.assertEqual(var.device, '/job:ps/task:0')
self.assertEqual(y_add[part_id].device, var.device)
self.assertEqual(var.device, x_add[0].device)
return x_add, y_add
x, y = d.call_for_each_replica(model_fn)
x = d.call_for_each_replica(model_fn)
if context.num_gpus() >= 1:
variables.global_variables_initializer().run()
x_val, y_val = sess.run([x, y])
x_val = sess.run(x)
if num_gpus < 1:
self.assertEqual(x_val, [13.0, 25.0])
self.assertEqual(y_val, [33.0, 35.0])
else:
x_expect = [10.0 + 3 * num_gpus, 20.0 + 5 * num_gpus]
y_expect = [
20.0 + x_expect[0] * num_gpus, 10.0 + x_expect[1] * num_gpus
]
self.assertEqual(x_val, x_expect)
self.assertEqual(y_val, y_expect)
def _test_device_assignment_local(self,
d,

View File

@ -224,6 +224,9 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
# Update Strategy state to make sure we can track device initialization.
TPUExtended._initialized_devices.append(master)
def _validate_colocate_with_variable(self, colocate_with_variable):
values.validate_colocate_tpu_variable(colocate_with_variable, self)
def _get_enqueue_op_per_host(self, host_id, multi_worker_iterator,
input_shapes, iterations):
"""Create an enqueue op for a single host identified using host_id.

View File

@ -1026,7 +1026,7 @@ class DistributionStrategyExtended(object):
```
with strategy.scope():
var1 = tf.get_variable(...)
with strategy.extended.colocate_vars_with(v1):
with strategy.extended.colocate_vars_with(var1):
# var2 and var3 will be created on the same device(s) as var1
var2 = tf.get_variable(...)
var3 = tf.get_variable(...)
@ -1034,8 +1034,9 @@ class DistributionStrategyExtended(object):
def fn(v1, v2, v3):
# operates on v1 from var1, v2 from var2, and v3 from var3
# `fn` runs on every device `v1` is on, `v2` and `v3` will be there too.
strategy.extended.update(v1, fn, args=(v2, v3))
# `fn` runs on every device `var1` is on, `var2` and `var3` will be there
# too.
strategy.extended.update(var1, fn, args=(var2, var3))
```
Args:
@ -1053,8 +1054,13 @@ class DistributionStrategyExtended(object):
return next_creator(*args, **kwargs)
_require_strategy_scope_extended(self)
self._validate_colocate_with_variable(colocate_with_variable)
return variable_scope.variable_creator_scope(create_colocated_variable)
def _validate_colocate_with_variable(self, colocate_with_variable):
"""Validate `colocate_with_variable` argument to `colocate_vars_with`."""
pass
def _call_dataset_fn(self, dataset_fn):
"""Call the `dataset_fn` with `input_context` as argument."""
result = dataset_fn()

View File

@ -537,6 +537,9 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended):
self._container_strategy(), device_map, logical_device,
_real_mirrored_creator, *args, **kwargs)
def _validate_colocate_with_variable(self, colocate_with_variable):
values.validate_colocate_distributed_variable(colocate_with_variable, self)
def _distribute_dataset(self, dataset_fn):
if self._local_mode:
worker_index = 0

View File

@ -575,6 +575,38 @@ class DistributedVariable(DistributedDelegate):
ops.register_dense_tensor_like_type(DistributedVariable)
def _validate_colocate_extended(v, extended):
if v.distribute_strategy.extended is not extended:
raise ValueError(
"`colocate_vars_with` must only be passed a variable created in this "
"tf.distribute.Strategy.scope(), not %s created in scope: %s" %
(v, v.distribute_strategy,))
def validate_colocate_distributed_variable(v, extended):
if not isinstance(v, DistributedVariable):
raise ValueError(
"`colocate_vars_with` must only be passed a variable created in this "
"tf.distribute.Strategy.scope(), not: %r" % (v,))
_validate_colocate_extended(v, extended)
def validate_colocate_tpu_variable(v, extended):
if not isinstance(v, TPUMirroredVariable):
raise ValueError(
"`colocate_vars_with` must only be passed a variable created in this "
"tf.distribute.Strategy.scope(), not: %r" % (v,))
_validate_colocate_extended(v, extended)
def validate_colocate(v, extended):
if not hasattr(v, "distribute_strategy"):
raise ValueError(
"`colocate_vars_with` must only be passed a variable created in this "
"tf.distribute.Strategy.scope(), not: %r" % (v,))
_validate_colocate_extended(v, extended)
def _apply_aggregation(strategy, value, aggregation, destinations):
if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
return strategy.broadcast(strategy.unwrap(value)[0],