tfdbg CLI: consolidate run-start and run-end CLIs

Before this CL, there are two CLIs launched for each Session.run() call: the run-start CLI and the run-end one. They have different sets of commands, which is potentially confusing. This CL consolidates the two.

* The run-start CLI is launched only before the first Session.run() call.
* In subsequent Session.run() calls, only the run-end CLI is shown. At this CL, the same "run" command as in the run-start CLI works.
* Added the "-t" option flag to the "run" command to allow running through a number of Session.run() calls without pausing at the CLI.
* Adde the "run_info" (shorthand: "ri") command to allow inspection of the run fetches and feeds.
* Added more unit test coverage for the LocalCLIDebugWrapperSession class.
* Made corresponding doc updates.
* Added "tfdbg" logo to the "splash screen" of the CLI.
Change: 141221720
This commit is contained in:
Shanqing Cai 2016-12-06 13:53:54 -08:00 committed by TensorFlower Gardener
parent 4d51b55cf8
commit aa6ab8b962
6 changed files with 547 additions and 183 deletions

View File

@ -349,6 +349,7 @@ py_test(
],
srcs_version = "PY2AND3",
deps = [
":debugger_cli_common",
":local_cli_wrapper",
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework_test_lib",

View File

@ -71,7 +71,22 @@ def _recommend_command(command, description, indent=2):
return debugger_cli_common.RichTextLines(lines, font_attr_segs=font_attr_segs)
def get_run_start_intro(run_call_count, fetches, feed_dict, tensor_filters):
def get_tfdbg_logo():
lines = [
"TTTTTT FFFF DDD BBBB GGG ",
" TT F D D B B G ",
" TT FFF D D BBBB G GG",
" TT F D D B B G G",
" TT F DDD BBBB GGG ",
"",
]
return debugger_cli_common.RichTextLines(lines)
def get_run_start_intro(run_call_count,
fetches,
feed_dict,
tensor_filters):
"""Generate formatted intro for run-start UI.
Args:
@ -101,8 +116,8 @@ def get_run_start_intro(run_call_count, fetches, feed_dict, tensor_filters):
intro_lines = [
"======================================",
"About to enter Session run() call #%d:" % run_call_count, "",
"Fetch(es):"
"Session.run() call #%d:" % run_call_count,
"", "Fetch(es):"
]
intro_lines.extend([" " + line for line in fetch_lines])
intro_lines.extend(["", "Feed dict(s):"])
@ -120,11 +135,16 @@ def get_run_start_intro(run_call_count, fetches, feed_dict, tensor_filters):
out.extend(
_recommend_command(
"run -n", "Execute the run() call without debug tensor-watching"))
out.extend(
_recommend_command(
"run -t <T>",
"Execute run() calls (T - 1) times without debugging, then "
"execute run() one more time and drop back to the CLI"))
out.extend(
_recommend_command(
"run -f <filter_name>",
"Keep executing run() calls until a dumped tensor passes a given, "
"registered filter (conditional breakpoint mode)."))
"registered filter (conditional breakpoint mode)"))
more_font_attr_segs = {}
more_lines = [" Registered filter(s):"]

View File

@ -42,8 +42,7 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
run_start_intro = cli_shared.get_run_start_intro(12, self.const_a, None, {})
# Verify line about run() call number.
self.assertEqual("About to enter Session run() call #12:",
run_start_intro.lines[1])
self.assertTrue(run_start_intro.lines[1].endswith("run() call #12:"))
# Verify line about fetch.
const_a_name_line = run_start_intro.lines[4]
@ -58,8 +57,10 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
self.assertEqual([(2, 5, "bold")], run_start_intro.font_attr_segs[11])
self.assertEqual("run -n:", run_start_intro.lines[13][2:])
self.assertEqual([(2, 8, "bold")], run_start_intro.font_attr_segs[13])
self.assertEqual("run -f <filter_name>:", run_start_intro.lines[15][2:])
self.assertEqual([(2, 22, "bold")], run_start_intro.font_attr_segs[15])
self.assertEqual("run -t <T>:", run_start_intro.lines[15][2:])
self.assertEqual([(2, 12, "bold")], run_start_intro.font_attr_segs[15])
self.assertEqual("run -f <filter_name>:", run_start_intro.lines[17][2:])
self.assertEqual([(2, 22, "bold")], run_start_intro.font_attr_segs[17])
# Verify short description.
description = cli_shared.get_run_short_description(12, self.const_a, None)
@ -179,8 +180,8 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
# Verify the listed names of the tensor filters.
filter_names = set()
filter_names.add(run_start_intro.lines[18].split(" ")[-1])
filter_names.add(run_start_intro.lines[19].split(" ")[-1])
filter_names.add(run_start_intro.lines[20].split(" ")[-1])
filter_names.add(run_start_intro.lines[21].split(" ")[-1])
self.assertEqual({"filter_a", "filter_b"}, filter_names)
@ -218,14 +219,14 @@ class GetErrorIntroTest(test_util.TensorFlowTestCase):
self.assertEqual(2, error_intro.lines[8].index("lt"))
self.assertEqual([(2, 4, "bold")], error_intro.font_attr_segs[8])
self.assertTrue(error_intro.lines[11].startswith("Op name:"))
self.assertStartsWith(error_intro.lines[11], "Op name:")
self.assertTrue(error_intro.lines[11].endswith("a/Assign"))
self.assertTrue(error_intro.lines[12].startswith("Error type:"))
self.assertStartsWith(error_intro.lines[12], "Error type:")
self.assertTrue(error_intro.lines[12].endswith(str(type(tf_error))))
self.assertEqual("Details:", error_intro.lines[14])
self.assertTrue(error_intro.lines[15].startswith("foo description"))
self.assertStartsWith(error_intro.lines[15], "foo description")
if __name__ == "__main__":

View File

@ -1,12 +1,16 @@
# TensorFlow Debugger (tfdbg) Command-Line-Interface Tutorial: MNIST
**(Under development, subject to change)**
**(Experimental)**
This tutorial showcases the features of TensorFlow Debugger (**tfdbg**)
command-line interface.
It contains an example of how to debug a frequently encountered problem in
TensorFlow model development: bad numerical values (`nan`s and `inf`s) causing
training to fail.
TensorFlow debugger (**tfdbg**) is a specialized debugger for TensorFlow. It
provides visibility into the internal structure and states of running
TensorFlow graphs. The insight gained from this visibility should facilitate
debugging of various types of model bugs during training and inference.
This tutorial showcases the features of tfdbg
command-line interface (CLI), by focusing on how to debug a
type of frequently-encountered bug in TensorFlow model development:
bad numerical values (`nan`s and `inf`s) causing training to fail.
To **observe** such an issue, run the following code without the debugger:
@ -25,11 +29,7 @@ Accuracy at step 1: 0.3183
Accuracy at step 2: 0.098
Accuracy at step 3: 0.098
Accuracy at step 4: 0.098
Accuracy at step 5: 0.098
Accuracy at step 6: 0.098
Accuracy at step 7: 0.098
Accuracy at step 8: 0.098
Accuracy at step 9: 0.098
...
```
Scratching your head, you suspect that certain nodes in the training graph
@ -122,7 +122,9 @@ output.
As the screen output indicates, the first `run()` call calculates the accuracy
using a test data set—i.e., a forward pass on the graph. You can enter the
command `run` to launch the `run()` call. This will bring up another screen
command `run` (or its shorthand `r`) to launch the `run()` call.
This will bring up another screen
right after the `run()` call has ended, which will display all dumped
intermedate tensors from the run. (These tensors can also be obtained by
running the command `lt` after you executed `run`.) This is called the
@ -167,27 +169,21 @@ Try the following commands at the `tfdbg>` prompt (referencing the code at
| `lo -r hidden/Relu:0` | List the recipients of the output of the node `hidden/Relu`, recursively—i.e., the output recipient tree. |
| `lt -n softmax.*` | List all dumped tensors whose names match the regular-expression pattern `softmax.*`. |
| `lt -t MatMul` | List all dumped tensors whose node type is `MatMul`. |
| `run_info` or `ri` | Display information about the current run, including fetches and feeds. |
| `help` | Print general help information listing all available **tfdbg** commands and their flags. |
| `help lt` | Print the help information for the `lt` command. |
In this first `run()` call, there happen to be no problematic numerical values.
You can exit the run-end UI by entering the command `exit`. Then you will be at
the second run-start UI:
You can move on to the next run by using the command `run` or its shorthand `r`.
```none
--- run-start: run #2: fetch: train/Adam; 2 feeds --------------
======================================
About to enter Session run() call #2:
Fetch(es):
train/Adam
Feed dict(s):
input/x-input:0
input/y-input:0
======================================
...
```
> TIP: If you enter `run` or `r` repeatedly, you will be able to move through the
> `run()` calls in a sequential manner.
>
> You can also use the `-t` flag to move ahead a number of `run()` calls at a time, for example:
>
> ```
> tfdbg> run -t 10
> ```
Instead of entering `run` repeatedly and manually searching for `nan`s and
`inf`s in the run-end UI after every `run()` call, you can use the following

View File

@ -70,20 +70,44 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
self._dump_root = dump_root
# State flag for running till a tensor filter is passed.
self._run_till_filter_pass = None
self._initialize_argparsers()
# State related to tensor filters.
# Registered tensor filters.
self._tensor_filters = {}
# Options for the on-run-start hook:
# 1) run (DEBUG_RUN)
# 2) run --nodebug (NON_DEBUG_RUN)
# 3) invoke_stepper (INVOKE_STEPPER, not implemented)
self._on_run_start_parsers = {}
# Below are the state variables of this wrapper object.
# _active_tensor_filter: what (if any) tensor filter is in effect. If such
# a filter is in effect, this object will call run() method of the
# underlying TensorFlow Session object until the filter passes. This is
# activated by the "-f" flag of the "run" command.
# _run_through_times: keeps track of how many times the wrapper needs to
# run through without stopping at the run-end CLI. It is activated by the
# "-t" option of the "run" command.
# _skip_debug: keeps track of whether the current run should be executed
# without debugging. It is activated by the "-n" option of the "run"
# command.
#
# _run_start_response: keeps track what OnRunStartResponse the wrapper
# should return at the next run-start callback. If this information is
# unavailable (i.e., is None), the run-start CLI will be launched to ask
# the user. This is the case, e.g., right before the first run starts.
self._active_tensor_filter = None
self._run_through_times = 1
self._skip_debug = False
self._run_start_response = None
def _initialize_argparsers(self):
self._argparsers = {}
ap = argparse.ArgumentParser(
description="Run through, with or without debug tensor watching.",
usage=argparse.SUPPRESS)
ap.add_argument(
"-t",
"--times",
dest="times",
type=int,
default=1,
help="How many Session.run() calls to proceed with.")
ap.add_argument(
"-n",
"--no_debug",
@ -97,12 +121,17 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
type=str,
default="",
help="Run until a tensor in the graph passes the specified filter.")
self._on_run_start_parsers["run"] = ap
self._argparsers["run"] = ap
ap = argparse.ArgumentParser(
description="Invoke stepper (cont, step, breakpoint, etc.)",
usage=argparse.SUPPRESS)
self._on_run_start_parsers["invoke_stepper"] = ap
self._argparsers["invoke_stepper"] = ap
ap = argparse.ArgumentParser(
description="Display information about this Session.run() call.",
usage=argparse.SUPPRESS)
self._argparsers["run_info"] = ap
def add_tensor_filter(self, filter_name, tensor_filter):
"""Add a tensor filter.
@ -151,46 +180,58 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
self._update_run_calls_state(request.run_call_count, request.fetches,
request.feed_dict)
if self._run_till_filter_pass:
if self._active_tensor_filter:
# If we are running till a filter passes, we just need to keep running
# with the DEBUG_RUN option.
return framework.OnRunStartResponse(framework.OnRunStartAction.DEBUG_RUN,
self._get_run_debug_urls())
run_start_cli = curses_ui.CursesUI()
if self._run_call_count > 1 and not self._skip_debug:
if self._run_through_times > 0:
# Just run through without debugging.
return framework.OnRunStartResponse(
framework.OnRunStartAction.NON_DEBUG_RUN, [])
elif self._run_through_times == 0:
# It is the run at which the run-end CLI will be launched: activate
# debugging.
return framework.OnRunStartResponse(
framework.OnRunStartAction.DEBUG_RUN,
self._get_run_debug_urls())
run_start_cli.register_command_handler(
"run",
self._on_run_start_run_handler,
self._on_run_start_parsers["run"].format_help(),
prefix_aliases=["r"])
run_start_cli.register_command_handler(
"invoke_stepper",
self._on_run_start_step_handler,
self._on_run_start_parsers["invoke_stepper"].format_help(),
prefix_aliases=["s"])
if self._run_start_response is None:
self._prep_cli_for_run_start()
if self._tensor_filters:
# Register tab completion for the filter names.
run_start_cli.register_tab_comp_context(["run", "r"],
list(self._tensor_filters.keys()))
self._run_start_response = self._launch_cli(is_run_start=True)
if self._run_through_times > 1:
self._run_through_times -= 1
run_start_cli.set_help_intro(
cli_shared.get_run_start_intro(request.run_call_count, request.fetches,
request.feed_dict, self._tensor_filters))
# Create initial screen output detailing the run.
title = "run-start: " + self._run_description
response = run_start_cli.run_ui(
init_command="help", title=title, title_color="blue_on_white")
if response == debugger_cli_common.EXPLICIT_USER_EXIT:
if self._run_start_response == debugger_cli_common.EXPLICIT_USER_EXIT:
# Explicit user "exit" command leads to sys.exit(1).
print(
"Note: user exited from debugger CLI: Calling sys.exit(1).",
file=sys.stderr)
sys.exit(1)
return response
return self._run_start_response
def _prep_cli_for_run_start(self):
"""Prepare (but not launch) the CLI for run-start."""
self._run_cli = curses_ui.CursesUI()
help_intro = debugger_cli_common.RichTextLines([])
if self._run_call_count == 1:
# Show logo at the onset of the first run.
help_intro.extend(cli_shared.get_tfdbg_logo())
help_intro.extend(debugger_cli_common.RichTextLines("Upcoming run:"))
help_intro.extend(self._run_info)
self._run_cli.set_help_intro(help_intro)
# Create initial screen output detailing the run.
self._title = "run-start: " + self._run_description
self._init_command = "help"
self._title_color = "blue_on_white"
def on_run_end(self, request):
"""Overrides on-run-end callback.
@ -216,111 +257,150 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
debug_dump = debug_data.DebugDumpDir(
self._dump_root, partition_graphs=partition_graphs)
if request.tf_error:
help_intro = cli_shared.get_error_intro(request.tf_error)
passed_filter = None
if self._active_tensor_filter:
if not debug_dump.find(
self._tensor_filters[self._active_tensor_filter], first_n=1):
# No dumped tensor passes the filter in this run. Clean up the dump
# directory and move on.
self._remove_dump_root()
return framework.OnRunEndResponse()
else:
# Some dumped tensor(s) from this run passed the filter.
passed_filter = self._active_tensor_filter
self._active_tensor_filter = None
init_command = "help"
title_color = "red_on_white"
else:
help_intro = None
init_command = "lt"
self._prep_cli_for_run_end(debug_dump, request.tf_error, passed_filter)
title_color = "black_on_white"
if self._run_till_filter_pass:
if not debug_dump.find(
self._tensor_filters[self._run_till_filter_pass], first_n=1):
# No dumped tensor passes the filter in this run. Clean up the dump
# directory and move on.
shutil.rmtree(self._dump_root)
return framework.OnRunEndResponse()
else:
# Some dumped tensor(s) from this run passed the filter.
init_command = "lt -f %s" % self._run_till_filter_pass
title_color = "red_on_white"
self._run_till_filter_pass = None
self._run_start_response = self._launch_cli()
analyzer = analyzer_cli.DebugAnalyzer(debug_dump)
# Supply all the available tensor filters.
for filter_name in self._tensor_filters:
analyzer.add_tensor_filter(filter_name,
self._tensor_filters[filter_name])
run_end_cli = curses_ui.CursesUI()
run_end_cli.register_command_handler(
"list_tensors",
analyzer.list_tensors,
analyzer.get_help("list_tensors"),
prefix_aliases=["lt"])
run_end_cli.register_command_handler(
"node_info",
analyzer.node_info,
analyzer.get_help("node_info"),
prefix_aliases=["ni"])
run_end_cli.register_command_handler(
"list_inputs",
analyzer.list_inputs,
analyzer.get_help("list_inputs"),
prefix_aliases=["li"])
run_end_cli.register_command_handler(
"list_outputs",
analyzer.list_outputs,
analyzer.get_help("list_outputs"),
prefix_aliases=["lo"])
run_end_cli.register_command_handler(
"print_tensor",
analyzer.print_tensor,
analyzer.get_help("print_tensor"),
prefix_aliases=["pt"])
run_end_cli.register_command_handler(
"run",
self._run_end_run_command_handler,
"Helper command for incorrectly entered run command at the run-end "
"prompt.",
prefix_aliases=["r"]
)
# Get names of all dumped tensors.
dumped_tensor_names = []
for datum in debug_dump.dumped_tensor_data:
dumped_tensor_names.append("%s:%d" %
(datum.node_name, datum.output_slot))
# Tab completions for command "print_tensors".
run_end_cli.register_tab_comp_context(["print_tensor", "pt"],
dumped_tensor_names)
# Tab completion for commands "node_info", "list_inputs" and
# "list_outputs". The list comprehension is used below because nodes()
# output can be unicodes and they need to be converted to strs.
run_end_cli.register_tab_comp_context(
["node_info", "ni", "list_inputs", "li", "list_outputs", "lo"],
[str(node_name) for node_name in debug_dump.nodes()])
# TODO(cais): Reduce API surface area for aliases vis-a-vis tab
# completion contexts and registered command handlers.
title = "run-end: " + self._run_description
if help_intro:
run_end_cli.set_help_intro(help_intro)
run_end_cli.run_ui(
init_command=init_command, title=title, title_color=title_color)
# Clean up the dump directory.
shutil.rmtree(self._dump_root)
# Clean up the dump generated by this run.
self._remove_dump_root()
else:
print("No debug information to show following a non-debug run() call.")
# No debug information to show following a non-debug run() call.
self._run_start_response = None
# Return placeholder response that currently holds no additional
# information.
return framework.OnRunEndResponse()
def _on_run_start_run_handler(self, args, screen_info=None):
def _remove_dump_root(self):
if os.path.isdir(self._dump_root):
shutil.rmtree(self._dump_root)
def _prep_cli_for_run_end(self, debug_dump, tf_error, passed_filter):
"""Prepare (but not launch) CLI for run-end, with debug dump from the run.
Args:
debug_dump: (debug_data.DebugDumpDir) The debug dump directory from this
run.
tf_error: (None or OpError) OpError that happened during the run() call
(if any).
passed_filter: (None or str) Name of the tensor filter that just passed
and caused the preparation of this run-end CLI (if any).
"""
if tf_error:
help_intro = cli_shared.get_error_intro(tf_error)
self._init_command = "help"
self._title_color = "red_on_white"
else:
help_intro = None
self._init_command = "lt"
self._title_color = "black_on_white"
if passed_filter is not None:
# Some dumped tensor(s) from this run passed the filter.
self._init_command = "lt -f %s" % passed_filter
self._title_color = "red_on_white"
analyzer = analyzer_cli.DebugAnalyzer(debug_dump)
# Supply all the available tensor filters.
for filter_name in self._tensor_filters:
analyzer.add_tensor_filter(filter_name,
self._tensor_filters[filter_name])
self._run_cli = curses_ui.CursesUI()
self._run_cli.register_command_handler(
"list_tensors",
analyzer.list_tensors,
analyzer.get_help("list_tensors"),
prefix_aliases=["lt"])
self._run_cli.register_command_handler(
"node_info",
analyzer.node_info,
analyzer.get_help("node_info"),
prefix_aliases=["ni"])
self._run_cli.register_command_handler(
"list_inputs",
analyzer.list_inputs,
analyzer.get_help("list_inputs"),
prefix_aliases=["li"])
self._run_cli.register_command_handler(
"list_outputs",
analyzer.list_outputs,
analyzer.get_help("list_outputs"),
prefix_aliases=["lo"])
self._run_cli.register_command_handler(
"print_tensor",
analyzer.print_tensor,
analyzer.get_help("print_tensor"),
prefix_aliases=["pt"])
# Get names of all dumped tensors.
dumped_tensor_names = []
for datum in debug_dump.dumped_tensor_data:
dumped_tensor_names.append("%s:%d" %
(datum.node_name, datum.output_slot))
# Tab completions for command "print_tensors".
self._run_cli.register_tab_comp_context(["print_tensor", "pt"],
dumped_tensor_names)
# Tab completion for commands "node_info", "list_inputs" and
# "list_outputs". The list comprehension is used below because nodes()
# output can be unicodes and they need to be converted to strs.
self._run_cli.register_tab_comp_context(
["node_info", "ni", "list_inputs", "li", "list_outputs", "lo"],
[str(node_name) for node_name in debug_dump.nodes()])
# TODO(cais): Reduce API surface area for aliases vis-a-vis tab
# completion contexts and registered command handlers.
self._title = "run-end: " + self._run_description
if help_intro:
self._run_cli.set_help_intro(help_intro)
def _launch_cli(self, is_run_start=False):
"""Launch the interactive command-line interface.
Args:
is_run_start: (bool) whether this CLI launch occurs at a run-start
callback.
Returns:
The OnRunStartResponse specified by the user using the "run" command.
"""
self._register_this_run_info(self._run_cli)
response = self._run_cli.run_ui(
init_command=self._init_command,
title=self._title,
title_color=self._title_color)
return response
def _run_info_handler(self, args, screen_info=None):
return self._run_info
def _run_handler(self, args, screen_info=None):
"""Command handler for "run" command during on-run-start."""
_ = screen_info # Currently unused.
parsed = self._on_run_start_parsers["run"].parse_args(args)
parsed = self._argparsers["run"].parse_args(args)
if parsed.till_filter_pass:
# For the run-till-bad-numerical-value-appears mode, use the DEBUG_RUN
@ -328,14 +408,18 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
# state flag of the class itself to True.
if parsed.till_filter_pass in self._tensor_filters:
action = framework.OnRunStartAction.DEBUG_RUN
self._run_till_filter_pass = parsed.till_filter_pass
self._active_tensor_filter = parsed.till_filter_pass
else:
# Handle invalid filter name.
return debugger_cli_common.RichTextLines(
["ERROR: tensor filter \"%s\" does not exist." %
parsed.till_filter_pass])
if parsed.no_debug:
self._skip_debug = parsed.no_debug
self._run_through_times = parsed.times
if parsed.times > 1 or parsed.no_debug:
# If requested -t times > 1, the very next run will be a non-debug run.
action = framework.OnRunStartAction.NON_DEBUG_RUN
debug_urls = []
else:
@ -346,6 +430,28 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
raise debugger_cli_common.CommandLineExit(
exit_token=framework.OnRunStartResponse(action, debug_urls))
def _register_this_run_info(self, curses_cli):
curses_cli.register_command_handler(
"run",
self._run_handler,
self._argparsers["run"].format_help(),
prefix_aliases=["r"])
curses_cli.register_command_handler(
"invoke_stepper",
self._on_run_start_step_handler,
self._argparsers["invoke_stepper"].format_help(),
prefix_aliases=["s"])
curses_cli.register_command_handler(
"run_info",
self._run_info_handler,
self._argparsers["run_info"].format_help(),
prefix_aliases=["ri"])
if self._tensor_filters:
# Register tab completion for the filter names.
curses_cli.register_tab_comp_context(["run", "r"],
list(self._tensor_filters.keys()))
def _on_run_start_step_handler(self, args, screen_info=None):
"""Command handler for "invoke_stepper" command during on-run-start."""
@ -359,18 +465,6 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
exit_token=framework.OnRunStartResponse(
framework.OnRunStartAction.INVOKE_STEPPER, []))
def _run_end_run_command_handler(self, args, screen_info=None):
"""Handler for incorrectly entered run command at run-end prompt."""
_ = screen_info # Currently unused.
return debugger_cli_common.RichTextLines([
"ERROR: the \"run\" command is invalid for the run-end prompt.", "",
"To proceed to the next run, ",
" 1) exit this run-end prompt using the command \"exit\"",
" 2) enter the command \"run\" at the next run-start prompt.",
])
def _get_run_debug_urls(self):
"""Get the debug_urls value for the current run() call.
@ -397,3 +491,9 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
self._run_description = cli_shared.get_run_short_description(run_call_count,
fetches,
feed_dict)
self._run_through_times -= 1
self._run_info = cli_shared.get_run_start_intro(run_call_count,
fetches,
feed_dict,
self._tensor_filters)

View File

@ -21,18 +21,96 @@ import os
import shutil
import tempfile
import tensorflow as tf
from tensorflow.python.client import session
from tensorflow.python.debug.cli import debugger_cli_common
from tensorflow.python.debug.wrappers import local_cli_wrapper
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
class LocalCLIDebuggerWrapperSessionForTest(
local_cli_wrapper.LocalCLIDebugWrapperSession):
"""Subclasses the wrapper class for testing.
Overrides its CLI-related methods for headless testing environments.
Inserts observer variables for assertions.
"""
def __init__(self,
command_args_sequence,
sess,
dump_root=None):
"""Constructor of the for-test subclass.
Args:
command_args_sequence: (list of list of str) A list of arguments for the
"run" command.
sess: See the doc string of LocalCLIDebugWrapperSession.__init__.
dump_root: See the doc string of LocalCLIDebugWrapperSession.__init__.
"""
local_cli_wrapper.LocalCLIDebugWrapperSession.__init__(
self, sess, dump_root=dump_root, log_usage=False)
self._command_args_sequence = command_args_sequence
self._response_pointer = 0
# Observer variables.
self.observers = {
"debug_dumps": [],
"tf_errors": [],
"run_start_cli_run_numbers": [],
"run_end_cli_run_numbers": [],
}
def _prep_cli_for_run_start(self):
pass
def _prep_cli_for_run_end(self, debug_dump, tf_error, passed_filter):
self.observers["debug_dumps"].append(debug_dump)
self.observers["tf_errors"].append(tf_error)
def _launch_cli(self, is_run_start=False):
if is_run_start:
self.observers["run_start_cli_run_numbers"].append(self._run_call_count)
else:
self.observers["run_end_cli_run_numbers"].append(self._run_call_count)
command_args = self._command_args_sequence[self._response_pointer]
self._response_pointer += 1
try:
self._run_handler(command_args)
except debugger_cli_common.CommandLineExit as e:
response = e.exit_token
return response
class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
def setUp(self):
self._tmp_dir = tempfile.mktemp()
self.v = tf.Variable(10.0, name="v")
self.delta = tf.constant(1.0, name="delta")
self.inc_v = tf.assign_add(self.v, self.delta, name="inc_v")
self.ph = tf.placeholder(tf.float32, name="ph")
self.xph = tf.transpose(self.ph, name="xph")
self.m = tf.constant(
[[0.0, 1.0, 2.0], [-4.0, -1.0, 0.0]], dtype=tf.float32, name="m")
self.y = tf.matmul(self.m, self.xph, name="y")
self.sess = tf.Session()
# Initialize variable.
self.sess.run(self.v.initializer)
def tearDown(self):
tf.reset_default_graph()
if os.path.isdir(self._tmp_dir):
shutil.rmtree(self._tmp_dir)
@ -68,6 +146,174 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
local_cli_wrapper.LocalCLIDebugWrapperSession(
session.Session(), dump_root=file_path, log_usage=False)
def testRunsUnderDebugMode(self):
# Test command sequence: run; run; run;
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
[[], [], []], self.sess, dump_root=self._tmp_dir)
# run under debug mode twice.
wrapped_sess.run(self.inc_v)
wrapped_sess.run(self.inc_v)
# Verify that the assign_add op did take effect.
self.assertAllClose(12.0, self.sess.run(self.v))
# Assert correct run call numbers for which the CLI has been launched at
# run-start and run-end.
self.assertEqual([1], wrapped_sess.observers["run_start_cli_run_numbers"])
self.assertEqual([1, 2], wrapped_sess.observers["run_end_cli_run_numbers"])
# Verify that the dumps have been generated and picked up during run-end.
self.assertEqual(2, len(wrapped_sess.observers["debug_dumps"]))
# Verify that the TensorFlow runtime errors are picked up and in this case,
# they should be both None.
self.assertEqual([None, None], wrapped_sess.observers["tf_errors"])
def testRunsUnderNonDebugMode(self):
# Test command sequence: run -n; run -n; run -n;
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
[["-n"], ["-n"], ["-n"]],
self.sess,
dump_root=self._tmp_dir)
# run three times.
wrapped_sess.run(self.inc_v)
wrapped_sess.run(self.inc_v)
wrapped_sess.run(self.inc_v)
self.assertAllClose(13.0, self.sess.run(self.v))
self.assertEqual([1, 2, 3],
wrapped_sess.observers["run_start_cli_run_numbers"])
self.assertEqual([], wrapped_sess.observers["run_end_cli_run_numbers"])
def testRunsUnderNonDebugThenDebugMode(self):
# Test command sequence: run -n; run -n; run; run;
# Do two NON_DEBUG_RUNs, followed by DEBUG_RUNs.
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
[["-n"], ["-n"], [], []],
self.sess,
dump_root=self._tmp_dir)
# run three times.
wrapped_sess.run(self.inc_v)
wrapped_sess.run(self.inc_v)
wrapped_sess.run(self.inc_v)
self.assertAllClose(13.0, self.sess.run(self.v))
self.assertEqual([1, 2, 3],
wrapped_sess.observers["run_start_cli_run_numbers"])
# Here, the CLI should have been launched only under the third run,
# because the first and second runs are NON_DEBUG.
self.assertEqual([3], wrapped_sess.observers["run_end_cli_run_numbers"])
self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"]))
self.assertEqual([None], wrapped_sess.observers["tf_errors"])
def testRunMultipleTimesWithinLimit(self):
# Test command sequence: run -t 3; run;
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
[["-t", "3"], []], self.sess, dump_root=self._tmp_dir)
# run three times.
wrapped_sess.run(self.inc_v)
wrapped_sess.run(self.inc_v)
wrapped_sess.run(self.inc_v)
self.assertAllClose(13.0, self.sess.run(self.v))
self.assertEqual([1], wrapped_sess.observers["run_start_cli_run_numbers"])
self.assertEqual([3], wrapped_sess.observers["run_end_cli_run_numbers"])
self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"]))
self.assertEqual([None], wrapped_sess.observers["tf_errors"])
def testRunMultipleTimesOverLimit(self):
# Test command sequence: run -t 3;
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
[["-t", "3"]], self.sess, dump_root=self._tmp_dir)
# run twice, which is less than the number of times specified by the
# command.
wrapped_sess.run(self.inc_v)
wrapped_sess.run(self.inc_v)
self.assertAllClose(12.0, self.sess.run(self.v))
self.assertEqual([1], wrapped_sess.observers["run_start_cli_run_numbers"])
self.assertEqual([], wrapped_sess.observers["run_end_cli_run_numbers"])
self.assertEqual(0, len(wrapped_sess.observers["debug_dumps"]))
self.assertEqual([], wrapped_sess.observers["tf_errors"])
def testRunMixingDebugModeAndMultpleTimes(self):
# Test command sequence: run -n; run -t 2; run; run;
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
[["-n"], ["-t", "2"], [], []],
self.sess,
dump_root=self._tmp_dir)
# run four times.
wrapped_sess.run(self.inc_v)
wrapped_sess.run(self.inc_v)
wrapped_sess.run(self.inc_v)
wrapped_sess.run(self.inc_v)
self.assertAllClose(14.0, self.sess.run(self.v))
self.assertEqual([1, 2],
wrapped_sess.observers["run_start_cli_run_numbers"])
self.assertEqual([3, 4], wrapped_sess.observers["run_end_cli_run_numbers"])
self.assertEqual(2, len(wrapped_sess.observers["debug_dumps"]))
self.assertEqual([None, None], wrapped_sess.observers["tf_errors"])
def testRuntimeErrorShouldBeCaught(self):
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
[[], []], self.sess, dump_root=self._tmp_dir)
# Do a run that should lead to an TensorFlow runtime error.
wrapped_sess.run(self.y, feed_dict={self.ph: [[0.0], [1.0], [2.0]]})
self.assertEqual([1], wrapped_sess.observers["run_start_cli_run_numbers"])
self.assertEqual([1], wrapped_sess.observers["run_end_cli_run_numbers"])
self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"]))
# Verify that the runtime error is caught by the wrapped session properly.
self.assertEqual(1, len(wrapped_sess.observers["tf_errors"]))
tf_error = wrapped_sess.observers["tf_errors"][0]
self.assertEqual("y", tf_error.op.name)
def testRunTillFilterPassesShouldLaunchCLIAtCorrectRun(self):
# Test command sequence:
# run -f greater_than_twelve; run -f greater_than_twelve; run;
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
[["-f", "v_greater_than_twelve"], ["-f", "v_greater_than_twelve"], []],
self.sess, dump_root=self._tmp_dir)
def v_greater_than_twelve(datum, tensor):
return datum.node_name == "v" and tensor > 12.0
wrapped_sess.add_tensor_filter(
"v_greater_than_twelve", v_greater_than_twelve)
# run five times.
wrapped_sess.run(self.inc_v)
wrapped_sess.run(self.inc_v)
wrapped_sess.run(self.inc_v)
wrapped_sess.run(self.inc_v)
wrapped_sess.run(self.inc_v)
self.assertAllClose(15.0, self.sess.run(self.v))
self.assertEqual([1], wrapped_sess.observers["run_start_cli_run_numbers"])
# run-end CLI should NOT have been launched for run #2 and #3, because only
# starting from run #4 v becomes greater than 12.0.
self.assertEqual([4, 5], wrapped_sess.observers["run_end_cli_run_numbers"])
self.assertEqual(2, len(wrapped_sess.observers["debug_dumps"]))
self.assertEqual([None, None], wrapped_sess.observers["tf_errors"])
if __name__ == "__main__":
googletest.main()