diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index 5986bc4661f..700751d68c5 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import six - from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import values @@ -61,16 +59,12 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended): if colocate_with is None: with ops.device(self._device): return next_creator(*args, **kwargs) - if isinstance(colocate_with, six.string_types): - with ops.device(colocate_with): - return next_creator(*args, **kwargs) - if (isinstance(colocate_with, (list, tuple)) and len(colocate_with) == 1 and - isinstance(colocate_with[0], six.string_types)): - with ops.device(colocate_with[0]): - return next_creator(*args, **kwargs) with ops.colocate_with(colocate_with): return next_creator(*args, **kwargs) + def _validate_colocate_with_variable(self, colocate_with_variable): + values.validate_colocate(colocate_with_variable, self) + def _make_dataset_iterator(self, dataset): """Make iterator from dataset without splitting the batch.""" return values.DatasetIterator(dataset, self._input_workers) diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py index 2fd0c4d6ea6..a6e924b509f 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -232,6 +232,9 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): "ParameterServerStrategy with compute_devices = %r, " "variable_device = %r", compute_devices, self._variable_device) + def _validate_colocate_with_variable(self, colocate_with_variable): + values.validate_colocate(colocate_with_variable, self) + def _distribute_dataset(self, dataset_fn): """Distributes the dataset to each local GPU.""" return values.PerReplicaDataset( diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py index 44214f82265..ce7065f2205 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py @@ -225,39 +225,18 @@ class ParameterServerStrategyTestBase( self.assertEqual(var.device, '/job:ps/task:%d' % part_id) self.assertEqual(var.device, x_add[part_id].device) - # The colocate_vars_with can override the distribution's device. - with d.colocate_vars_with(x_add[0]): - y = variable_scope.get_variable( - 'y', - initializer=constant_op.constant([20.0, 10.0]), - aggregation=variable_scope.VariableAggregation.SUM, - partitioner=partitioner) - y_add = y.assign_add( - [array_ops.identity(x_add[0]), - array_ops.identity(x_add[1])]) + return x_add - for part_id, var in enumerate(y): - self.assertEqual(var.device, '/job:ps/task:0') - self.assertEqual(y_add[part_id].device, var.device) - self.assertEqual(var.device, x_add[0].device) - - return x_add, y_add - - x, y = d.call_for_each_replica(model_fn) + x = d.call_for_each_replica(model_fn) if context.num_gpus() >= 1: variables.global_variables_initializer().run() - x_val, y_val = sess.run([x, y]) + x_val = sess.run(x) if num_gpus < 1: self.assertEqual(x_val, [13.0, 25.0]) - self.assertEqual(y_val, [33.0, 35.0]) else: x_expect = [10.0 + 3 * num_gpus, 20.0 + 5 * num_gpus] - y_expect = [ - 20.0 + x_expect[0] * num_gpus, 10.0 + x_expect[1] * num_gpus - ] self.assertEqual(x_val, x_expect) - self.assertEqual(y_val, y_expect) def _test_device_assignment_local(self, d, diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 1139f494d7b..9e465f30c1b 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -224,6 +224,9 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): # Update Strategy state to make sure we can track device initialization. TPUExtended._initialized_devices.append(master) + def _validate_colocate_with_variable(self, colocate_with_variable): + values.validate_colocate_tpu_variable(colocate_with_variable, self) + def _get_enqueue_op_per_host(self, host_id, multi_worker_iterator, input_shapes, iterations): """Create an enqueue op for a single host identified using host_id. diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py index 7d4fd578432..76cbdd53d9d 100644 --- a/tensorflow/python/distribute/distribute_lib.py +++ b/tensorflow/python/distribute/distribute_lib.py @@ -1026,7 +1026,7 @@ class DistributionStrategyExtended(object): ``` with strategy.scope(): var1 = tf.get_variable(...) - with strategy.extended.colocate_vars_with(v1): + with strategy.extended.colocate_vars_with(var1): # var2 and var3 will be created on the same device(s) as var1 var2 = tf.get_variable(...) var3 = tf.get_variable(...) @@ -1034,8 +1034,9 @@ class DistributionStrategyExtended(object): def fn(v1, v2, v3): # operates on v1 from var1, v2 from var2, and v3 from var3 - # `fn` runs on every device `v1` is on, `v2` and `v3` will be there too. - strategy.extended.update(v1, fn, args=(v2, v3)) + # `fn` runs on every device `var1` is on, `var2` and `var3` will be there + # too. + strategy.extended.update(var1, fn, args=(var2, var3)) ``` Args: @@ -1053,8 +1054,13 @@ class DistributionStrategyExtended(object): return next_creator(*args, **kwargs) _require_strategy_scope_extended(self) + self._validate_colocate_with_variable(colocate_with_variable) return variable_scope.variable_creator_scope(create_colocated_variable) + def _validate_colocate_with_variable(self, colocate_with_variable): + """Validate `colocate_with_variable` argument to `colocate_vars_with`.""" + pass + def _call_dataset_fn(self, dataset_fn): """Call the `dataset_fn` with `input_context` as argument.""" result = dataset_fn() diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py index 71030d750ba..601eafbb5ea 100644 --- a/tensorflow/python/distribute/mirrored_strategy.py +++ b/tensorflow/python/distribute/mirrored_strategy.py @@ -537,6 +537,9 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended): self._container_strategy(), device_map, logical_device, _real_mirrored_creator, *args, **kwargs) + def _validate_colocate_with_variable(self, colocate_with_variable): + values.validate_colocate_distributed_variable(colocate_with_variable, self) + def _distribute_dataset(self, dataset_fn): if self._local_mode: worker_index = 0 diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py index df0e007692e..ef1d2a94992 100644 --- a/tensorflow/python/distribute/values.py +++ b/tensorflow/python/distribute/values.py @@ -575,6 +575,38 @@ class DistributedVariable(DistributedDelegate): ops.register_dense_tensor_like_type(DistributedVariable) +def _validate_colocate_extended(v, extended): + if v.distribute_strategy.extended is not extended: + raise ValueError( + "`colocate_vars_with` must only be passed a variable created in this " + "tf.distribute.Strategy.scope(), not %s created in scope: %s" % + (v, v.distribute_strategy,)) + + +def validate_colocate_distributed_variable(v, extended): + if not isinstance(v, DistributedVariable): + raise ValueError( + "`colocate_vars_with` must only be passed a variable created in this " + "tf.distribute.Strategy.scope(), not: %r" % (v,)) + _validate_colocate_extended(v, extended) + + +def validate_colocate_tpu_variable(v, extended): + if not isinstance(v, TPUMirroredVariable): + raise ValueError( + "`colocate_vars_with` must only be passed a variable created in this " + "tf.distribute.Strategy.scope(), not: %r" % (v,)) + _validate_colocate_extended(v, extended) + + +def validate_colocate(v, extended): + if not hasattr(v, "distribute_strategy"): + raise ValueError( + "`colocate_vars_with` must only be passed a variable created in this " + "tf.distribute.Strategy.scope(), not: %r" % (v,)) + _validate_colocate_extended(v, extended) + + def _apply_aggregation(strategy, value, aggregation, destinations): if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: return strategy.broadcast(strategy.unwrap(value)[0],