Implement TPUStrategy.experimental_run
.
PiperOrigin-RevId: 230883811
This commit is contained in:
parent
7cbe113239
commit
52ae056566
@ -188,6 +188,51 @@ class TPUStrategy(distribute_lib.DistributionStrategy):
|
|||||||
"""DEPRECATED: use .extended.steps_per_run instead."""
|
"""DEPRECATED: use .extended.steps_per_run instead."""
|
||||||
return self._extended.steps_per_run
|
return self._extended.steps_per_run
|
||||||
|
|
||||||
|
# TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this
|
||||||
|
# can use the default implementation.
|
||||||
|
# This implementation runs a single step. It does not use infeed or outfeed.
|
||||||
|
def experimental_run(self, fn, input_iterator=None):
|
||||||
|
"""See base class."""
|
||||||
|
if context.executing_eagerly():
|
||||||
|
raise NotImplementedError("Eager mode not supported in TPUStrategy.")
|
||||||
|
|
||||||
|
if self.extended._disable_training_loop_on_host: # pylint: disable=protected-access
|
||||||
|
raise NotImplementedError(
|
||||||
|
"`experimental_run` is not compatible with "
|
||||||
|
"`_disable_training_loop_on_host=True`")
|
||||||
|
|
||||||
|
if input_iterator is None:
|
||||||
|
inputs = []
|
||||||
|
else:
|
||||||
|
inputs = input_iterator.get_next()
|
||||||
|
|
||||||
|
result = [None]
|
||||||
|
def replicated_fn(replica_id, inputs):
|
||||||
|
"""Wraps user function to provide replica ID and `Tensor` inputs."""
|
||||||
|
with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id):
|
||||||
|
if input_iterator is None:
|
||||||
|
result[0] = fn()
|
||||||
|
else:
|
||||||
|
result[0] = fn(inputs)
|
||||||
|
return result[0]
|
||||||
|
|
||||||
|
replicate_inputs = [] # By replica.
|
||||||
|
for i in range(self.num_replicas_in_sync):
|
||||||
|
replicate_inputs.append(
|
||||||
|
[constant_op.constant(i, dtype=dtypes.int32),
|
||||||
|
values.select_replica(i, inputs)])
|
||||||
|
|
||||||
|
with self.scope():
|
||||||
|
replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs)
|
||||||
|
|
||||||
|
# Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
|
||||||
|
replicate_outputs = [
|
||||||
|
nest.pack_sequence_as(result[0], nest.flatten(replica_outputs))
|
||||||
|
for replica_outputs in replicate_outputs]
|
||||||
|
|
||||||
|
device_map = self.extended._device_map # pylint: disable=protected-access
|
||||||
|
return values.regroup(device_map, replicate_outputs)
|
||||||
|
|
||||||
|
|
||||||
class TPUExtended(distribute_lib.DistributionStrategyExtended):
|
class TPUExtended(distribute_lib.DistributionStrategyExtended):
|
||||||
"""Implementation of TPUStrategy."""
|
"""Implementation of TPUStrategy."""
|
||||||
@ -783,18 +828,24 @@ class _TPUReplicaContext(distribute_lib.ReplicaContext):
|
|||||||
"""Replication Context class for TPU Strategy."""
|
"""Replication Context class for TPU Strategy."""
|
||||||
|
|
||||||
# TODO(sourabhbajaj): Call for each replica should be updating this.
|
# TODO(sourabhbajaj): Call for each replica should be updating this.
|
||||||
def __init__(self, strategy):
|
# TODO(b/118385803): Always properly initialize replica_id.
|
||||||
# TODO(b/118385803): properly initialize replica_id, instead of always 0
|
def __init__(self, strategy, replica_id_in_sync_group=None):
|
||||||
replica_id = constant_op.constant(0, dtypes.int32)
|
if replica_id_in_sync_group is None:
|
||||||
|
replica_id_in_sync_group = constant_op.constant(0, dtypes.int32)
|
||||||
distribute_lib.ReplicaContext.__init__(
|
distribute_lib.ReplicaContext.__init__(
|
||||||
self, strategy, replica_id_in_sync_group=replica_id)
|
self, strategy, replica_id_in_sync_group=replica_id_in_sync_group)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def devices(self):
|
def devices(self):
|
||||||
distribute_lib.require_replica_context(self)
|
distribute_lib.require_replica_context(self)
|
||||||
ds = self._strategy
|
ds = self._strategy
|
||||||
replica_id = tensor_util.constant_value(self._replica_id_in_sync_group)
|
replica_id = tensor_util.constant_value(self._replica_id_in_sync_group)
|
||||||
return (ds.extended.worker_devices[replica_id],)
|
|
||||||
|
if replica_id is None: # Non-constant `Tensor` inside `tpu.replicate`.
|
||||||
|
# TODO(cjfj): Return other devices when model parallelism is supported.
|
||||||
|
return (tpu.core(0),)
|
||||||
|
else:
|
||||||
|
return (ds.extended.worker_devices[replica_id],)
|
||||||
|
|
||||||
|
|
||||||
def _get_host_for_device(device):
|
def _get_host_for_device(device):
|
||||||
|
Loading…
Reference in New Issue
Block a user