Fix bug with global_step.assign_add(1) when using multiple GPUs per machine

and `distribute.ParameterServerStrategy`. Add some testing for this case with
`MirroredStrategy` as well.

PiperOrigin-RevId: 221938703
This commit is contained in:
A. Unique TensorFlower 2018-11-17 14:23:58 -08:00 committed by TensorFlower Gardener
parent 1f33c8d2ed
commit 8697d63b37
5 changed files with 48 additions and 1 deletions

View File

@ -587,7 +587,12 @@ class CoreMirroredExtended(distribute_lib.DistributionStrategyExtended):
return ctx
def _broadcast_to(self, tensor, destinations):
if isinstance(tensor, (float, int)): # Fast path for Python constants.
# This is both a fast path for Python constants, and a way to delay
# converting Python values to a tensor until we know what type it
# should be converted to. Otherwise we have trouble with:
# global_step.assign_add(1)
# since the `1` gets broadcast as an int32 but global_step is int64.
if isinstance(tensor, (float, int)):
return tensor
# TODO(josh11b): In eager mode, use one thread per device, or async mode.
return self._get_cross_device_ops().broadcast(

View File

@ -162,6 +162,12 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase):
iterator = d.make_input_fn_iterator(input_fn)
self._test_input_fn_iterator(iterator, d.worker_devices, expected_values)
@test_util.run_in_graph_and_eager_modes
def testGlobalStepUpdate(self):
if not GPU_TEST:
self.skipTest("Not GPU test")
self._test_global_step_update(self._get_distribution_strategy())
class MirroredStrategyVariableCreationTest(test.TestCase):

View File

@ -257,6 +257,13 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended):
input_fn, worker_device_pairs, [input_context])
def _broadcast_to(self, tensor, destinations):
# This is both a fast path for Python constants, and a way to delay
# converting Python values to a tensor until we know what type it
# should be converted to. Otherwise we have trouble with:
# global_step.assign_add(1)
# since the `1` gets broadcast as an int32 but global_step is int64.
if isinstance(tensor, (float, int)):
return tensor
if not cross_device_ops_lib.check_destinations(destinations):
destinations = self._compute_devices
return self._cross_device_ops.broadcast(tensor, destinations)

View File

@ -651,6 +651,11 @@ class ParameterServerStrategyTest(ParameterServerStrategyTestBase,
self._test_input_fn_iterator(None, None, num_gpus,
input_fn, expected_values)
def testGlobalStepUpdate(self):
strategy = parameter_server_strategy.ParameterServerStrategy(
num_gpus_per_worker=context.num_gpus())
self._test_global_step_update(strategy)
class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase,
parameterized.TestCase):

View File

@ -25,10 +25,13 @@ from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.layers import core
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import distribution_strategy_context as ds_context
from tensorflow.python.training import optimizer
@ -263,3 +266,24 @@ class DistributionTestBase(test.TestCase):
[values.select_device(d, next_element) for d in devices])
self.assertEqual(expected_value, computed_value)
def _test_global_step_update(self, strategy):
with strategy.scope():
global_step = variable_scope.get_variable(
"global_step",
shape=[],
dtype=dtypes.int64,
initializer=init_ops.zeros_initializer(),
trainable=False,
aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA)
self.evaluate(variables.global_variables_initializer())
def model_fn():
train_op = global_step.assign_add(1)
value = global_step.read_value()
return train_op, value
train_ops, value = strategy.call_for_each_replica(model_fn)
self.evaluate(strategy.group(train_ops))
global_step_tensors = strategy.unwrap(value)
global_step_values = self.evaluate(global_step_tensors)
self.assertEqual([1] * len(global_step_tensors), global_step_values)