From 20ebd8a796467b229d3a735096931efb33f95710 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 4 Oct 2019 11:53:06 -0700 Subject: [PATCH] Add new kwargs key to escape from creating mirrored variable in TPU strategy scope. PiperOrigin-RevId: 272926391 --- tensorflow/python/distribute/tpu_strategy.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py index 894f96c0341..0073f3c0d83 100644 --- a/tensorflow/python/distribute/tpu_strategy.py +++ b/tensorflow/python/distribute/tpu_strategy.py @@ -386,6 +386,9 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): def _create_variable(self, next_creator, *args, **kwargs): """Create a TPUMirroredVariable. See `DistributionStrategy.scope`.""" + if kwargs.pop("tpu_embedding_variable_creator", False): + return next_creator(*args, **kwargs) + colocate_with = kwargs.pop("colocate_with", None) if colocate_with is None: device_map = self._device_map