diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py index 8303f6eaa76..14814f3700c 100644 --- a/tensorflow/python/distribute/values.py +++ b/tensorflow/python/distribute/values.py @@ -796,6 +796,14 @@ class TPUMirroredVariable(trackable.Trackable): for v in self._values: v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access self._common_name = self.primary.name.split(":")[0] + + # Handle id is needed for get_replicated_var_handle to cache the variables + # correctly since in eager mode different variables can have the same name. + if context.executing_eagerly(): + self._handle_id = self._common_name + "_" + str(id(self.primary)) + else: + self._handle_id = self._common_name + self._aggregation = aggregation # Needed for GradientTape self._trainable = self.primary.trainable @@ -927,7 +935,7 @@ class TPUMirroredVariable(trackable.Trackable): tpu_context = _enclosing_tpu_context() if tpu_context is not None: return tpu_context.get_replicated_var_handle( - self._common_name, self._values) + self._handle_id, self._values) device = distribute_lib.get_update_device() if device is None: