diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py index 894f96c0341..0073f3c0d83 100644 --- a/tensorflow/python/distribute/tpu_strategy.py +++ b/tensorflow/python/distribute/tpu_strategy.py @@ -386,6 +386,9 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): def _create_variable(self, next_creator, *args, **kwargs): """Create a TPUMirroredVariable. See `DistributionStrategy.scope`.""" + if kwargs.pop("tpu_embedding_variable_creator", False): + return next_creator(*args, **kwargs) + colocate_with = kwargs.pop("colocate_with", None) if colocate_with is None: device_map = self._device_map