tfdbg: two small bug fixes
1. Make the BaseDebugWrapperSession class capable of working as a context manager, as the non-debug Session. Also let the wrapper class support the close() method. 2. Handle the case in which a fetch is an object without the "name" attribute, e.g., a SparseTensor. Change: 141308246
This commit is contained in:
parent
c2bd4034ae
commit
507376940d
@ -24,6 +24,20 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import variables
|
||||
|
||||
|
||||
def _get_fetch_name(fetch):
|
||||
"""Obtain the name or string representation of a fetch.
|
||||
|
||||
Args:
|
||||
fetch: The fetch in question.
|
||||
|
||||
Returns:
|
||||
If the attribute 'name' is available, return the name. Otherwise, return
|
||||
str(fetch).
|
||||
"""
|
||||
|
||||
return fetch.name if hasattr(fetch, "name") else str(fetch)
|
||||
|
||||
|
||||
def _get_fetch_names(fetches):
|
||||
"""Get a flattened list of the names in run() call fetches.
|
||||
|
||||
@ -46,7 +60,7 @@ def _get_fetch_names(fetches):
|
||||
else:
|
||||
# This ought to be a Tensor, an Operation or a Variable, for which the name
|
||||
# attribute should be available. (Bottom-out condition of the recursion.)
|
||||
lines.append(fetches.name)
|
||||
lines.append(_get_fetch_name(fetches))
|
||||
|
||||
return lines
|
||||
|
||||
@ -190,7 +204,7 @@ def get_run_short_description(run_call_count, fetches, feed_dict):
|
||||
description = "run #%d: " % run_call_count
|
||||
|
||||
if isinstance(fetches, (ops.Tensor, ops.Operation, variables.Variable)):
|
||||
description += "1 fetch (%s); " % fetches.name
|
||||
description += "1 fetch (%s); " % _get_fetch_name(fetches)
|
||||
else:
|
||||
# Could be (nested) list, tuple, dict or namedtuple.
|
||||
num_fetches = len(_get_fetch_names(fetches))
|
||||
|
@ -23,6 +23,7 @@ from tensorflow.python.debug.cli import cli_shared
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import googletest
|
||||
@ -35,6 +36,9 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
||||
self.const_b = constant_op.constant(22.0, name="b")
|
||||
self.const_c = constant_op.constant(33.0, name="c")
|
||||
|
||||
self.sparse_d = sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0], [1, 1]], values=[1.0, 2.0], dense_shape=[3, 3])
|
||||
|
||||
def tearDown(self):
|
||||
ops.reset_default_graph()
|
||||
|
||||
@ -66,6 +70,10 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
||||
description = cli_shared.get_run_short_description(12, self.const_a, None)
|
||||
self.assertEqual("run #12: 1 fetch (a:0); 0 feeds", description)
|
||||
|
||||
def testSparseTensorAsFetchShouldHandleNoNameAttribute(self):
|
||||
run_start_intro = cli_shared.get_run_start_intro(1, self.sparse_d, None, {})
|
||||
self.assertEqual(str(self.sparse_d), run_start_intro.lines[4].strip())
|
||||
|
||||
def testTwoFetchesListNoFeeds(self):
|
||||
fetches = [self.const_a, self.const_b]
|
||||
run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})
|
||||
|
@ -501,5 +501,14 @@ class BaseDebugWrapperSession(session.SessionInterface):
|
||||
"""
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
return self._sess.__enter__()
|
||||
|
||||
def __exit__(self, exec_type, exec_value, exec_tb):
|
||||
self._sess.__exit__(exec_type, exec_value, exec_tb)
|
||||
|
||||
def close(self):
|
||||
self._sess.close()
|
||||
|
||||
# TODO(cais): Add _node_name_regex_whitelist and
|
||||
# _node_op_type_regex_whitelist.
|
||||
|
@ -299,6 +299,18 @@ class DebugWrapperSessionTest(test_util.TensorFlowTestCase):
|
||||
self.assertTrue(
|
||||
isinstance(self._observer["tf_error"], errors.InvalidArgumentError))
|
||||
|
||||
def testUsingWrappedSessionShouldWorkAsContextManager(self):
|
||||
wrapper = TestDebugWrapperSession(self._sess, self._dump_root,
|
||||
self._observer)
|
||||
|
||||
with wrapper as sess:
|
||||
sess.run(self._s)
|
||||
|
||||
def testWrapperShouldSupportSessionClose(self):
|
||||
wrapper = TestDebugWrapperSession(self._sess, self._dump_root,
|
||||
self._observer)
|
||||
wrapper.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user