Explicitly disable SPMD in TPU strategy.

Mirrored variables are not yet supported for SPMD. This is in preparation of turning SPMD by default.

PiperOrigin-RevId: 324548129
Change-Id: Ie7adf563402bd5ef31b7759232b1cd8f441586c7
This commit is contained in:
Yuanzhong Xu 2020-08-02 23:17:55 -07:00 committed by TensorFlower Gardener
parent 151bd5901a
commit 1cb7ce30b7

View File

@ -690,7 +690,10 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
select_replica, per_replica_inputs),))
replicate_outputs = tpu.replicate(
run_fn, replicate_inputs, device_assignment=self._device_assignment)
run_fn,
replicate_inputs,
device_assignment=self._device_assignment,
xla_options=tpu.XLAOptions(use_spmd_for_xla_partitioning=False))
# If run_fn has tensor outputs, tpu.replicate returns a list of list. We
# will flatten it in this case. If run_fn has no tensor outputs,
@ -1166,7 +1169,8 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
replicate_inputs,
device_assignment=self._device_assignment,
maximum_shapes=maximum_shapes,
padding_spec=padding_spec)
padding_spec=padding_spec,
xla_options=tpu.XLAOptions(use_spmd_for_xla_partitioning=False))
# Remove all no ops that may have been added during 'tpu.replicate()'
if isinstance(result[0], list):