Implement TPUStrategy.experimental_run.

PiperOrigin-RevId: 230883811
This commit is contained in:
Chris Jones 2019-01-25 04:48:47 -08:00 committed by TensorFlower Gardener
parent 7cbe113239
commit 52ae056566

View File

@ -188,6 +188,51 @@ class TPUStrategy(distribute_lib.DistributionStrategy):
"""DEPRECATED: use .extended.steps_per_run instead."""
return self._extended.steps_per_run
# TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this
# can use the default implementation.
# This implementation runs a single step. It does not use infeed or outfeed.
def experimental_run(self, fn, input_iterator=None):
"""See base class."""
if context.executing_eagerly():
raise NotImplementedError("Eager mode not supported in TPUStrategy.")
if self.extended._disable_training_loop_on_host: # pylint: disable=protected-access
raise NotImplementedError(
"`experimental_run` is not compatible with "
"`_disable_training_loop_on_host=True`")
if input_iterator is None:
inputs = []
else:
inputs = input_iterator.get_next()
result = [None]
def replicated_fn(replica_id, inputs):
"""Wraps user function to provide replica ID and `Tensor` inputs."""
with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id):
if input_iterator is None:
result[0] = fn()
else:
result[0] = fn(inputs)
return result[0]
replicate_inputs = [] # By replica.
for i in range(self.num_replicas_in_sync):
replicate_inputs.append(
[constant_op.constant(i, dtype=dtypes.int32),
values.select_replica(i, inputs)])
with self.scope():
replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs)
# Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
replicate_outputs = [
nest.pack_sequence_as(result[0], nest.flatten(replica_outputs))
for replica_outputs in replicate_outputs]
device_map = self.extended._device_map # pylint: disable=protected-access
return values.regroup(device_map, replicate_outputs)
class TPUExtended(distribute_lib.DistributionStrategyExtended):
"""Implementation of TPUStrategy."""
@ -783,17 +828,23 @@ class _TPUReplicaContext(distribute_lib.ReplicaContext):
"""Replication Context class for TPU Strategy."""
# TODO(sourabhbajaj): Call for each replica should be updating this.
def __init__(self, strategy):
# TODO(b/118385803): properly initialize replica_id, instead of always 0
replica_id = constant_op.constant(0, dtypes.int32)
# TODO(b/118385803): Always properly initialize replica_id.
def __init__(self, strategy, replica_id_in_sync_group=None):
if replica_id_in_sync_group is None:
replica_id_in_sync_group = constant_op.constant(0, dtypes.int32)
distribute_lib.ReplicaContext.__init__(
self, strategy, replica_id_in_sync_group=replica_id)
self, strategy, replica_id_in_sync_group=replica_id_in_sync_group)
@property
def devices(self):
distribute_lib.require_replica_context(self)
ds = self._strategy
replica_id = tensor_util.constant_value(self._replica_id_in_sync_group)
if replica_id is None: # Non-constant `Tensor` inside `tpu.replicate`.
# TODO(cjfj): Return other devices when model parallelism is supported.
return (tpu.core(0),)
else:
return (ds.extended.worker_devices[replica_id],)