From a2b9788ce440c350d4e3fef53fe0c51ba1c10c1a Mon Sep 17 00:00:00 2001 From: Frank Li <lif@google.com> Date: Wed, 29 Jun 2016 16:30:37 -0800 Subject: [PATCH] Ensure ExponentialMovingAverage returns correct average_name when used in scope Fixes #2740 Change: 126251978 --- tensorflow/python/training/moving_averages.py | 5 ++- .../python/training/moving_averages_test.py | 31 +++++++++++++++++++ .../python/training/slot_creator_test.py | 8 +++++ 3 files changed, 43 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py index 9cc1d917ab1..4bbb89ff897 100644 --- a/tensorflow/python/training/moving_averages.py +++ b/tensorflow/python/training/moving_averages.py @@ -339,7 +339,10 @@ class ExponentialMovingAverage(object): by the `ExponentialMovingAverage class` to hold the moving average of `var`. """ - return var.op.name + "/" + self._name + if var in self._averages: + return self._averages[var].op.name + return ops.get_default_graph().unique_name( + var.op.name + "/" + self._name, mark_as_used=False) def variables_to_restore(self, moving_avg_variables=None): """Returns a map of names to `Variables` to restore. diff --git a/tensorflow/python/training/moving_averages_test.py b/tensorflow/python/training/moving_averages_test.py index 437f2170fb3..413c38e601a 100644 --- a/tensorflow/python/training/moving_averages_test.py +++ b/tensorflow/python/training/moving_averages_test.py @@ -208,6 +208,37 @@ class ExponentialMovingAverageTest(tf.test.TestCase): self.assertEqual(ema.average_name(v1), ema.average(v1).op.name) self.assertEqual(ema.average_name(tensor2), ema.average(tensor2).op.name) + def testAverageVariablesNamesRespectScope(self): + # See discussion on #2740. + with self.test_session(): + with tf.variable_scope("scope1"): + v0 = tf.Variable(10.0, name="v0") + v1 = tf.Variable(30.0, name="v1") + # Add a non-trainable variable. + v2 = tf.Variable(20.0, name="v2", trainable=False) + tensor2 = v0 + v1 + with tf.variable_scope("scope2"): + ema = tf.train.ExponentialMovingAverage(0.25, name="foo_avg") + self.assertEqual("scope2/scope1/v0/foo_avg", ema.average_name(v0)) + self.assertEqual("scope2/scope1/v1/foo_avg", ema.average_name(v1)) + self.assertEqual("scope2/scope1/add/foo_avg", ema.average_name(tensor2)) + ema.apply([v0, v1, tensor2]) + vars_to_restore = ema.variables_to_restore() + # vars_to_restore should contain the following: + # {scope2/scope1/v0/foo_avg : v0, + # scope2/scope1/v1/foo_avg : v1, + # scope2/scope1/add/foo_avg : add/foo_avg + # scope1/v2 : v2} + self.assertEqual(sorted(vars_to_restore.keys()), + sorted([ema.average_name(v0), + ema.average_name(v1), + ema.average_name(tensor2), + v2.op.name])) + self.assertEqual(ema.average_name(v0), ema.average(v0).op.name) + self.assertEqual(ema.average_name(v1), ema.average(v1).op.name) + self.assertEqual(ema.average_name(tensor2), + ema.average(tensor2).op.name) + def testSubsetAverageVariablesNames(self): with self.test_session(): v0 = tf.Variable(10.0, name="v0") diff --git a/tensorflow/python/training/slot_creator_test.py b/tensorflow/python/training/slot_creator_test.py index 8a11e44d1e6..4e2b980751d 100644 --- a/tensorflow/python/training/slot_creator_test.py +++ b/tensorflow/python/training/slot_creator_test.py @@ -73,5 +73,13 @@ class SlotCreatorTest(tf.test.TestCase): self.assertEqual(slot.dtype.base_dtype, tf.float32) self.assertAllEqual(slot.eval(), [0.0, 0.0]) + def testCreateSlotFromVariableRespectsScope(self): + # See discussion on #2740. + with self.test_session(): + with tf.variable_scope("scope"): + v = tf.Variable([1.0, 2.5], name="var") + slot = slot_creator.create_slot(v, v.initialized_value(), name="slot") + self.assertEqual(slot.op.name, "scope/scope/var/slot") + if __name__ == "__main__": tf.test.main()