Make sure the name of the learning phase tensor is independent of any outer name scope.

PiperOrigin-RevId: 231328536
This commit is contained in:
Francois Chollet 2019-01-28 18:43:01 -08:00 committed by TensorFlower Gardener
parent cbc8e83f64
commit e7f1a44a5f
2 changed files with 12 additions and 4 deletions
tensorflow/python/keras

View File

@ -215,8 +215,9 @@ def clear_session():
_SESSION.session = None
graph = get_graph()
with graph.as_default():
phase = array_ops.placeholder_with_default(
False, shape=(), name='keras_learning_phase')
with ops.name_scope(''):
phase = array_ops.placeholder_with_default(
False, shape=(), name='keras_learning_phase')
_GRAPH_LEARNING_PHASES = {}
_GRAPH_LEARNING_PHASES[graph] = phase
_GRAPH_VARIABLES.pop(graph, None)
@ -275,8 +276,9 @@ def symbolic_learning_phase():
graph = get_graph()
with graph.as_default():
if graph not in _GRAPH_LEARNING_PHASES:
phase = array_ops.placeholder_with_default(
False, shape=(), name='keras_learning_phase')
with ops.name_scope(''):
phase = array_ops.placeholder_with_default(
False, shape=(), name='keras_learning_phase')
_GRAPH_LEARNING_PHASES[graph] = phase
return _GRAPH_LEARNING_PHASES[graph]

View File

@ -119,6 +119,12 @@ class BackendUtilsTest(test.TestCase):
self.evaluate(variables.global_variables_initializer())
sess.run(y, feed_dict={x: np.random.random((2, 3))})
def test_learning_phase_name(self):
with ops.name_scope('test_scope'):
# Test that outer name scopes do not affect the learning phase's name.
lp = keras.backend.symbolic_learning_phase()
self.assertEqual(lp.name, 'keras_learning_phase:0')
def test_learning_phase_scope(self):
initial_learning_phase = keras.backend.learning_phase()
with keras.backend.learning_phase_scope(1):