unfinalize the graph at exit of MonitoredSession if it wasn't finalized before MonitoredSession creation.

Change: 139481489
This commit is contained in:
Mustafa Ispir 2016-11-17 11:02:24 -08:00 committed by TensorFlower Gardener
parent b499ead596
commit adf0dc7c15
2 changed files with 20 additions and 0 deletions

View File

@ -436,6 +436,7 @@ class MonitoredSession(object):
`ChiefSessionCreator` which is the default one.
hooks: An iterable of `SessionRunHook' objects.
"""
self._graph_was_finalized = ops.get_default_graph().finalized
self._hooks = hooks or []
for h in self._hooks:
h.begin()
@ -520,6 +521,8 @@ class MonitoredSession(object):
self._sess = None
self._coordinated_creator.tf_sess = None
self._coordinated_creator.coord = None
if not self._graph_was_finalized:
ops.get_default_graph()._unsafe_unfinalize() # pylint: disable=protected-access
def _is_closed(self):
"""Return True if the supervised session is closed. For tests only.

View File

@ -995,6 +995,23 @@ class MonitoredSessionTest(tf.test.TestCase):
session = tf.train.MonitoredSession()
self.assertEqual(g, session.graph)
def test_graph_finalized_during_run_unfinalized_after_exit(self):
with tf.Graph().as_default() as g:
a_var = tf.Variable(0)
with tf.train.MonitoredSession() as session:
self.assertEqual(0, session.run(a_var))
self.assertTrue(g.finalized)
self.assertFalse(g.finalized)
def test_keep_finalized_graph_as_finalized(self):
with tf.Graph().as_default() as g:
a_var = tf.Variable(0)
tf.train.Scaffold().finalize()
with tf.train.MonitoredSession() as session:
self.assertEqual(0, session.run(a_var))
self.assertTrue(g.finalized)
self.assertTrue(g.finalized)
def test_merge_run_options_from_hooks(self):
"""Test for rewriting RunOptions and observing RunMetadata with hooks."""