Ensure ExponentialMovingAverage returns correct average_name when used in scope
Fixes #2740 Change: 126251978
This commit is contained in:
parent
3f488101d8
commit
a2b9788ce4
@ -339,7 +339,10 @@ class ExponentialMovingAverage(object):
|
||||
by the `ExponentialMovingAverage class` to hold the moving average of
|
||||
`var`.
|
||||
"""
|
||||
return var.op.name + "/" + self._name
|
||||
if var in self._averages:
|
||||
return self._averages[var].op.name
|
||||
return ops.get_default_graph().unique_name(
|
||||
var.op.name + "/" + self._name, mark_as_used=False)
|
||||
|
||||
def variables_to_restore(self, moving_avg_variables=None):
|
||||
"""Returns a map of names to `Variables` to restore.
|
||||
|
@ -208,6 +208,37 @@ class ExponentialMovingAverageTest(tf.test.TestCase):
|
||||
self.assertEqual(ema.average_name(v1), ema.average(v1).op.name)
|
||||
self.assertEqual(ema.average_name(tensor2), ema.average(tensor2).op.name)
|
||||
|
||||
def testAverageVariablesNamesRespectScope(self):
|
||||
# See discussion on #2740.
|
||||
with self.test_session():
|
||||
with tf.variable_scope("scope1"):
|
||||
v0 = tf.Variable(10.0, name="v0")
|
||||
v1 = tf.Variable(30.0, name="v1")
|
||||
# Add a non-trainable variable.
|
||||
v2 = tf.Variable(20.0, name="v2", trainable=False)
|
||||
tensor2 = v0 + v1
|
||||
with tf.variable_scope("scope2"):
|
||||
ema = tf.train.ExponentialMovingAverage(0.25, name="foo_avg")
|
||||
self.assertEqual("scope2/scope1/v0/foo_avg", ema.average_name(v0))
|
||||
self.assertEqual("scope2/scope1/v1/foo_avg", ema.average_name(v1))
|
||||
self.assertEqual("scope2/scope1/add/foo_avg", ema.average_name(tensor2))
|
||||
ema.apply([v0, v1, tensor2])
|
||||
vars_to_restore = ema.variables_to_restore()
|
||||
# vars_to_restore should contain the following:
|
||||
# {scope2/scope1/v0/foo_avg : v0,
|
||||
# scope2/scope1/v1/foo_avg : v1,
|
||||
# scope2/scope1/add/foo_avg : add/foo_avg
|
||||
# scope1/v2 : v2}
|
||||
self.assertEqual(sorted(vars_to_restore.keys()),
|
||||
sorted([ema.average_name(v0),
|
||||
ema.average_name(v1),
|
||||
ema.average_name(tensor2),
|
||||
v2.op.name]))
|
||||
self.assertEqual(ema.average_name(v0), ema.average(v0).op.name)
|
||||
self.assertEqual(ema.average_name(v1), ema.average(v1).op.name)
|
||||
self.assertEqual(ema.average_name(tensor2),
|
||||
ema.average(tensor2).op.name)
|
||||
|
||||
def testSubsetAverageVariablesNames(self):
|
||||
with self.test_session():
|
||||
v0 = tf.Variable(10.0, name="v0")
|
||||
|
@ -73,5 +73,13 @@ class SlotCreatorTest(tf.test.TestCase):
|
||||
self.assertEqual(slot.dtype.base_dtype, tf.float32)
|
||||
self.assertAllEqual(slot.eval(), [0.0, 0.0])
|
||||
|
||||
def testCreateSlotFromVariableRespectsScope(self):
|
||||
# See discussion on #2740.
|
||||
with self.test_session():
|
||||
with tf.variable_scope("scope"):
|
||||
v = tf.Variable([1.0, 2.5], name="var")
|
||||
slot = slot_creator.create_slot(v, v.initialized_value(), name="slot")
|
||||
self.assertEqual(slot.op.name, "scope/scope/var/slot")
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user