Multiple mirrored variables created in eager mode with the same name should be handled correctly inside TPUMirroredVariable.
PiperOrigin-RevId: 241829490
This commit is contained in:
parent
bb5880c426
commit
24405dd523
@ -796,6 +796,14 @@ class TPUMirroredVariable(trackable.Trackable):
|
||||
for v in self._values:
|
||||
v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access
|
||||
self._common_name = self.primary.name.split(":")[0]
|
||||
|
||||
# Handle id is needed for get_replicated_var_handle to cache the variables
|
||||
# correctly since in eager mode different variables can have the same name.
|
||||
if context.executing_eagerly():
|
||||
self._handle_id = self._common_name + "_" + str(id(self.primary))
|
||||
else:
|
||||
self._handle_id = self._common_name
|
||||
|
||||
self._aggregation = aggregation
|
||||
# Needed for GradientTape
|
||||
self._trainable = self.primary.trainable
|
||||
@ -927,7 +935,7 @@ class TPUMirroredVariable(trackable.Trackable):
|
||||
tpu_context = _enclosing_tpu_context()
|
||||
if tpu_context is not None:
|
||||
return tpu_context.get_replicated_var_handle(
|
||||
self._common_name, self._values)
|
||||
self._handle_id, self._values)
|
||||
|
||||
device = distribute_lib.get_update_device()
|
||||
if device is None:
|
||||
|
Loading…
Reference in New Issue
Block a user