diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py index e7884c0eef3..dd82d2d24d8 100644 --- a/tensorflow/python/distribute/tpu_strategy.py +++ b/tensorflow/python/distribute/tpu_strategy.py @@ -741,6 +741,9 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): # Flag to turn on VariablePolicy self._use_var_policy = False + # Flag to enable TF2 SPMD + self._use_spmd_for_xla_partitioning = False + def _validate_colocate_with_variable(self, colocate_with_variable): distribute_utils. validate_colocate(colocate_with_variable, self) @@ -900,8 +903,8 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): run_fn, replicate_inputs, device_assignment=self._device_assignment, - xla_options=tpu.XLAOptions(use_spmd_for_xla_partitioning=False)) - + xla_options=tpu.XLAOptions(use_spmd_for_xla_partitioning=self + ._use_spmd_for_xla_partitioning)) # If run_fn has tensor outputs, tpu.replicate returns a list of list. We # will flatten it in this case. If run_fn has no tensor outputs, # tpu.replicate returns a list of no_ops, we will keep the output as it @@ -1361,7 +1364,8 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): device_assignment=self._device_assignment, maximum_shapes=maximum_shapes, padding_spec=padding_spec, - xla_options=tpu.XLAOptions(use_spmd_for_xla_partitioning=False)) + xla_options=tpu.XLAOptions(use_spmd_for_xla_partitioning=self + ._use_spmd_for_xla_partitioning)) # Remove all no ops that may have been added during 'tpu.replicate()' if isinstance(result[0], list):