Add new kwargs key to escape from creating mirrored variable in TPU strategy scope.
PiperOrigin-RevId: 272926391
This commit is contained in:
parent
fe1178ab88
commit
20ebd8a796
@ -386,6 +386,9 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
|
|
||||||
def _create_variable(self, next_creator, *args, **kwargs):
|
def _create_variable(self, next_creator, *args, **kwargs):
|
||||||
"""Create a TPUMirroredVariable. See `DistributionStrategy.scope`."""
|
"""Create a TPUMirroredVariable. See `DistributionStrategy.scope`."""
|
||||||
|
if kwargs.pop("tpu_embedding_variable_creator", False):
|
||||||
|
return next_creator(*args, **kwargs)
|
||||||
|
|
||||||
colocate_with = kwargs.pop("colocate_with", None)
|
colocate_with = kwargs.pop("colocate_with", None)
|
||||||
if colocate_with is None:
|
if colocate_with is None:
|
||||||
device_map = self._device_map
|
device_map = self._device_map
|
||||||
|
Loading…
Reference in New Issue
Block a user