Ensure ExponentialMovingAverage returns correct average_name when used in scope

Fixes #2740
Change: 126251978
This commit is contained in:
Frank Li 2016-06-29 16:30:37 -08:00 committed by TensorFlower Gardener
parent 3f488101d8
commit a2b9788ce4
3 changed files with 43 additions and 1 deletions

View File

@ -339,7 +339,10 @@ class ExponentialMovingAverage(object):
by the `ExponentialMovingAverage class` to hold the moving average of
`var`.
"""
return var.op.name + "/" + self._name
if var in self._averages:
return self._averages[var].op.name
return ops.get_default_graph().unique_name(
var.op.name + "/" + self._name, mark_as_used=False)
def variables_to_restore(self, moving_avg_variables=None):
"""Returns a map of names to `Variables` to restore.

View File

@ -208,6 +208,37 @@ class ExponentialMovingAverageTest(tf.test.TestCase):
self.assertEqual(ema.average_name(v1), ema.average(v1).op.name)
self.assertEqual(ema.average_name(tensor2), ema.average(tensor2).op.name)
def testAverageVariablesNamesRespectScope(self):
# See discussion on #2740.
with self.test_session():
with tf.variable_scope("scope1"):
v0 = tf.Variable(10.0, name="v0")
v1 = tf.Variable(30.0, name="v1")
# Add a non-trainable variable.
v2 = tf.Variable(20.0, name="v2", trainable=False)
tensor2 = v0 + v1
with tf.variable_scope("scope2"):
ema = tf.train.ExponentialMovingAverage(0.25, name="foo_avg")
self.assertEqual("scope2/scope1/v0/foo_avg", ema.average_name(v0))
self.assertEqual("scope2/scope1/v1/foo_avg", ema.average_name(v1))
self.assertEqual("scope2/scope1/add/foo_avg", ema.average_name(tensor2))
ema.apply([v0, v1, tensor2])
vars_to_restore = ema.variables_to_restore()
# vars_to_restore should contain the following:
# {scope2/scope1/v0/foo_avg : v0,
# scope2/scope1/v1/foo_avg : v1,
# scope2/scope1/add/foo_avg : add/foo_avg
# scope1/v2 : v2}
self.assertEqual(sorted(vars_to_restore.keys()),
sorted([ema.average_name(v0),
ema.average_name(v1),
ema.average_name(tensor2),
v2.op.name]))
self.assertEqual(ema.average_name(v0), ema.average(v0).op.name)
self.assertEqual(ema.average_name(v1), ema.average(v1).op.name)
self.assertEqual(ema.average_name(tensor2),
ema.average(tensor2).op.name)
def testSubsetAverageVariablesNames(self):
with self.test_session():
v0 = tf.Variable(10.0, name="v0")

View File

@ -73,5 +73,13 @@ class SlotCreatorTest(tf.test.TestCase):
self.assertEqual(slot.dtype.base_dtype, tf.float32)
self.assertAllEqual(slot.eval(), [0.0, 0.0])
def testCreateSlotFromVariableRespectsScope(self):
# See discussion on #2740.
with self.test_session():
with tf.variable_scope("scope"):
v = tf.Variable([1.0, 2.5], name="var")
slot = slot_creator.create_slot(v, v.initialized_value(), name="slot")
self.assertEqual(slot.op.name, "scope/scope/var/slot")
if __name__ == "__main__":
tf.test.main()