diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 370c96ac4fb..f7432162cba 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -587,7 +587,12 @@ class CoreMirroredExtended(distribute_lib.DistributionStrategyExtended): return ctx def _broadcast_to(self, tensor, destinations): - if isinstance(tensor, (float, int)): # Fast path for Python constants. + # This is both a fast path for Python constants, and a way to delay + # converting Python values to a tensor until we know what type it + # should be converted to. Otherwise we have trouble with: + # global_step.assign_add(1) + # since the `1` gets broadcast as an int32 but global_step is int64. + if isinstance(tensor, (float, int)): return tensor # TODO(josh11b): In eager mode, use one thread per device, or async mode. return self._get_cross_device_ops().broadcast( diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index 76fdc6f7626..2cc56565d36 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -162,6 +162,12 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): iterator = d.make_input_fn_iterator(input_fn) self._test_input_fn_iterator(iterator, d.worker_devices, expected_values) + @test_util.run_in_graph_and_eager_modes + def testGlobalStepUpdate(self): + if not GPU_TEST: + self.skipTest("Not GPU test") + self._test_global_step_update(self._get_distribution_strategy()) + class MirroredStrategyVariableCreationTest(test.TestCase): diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py index 4419d4afe13..6fc81a1e57a 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -257,6 +257,13 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): input_fn, worker_device_pairs, [input_context]) def _broadcast_to(self, tensor, destinations): + # This is both a fast path for Python constants, and a way to delay + # converting Python values to a tensor until we know what type it + # should be converted to. Otherwise we have trouble with: + # global_step.assign_add(1) + # since the `1` gets broadcast as an int32 but global_step is int64. + if isinstance(tensor, (float, int)): + return tensor if not cross_device_ops_lib.check_destinations(destinations): destinations = self._compute_devices return self._cross_device_ops.broadcast(tensor, destinations) diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py index 5d5e65600c5..b4c098aa57c 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py @@ -651,6 +651,11 @@ class ParameterServerStrategyTest(ParameterServerStrategyTestBase, self._test_input_fn_iterator(None, None, num_gpus, input_fn, expected_values) + def testGlobalStepUpdate(self): + strategy = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=context.num_gpus()) + self._test_global_step_update(strategy) + class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase, parameterized.TestCase): diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py index b9e37617382..d94ebdc06b8 100644 --- a/tensorflow/contrib/distribute/python/strategy_test_lib.py +++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py @@ -25,10 +25,13 @@ from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.layers import core from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import distribution_strategy_context as ds_context from tensorflow.python.training import optimizer @@ -263,3 +266,24 @@ class DistributionTestBase(test.TestCase): [values.select_device(d, next_element) for d in devices]) self.assertEqual(expected_value, computed_value) + def _test_global_step_update(self, strategy): + with strategy.scope(): + global_step = variable_scope.get_variable( + "global_step", + shape=[], + dtype=dtypes.int64, + initializer=init_ops.zeros_initializer(), + trainable=False, + aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA) + self.evaluate(variables.global_variables_initializer()) + + def model_fn(): + train_op = global_step.assign_add(1) + value = global_step.read_value() + return train_op, value + + train_ops, value = strategy.call_for_each_replica(model_fn) + self.evaluate(strategy.group(train_ops)) + global_step_tensors = strategy.unwrap(value) + global_step_values = self.evaluate(global_step_tensors) + self.assertEqual([1] * len(global_step_tensors), global_step_values)