Suppresses AssertionError when using InteractiveSession.

The following code would previously raise an AssertionError in the
garbage collector:

```python
sess = tf.InteractiveSession()
sess = tf.InteractiveSession()
```

This happens because the second session will be installed on the
default session stack before the first session is destroyed, which
leads to the assertion (of strict nesting) failing. Since this is
common behavior in IPython notebooks, we suppress this error for
InteractiveSession objects.

Fixes .
Change: 123040755
This commit is contained in:
Derek Murray 2016-05-23 13:57:58 -08:00 committed by TensorFlower Gardener
parent 4fcf863963
commit aec225cbc1
3 changed files with 42 additions and 4 deletions
tensorflow/python

View File

@ -900,10 +900,12 @@ class InteractiveSession(BaseSession):
super(InteractiveSession, self).__init__(target, graph, config)
self._default_session = self.as_default()
self._default_session.enforce_nesting = False
self._default_session.__enter__()
self._explicit_graph = graph
if self._explicit_graph is not None:
self._default_graph = graph.as_default()
self._default_graph.enforce_nesting = False
self._default_graph.__enter__()
def close(self):

View File

@ -1138,11 +1138,33 @@ class SessionTest(test_util.TensorFlowTestCase):
d = math_ops.mul(c, c)
for step in xrange(120):
run_metadata = config_pb2.RunMetadata()
sess.run(d, feed_dict={a: 1.0}, options=run_options, run_metadata=run_metadata)
sess.run(d, feed_dict={a: 1.0},
options=run_options, run_metadata=run_metadata)
if step == 99:
self.assertTrue(run_metadata.HasField('cost_graph'))
else:
self.assertFalse(run_metadata.HasField('cost_graph'))
def testNonInteractiveSessionNesting(self):
sess1 = session.Session()
sess1_controller = sess1.as_default()
sess1_controller.__enter__()
sess2 = session.Session()
sess2_controller = sess2.as_default()
sess2_controller.__enter__()
with self.assertRaisesRegexp(AssertionError, 'Nesting violated'):
sess1_controller.__exit__(None, None, None)
ops._default_session_stack.reset()
def testInteractiveSessionNesting(self):
sess1 = session.InteractiveSession()
sess2 = session.InteractiveSession()
del sess1
del sess2
if __name__ == '__main__':
googletest.main()

View File

@ -3306,6 +3306,7 @@ class _DefaultStack(threading.local):
def __init__(self):
super(_DefaultStack, self).__init__()
self._enforce_nesting = True
self.stack = []
def get_default(self):
@ -3314,6 +3315,14 @@ class _DefaultStack(threading.local):
def reset(self):
self.stack = []
@property
def enforce_nesting(self):
return self._enforce_nesting
@enforce_nesting.setter
def enforce_nesting(self, value):
self._enforce_nesting = value
@contextlib.contextmanager
def get_controller(self, default):
"""A context manager for manipulating a default stack."""
@ -3321,9 +3330,14 @@ class _DefaultStack(threading.local):
self.stack.append(default)
yield default
finally:
assert self.stack[-1] is default
self.stack.pop()
if self._enforce_nesting:
if self.stack[-1] is not default:
raise AssertionError(
"Nesting violated for default stack of %s objects"
% type(default))
self.stack.pop()
else:
self.stack.remove(default)
_default_session_stack = _DefaultStack()