diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py index 3446f78288d..22aeb37ff7c 100644 --- a/tensorflow/python/distribute/tpu_strategy.py +++ b/tensorflow/python/distribute/tpu_strategy.py @@ -690,7 +690,10 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): select_replica, per_replica_inputs),)) replicate_outputs = tpu.replicate( - run_fn, replicate_inputs, device_assignment=self._device_assignment) + run_fn, + replicate_inputs, + device_assignment=self._device_assignment, + xla_options=tpu.XLAOptions(use_spmd_for_xla_partitioning=False)) # If run_fn has tensor outputs, tpu.replicate returns a list of list. We # will flatten it in this case. If run_fn has no tensor outputs, @@ -1166,7 +1169,8 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): replicate_inputs, device_assignment=self._device_assignment, maximum_shapes=maximum_shapes, - padding_spec=padding_spec) + padding_spec=padding_spec, + xla_options=tpu.XLAOptions(use_spmd_for_xla_partitioning=False)) # Remove all no ops that may have been added during 'tpu.replicate()' if isinstance(result[0], list):