Add new kwargs key to escape from creating mirrored variable in TPU strategy scope.

PiperOrigin-RevId: 272926391
This commit is contained in:
A. Unique TensorFlower 2019-10-04 11:53:06 -07:00 committed by TensorFlower Gardener
parent fe1178ab88
commit 20ebd8a796

View File

@ -386,6 +386,9 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
def _create_variable(self, next_creator, *args, **kwargs):
"""Create a TPUMirroredVariable. See `DistributionStrategy.scope`."""
if kwargs.pop("tpu_embedding_variable_creator", False):
return next_creator(*args, **kwargs)
colocate_with = kwargs.pop("colocate_with", None)
if colocate_with is None:
device_map = self._device_map