unfinalize the graph at exit of MonitoredSession if it wasn't finalized before MonitoredSession creation.
Change: 139481489
This commit is contained in:
parent
b499ead596
commit
adf0dc7c15
tensorflow/python/training
@ -436,6 +436,7 @@ class MonitoredSession(object):
|
||||
`ChiefSessionCreator` which is the default one.
|
||||
hooks: An iterable of `SessionRunHook' objects.
|
||||
"""
|
||||
self._graph_was_finalized = ops.get_default_graph().finalized
|
||||
self._hooks = hooks or []
|
||||
for h in self._hooks:
|
||||
h.begin()
|
||||
@ -520,6 +521,8 @@ class MonitoredSession(object):
|
||||
self._sess = None
|
||||
self._coordinated_creator.tf_sess = None
|
||||
self._coordinated_creator.coord = None
|
||||
if not self._graph_was_finalized:
|
||||
ops.get_default_graph()._unsafe_unfinalize() # pylint: disable=protected-access
|
||||
|
||||
def _is_closed(self):
|
||||
"""Return True if the supervised session is closed. For tests only.
|
||||
|
@ -995,6 +995,23 @@ class MonitoredSessionTest(tf.test.TestCase):
|
||||
session = tf.train.MonitoredSession()
|
||||
self.assertEqual(g, session.graph)
|
||||
|
||||
def test_graph_finalized_during_run_unfinalized_after_exit(self):
|
||||
with tf.Graph().as_default() as g:
|
||||
a_var = tf.Variable(0)
|
||||
with tf.train.MonitoredSession() as session:
|
||||
self.assertEqual(0, session.run(a_var))
|
||||
self.assertTrue(g.finalized)
|
||||
self.assertFalse(g.finalized)
|
||||
|
||||
def test_keep_finalized_graph_as_finalized(self):
|
||||
with tf.Graph().as_default() as g:
|
||||
a_var = tf.Variable(0)
|
||||
tf.train.Scaffold().finalize()
|
||||
with tf.train.MonitoredSession() as session:
|
||||
self.assertEqual(0, session.run(a_var))
|
||||
self.assertTrue(g.finalized)
|
||||
self.assertTrue(g.finalized)
|
||||
|
||||
def test_merge_run_options_from_hooks(self):
|
||||
"""Test for rewriting RunOptions and observing RunMetadata with hooks."""
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user