Multiple mirrored variables created in eager mode with the same name should be handled correctly inside TPUMirroredVariable.

PiperOrigin-RevId: 241829490
This commit is contained in:
Sourabh Bajaj 2019-04-03 16:30:03 -07:00 committed by TensorFlower Gardener
parent bb5880c426
commit 24405dd523

View File

@ -796,6 +796,14 @@ class TPUMirroredVariable(trackable.Trackable):
for v in self._values:
v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access
self._common_name = self.primary.name.split(":")[0]
# Handle id is needed for get_replicated_var_handle to cache the variables
# correctly since in eager mode different variables can have the same name.
if context.executing_eagerly():
self._handle_id = self._common_name + "_" + str(id(self.primary))
else:
self._handle_id = self._common_name
self._aggregation = aggregation
# Needed for GradientTape
self._trainable = self.primary.trainable
@ -927,7 +935,7 @@ class TPUMirroredVariable(trackable.Trackable):
tpu_context = _enclosing_tpu_context()
if tpu_context is not None:
return tpu_context.get_replicated_var_handle(
self._common_name, self._values)
self._handle_id, self._values)
device = distribute_lib.get_update_device()
if device is None: