Explicitly disable SPMD in TPU strategy.
Mirrored variables are not yet supported for SPMD. This is in preparation of turning SPMD by default. PiperOrigin-RevId: 324548129 Change-Id: Ie7adf563402bd5ef31b7759232b1cd8f441586c7
This commit is contained in:
parent
151bd5901a
commit
1cb7ce30b7
@ -690,7 +690,10 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
select_replica, per_replica_inputs),))
|
||||
|
||||
replicate_outputs = tpu.replicate(
|
||||
run_fn, replicate_inputs, device_assignment=self._device_assignment)
|
||||
run_fn,
|
||||
replicate_inputs,
|
||||
device_assignment=self._device_assignment,
|
||||
xla_options=tpu.XLAOptions(use_spmd_for_xla_partitioning=False))
|
||||
|
||||
# If run_fn has tensor outputs, tpu.replicate returns a list of list. We
|
||||
# will flatten it in this case. If run_fn has no tensor outputs,
|
||||
@ -1166,7 +1169,8 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
replicate_inputs,
|
||||
device_assignment=self._device_assignment,
|
||||
maximum_shapes=maximum_shapes,
|
||||
padding_spec=padding_spec)
|
||||
padding_spec=padding_spec,
|
||||
xla_options=tpu.XLAOptions(use_spmd_for_xla_partitioning=False))
|
||||
|
||||
# Remove all no ops that may have been added during 'tpu.replicate()'
|
||||
if isinstance(result[0], list):
|
||||
|
Loading…
x
Reference in New Issue
Block a user