From 24405dd5238326e2c544a04025ac4cfcdbdee1a8 Mon Sep 17 00:00:00 2001 From: Sourabh Bajaj Date: Wed, 3 Apr 2019 16:30:03 -0700 Subject: [PATCH] Multiple mirrored variables created in eager mode with the same name should be handled correctly inside TPUMirroredVariable. PiperOrigin-RevId: 241829490 --- tensorflow/python/distribute/values.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py index 8303f6eaa76..14814f3700c 100644 --- a/tensorflow/python/distribute/values.py +++ b/tensorflow/python/distribute/values.py @@ -796,6 +796,14 @@ class TPUMirroredVariable(trackable.Trackable): for v in self._values: v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access self._common_name = self.primary.name.split(":")[0] + + # Handle id is needed for get_replicated_var_handle to cache the variables + # correctly since in eager mode different variables can have the same name. + if context.executing_eagerly(): + self._handle_id = self._common_name + "_" + str(id(self.primary)) + else: + self._handle_id = self._common_name + self._aggregation = aggregation # Needed for GradientTape self._trainable = self.primary.trainable @@ -927,7 +935,7 @@ class TPUMirroredVariable(trackable.Trackable): tpu_context = _enclosing_tpu_context() if tpu_context is not None: return tpu_context.get_replicated_var_handle( - self._common_name, self._values) + self._handle_id, self._values) device = distribute_lib.get_update_device() if device is None: