Add new kwargs key to escape from creating mirrored variable in TPU strategy scope.
PiperOrigin-RevId: 272926391
This commit is contained in:
parent
fe1178ab88
commit
20ebd8a796
@ -386,6 +386,9 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
|
||||
def _create_variable(self, next_creator, *args, **kwargs):
|
||||
"""Create a TPUMirroredVariable. See `DistributionStrategy.scope`."""
|
||||
if kwargs.pop("tpu_embedding_variable_creator", False):
|
||||
return next_creator(*args, **kwargs)
|
||||
|
||||
colocate_with = kwargs.pop("colocate_with", None)
|
||||
if colocate_with is None:
|
||||
device_map = self._device_map
|
||||
|
Loading…
Reference in New Issue
Block a user