From fba2679b56193e2b959fd1dbee11e7d0665b27e6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 23 Oct 2020 11:09:26 -0700 Subject: [PATCH] Add an internal field in TPUExtended to turn on/off SPMD mode for TF2 PiperOrigin-RevId: 338708212 Change-Id: Id35a59441a0b1e488013431ec827bf59b6d5bb3e --- tensorflow/python/distribute/tpu_strategy.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py index e7884c0eef3..dd82d2d24d8 100644 --- a/tensorflow/python/distribute/tpu_strategy.py +++ b/tensorflow/python/distribute/tpu_strategy.py @@ -741,6 +741,9 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): # Flag to turn on VariablePolicy self._use_var_policy = False + # Flag to enable TF2 SPMD + self._use_spmd_for_xla_partitioning = False + def _validate_colocate_with_variable(self, colocate_with_variable): distribute_utils. validate_colocate(colocate_with_variable, self) @@ -900,8 +903,8 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): run_fn, replicate_inputs, device_assignment=self._device_assignment, - xla_options=tpu.XLAOptions(use_spmd_for_xla_partitioning=False)) - + xla_options=tpu.XLAOptions(use_spmd_for_xla_partitioning=self + ._use_spmd_for_xla_partitioning)) # If run_fn has tensor outputs, tpu.replicate returns a list of list. We # will flatten it in this case. If run_fn has no tensor outputs, # tpu.replicate returns a list of no_ops, we will keep the output as it @@ -1361,7 +1364,8 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): device_assignment=self._device_assignment, maximum_shapes=maximum_shapes, padding_spec=padding_spec, - xla_options=tpu.XLAOptions(use_spmd_for_xla_partitioning=False)) + xla_options=tpu.XLAOptions(use_spmd_for_xla_partitioning=self + ._use_spmd_for_xla_partitioning)) # Remove all no ops that may have been added during 'tpu.replicate()' if isinstance(result[0], list):