tfdbg CLI: consolidate run-start and run-end CLIs
Before this CL, there are two CLIs launched for each Session.run() call: the run-start CLI and the run-end one. They have different sets of commands, which is potentially confusing. This CL consolidates the two. * The run-start CLI is launched only before the first Session.run() call. * In subsequent Session.run() calls, only the run-end CLI is shown. At this CL, the same "run" command as in the run-start CLI works. * Added the "-t" option flag to the "run" command to allow running through a number of Session.run() calls without pausing at the CLI. * Adde the "run_info" (shorthand: "ri") command to allow inspection of the run fetches and feeds. * Added more unit test coverage for the LocalCLIDebugWrapperSession class. * Made corresponding doc updates. * Added "tfdbg" logo to the "splash screen" of the CLI. Change: 141221720
This commit is contained in:
parent
4d51b55cf8
commit
aa6ab8b962
tensorflow/python/debug
@ -349,6 +349,7 @@ py_test(
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":debugger_cli_common",
|
||||
":local_cli_wrapper",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
|
@ -71,7 +71,22 @@ def _recommend_command(command, description, indent=2):
|
||||
return debugger_cli_common.RichTextLines(lines, font_attr_segs=font_attr_segs)
|
||||
|
||||
|
||||
def get_run_start_intro(run_call_count, fetches, feed_dict, tensor_filters):
|
||||
def get_tfdbg_logo():
|
||||
lines = [
|
||||
"TTTTTT FFFF DDD BBBB GGG ",
|
||||
" TT F D D B B G ",
|
||||
" TT FFF D D BBBB G GG",
|
||||
" TT F D D B B G G",
|
||||
" TT F DDD BBBB GGG ",
|
||||
"",
|
||||
]
|
||||
return debugger_cli_common.RichTextLines(lines)
|
||||
|
||||
|
||||
def get_run_start_intro(run_call_count,
|
||||
fetches,
|
||||
feed_dict,
|
||||
tensor_filters):
|
||||
"""Generate formatted intro for run-start UI.
|
||||
|
||||
Args:
|
||||
@ -101,8 +116,8 @@ def get_run_start_intro(run_call_count, fetches, feed_dict, tensor_filters):
|
||||
|
||||
intro_lines = [
|
||||
"======================================",
|
||||
"About to enter Session run() call #%d:" % run_call_count, "",
|
||||
"Fetch(es):"
|
||||
"Session.run() call #%d:" % run_call_count,
|
||||
"", "Fetch(es):"
|
||||
]
|
||||
intro_lines.extend([" " + line for line in fetch_lines])
|
||||
intro_lines.extend(["", "Feed dict(s):"])
|
||||
@ -120,11 +135,16 @@ def get_run_start_intro(run_call_count, fetches, feed_dict, tensor_filters):
|
||||
out.extend(
|
||||
_recommend_command(
|
||||
"run -n", "Execute the run() call without debug tensor-watching"))
|
||||
out.extend(
|
||||
_recommend_command(
|
||||
"run -t <T>",
|
||||
"Execute run() calls (T - 1) times without debugging, then "
|
||||
"execute run() one more time and drop back to the CLI"))
|
||||
out.extend(
|
||||
_recommend_command(
|
||||
"run -f <filter_name>",
|
||||
"Keep executing run() calls until a dumped tensor passes a given, "
|
||||
"registered filter (conditional breakpoint mode)."))
|
||||
"registered filter (conditional breakpoint mode)"))
|
||||
|
||||
more_font_attr_segs = {}
|
||||
more_lines = [" Registered filter(s):"]
|
||||
|
@ -42,8 +42,7 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
||||
run_start_intro = cli_shared.get_run_start_intro(12, self.const_a, None, {})
|
||||
|
||||
# Verify line about run() call number.
|
||||
self.assertEqual("About to enter Session run() call #12:",
|
||||
run_start_intro.lines[1])
|
||||
self.assertTrue(run_start_intro.lines[1].endswith("run() call #12:"))
|
||||
|
||||
# Verify line about fetch.
|
||||
const_a_name_line = run_start_intro.lines[4]
|
||||
@ -58,8 +57,10 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual([(2, 5, "bold")], run_start_intro.font_attr_segs[11])
|
||||
self.assertEqual("run -n:", run_start_intro.lines[13][2:])
|
||||
self.assertEqual([(2, 8, "bold")], run_start_intro.font_attr_segs[13])
|
||||
self.assertEqual("run -f <filter_name>:", run_start_intro.lines[15][2:])
|
||||
self.assertEqual([(2, 22, "bold")], run_start_intro.font_attr_segs[15])
|
||||
self.assertEqual("run -t <T>:", run_start_intro.lines[15][2:])
|
||||
self.assertEqual([(2, 12, "bold")], run_start_intro.font_attr_segs[15])
|
||||
self.assertEqual("run -f <filter_name>:", run_start_intro.lines[17][2:])
|
||||
self.assertEqual([(2, 22, "bold")], run_start_intro.font_attr_segs[17])
|
||||
|
||||
# Verify short description.
|
||||
description = cli_shared.get_run_short_description(12, self.const_a, None)
|
||||
@ -179,8 +180,8 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
||||
|
||||
# Verify the listed names of the tensor filters.
|
||||
filter_names = set()
|
||||
filter_names.add(run_start_intro.lines[18].split(" ")[-1])
|
||||
filter_names.add(run_start_intro.lines[19].split(" ")[-1])
|
||||
filter_names.add(run_start_intro.lines[20].split(" ")[-1])
|
||||
filter_names.add(run_start_intro.lines[21].split(" ")[-1])
|
||||
|
||||
self.assertEqual({"filter_a", "filter_b"}, filter_names)
|
||||
|
||||
@ -218,14 +219,14 @@ class GetErrorIntroTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(2, error_intro.lines[8].index("lt"))
|
||||
self.assertEqual([(2, 4, "bold")], error_intro.font_attr_segs[8])
|
||||
|
||||
self.assertTrue(error_intro.lines[11].startswith("Op name:"))
|
||||
self.assertStartsWith(error_intro.lines[11], "Op name:")
|
||||
self.assertTrue(error_intro.lines[11].endswith("a/Assign"))
|
||||
|
||||
self.assertTrue(error_intro.lines[12].startswith("Error type:"))
|
||||
self.assertStartsWith(error_intro.lines[12], "Error type:")
|
||||
self.assertTrue(error_intro.lines[12].endswith(str(type(tf_error))))
|
||||
|
||||
self.assertEqual("Details:", error_intro.lines[14])
|
||||
self.assertTrue(error_intro.lines[15].startswith("foo description"))
|
||||
self.assertStartsWith(error_intro.lines[15], "foo description")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,12 +1,16 @@
|
||||
# TensorFlow Debugger (tfdbg) Command-Line-Interface Tutorial: MNIST
|
||||
|
||||
**(Under development, subject to change)**
|
||||
**(Experimental)**
|
||||
|
||||
This tutorial showcases the features of TensorFlow Debugger (**tfdbg**)
|
||||
command-line interface.
|
||||
It contains an example of how to debug a frequently encountered problem in
|
||||
TensorFlow model development: bad numerical values (`nan`s and `inf`s) causing
|
||||
training to fail.
|
||||
TensorFlow debugger (**tfdbg**) is a specialized debugger for TensorFlow. It
|
||||
provides visibility into the internal structure and states of running
|
||||
TensorFlow graphs. The insight gained from this visibility should facilitate
|
||||
debugging of various types of model bugs during training and inference.
|
||||
|
||||
This tutorial showcases the features of tfdbg
|
||||
command-line interface (CLI), by focusing on how to debug a
|
||||
type of frequently-encountered bug in TensorFlow model development:
|
||||
bad numerical values (`nan`s and `inf`s) causing training to fail.
|
||||
|
||||
To **observe** such an issue, run the following code without the debugger:
|
||||
|
||||
@ -25,11 +29,7 @@ Accuracy at step 1: 0.3183
|
||||
Accuracy at step 2: 0.098
|
||||
Accuracy at step 3: 0.098
|
||||
Accuracy at step 4: 0.098
|
||||
Accuracy at step 5: 0.098
|
||||
Accuracy at step 6: 0.098
|
||||
Accuracy at step 7: 0.098
|
||||
Accuracy at step 8: 0.098
|
||||
Accuracy at step 9: 0.098
|
||||
...
|
||||
```
|
||||
|
||||
Scratching your head, you suspect that certain nodes in the training graph
|
||||
@ -122,7 +122,9 @@ output.
|
||||
|
||||
As the screen output indicates, the first `run()` call calculates the accuracy
|
||||
using a test data set—i.e., a forward pass on the graph. You can enter the
|
||||
command `run` to launch the `run()` call. This will bring up another screen
|
||||
command `run` (or its shorthand `r`) to launch the `run()` call.
|
||||
|
||||
This will bring up another screen
|
||||
right after the `run()` call has ended, which will display all dumped
|
||||
intermedate tensors from the run. (These tensors can also be obtained by
|
||||
running the command `lt` after you executed `run`.) This is called the
|
||||
@ -167,27 +169,21 @@ Try the following commands at the `tfdbg>` prompt (referencing the code at
|
||||
| `lo -r hidden/Relu:0` | List the recipients of the output of the node `hidden/Relu`, recursively—i.e., the output recipient tree. |
|
||||
| `lt -n softmax.*` | List all dumped tensors whose names match the regular-expression pattern `softmax.*`. |
|
||||
| `lt -t MatMul` | List all dumped tensors whose node type is `MatMul`. |
|
||||
| `run_info` or `ri` | Display information about the current run, including fetches and feeds. |
|
||||
| `help` | Print general help information listing all available **tfdbg** commands and their flags. |
|
||||
| `help lt` | Print the help information for the `lt` command. |
|
||||
|
||||
In this first `run()` call, there happen to be no problematic numerical values.
|
||||
You can exit the run-end UI by entering the command `exit`. Then you will be at
|
||||
the second run-start UI:
|
||||
You can move on to the next run by using the command `run` or its shorthand `r`.
|
||||
|
||||
```none
|
||||
--- run-start: run #2: fetch: train/Adam; 2 feeds --------------
|
||||
======================================
|
||||
About to enter Session run() call #2:
|
||||
|
||||
Fetch(es):
|
||||
train/Adam
|
||||
|
||||
Feed dict(s):
|
||||
input/x-input:0
|
||||
input/y-input:0
|
||||
======================================
|
||||
...
|
||||
```
|
||||
> TIP: If you enter `run` or `r` repeatedly, you will be able to move through the
|
||||
> `run()` calls in a sequential manner.
|
||||
>
|
||||
> You can also use the `-t` flag to move ahead a number of `run()` calls at a time, for example:
|
||||
>
|
||||
> ```
|
||||
> tfdbg> run -t 10
|
||||
> ```
|
||||
|
||||
Instead of entering `run` repeatedly and manually searching for `nan`s and
|
||||
`inf`s in the run-end UI after every `run()` call, you can use the following
|
||||
|
@ -70,20 +70,44 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
|
||||
|
||||
self._dump_root = dump_root
|
||||
|
||||
# State flag for running till a tensor filter is passed.
|
||||
self._run_till_filter_pass = None
|
||||
self._initialize_argparsers()
|
||||
|
||||
# State related to tensor filters.
|
||||
# Registered tensor filters.
|
||||
self._tensor_filters = {}
|
||||
|
||||
# Options for the on-run-start hook:
|
||||
# 1) run (DEBUG_RUN)
|
||||
# 2) run --nodebug (NON_DEBUG_RUN)
|
||||
# 3) invoke_stepper (INVOKE_STEPPER, not implemented)
|
||||
self._on_run_start_parsers = {}
|
||||
# Below are the state variables of this wrapper object.
|
||||
# _active_tensor_filter: what (if any) tensor filter is in effect. If such
|
||||
# a filter is in effect, this object will call run() method of the
|
||||
# underlying TensorFlow Session object until the filter passes. This is
|
||||
# activated by the "-f" flag of the "run" command.
|
||||
# _run_through_times: keeps track of how many times the wrapper needs to
|
||||
# run through without stopping at the run-end CLI. It is activated by the
|
||||
# "-t" option of the "run" command.
|
||||
# _skip_debug: keeps track of whether the current run should be executed
|
||||
# without debugging. It is activated by the "-n" option of the "run"
|
||||
# command.
|
||||
#
|
||||
# _run_start_response: keeps track what OnRunStartResponse the wrapper
|
||||
# should return at the next run-start callback. If this information is
|
||||
# unavailable (i.e., is None), the run-start CLI will be launched to ask
|
||||
# the user. This is the case, e.g., right before the first run starts.
|
||||
self._active_tensor_filter = None
|
||||
self._run_through_times = 1
|
||||
self._skip_debug = False
|
||||
self._run_start_response = None
|
||||
|
||||
def _initialize_argparsers(self):
|
||||
self._argparsers = {}
|
||||
ap = argparse.ArgumentParser(
|
||||
description="Run through, with or without debug tensor watching.",
|
||||
usage=argparse.SUPPRESS)
|
||||
ap.add_argument(
|
||||
"-t",
|
||||
"--times",
|
||||
dest="times",
|
||||
type=int,
|
||||
default=1,
|
||||
help="How many Session.run() calls to proceed with.")
|
||||
ap.add_argument(
|
||||
"-n",
|
||||
"--no_debug",
|
||||
@ -97,12 +121,17 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
|
||||
type=str,
|
||||
default="",
|
||||
help="Run until a tensor in the graph passes the specified filter.")
|
||||
self._on_run_start_parsers["run"] = ap
|
||||
self._argparsers["run"] = ap
|
||||
|
||||
ap = argparse.ArgumentParser(
|
||||
description="Invoke stepper (cont, step, breakpoint, etc.)",
|
||||
usage=argparse.SUPPRESS)
|
||||
self._on_run_start_parsers["invoke_stepper"] = ap
|
||||
self._argparsers["invoke_stepper"] = ap
|
||||
|
||||
ap = argparse.ArgumentParser(
|
||||
description="Display information about this Session.run() call.",
|
||||
usage=argparse.SUPPRESS)
|
||||
self._argparsers["run_info"] = ap
|
||||
|
||||
def add_tensor_filter(self, filter_name, tensor_filter):
|
||||
"""Add a tensor filter.
|
||||
@ -151,46 +180,58 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
|
||||
self._update_run_calls_state(request.run_call_count, request.fetches,
|
||||
request.feed_dict)
|
||||
|
||||
if self._run_till_filter_pass:
|
||||
if self._active_tensor_filter:
|
||||
# If we are running till a filter passes, we just need to keep running
|
||||
# with the DEBUG_RUN option.
|
||||
return framework.OnRunStartResponse(framework.OnRunStartAction.DEBUG_RUN,
|
||||
self._get_run_debug_urls())
|
||||
|
||||
run_start_cli = curses_ui.CursesUI()
|
||||
if self._run_call_count > 1 and not self._skip_debug:
|
||||
if self._run_through_times > 0:
|
||||
# Just run through without debugging.
|
||||
return framework.OnRunStartResponse(
|
||||
framework.OnRunStartAction.NON_DEBUG_RUN, [])
|
||||
elif self._run_through_times == 0:
|
||||
# It is the run at which the run-end CLI will be launched: activate
|
||||
# debugging.
|
||||
return framework.OnRunStartResponse(
|
||||
framework.OnRunStartAction.DEBUG_RUN,
|
||||
self._get_run_debug_urls())
|
||||
|
||||
run_start_cli.register_command_handler(
|
||||
"run",
|
||||
self._on_run_start_run_handler,
|
||||
self._on_run_start_parsers["run"].format_help(),
|
||||
prefix_aliases=["r"])
|
||||
run_start_cli.register_command_handler(
|
||||
"invoke_stepper",
|
||||
self._on_run_start_step_handler,
|
||||
self._on_run_start_parsers["invoke_stepper"].format_help(),
|
||||
prefix_aliases=["s"])
|
||||
if self._run_start_response is None:
|
||||
self._prep_cli_for_run_start()
|
||||
|
||||
if self._tensor_filters:
|
||||
# Register tab completion for the filter names.
|
||||
run_start_cli.register_tab_comp_context(["run", "r"],
|
||||
list(self._tensor_filters.keys()))
|
||||
self._run_start_response = self._launch_cli(is_run_start=True)
|
||||
if self._run_through_times > 1:
|
||||
self._run_through_times -= 1
|
||||
|
||||
run_start_cli.set_help_intro(
|
||||
cli_shared.get_run_start_intro(request.run_call_count, request.fetches,
|
||||
request.feed_dict, self._tensor_filters))
|
||||
|
||||
# Create initial screen output detailing the run.
|
||||
title = "run-start: " + self._run_description
|
||||
response = run_start_cli.run_ui(
|
||||
init_command="help", title=title, title_color="blue_on_white")
|
||||
if response == debugger_cli_common.EXPLICIT_USER_EXIT:
|
||||
if self._run_start_response == debugger_cli_common.EXPLICIT_USER_EXIT:
|
||||
# Explicit user "exit" command leads to sys.exit(1).
|
||||
print(
|
||||
"Note: user exited from debugger CLI: Calling sys.exit(1).",
|
||||
file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
return response
|
||||
return self._run_start_response
|
||||
|
||||
def _prep_cli_for_run_start(self):
|
||||
"""Prepare (but not launch) the CLI for run-start."""
|
||||
|
||||
self._run_cli = curses_ui.CursesUI()
|
||||
|
||||
help_intro = debugger_cli_common.RichTextLines([])
|
||||
if self._run_call_count == 1:
|
||||
# Show logo at the onset of the first run.
|
||||
help_intro.extend(cli_shared.get_tfdbg_logo())
|
||||
help_intro.extend(debugger_cli_common.RichTextLines("Upcoming run:"))
|
||||
help_intro.extend(self._run_info)
|
||||
|
||||
self._run_cli.set_help_intro(help_intro)
|
||||
|
||||
# Create initial screen output detailing the run.
|
||||
self._title = "run-start: " + self._run_description
|
||||
self._init_command = "help"
|
||||
self._title_color = "blue_on_white"
|
||||
|
||||
def on_run_end(self, request):
|
||||
"""Overrides on-run-end callback.
|
||||
@ -216,111 +257,150 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
|
||||
debug_dump = debug_data.DebugDumpDir(
|
||||
self._dump_root, partition_graphs=partition_graphs)
|
||||
|
||||
if request.tf_error:
|
||||
help_intro = cli_shared.get_error_intro(request.tf_error)
|
||||
passed_filter = None
|
||||
if self._active_tensor_filter:
|
||||
if not debug_dump.find(
|
||||
self._tensor_filters[self._active_tensor_filter], first_n=1):
|
||||
# No dumped tensor passes the filter in this run. Clean up the dump
|
||||
# directory and move on.
|
||||
self._remove_dump_root()
|
||||
return framework.OnRunEndResponse()
|
||||
else:
|
||||
# Some dumped tensor(s) from this run passed the filter.
|
||||
passed_filter = self._active_tensor_filter
|
||||
self._active_tensor_filter = None
|
||||
|
||||
init_command = "help"
|
||||
title_color = "red_on_white"
|
||||
else:
|
||||
help_intro = None
|
||||
init_command = "lt"
|
||||
self._prep_cli_for_run_end(debug_dump, request.tf_error, passed_filter)
|
||||
|
||||
title_color = "black_on_white"
|
||||
if self._run_till_filter_pass:
|
||||
if not debug_dump.find(
|
||||
self._tensor_filters[self._run_till_filter_pass], first_n=1):
|
||||
# No dumped tensor passes the filter in this run. Clean up the dump
|
||||
# directory and move on.
|
||||
shutil.rmtree(self._dump_root)
|
||||
return framework.OnRunEndResponse()
|
||||
else:
|
||||
# Some dumped tensor(s) from this run passed the filter.
|
||||
init_command = "lt -f %s" % self._run_till_filter_pass
|
||||
title_color = "red_on_white"
|
||||
self._run_till_filter_pass = None
|
||||
self._run_start_response = self._launch_cli()
|
||||
|
||||
analyzer = analyzer_cli.DebugAnalyzer(debug_dump)
|
||||
|
||||
# Supply all the available tensor filters.
|
||||
for filter_name in self._tensor_filters:
|
||||
analyzer.add_tensor_filter(filter_name,
|
||||
self._tensor_filters[filter_name])
|
||||
|
||||
run_end_cli = curses_ui.CursesUI()
|
||||
run_end_cli.register_command_handler(
|
||||
"list_tensors",
|
||||
analyzer.list_tensors,
|
||||
analyzer.get_help("list_tensors"),
|
||||
prefix_aliases=["lt"])
|
||||
run_end_cli.register_command_handler(
|
||||
"node_info",
|
||||
analyzer.node_info,
|
||||
analyzer.get_help("node_info"),
|
||||
prefix_aliases=["ni"])
|
||||
run_end_cli.register_command_handler(
|
||||
"list_inputs",
|
||||
analyzer.list_inputs,
|
||||
analyzer.get_help("list_inputs"),
|
||||
prefix_aliases=["li"])
|
||||
run_end_cli.register_command_handler(
|
||||
"list_outputs",
|
||||
analyzer.list_outputs,
|
||||
analyzer.get_help("list_outputs"),
|
||||
prefix_aliases=["lo"])
|
||||
run_end_cli.register_command_handler(
|
||||
"print_tensor",
|
||||
analyzer.print_tensor,
|
||||
analyzer.get_help("print_tensor"),
|
||||
prefix_aliases=["pt"])
|
||||
|
||||
run_end_cli.register_command_handler(
|
||||
"run",
|
||||
self._run_end_run_command_handler,
|
||||
"Helper command for incorrectly entered run command at the run-end "
|
||||
"prompt.",
|
||||
prefix_aliases=["r"]
|
||||
)
|
||||
|
||||
# Get names of all dumped tensors.
|
||||
dumped_tensor_names = []
|
||||
for datum in debug_dump.dumped_tensor_data:
|
||||
dumped_tensor_names.append("%s:%d" %
|
||||
(datum.node_name, datum.output_slot))
|
||||
|
||||
# Tab completions for command "print_tensors".
|
||||
run_end_cli.register_tab_comp_context(["print_tensor", "pt"],
|
||||
dumped_tensor_names)
|
||||
|
||||
# Tab completion for commands "node_info", "list_inputs" and
|
||||
# "list_outputs". The list comprehension is used below because nodes()
|
||||
# output can be unicodes and they need to be converted to strs.
|
||||
run_end_cli.register_tab_comp_context(
|
||||
["node_info", "ni", "list_inputs", "li", "list_outputs", "lo"],
|
||||
[str(node_name) for node_name in debug_dump.nodes()])
|
||||
# TODO(cais): Reduce API surface area for aliases vis-a-vis tab
|
||||
# completion contexts and registered command handlers.
|
||||
|
||||
title = "run-end: " + self._run_description
|
||||
if help_intro:
|
||||
run_end_cli.set_help_intro(help_intro)
|
||||
run_end_cli.run_ui(
|
||||
init_command=init_command, title=title, title_color=title_color)
|
||||
|
||||
# Clean up the dump directory.
|
||||
shutil.rmtree(self._dump_root)
|
||||
# Clean up the dump generated by this run.
|
||||
self._remove_dump_root()
|
||||
else:
|
||||
print("No debug information to show following a non-debug run() call.")
|
||||
# No debug information to show following a non-debug run() call.
|
||||
self._run_start_response = None
|
||||
|
||||
# Return placeholder response that currently holds no additional
|
||||
# information.
|
||||
return framework.OnRunEndResponse()
|
||||
|
||||
def _on_run_start_run_handler(self, args, screen_info=None):
|
||||
def _remove_dump_root(self):
|
||||
if os.path.isdir(self._dump_root):
|
||||
shutil.rmtree(self._dump_root)
|
||||
|
||||
def _prep_cli_for_run_end(self, debug_dump, tf_error, passed_filter):
|
||||
"""Prepare (but not launch) CLI for run-end, with debug dump from the run.
|
||||
|
||||
Args:
|
||||
debug_dump: (debug_data.DebugDumpDir) The debug dump directory from this
|
||||
run.
|
||||
tf_error: (None or OpError) OpError that happened during the run() call
|
||||
(if any).
|
||||
passed_filter: (None or str) Name of the tensor filter that just passed
|
||||
and caused the preparation of this run-end CLI (if any).
|
||||
"""
|
||||
|
||||
if tf_error:
|
||||
help_intro = cli_shared.get_error_intro(tf_error)
|
||||
|
||||
self._init_command = "help"
|
||||
self._title_color = "red_on_white"
|
||||
else:
|
||||
help_intro = None
|
||||
self._init_command = "lt"
|
||||
|
||||
self._title_color = "black_on_white"
|
||||
if passed_filter is not None:
|
||||
# Some dumped tensor(s) from this run passed the filter.
|
||||
self._init_command = "lt -f %s" % passed_filter
|
||||
self._title_color = "red_on_white"
|
||||
|
||||
analyzer = analyzer_cli.DebugAnalyzer(debug_dump)
|
||||
|
||||
# Supply all the available tensor filters.
|
||||
for filter_name in self._tensor_filters:
|
||||
analyzer.add_tensor_filter(filter_name,
|
||||
self._tensor_filters[filter_name])
|
||||
|
||||
self._run_cli = curses_ui.CursesUI()
|
||||
self._run_cli.register_command_handler(
|
||||
"list_tensors",
|
||||
analyzer.list_tensors,
|
||||
analyzer.get_help("list_tensors"),
|
||||
prefix_aliases=["lt"])
|
||||
self._run_cli.register_command_handler(
|
||||
"node_info",
|
||||
analyzer.node_info,
|
||||
analyzer.get_help("node_info"),
|
||||
prefix_aliases=["ni"])
|
||||
self._run_cli.register_command_handler(
|
||||
"list_inputs",
|
||||
analyzer.list_inputs,
|
||||
analyzer.get_help("list_inputs"),
|
||||
prefix_aliases=["li"])
|
||||
self._run_cli.register_command_handler(
|
||||
"list_outputs",
|
||||
analyzer.list_outputs,
|
||||
analyzer.get_help("list_outputs"),
|
||||
prefix_aliases=["lo"])
|
||||
self._run_cli.register_command_handler(
|
||||
"print_tensor",
|
||||
analyzer.print_tensor,
|
||||
analyzer.get_help("print_tensor"),
|
||||
prefix_aliases=["pt"])
|
||||
|
||||
# Get names of all dumped tensors.
|
||||
dumped_tensor_names = []
|
||||
for datum in debug_dump.dumped_tensor_data:
|
||||
dumped_tensor_names.append("%s:%d" %
|
||||
(datum.node_name, datum.output_slot))
|
||||
|
||||
# Tab completions for command "print_tensors".
|
||||
self._run_cli.register_tab_comp_context(["print_tensor", "pt"],
|
||||
dumped_tensor_names)
|
||||
|
||||
# Tab completion for commands "node_info", "list_inputs" and
|
||||
# "list_outputs". The list comprehension is used below because nodes()
|
||||
# output can be unicodes and they need to be converted to strs.
|
||||
self._run_cli.register_tab_comp_context(
|
||||
["node_info", "ni", "list_inputs", "li", "list_outputs", "lo"],
|
||||
[str(node_name) for node_name in debug_dump.nodes()])
|
||||
# TODO(cais): Reduce API surface area for aliases vis-a-vis tab
|
||||
# completion contexts and registered command handlers.
|
||||
|
||||
self._title = "run-end: " + self._run_description
|
||||
|
||||
if help_intro:
|
||||
self._run_cli.set_help_intro(help_intro)
|
||||
|
||||
def _launch_cli(self, is_run_start=False):
|
||||
"""Launch the interactive command-line interface.
|
||||
|
||||
Args:
|
||||
is_run_start: (bool) whether this CLI launch occurs at a run-start
|
||||
callback.
|
||||
|
||||
Returns:
|
||||
The OnRunStartResponse specified by the user using the "run" command.
|
||||
"""
|
||||
|
||||
self._register_this_run_info(self._run_cli)
|
||||
response = self._run_cli.run_ui(
|
||||
init_command=self._init_command,
|
||||
title=self._title,
|
||||
title_color=self._title_color)
|
||||
|
||||
return response
|
||||
|
||||
def _run_info_handler(self, args, screen_info=None):
|
||||
return self._run_info
|
||||
|
||||
def _run_handler(self, args, screen_info=None):
|
||||
"""Command handler for "run" command during on-run-start."""
|
||||
|
||||
_ = screen_info # Currently unused.
|
||||
|
||||
parsed = self._on_run_start_parsers["run"].parse_args(args)
|
||||
parsed = self._argparsers["run"].parse_args(args)
|
||||
|
||||
if parsed.till_filter_pass:
|
||||
# For the run-till-bad-numerical-value-appears mode, use the DEBUG_RUN
|
||||
@ -328,14 +408,18 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
|
||||
# state flag of the class itself to True.
|
||||
if parsed.till_filter_pass in self._tensor_filters:
|
||||
action = framework.OnRunStartAction.DEBUG_RUN
|
||||
self._run_till_filter_pass = parsed.till_filter_pass
|
||||
self._active_tensor_filter = parsed.till_filter_pass
|
||||
else:
|
||||
# Handle invalid filter name.
|
||||
return debugger_cli_common.RichTextLines(
|
||||
["ERROR: tensor filter \"%s\" does not exist." %
|
||||
parsed.till_filter_pass])
|
||||
|
||||
if parsed.no_debug:
|
||||
self._skip_debug = parsed.no_debug
|
||||
self._run_through_times = parsed.times
|
||||
|
||||
if parsed.times > 1 or parsed.no_debug:
|
||||
# If requested -t times > 1, the very next run will be a non-debug run.
|
||||
action = framework.OnRunStartAction.NON_DEBUG_RUN
|
||||
debug_urls = []
|
||||
else:
|
||||
@ -346,6 +430,28 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
|
||||
raise debugger_cli_common.CommandLineExit(
|
||||
exit_token=framework.OnRunStartResponse(action, debug_urls))
|
||||
|
||||
def _register_this_run_info(self, curses_cli):
|
||||
curses_cli.register_command_handler(
|
||||
"run",
|
||||
self._run_handler,
|
||||
self._argparsers["run"].format_help(),
|
||||
prefix_aliases=["r"])
|
||||
curses_cli.register_command_handler(
|
||||
"invoke_stepper",
|
||||
self._on_run_start_step_handler,
|
||||
self._argparsers["invoke_stepper"].format_help(),
|
||||
prefix_aliases=["s"])
|
||||
curses_cli.register_command_handler(
|
||||
"run_info",
|
||||
self._run_info_handler,
|
||||
self._argparsers["run_info"].format_help(),
|
||||
prefix_aliases=["ri"])
|
||||
|
||||
if self._tensor_filters:
|
||||
# Register tab completion for the filter names.
|
||||
curses_cli.register_tab_comp_context(["run", "r"],
|
||||
list(self._tensor_filters.keys()))
|
||||
|
||||
def _on_run_start_step_handler(self, args, screen_info=None):
|
||||
"""Command handler for "invoke_stepper" command during on-run-start."""
|
||||
|
||||
@ -359,18 +465,6 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
|
||||
exit_token=framework.OnRunStartResponse(
|
||||
framework.OnRunStartAction.INVOKE_STEPPER, []))
|
||||
|
||||
def _run_end_run_command_handler(self, args, screen_info=None):
|
||||
"""Handler for incorrectly entered run command at run-end prompt."""
|
||||
|
||||
_ = screen_info # Currently unused.
|
||||
|
||||
return debugger_cli_common.RichTextLines([
|
||||
"ERROR: the \"run\" command is invalid for the run-end prompt.", "",
|
||||
"To proceed to the next run, ",
|
||||
" 1) exit this run-end prompt using the command \"exit\"",
|
||||
" 2) enter the command \"run\" at the next run-start prompt.",
|
||||
])
|
||||
|
||||
def _get_run_debug_urls(self):
|
||||
"""Get the debug_urls value for the current run() call.
|
||||
|
||||
@ -397,3 +491,9 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
|
||||
self._run_description = cli_shared.get_run_short_description(run_call_count,
|
||||
fetches,
|
||||
feed_dict)
|
||||
self._run_through_times -= 1
|
||||
|
||||
self._run_info = cli_shared.get_run_start_intro(run_call_count,
|
||||
fetches,
|
||||
feed_dict,
|
||||
self._tensor_filters)
|
||||
|
@ -21,18 +21,96 @@ import os
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.debug.cli import debugger_cli_common
|
||||
from tensorflow.python.debug.wrappers import local_cli_wrapper
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class LocalCLIDebuggerWrapperSessionForTest(
|
||||
local_cli_wrapper.LocalCLIDebugWrapperSession):
|
||||
"""Subclasses the wrapper class for testing.
|
||||
|
||||
Overrides its CLI-related methods for headless testing environments.
|
||||
Inserts observer variables for assertions.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
command_args_sequence,
|
||||
sess,
|
||||
dump_root=None):
|
||||
"""Constructor of the for-test subclass.
|
||||
|
||||
Args:
|
||||
command_args_sequence: (list of list of str) A list of arguments for the
|
||||
"run" command.
|
||||
sess: See the doc string of LocalCLIDebugWrapperSession.__init__.
|
||||
dump_root: See the doc string of LocalCLIDebugWrapperSession.__init__.
|
||||
"""
|
||||
|
||||
local_cli_wrapper.LocalCLIDebugWrapperSession.__init__(
|
||||
self, sess, dump_root=dump_root, log_usage=False)
|
||||
|
||||
self._command_args_sequence = command_args_sequence
|
||||
self._response_pointer = 0
|
||||
|
||||
# Observer variables.
|
||||
self.observers = {
|
||||
"debug_dumps": [],
|
||||
"tf_errors": [],
|
||||
"run_start_cli_run_numbers": [],
|
||||
"run_end_cli_run_numbers": [],
|
||||
}
|
||||
|
||||
def _prep_cli_for_run_start(self):
|
||||
pass
|
||||
|
||||
def _prep_cli_for_run_end(self, debug_dump, tf_error, passed_filter):
|
||||
self.observers["debug_dumps"].append(debug_dump)
|
||||
self.observers["tf_errors"].append(tf_error)
|
||||
|
||||
def _launch_cli(self, is_run_start=False):
|
||||
if is_run_start:
|
||||
self.observers["run_start_cli_run_numbers"].append(self._run_call_count)
|
||||
else:
|
||||
self.observers["run_end_cli_run_numbers"].append(self._run_call_count)
|
||||
|
||||
command_args = self._command_args_sequence[self._response_pointer]
|
||||
self._response_pointer += 1
|
||||
|
||||
try:
|
||||
self._run_handler(command_args)
|
||||
except debugger_cli_common.CommandLineExit as e:
|
||||
response = e.exit_token
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._tmp_dir = tempfile.mktemp()
|
||||
|
||||
self.v = tf.Variable(10.0, name="v")
|
||||
self.delta = tf.constant(1.0, name="delta")
|
||||
self.inc_v = tf.assign_add(self.v, self.delta, name="inc_v")
|
||||
|
||||
self.ph = tf.placeholder(tf.float32, name="ph")
|
||||
self.xph = tf.transpose(self.ph, name="xph")
|
||||
self.m = tf.constant(
|
||||
[[0.0, 1.0, 2.0], [-4.0, -1.0, 0.0]], dtype=tf.float32, name="m")
|
||||
self.y = tf.matmul(self.m, self.xph, name="y")
|
||||
|
||||
self.sess = tf.Session()
|
||||
|
||||
# Initialize variable.
|
||||
self.sess.run(self.v.initializer)
|
||||
|
||||
def tearDown(self):
|
||||
tf.reset_default_graph()
|
||||
if os.path.isdir(self._tmp_dir):
|
||||
shutil.rmtree(self._tmp_dir)
|
||||
|
||||
@ -68,6 +146,174 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
|
||||
local_cli_wrapper.LocalCLIDebugWrapperSession(
|
||||
session.Session(), dump_root=file_path, log_usage=False)
|
||||
|
||||
def testRunsUnderDebugMode(self):
|
||||
# Test command sequence: run; run; run;
|
||||
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
|
||||
[[], [], []], self.sess, dump_root=self._tmp_dir)
|
||||
|
||||
# run under debug mode twice.
|
||||
wrapped_sess.run(self.inc_v)
|
||||
wrapped_sess.run(self.inc_v)
|
||||
|
||||
# Verify that the assign_add op did take effect.
|
||||
self.assertAllClose(12.0, self.sess.run(self.v))
|
||||
|
||||
# Assert correct run call numbers for which the CLI has been launched at
|
||||
# run-start and run-end.
|
||||
self.assertEqual([1], wrapped_sess.observers["run_start_cli_run_numbers"])
|
||||
self.assertEqual([1, 2], wrapped_sess.observers["run_end_cli_run_numbers"])
|
||||
|
||||
# Verify that the dumps have been generated and picked up during run-end.
|
||||
self.assertEqual(2, len(wrapped_sess.observers["debug_dumps"]))
|
||||
|
||||
# Verify that the TensorFlow runtime errors are picked up and in this case,
|
||||
# they should be both None.
|
||||
self.assertEqual([None, None], wrapped_sess.observers["tf_errors"])
|
||||
|
||||
def testRunsUnderNonDebugMode(self):
|
||||
# Test command sequence: run -n; run -n; run -n;
|
||||
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
|
||||
[["-n"], ["-n"], ["-n"]],
|
||||
self.sess,
|
||||
dump_root=self._tmp_dir)
|
||||
|
||||
# run three times.
|
||||
wrapped_sess.run(self.inc_v)
|
||||
wrapped_sess.run(self.inc_v)
|
||||
wrapped_sess.run(self.inc_v)
|
||||
|
||||
self.assertAllClose(13.0, self.sess.run(self.v))
|
||||
|
||||
self.assertEqual([1, 2, 3],
|
||||
wrapped_sess.observers["run_start_cli_run_numbers"])
|
||||
self.assertEqual([], wrapped_sess.observers["run_end_cli_run_numbers"])
|
||||
|
||||
def testRunsUnderNonDebugThenDebugMode(self):
|
||||
# Test command sequence: run -n; run -n; run; run;
|
||||
# Do two NON_DEBUG_RUNs, followed by DEBUG_RUNs.
|
||||
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
|
||||
[["-n"], ["-n"], [], []],
|
||||
self.sess,
|
||||
dump_root=self._tmp_dir)
|
||||
|
||||
# run three times.
|
||||
wrapped_sess.run(self.inc_v)
|
||||
wrapped_sess.run(self.inc_v)
|
||||
wrapped_sess.run(self.inc_v)
|
||||
|
||||
self.assertAllClose(13.0, self.sess.run(self.v))
|
||||
|
||||
self.assertEqual([1, 2, 3],
|
||||
wrapped_sess.observers["run_start_cli_run_numbers"])
|
||||
|
||||
# Here, the CLI should have been launched only under the third run,
|
||||
# because the first and second runs are NON_DEBUG.
|
||||
self.assertEqual([3], wrapped_sess.observers["run_end_cli_run_numbers"])
|
||||
self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"]))
|
||||
self.assertEqual([None], wrapped_sess.observers["tf_errors"])
|
||||
|
||||
def testRunMultipleTimesWithinLimit(self):
|
||||
# Test command sequence: run -t 3; run;
|
||||
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
|
||||
[["-t", "3"], []], self.sess, dump_root=self._tmp_dir)
|
||||
|
||||
# run three times.
|
||||
wrapped_sess.run(self.inc_v)
|
||||
wrapped_sess.run(self.inc_v)
|
||||
wrapped_sess.run(self.inc_v)
|
||||
|
||||
self.assertAllClose(13.0, self.sess.run(self.v))
|
||||
|
||||
self.assertEqual([1], wrapped_sess.observers["run_start_cli_run_numbers"])
|
||||
self.assertEqual([3], wrapped_sess.observers["run_end_cli_run_numbers"])
|
||||
self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"]))
|
||||
self.assertEqual([None], wrapped_sess.observers["tf_errors"])
|
||||
|
||||
def testRunMultipleTimesOverLimit(self):
|
||||
# Test command sequence: run -t 3;
|
||||
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
|
||||
[["-t", "3"]], self.sess, dump_root=self._tmp_dir)
|
||||
|
||||
# run twice, which is less than the number of times specified by the
|
||||
# command.
|
||||
wrapped_sess.run(self.inc_v)
|
||||
wrapped_sess.run(self.inc_v)
|
||||
|
||||
self.assertAllClose(12.0, self.sess.run(self.v))
|
||||
|
||||
self.assertEqual([1], wrapped_sess.observers["run_start_cli_run_numbers"])
|
||||
self.assertEqual([], wrapped_sess.observers["run_end_cli_run_numbers"])
|
||||
self.assertEqual(0, len(wrapped_sess.observers["debug_dumps"]))
|
||||
self.assertEqual([], wrapped_sess.observers["tf_errors"])
|
||||
|
||||
def testRunMixingDebugModeAndMultpleTimes(self):
|
||||
# Test command sequence: run -n; run -t 2; run; run;
|
||||
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
|
||||
[["-n"], ["-t", "2"], [], []],
|
||||
self.sess,
|
||||
dump_root=self._tmp_dir)
|
||||
|
||||
# run four times.
|
||||
wrapped_sess.run(self.inc_v)
|
||||
wrapped_sess.run(self.inc_v)
|
||||
wrapped_sess.run(self.inc_v)
|
||||
wrapped_sess.run(self.inc_v)
|
||||
|
||||
self.assertAllClose(14.0, self.sess.run(self.v))
|
||||
|
||||
self.assertEqual([1, 2],
|
||||
wrapped_sess.observers["run_start_cli_run_numbers"])
|
||||
self.assertEqual([3, 4], wrapped_sess.observers["run_end_cli_run_numbers"])
|
||||
self.assertEqual(2, len(wrapped_sess.observers["debug_dumps"]))
|
||||
self.assertEqual([None, None], wrapped_sess.observers["tf_errors"])
|
||||
|
||||
def testRuntimeErrorShouldBeCaught(self):
|
||||
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
|
||||
[[], []], self.sess, dump_root=self._tmp_dir)
|
||||
|
||||
# Do a run that should lead to an TensorFlow runtime error.
|
||||
wrapped_sess.run(self.y, feed_dict={self.ph: [[0.0], [1.0], [2.0]]})
|
||||
|
||||
self.assertEqual([1], wrapped_sess.observers["run_start_cli_run_numbers"])
|
||||
self.assertEqual([1], wrapped_sess.observers["run_end_cli_run_numbers"])
|
||||
self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"]))
|
||||
|
||||
# Verify that the runtime error is caught by the wrapped session properly.
|
||||
self.assertEqual(1, len(wrapped_sess.observers["tf_errors"]))
|
||||
tf_error = wrapped_sess.observers["tf_errors"][0]
|
||||
self.assertEqual("y", tf_error.op.name)
|
||||
|
||||
def testRunTillFilterPassesShouldLaunchCLIAtCorrectRun(self):
|
||||
# Test command sequence:
|
||||
# run -f greater_than_twelve; run -f greater_than_twelve; run;
|
||||
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
|
||||
[["-f", "v_greater_than_twelve"], ["-f", "v_greater_than_twelve"], []],
|
||||
self.sess, dump_root=self._tmp_dir)
|
||||
|
||||
def v_greater_than_twelve(datum, tensor):
|
||||
return datum.node_name == "v" and tensor > 12.0
|
||||
|
||||
wrapped_sess.add_tensor_filter(
|
||||
"v_greater_than_twelve", v_greater_than_twelve)
|
||||
|
||||
# run five times.
|
||||
wrapped_sess.run(self.inc_v)
|
||||
wrapped_sess.run(self.inc_v)
|
||||
wrapped_sess.run(self.inc_v)
|
||||
wrapped_sess.run(self.inc_v)
|
||||
wrapped_sess.run(self.inc_v)
|
||||
|
||||
self.assertAllClose(15.0, self.sess.run(self.v))
|
||||
|
||||
self.assertEqual([1], wrapped_sess.observers["run_start_cli_run_numbers"])
|
||||
|
||||
# run-end CLI should NOT have been launched for run #2 and #3, because only
|
||||
# starting from run #4 v becomes greater than 12.0.
|
||||
self.assertEqual([4, 5], wrapped_sess.observers["run_end_cli_run_numbers"])
|
||||
|
||||
self.assertEqual(2, len(wrapped_sess.observers["debug_dumps"]))
|
||||
self.assertEqual([None, None], wrapped_sess.observers["tf_errors"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user