tfdbg: let session wrappers handle empty fetches correctly
Fixes: #15882 PiperOrigin-RevId: 183685645
This commit is contained in:
parent
0905a7ed03
commit
1f26c65254
@ -389,6 +389,11 @@ class DumpingDebugWrapperSessionTest(test_util.TensorFlowTestCase):
|
|||||||
r"mode\."):
|
r"mode\."):
|
||||||
sess.invoke_node_stepper(node_stepper)
|
sess.invoke_node_stepper(node_stepper)
|
||||||
|
|
||||||
|
def testDumpingWrapperWithEmptyFetchWorks(self):
|
||||||
|
sess = dumping_wrapper.DumpingDebugWrapperSession(
|
||||||
|
self.sess, session_root=self.session_root, log_usage=False)
|
||||||
|
sess.run([])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
googletest.main()
|
googletest.main()
|
||||||
|
@ -121,7 +121,9 @@ from tensorflow.python.debug.lib import debug_utils
|
|||||||
from tensorflow.python.debug.lib import stepper
|
from tensorflow.python.debug.lib import stepper
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.platform import tf_logging
|
||||||
from tensorflow.python.training import monitored_session
|
from tensorflow.python.training import monitored_session
|
||||||
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
|
|
||||||
# Helper function.
|
# Helper function.
|
||||||
@ -439,7 +441,12 @@ class BaseDebugWrapperSession(session.SessionInterface):
|
|||||||
"callable_runner and fetches/feed_dict are mutually exclusive, but "
|
"callable_runner and fetches/feed_dict are mutually exclusive, but "
|
||||||
"are used simultaneously.")
|
"are used simultaneously.")
|
||||||
|
|
||||||
if self._is_disabled_thread():
|
empty_fetches = not nest.flatten(fetches)
|
||||||
|
if empty_fetches:
|
||||||
|
tf_logging.info(
|
||||||
|
"Due to empty fetches, tfdbg Session wrapper is letting a "
|
||||||
|
"Session.run pass through without any debugging actions.")
|
||||||
|
if self._is_disabled_thread() or empty_fetches:
|
||||||
if callable_runner:
|
if callable_runner:
|
||||||
return callable_runner(*callable_runner_args)
|
return callable_runner(*callable_runner_args)
|
||||||
else:
|
else:
|
||||||
|
@ -664,6 +664,20 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
|
|||||||
[["run"], ["run"]], monitored_sess)
|
[["run"], ["run"]], monitored_sess)
|
||||||
self.assertFalse(wrapped_monitored_sess.should_stop())
|
self.assertFalse(wrapped_monitored_sess.should_stop())
|
||||||
|
|
||||||
|
def testRunsWithEmptyFetchWorks(self):
|
||||||
|
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
|
||||||
|
[["run"]], self.sess, dump_root="")
|
||||||
|
|
||||||
|
run_output = wrapped_sess.run([])
|
||||||
|
self.assertEqual([], run_output)
|
||||||
|
|
||||||
|
def testRunsWithEmptyNestedFetchWorks(self):
|
||||||
|
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
|
||||||
|
[["run"]], self.sess, dump_root="")
|
||||||
|
|
||||||
|
run_output = wrapped_sess.run({"foo": {"baz": []}, "bar": ()})
|
||||||
|
self.assertEqual({"foo": {"baz": []}, "bar": ()}, run_output)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
googletest.main()
|
googletest.main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user