Add an internal field in TPUExtended to turn on/off SPMD mode for TF2
PiperOrigin-RevId: 338708212 Change-Id: Id35a59441a0b1e488013431ec827bf59b6d5bb3e
This commit is contained in:
parent
65ce606dbc
commit
fba2679b56
@ -741,6 +741,9 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
# Flag to turn on VariablePolicy
|
||||
self._use_var_policy = False
|
||||
|
||||
# Flag to enable TF2 SPMD
|
||||
self._use_spmd_for_xla_partitioning = False
|
||||
|
||||
def _validate_colocate_with_variable(self, colocate_with_variable):
|
||||
distribute_utils. validate_colocate(colocate_with_variable, self)
|
||||
|
||||
@ -900,8 +903,8 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
run_fn,
|
||||
replicate_inputs,
|
||||
device_assignment=self._device_assignment,
|
||||
xla_options=tpu.XLAOptions(use_spmd_for_xla_partitioning=False))
|
||||
|
||||
xla_options=tpu.XLAOptions(use_spmd_for_xla_partitioning=self
|
||||
._use_spmd_for_xla_partitioning))
|
||||
# If run_fn has tensor outputs, tpu.replicate returns a list of list. We
|
||||
# will flatten it in this case. If run_fn has no tensor outputs,
|
||||
# tpu.replicate returns a list of no_ops, we will keep the output as it
|
||||
@ -1361,7 +1364,8 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
device_assignment=self._device_assignment,
|
||||
maximum_shapes=maximum_shapes,
|
||||
padding_spec=padding_spec,
|
||||
xla_options=tpu.XLAOptions(use_spmd_for_xla_partitioning=False))
|
||||
xla_options=tpu.XLAOptions(use_spmd_for_xla_partitioning=self
|
||||
._use_spmd_for_xla_partitioning))
|
||||
|
||||
# Remove all no ops that may have been added during 'tpu.replicate()'
|
||||
if isinstance(result[0], list):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user