From 52ae0565666c5722226518982db80bf4c57300d3 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Fri, 25 Jan 2019 04:48:47 -0800 Subject: [PATCH] Implement `TPUStrategy.experimental_run`. PiperOrigin-RevId: 230883811 --- .../contrib/distribute/python/tpu_strategy.py | 61 +++++++++++++++++-- 1 file changed, 56 insertions(+), 5 deletions(-) diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 4863bb29dae..4387210062e 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -188,6 +188,51 @@ class TPUStrategy(distribute_lib.DistributionStrategy): """DEPRECATED: use .extended.steps_per_run instead.""" return self._extended.steps_per_run + # TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this + # can use the default implementation. + # This implementation runs a single step. It does not use infeed or outfeed. + def experimental_run(self, fn, input_iterator=None): + """See base class.""" + if context.executing_eagerly(): + raise NotImplementedError("Eager mode not supported in TPUStrategy.") + + if self.extended._disable_training_loop_on_host: # pylint: disable=protected-access + raise NotImplementedError( + "`experimental_run` is not compatible with " + "`_disable_training_loop_on_host=True`") + + if input_iterator is None: + inputs = [] + else: + inputs = input_iterator.get_next() + + result = [None] + def replicated_fn(replica_id, inputs): + """Wraps user function to provide replica ID and `Tensor` inputs.""" + with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id): + if input_iterator is None: + result[0] = fn() + else: + result[0] = fn(inputs) + return result[0] + + replicate_inputs = [] # By replica. + for i in range(self.num_replicas_in_sync): + replicate_inputs.append( + [constant_op.constant(i, dtype=dtypes.int32), + values.select_replica(i, inputs)]) + + with self.scope(): + replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs) + + # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. + replicate_outputs = [ + nest.pack_sequence_as(result[0], nest.flatten(replica_outputs)) + for replica_outputs in replicate_outputs] + + device_map = self.extended._device_map # pylint: disable=protected-access + return values.regroup(device_map, replicate_outputs) + class TPUExtended(distribute_lib.DistributionStrategyExtended): """Implementation of TPUStrategy.""" @@ -783,18 +828,24 @@ class _TPUReplicaContext(distribute_lib.ReplicaContext): """Replication Context class for TPU Strategy.""" # TODO(sourabhbajaj): Call for each replica should be updating this. - def __init__(self, strategy): - # TODO(b/118385803): properly initialize replica_id, instead of always 0 - replica_id = constant_op.constant(0, dtypes.int32) + # TODO(b/118385803): Always properly initialize replica_id. + def __init__(self, strategy, replica_id_in_sync_group=None): + if replica_id_in_sync_group is None: + replica_id_in_sync_group = constant_op.constant(0, dtypes.int32) distribute_lib.ReplicaContext.__init__( - self, strategy, replica_id_in_sync_group=replica_id) + self, strategy, replica_id_in_sync_group=replica_id_in_sync_group) @property def devices(self): distribute_lib.require_replica_context(self) ds = self._strategy replica_id = tensor_util.constant_value(self._replica_id_in_sync_group) - return (ds.extended.worker_devices[replica_id],) + + if replica_id is None: # Non-constant `Tensor` inside `tpu.replicate`. + # TODO(cjfj): Return other devices when model parallelism is supported. + return (tpu.core(0),) + else: + return (ds.extended.worker_devices[replica_id],) def _get_host_for_device(device):