Suppresses AssertionError when using InteractiveSession.
The following code would previously raise an AssertionError in the garbage collector: ```python sess = tf.InteractiveSession() sess = tf.InteractiveSession() ``` This happens because the second session will be installed on the default session stack before the first session is destroyed, which leads to the assertion (of strict nesting) failing. Since this is common behavior in IPython notebooks, we suppress this error for InteractiveSession objects. Fixes #2474. Change: 123040755
This commit is contained in:
parent
4fcf863963
commit
aec225cbc1
tensorflow/python
@ -900,10 +900,12 @@ class InteractiveSession(BaseSession):
|
||||
|
||||
super(InteractiveSession, self).__init__(target, graph, config)
|
||||
self._default_session = self.as_default()
|
||||
self._default_session.enforce_nesting = False
|
||||
self._default_session.__enter__()
|
||||
self._explicit_graph = graph
|
||||
if self._explicit_graph is not None:
|
||||
self._default_graph = graph.as_default()
|
||||
self._default_graph.enforce_nesting = False
|
||||
self._default_graph.__enter__()
|
||||
|
||||
def close(self):
|
||||
|
@ -1138,11 +1138,33 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
d = math_ops.mul(c, c)
|
||||
for step in xrange(120):
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
sess.run(d, feed_dict={a: 1.0}, options=run_options, run_metadata=run_metadata)
|
||||
sess.run(d, feed_dict={a: 1.0},
|
||||
options=run_options, run_metadata=run_metadata)
|
||||
if step == 99:
|
||||
self.assertTrue(run_metadata.HasField('cost_graph'))
|
||||
else:
|
||||
self.assertFalse(run_metadata.HasField('cost_graph'))
|
||||
|
||||
def testNonInteractiveSessionNesting(self):
|
||||
sess1 = session.Session()
|
||||
sess1_controller = sess1.as_default()
|
||||
sess1_controller.__enter__()
|
||||
|
||||
sess2 = session.Session()
|
||||
sess2_controller = sess2.as_default()
|
||||
sess2_controller.__enter__()
|
||||
|
||||
with self.assertRaisesRegexp(AssertionError, 'Nesting violated'):
|
||||
sess1_controller.__exit__(None, None, None)
|
||||
|
||||
ops._default_session_stack.reset()
|
||||
|
||||
def testInteractiveSessionNesting(self):
|
||||
sess1 = session.InteractiveSession()
|
||||
sess2 = session.InteractiveSession()
|
||||
del sess1
|
||||
del sess2
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
googletest.main()
|
||||
|
@ -3306,6 +3306,7 @@ class _DefaultStack(threading.local):
|
||||
|
||||
def __init__(self):
|
||||
super(_DefaultStack, self).__init__()
|
||||
self._enforce_nesting = True
|
||||
self.stack = []
|
||||
|
||||
def get_default(self):
|
||||
@ -3314,6 +3315,14 @@ class _DefaultStack(threading.local):
|
||||
def reset(self):
|
||||
self.stack = []
|
||||
|
||||
@property
|
||||
def enforce_nesting(self):
|
||||
return self._enforce_nesting
|
||||
|
||||
@enforce_nesting.setter
|
||||
def enforce_nesting(self, value):
|
||||
self._enforce_nesting = value
|
||||
|
||||
@contextlib.contextmanager
|
||||
def get_controller(self, default):
|
||||
"""A context manager for manipulating a default stack."""
|
||||
@ -3321,9 +3330,14 @@ class _DefaultStack(threading.local):
|
||||
self.stack.append(default)
|
||||
yield default
|
||||
finally:
|
||||
assert self.stack[-1] is default
|
||||
self.stack.pop()
|
||||
|
||||
if self._enforce_nesting:
|
||||
if self.stack[-1] is not default:
|
||||
raise AssertionError(
|
||||
"Nesting violated for default stack of %s objects"
|
||||
% type(default))
|
||||
self.stack.pop()
|
||||
else:
|
||||
self.stack.remove(default)
|
||||
|
||||
_default_session_stack = _DefaultStack()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user