tfdbg: two small bug fixes

1. Make the BaseDebugWrapperSession class capable of working as a context manager, as the non-debug Session. Also let the wrapper class support the close() method.

2. Handle the case in which a fetch is an object without the "name" attribute, e.g., a SparseTensor.
Change: 141308246
This commit is contained in:
Shanqing Cai 2016-12-07 07:53:45 -08:00 committed by TensorFlower Gardener
parent c2bd4034ae
commit 507376940d
4 changed files with 45 additions and 2 deletions

View File

@ -24,6 +24,20 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import variables
def _get_fetch_name(fetch):
"""Obtain the name or string representation of a fetch.
Args:
fetch: The fetch in question.
Returns:
If the attribute 'name' is available, return the name. Otherwise, return
str(fetch).
"""
return fetch.name if hasattr(fetch, "name") else str(fetch)
def _get_fetch_names(fetches):
"""Get a flattened list of the names in run() call fetches.
@ -46,7 +60,7 @@ def _get_fetch_names(fetches):
else:
# This ought to be a Tensor, an Operation or a Variable, for which the name
# attribute should be available. (Bottom-out condition of the recursion.)
lines.append(fetches.name)
lines.append(_get_fetch_name(fetches))
return lines
@ -190,7 +204,7 @@ def get_run_short_description(run_call_count, fetches, feed_dict):
description = "run #%d: " % run_call_count
if isinstance(fetches, (ops.Tensor, ops.Operation, variables.Variable)):
description += "1 fetch (%s); " % fetches.name
description += "1 fetch (%s); " % _get_fetch_name(fetches)
else:
# Could be (nested) list, tuple, dict or namedtuple.
num_fetches = len(_get_fetch_names(fetches))

View File

@ -23,6 +23,7 @@ from tensorflow.python.debug.cli import cli_shared
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
@ -35,6 +36,9 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
self.const_b = constant_op.constant(22.0, name="b")
self.const_c = constant_op.constant(33.0, name="c")
self.sparse_d = sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 1]], values=[1.0, 2.0], dense_shape=[3, 3])
def tearDown(self):
ops.reset_default_graph()
@ -66,6 +70,10 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
description = cli_shared.get_run_short_description(12, self.const_a, None)
self.assertEqual("run #12: 1 fetch (a:0); 0 feeds", description)
def testSparseTensorAsFetchShouldHandleNoNameAttribute(self):
run_start_intro = cli_shared.get_run_start_intro(1, self.sparse_d, None, {})
self.assertEqual(str(self.sparse_d), run_start_intro.lines[4].strip())
def testTwoFetchesListNoFeeds(self):
fetches = [self.const_a, self.const_b]
run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})

View File

@ -501,5 +501,14 @@ class BaseDebugWrapperSession(session.SessionInterface):
"""
pass
def __enter__(self):
return self._sess.__enter__()
def __exit__(self, exec_type, exec_value, exec_tb):
self._sess.__exit__(exec_type, exec_value, exec_tb)
def close(self):
self._sess.close()
# TODO(cais): Add _node_name_regex_whitelist and
# _node_op_type_regex_whitelist.

View File

@ -299,6 +299,18 @@ class DebugWrapperSessionTest(test_util.TensorFlowTestCase):
self.assertTrue(
isinstance(self._observer["tf_error"], errors.InvalidArgumentError))
def testUsingWrappedSessionShouldWorkAsContextManager(self):
wrapper = TestDebugWrapperSession(self._sess, self._dump_root,
self._observer)
with wrapper as sess:
sess.run(self._s)
def testWrapperShouldSupportSessionClose(self):
wrapper = TestDebugWrapperSession(self._sess, self._dump_root,
self._observer)
wrapper.close()
if __name__ == "__main__":
googletest.main()