Implement TPUStrategy.experimental_run
.
PiperOrigin-RevId: 230883811
This commit is contained in:
parent
7cbe113239
commit
52ae056566
@ -188,6 +188,51 @@ class TPUStrategy(distribute_lib.DistributionStrategy):
|
||||
"""DEPRECATED: use .extended.steps_per_run instead."""
|
||||
return self._extended.steps_per_run
|
||||
|
||||
# TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this
|
||||
# can use the default implementation.
|
||||
# This implementation runs a single step. It does not use infeed or outfeed.
|
||||
def experimental_run(self, fn, input_iterator=None):
|
||||
"""See base class."""
|
||||
if context.executing_eagerly():
|
||||
raise NotImplementedError("Eager mode not supported in TPUStrategy.")
|
||||
|
||||
if self.extended._disable_training_loop_on_host: # pylint: disable=protected-access
|
||||
raise NotImplementedError(
|
||||
"`experimental_run` is not compatible with "
|
||||
"`_disable_training_loop_on_host=True`")
|
||||
|
||||
if input_iterator is None:
|
||||
inputs = []
|
||||
else:
|
||||
inputs = input_iterator.get_next()
|
||||
|
||||
result = [None]
|
||||
def replicated_fn(replica_id, inputs):
|
||||
"""Wraps user function to provide replica ID and `Tensor` inputs."""
|
||||
with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id):
|
||||
if input_iterator is None:
|
||||
result[0] = fn()
|
||||
else:
|
||||
result[0] = fn(inputs)
|
||||
return result[0]
|
||||
|
||||
replicate_inputs = [] # By replica.
|
||||
for i in range(self.num_replicas_in_sync):
|
||||
replicate_inputs.append(
|
||||
[constant_op.constant(i, dtype=dtypes.int32),
|
||||
values.select_replica(i, inputs)])
|
||||
|
||||
with self.scope():
|
||||
replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs)
|
||||
|
||||
# Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
|
||||
replicate_outputs = [
|
||||
nest.pack_sequence_as(result[0], nest.flatten(replica_outputs))
|
||||
for replica_outputs in replicate_outputs]
|
||||
|
||||
device_map = self.extended._device_map # pylint: disable=protected-access
|
||||
return values.regroup(device_map, replicate_outputs)
|
||||
|
||||
|
||||
class TPUExtended(distribute_lib.DistributionStrategyExtended):
|
||||
"""Implementation of TPUStrategy."""
|
||||
@ -783,17 +828,23 @@ class _TPUReplicaContext(distribute_lib.ReplicaContext):
|
||||
"""Replication Context class for TPU Strategy."""
|
||||
|
||||
# TODO(sourabhbajaj): Call for each replica should be updating this.
|
||||
def __init__(self, strategy):
|
||||
# TODO(b/118385803): properly initialize replica_id, instead of always 0
|
||||
replica_id = constant_op.constant(0, dtypes.int32)
|
||||
# TODO(b/118385803): Always properly initialize replica_id.
|
||||
def __init__(self, strategy, replica_id_in_sync_group=None):
|
||||
if replica_id_in_sync_group is None:
|
||||
replica_id_in_sync_group = constant_op.constant(0, dtypes.int32)
|
||||
distribute_lib.ReplicaContext.__init__(
|
||||
self, strategy, replica_id_in_sync_group=replica_id)
|
||||
self, strategy, replica_id_in_sync_group=replica_id_in_sync_group)
|
||||
|
||||
@property
|
||||
def devices(self):
|
||||
distribute_lib.require_replica_context(self)
|
||||
ds = self._strategy
|
||||
replica_id = tensor_util.constant_value(self._replica_id_in_sync_group)
|
||||
|
||||
if replica_id is None: # Non-constant `Tensor` inside `tpu.replicate`.
|
||||
# TODO(cjfj): Return other devices when model parallelism is supported.
|
||||
return (tpu.core(0),)
|
||||
else:
|
||||
return (ds.extended.worker_devices[replica_id],)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user