Add an internal field in TPUExtended to turn on/off SPMD mode for TF2

PiperOrigin-RevId: 338708212
Change-Id: Id35a59441a0b1e488013431ec827bf59b6d5bb3e
This commit is contained in:
A. Unique TensorFlower 2020-10-23 11:09:26 -07:00 committed by TensorFlower Gardener
parent 65ce606dbc
commit fba2679b56

View File

@ -741,6 +741,9 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
# Flag to turn on VariablePolicy
self._use_var_policy = False
# Flag to enable TF2 SPMD
self._use_spmd_for_xla_partitioning = False
def _validate_colocate_with_variable(self, colocate_with_variable):
distribute_utils. validate_colocate(colocate_with_variable, self)
@ -900,8 +903,8 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
run_fn,
replicate_inputs,
device_assignment=self._device_assignment,
xla_options=tpu.XLAOptions(use_spmd_for_xla_partitioning=False))
xla_options=tpu.XLAOptions(use_spmd_for_xla_partitioning=self
._use_spmd_for_xla_partitioning))
# If run_fn has tensor outputs, tpu.replicate returns a list of list. We
# will flatten it in this case. If run_fn has no tensor outputs,
# tpu.replicate returns a list of no_ops, we will keep the output as it
@ -1361,7 +1364,8 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
device_assignment=self._device_assignment,
maximum_shapes=maximum_shapes,
padding_spec=padding_spec,
xla_options=tpu.XLAOptions(use_spmd_for_xla_partitioning=False))
xla_options=tpu.XLAOptions(use_spmd_for_xla_partitioning=self
._use_spmd_for_xla_partitioning))
# Remove all no ops that may have been added during 'tpu.replicate()'
if isinstance(result[0], list):