tfdbg: let session wrappers handle empty fetches correctly

Fixes: #15882
PiperOrigin-RevId: 183685645
This commit is contained in:
Shanqing Cai 2018-01-29 10:22:10 -08:00 committed by TensorFlower Gardener
parent 0905a7ed03
commit 1f26c65254
3 changed files with 27 additions and 1 deletions

View File

@ -389,6 +389,11 @@ class DumpingDebugWrapperSessionTest(test_util.TensorFlowTestCase):
r"mode\."):
sess.invoke_node_stepper(node_stepper)
def testDumpingWrapperWithEmptyFetchWorks(self):
sess = dumping_wrapper.DumpingDebugWrapperSession(
self.sess, session_root=self.session_root, log_usage=False)
sess.run([])
if __name__ == "__main__":
googletest.main()

View File

@ -121,7 +121,9 @@ from tensorflow.python.debug.lib import debug_utils
from tensorflow.python.debug.lib import stepper
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging
from tensorflow.python.training import monitored_session
from tensorflow.python.util import nest
# Helper function.
@ -439,7 +441,12 @@ class BaseDebugWrapperSession(session.SessionInterface):
"callable_runner and fetches/feed_dict are mutually exclusive, but "
"are used simultaneously.")
if self._is_disabled_thread():
empty_fetches = not nest.flatten(fetches)
if empty_fetches:
tf_logging.info(
"Due to empty fetches, tfdbg Session wrapper is letting a "
"Session.run pass through without any debugging actions.")
if self._is_disabled_thread() or empty_fetches:
if callable_runner:
return callable_runner(*callable_runner_args)
else:

View File

@ -664,6 +664,20 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
[["run"], ["run"]], monitored_sess)
self.assertFalse(wrapped_monitored_sess.should_stop())
def testRunsWithEmptyFetchWorks(self):
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
[["run"]], self.sess, dump_root="")
run_output = wrapped_sess.run([])
self.assertEqual([], run_output)
def testRunsWithEmptyNestedFetchWorks(self):
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
[["run"]], self.sess, dump_root="")
run_output = wrapped_sess.run({"foo": {"baz": []}, "bar": ()})
self.assertEqual({"foo": {"baz": []}, "bar": ()}, run_output)
if __name__ == "__main__":
googletest.main()