648 lines
24 KiB
Python
648 lines
24 KiB
Python
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Debugger Wrapper Session Consisting of a Local Curses-based CLI."""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import argparse
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
|
|
# Google-internal import(s).
|
|
from tensorflow.python.debug.cli import analyzer_cli
|
|
from tensorflow.python.debug.cli import cli_config
|
|
from tensorflow.python.debug.cli import cli_shared
|
|
from tensorflow.python.debug.cli import command_parser
|
|
from tensorflow.python.debug.cli import debugger_cli_common
|
|
from tensorflow.python.debug.cli import profile_analyzer_cli
|
|
from tensorflow.python.debug.cli import ui_factory
|
|
from tensorflow.python.debug.lib import common
|
|
from tensorflow.python.debug.lib import debug_data
|
|
from tensorflow.python.debug.wrappers import framework
|
|
from tensorflow.python.lib.io import file_io
|
|
|
|
|
|
_DUMP_ROOT_PREFIX = "tfdbg_"
|
|
|
|
|
|
# TODO(donglin) Remove use_random_config_path after b/137652456 is fixed.
|
|
class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
|
|
"""Concrete subclass of BaseDebugWrapperSession implementing a local CLI.
|
|
|
|
This class has all the methods that a `session.Session` object has, in order
|
|
to support debugging with minimal code changes. Invoking its `run()` method
|
|
will launch the command-line interface (CLI) of tfdbg.
|
|
"""
|
|
|
|
def __init__(self,
|
|
sess,
|
|
dump_root=None,
|
|
log_usage=True,
|
|
ui_type="curses",
|
|
thread_name_filter=None,
|
|
config_file_path=False):
|
|
"""Constructor of LocalCLIDebugWrapperSession.
|
|
|
|
Args:
|
|
sess: The TensorFlow `Session` object being wrapped.
|
|
dump_root: (`str`) optional path to the dump root directory. Must be a
|
|
directory that does not exist or an empty directory. If the directory
|
|
does not exist, it will be created by the debugger core during debug
|
|
`run()` calls and removed afterwards. If `None`, the debug dumps will
|
|
be at tfdbg_<random_string> under the system temp directory.
|
|
log_usage: (`bool`) whether the usage of this class is to be logged.
|
|
ui_type: (`str`) requested UI type. Currently supported:
|
|
(curses | readline)
|
|
thread_name_filter: Regular-expression white list for thread name. See
|
|
the doc of `BaseDebugWrapperSession` for details.
|
|
config_file_path: Optional override to the default configuration file
|
|
path, which is at `${HOME}/.tfdbg_config`.
|
|
|
|
Raises:
|
|
ValueError: If dump_root is an existing and non-empty directory or if
|
|
dump_root is a file.
|
|
"""
|
|
|
|
if log_usage:
|
|
pass # No logging for open-source.
|
|
|
|
framework.BaseDebugWrapperSession.__init__(
|
|
self, sess, thread_name_filter=thread_name_filter)
|
|
|
|
if not dump_root:
|
|
self._dump_root = tempfile.mktemp(prefix=_DUMP_ROOT_PREFIX)
|
|
else:
|
|
dump_root = os.path.expanduser(dump_root)
|
|
if os.path.isfile(dump_root):
|
|
raise ValueError("dump_root path points to a file: %s" % dump_root)
|
|
elif os.path.isdir(dump_root) and os.listdir(dump_root):
|
|
raise ValueError("dump_root path points to a non-empty directory: %s" %
|
|
dump_root)
|
|
|
|
self._dump_root = dump_root
|
|
|
|
self._initialize_argparsers()
|
|
|
|
# Registered tensor filters.
|
|
self._tensor_filters = {}
|
|
# Register frequently-used filter(s).
|
|
self.add_tensor_filter("has_inf_or_nan", debug_data.has_inf_or_nan)
|
|
|
|
# Below are the state variables of this wrapper object.
|
|
# _active_tensor_filter: what (if any) tensor filter is in effect. If such
|
|
# a filter is in effect, this object will call run() method of the
|
|
# underlying TensorFlow Session object until the filter passes. This is
|
|
# activated by the "-f" flag of the "run" command.
|
|
# _run_through_times: keeps track of how many times the wrapper needs to
|
|
# run through without stopping at the run-end CLI. It is activated by the
|
|
# "-t" option of the "run" command.
|
|
# _skip_debug: keeps track of whether the current run should be executed
|
|
# without debugging. It is activated by the "-n" option of the "run"
|
|
# command.
|
|
#
|
|
# _run_start_response: keeps track what OnRunStartResponse the wrapper
|
|
# should return at the next run-start callback. If this information is
|
|
# unavailable (i.e., is None), the run-start CLI will be launched to ask
|
|
# the user. This is the case, e.g., right before the first run starts.
|
|
self._active_tensor_filter = None
|
|
self._active_filter_exclude_node_names = None
|
|
self._active_tensor_filter_run_start_response = None
|
|
self._run_through_times = 1
|
|
self._skip_debug = False
|
|
self._run_start_response = None
|
|
self._is_run_start = True
|
|
self._ui_type = ui_type
|
|
self._config = None
|
|
if config_file_path:
|
|
self._config = cli_config.CLIConfig(config_file_path=config_file_path)
|
|
|
|
def _is_disk_usage_reset_each_run(self):
|
|
# The dumped tensors are all cleaned up after every Session.run
|
|
# in a command-line wrapper.
|
|
return True
|
|
|
|
def _initialize_argparsers(self):
|
|
self._argparsers = {}
|
|
ap = argparse.ArgumentParser(
|
|
description="Run through, with or without debug tensor watching.",
|
|
usage=argparse.SUPPRESS)
|
|
ap.add_argument(
|
|
"-t",
|
|
"--times",
|
|
dest="times",
|
|
type=int,
|
|
default=1,
|
|
help="How many Session.run() calls to proceed with.")
|
|
ap.add_argument(
|
|
"-n",
|
|
"--no_debug",
|
|
dest="no_debug",
|
|
action="store_true",
|
|
help="Run through without debug tensor watching.")
|
|
ap.add_argument(
|
|
"-f",
|
|
"--till_filter_pass",
|
|
dest="till_filter_pass",
|
|
type=str,
|
|
default="",
|
|
help="Run until a tensor in the graph passes the specified filter.")
|
|
ap.add_argument(
|
|
"-fenn",
|
|
"--filter_exclude_node_names",
|
|
dest="filter_exclude_node_names",
|
|
type=str,
|
|
default="",
|
|
help="When applying the tensor filter, exclude node with names "
|
|
"matching the regular expression. Applicable only if --tensor_filter "
|
|
"or -f is used.")
|
|
ap.add_argument(
|
|
"--node_name_filter",
|
|
dest="node_name_filter",
|
|
type=str,
|
|
default="",
|
|
help="Regular-expression filter for node names to be watched in the "
|
|
"run, e.g., loss, reshape.*")
|
|
ap.add_argument(
|
|
"--op_type_filter",
|
|
dest="op_type_filter",
|
|
type=str,
|
|
default="",
|
|
help="Regular-expression filter for op type to be watched in the run, "
|
|
"e.g., (MatMul|Add), Variable.*")
|
|
ap.add_argument(
|
|
"--tensor_dtype_filter",
|
|
dest="tensor_dtype_filter",
|
|
type=str,
|
|
default="",
|
|
help="Regular-expression filter for tensor dtype to be watched in the "
|
|
"run, e.g., (float32|float64), int.*")
|
|
ap.add_argument(
|
|
"-p",
|
|
"--profile",
|
|
dest="profile",
|
|
action="store_true",
|
|
help="Run and profile TensorFlow graph execution.")
|
|
self._argparsers["run"] = ap
|
|
|
|
ap = argparse.ArgumentParser(
|
|
description="Display information about this Session.run() call.",
|
|
usage=argparse.SUPPRESS)
|
|
self._argparsers["run_info"] = ap
|
|
|
|
self._argparsers["print_feed"] = command_parser.get_print_tensor_argparser(
|
|
"Print the value of a feed in feed_dict.")
|
|
|
|
def add_tensor_filter(self, filter_name, tensor_filter):
|
|
"""Add a tensor filter.
|
|
|
|
Args:
|
|
filter_name: (`str`) name of the filter.
|
|
tensor_filter: (`callable`) the filter callable. See the doc string of
|
|
`DebugDumpDir.find()` for more details about its signature.
|
|
"""
|
|
|
|
self._tensor_filters[filter_name] = tensor_filter
|
|
|
|
def on_session_init(self, request):
|
|
"""Overrides on-session-init callback.
|
|
|
|
Args:
|
|
request: An instance of `OnSessionInitRequest`.
|
|
|
|
Returns:
|
|
An instance of `OnSessionInitResponse`.
|
|
"""
|
|
|
|
return framework.OnSessionInitResponse(
|
|
framework.OnSessionInitAction.PROCEED)
|
|
|
|
def on_run_start(self, request):
|
|
"""Overrides on-run-start callback.
|
|
|
|
Args:
|
|
request: An instance of `OnRunStartRequest`.
|
|
|
|
Returns:
|
|
An instance of `OnRunStartResponse`.
|
|
"""
|
|
self._is_run_start = True
|
|
self._update_run_calls_state(
|
|
request.run_call_count, request.fetches, request.feed_dict,
|
|
is_callable_runner=request.is_callable_runner)
|
|
|
|
if self._active_tensor_filter:
|
|
# If we are running until a filter passes, we just need to keep running
|
|
# with the previous `OnRunStartResponse`.
|
|
return self._active_tensor_filter_run_start_response
|
|
|
|
self._exit_if_requested_by_user()
|
|
|
|
if self._run_call_count > 1 and not self._skip_debug:
|
|
if self._run_through_times > 0:
|
|
# Just run through without debugging.
|
|
return framework.OnRunStartResponse(
|
|
framework.OnRunStartAction.NON_DEBUG_RUN, [])
|
|
elif self._run_through_times == 0:
|
|
# It is the run at which the run-end CLI will be launched: activate
|
|
# debugging.
|
|
return (self._run_start_response or
|
|
framework.OnRunStartResponse(
|
|
framework.OnRunStartAction.DEBUG_RUN,
|
|
self._get_run_debug_urls()))
|
|
|
|
if self._run_start_response is None:
|
|
self._prep_cli_for_run_start()
|
|
|
|
self._run_start_response = self._launch_cli()
|
|
if self._active_tensor_filter:
|
|
self._active_tensor_filter_run_start_response = self._run_start_response
|
|
if self._run_through_times > 1:
|
|
self._run_through_times -= 1
|
|
|
|
self._exit_if_requested_by_user()
|
|
return self._run_start_response
|
|
|
|
def _exit_if_requested_by_user(self):
|
|
if self._run_start_response == debugger_cli_common.EXPLICIT_USER_EXIT:
|
|
# Explicit user "exit" command leads to sys.exit(1).
|
|
print(
|
|
"Note: user exited from debugger CLI: Calling sys.exit(1).",
|
|
file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
def _prep_cli_for_run_start(self):
|
|
"""Prepare (but not launch) the CLI for run-start."""
|
|
self._run_cli = ui_factory.get_ui(self._ui_type, config=self._config)
|
|
|
|
help_intro = debugger_cli_common.RichTextLines([])
|
|
if self._run_call_count == 1:
|
|
# Show logo at the onset of the first run.
|
|
help_intro.extend(cli_shared.get_tfdbg_logo())
|
|
help_intro.extend(debugger_cli_common.get_tensorflow_version_lines())
|
|
help_intro.extend(debugger_cli_common.RichTextLines("Upcoming run:"))
|
|
help_intro.extend(self._run_info)
|
|
|
|
self._run_cli.set_help_intro(help_intro)
|
|
|
|
# Create initial screen output detailing the run.
|
|
self._title = "run-start: " + self._run_description
|
|
self._init_command = "run_info"
|
|
self._title_color = "blue_on_white"
|
|
|
|
def on_run_end(self, request):
|
|
"""Overrides on-run-end callback.
|
|
|
|
Actions taken:
|
|
1) Load the debug dump.
|
|
2) Bring up the Analyzer CLI.
|
|
|
|
Args:
|
|
request: An instance of OnSessionInitRequest.
|
|
|
|
Returns:
|
|
An instance of OnSessionInitResponse.
|
|
"""
|
|
|
|
self._is_run_start = False
|
|
if request.performed_action == framework.OnRunStartAction.DEBUG_RUN:
|
|
partition_graphs = None
|
|
if request.run_metadata and request.run_metadata.partition_graphs:
|
|
partition_graphs = request.run_metadata.partition_graphs
|
|
elif request.client_graph_def:
|
|
partition_graphs = [request.client_graph_def]
|
|
|
|
if request.tf_error and not os.path.isdir(self._dump_root):
|
|
# It is possible that the dump root may not exist due to errors that
|
|
# have occurred prior to graph execution (e.g., invalid device
|
|
# assignments), in which case we will just raise the exception as the
|
|
# unwrapped Session does.
|
|
raise request.tf_error
|
|
|
|
debug_dump = debug_data.DebugDumpDir(
|
|
self._dump_root, partition_graphs=partition_graphs)
|
|
debug_dump.set_python_graph(self._sess.graph)
|
|
|
|
passed_filter = None
|
|
passed_filter_exclude_node_names = None
|
|
if self._active_tensor_filter:
|
|
if not debug_dump.find(
|
|
self._tensor_filters[self._active_tensor_filter], first_n=1,
|
|
exclude_node_names=self._active_filter_exclude_node_names):
|
|
# No dumped tensor passes the filter in this run. Clean up the dump
|
|
# directory and move on.
|
|
self._remove_dump_root()
|
|
return framework.OnRunEndResponse()
|
|
else:
|
|
# Some dumped tensor(s) from this run passed the filter.
|
|
passed_filter = self._active_tensor_filter
|
|
passed_filter_exclude_node_names = (
|
|
self._active_filter_exclude_node_names)
|
|
self._active_tensor_filter = None
|
|
self._active_filter_exclude_node_names = None
|
|
|
|
self._prep_debug_cli_for_run_end(
|
|
debug_dump, request.tf_error, passed_filter,
|
|
passed_filter_exclude_node_names)
|
|
|
|
self._run_start_response = self._launch_cli()
|
|
|
|
# Clean up the dump generated by this run.
|
|
self._remove_dump_root()
|
|
elif request.performed_action == framework.OnRunStartAction.PROFILE_RUN:
|
|
self._prep_profile_cli_for_run_end(self._sess.graph, request.run_metadata)
|
|
self._run_start_response = self._launch_cli()
|
|
else:
|
|
# No debug information to show following a non-debug run() call.
|
|
self._run_start_response = None
|
|
|
|
# Return placeholder response that currently holds no additional
|
|
# information.
|
|
return framework.OnRunEndResponse()
|
|
|
|
def _remove_dump_root(self):
|
|
if os.path.isdir(self._dump_root):
|
|
file_io.delete_recursively(self._dump_root)
|
|
|
|
def _prep_debug_cli_for_run_end(self,
|
|
debug_dump,
|
|
tf_error,
|
|
passed_filter,
|
|
passed_filter_exclude_node_names):
|
|
"""Prepare (but not launch) CLI for run-end, with debug dump from the run.
|
|
|
|
Args:
|
|
debug_dump: (debug_data.DebugDumpDir) The debug dump directory from this
|
|
run.
|
|
tf_error: (None or OpError) OpError that happened during the run() call
|
|
(if any).
|
|
passed_filter: (None or str) Name of the tensor filter that just passed
|
|
and caused the preparation of this run-end CLI (if any).
|
|
passed_filter_exclude_node_names: (None or str) Regular expression used
|
|
with the tensor filter to exclude ops with names matching the regular
|
|
expression.
|
|
"""
|
|
|
|
if tf_error:
|
|
help_intro = cli_shared.get_error_intro(tf_error)
|
|
|
|
self._init_command = "help"
|
|
self._title_color = "red_on_white"
|
|
else:
|
|
help_intro = None
|
|
self._init_command = "lt"
|
|
|
|
self._title_color = "black_on_white"
|
|
if passed_filter is not None:
|
|
# Some dumped tensor(s) from this run passed the filter.
|
|
self._init_command = "lt -f %s" % passed_filter
|
|
if passed_filter_exclude_node_names:
|
|
self._init_command += (" --filter_exclude_node_names %s" %
|
|
passed_filter_exclude_node_names)
|
|
self._title_color = "red_on_white"
|
|
|
|
self._run_cli = analyzer_cli.create_analyzer_ui(
|
|
debug_dump,
|
|
self._tensor_filters,
|
|
ui_type=self._ui_type,
|
|
on_ui_exit=self._remove_dump_root,
|
|
config=self._config)
|
|
|
|
# Get names of all dumped tensors.
|
|
dumped_tensor_names = []
|
|
for datum in debug_dump.dumped_tensor_data:
|
|
dumped_tensor_names.append("%s:%d" %
|
|
(datum.node_name, datum.output_slot))
|
|
|
|
# Tab completions for command "print_tensors".
|
|
self._run_cli.register_tab_comp_context(["print_tensor", "pt"],
|
|
dumped_tensor_names)
|
|
|
|
# Tab completion for commands "node_info", "list_inputs" and
|
|
# "list_outputs". The list comprehension is used below because nodes()
|
|
# output can be unicodes and they need to be converted to strs.
|
|
self._run_cli.register_tab_comp_context(
|
|
["node_info", "ni", "list_inputs", "li", "list_outputs", "lo"],
|
|
[str(node_name) for node_name in debug_dump.nodes()])
|
|
# TODO(cais): Reduce API surface area for aliases vis-a-vis tab
|
|
# completion contexts and registered command handlers.
|
|
|
|
self._title = "run-end: " + self._run_description
|
|
|
|
if help_intro:
|
|
self._run_cli.set_help_intro(help_intro)
|
|
|
|
def _prep_profile_cli_for_run_end(self, py_graph, run_metadata):
|
|
self._init_command = "lp"
|
|
self._run_cli = profile_analyzer_cli.create_profiler_ui(
|
|
py_graph, run_metadata, ui_type=self._ui_type,
|
|
config=self._run_cli.config)
|
|
self._title = "run-end (profiler mode): " + self._run_description
|
|
|
|
def _launch_cli(self):
|
|
"""Launch the interactive command-line interface.
|
|
|
|
Returns:
|
|
The OnRunStartResponse specified by the user using the "run" command.
|
|
"""
|
|
|
|
self._register_this_run_info(self._run_cli)
|
|
response = self._run_cli.run_ui(
|
|
init_command=self._init_command,
|
|
title=self._title,
|
|
title_color=self._title_color)
|
|
|
|
return response
|
|
|
|
def _run_info_handler(self, args, screen_info=None):
|
|
output = debugger_cli_common.RichTextLines([])
|
|
|
|
if self._run_call_count == 1:
|
|
output.extend(cli_shared.get_tfdbg_logo())
|
|
output.extend(debugger_cli_common.get_tensorflow_version_lines())
|
|
output.extend(self._run_info)
|
|
|
|
if (not self._is_run_start and
|
|
debugger_cli_common.MAIN_MENU_KEY in output.annotations):
|
|
menu = output.annotations[debugger_cli_common.MAIN_MENU_KEY]
|
|
if "list_tensors" not in menu.captions():
|
|
menu.insert(
|
|
0, debugger_cli_common.MenuItem("list_tensors", "list_tensors"))
|
|
|
|
return output
|
|
|
|
def _print_feed_handler(self, args, screen_info=None):
|
|
np_printoptions = cli_shared.numpy_printoptions_from_screen_info(
|
|
screen_info)
|
|
|
|
if not self._feed_dict:
|
|
return cli_shared.error(
|
|
"The feed_dict of the current run is None or empty.")
|
|
|
|
parsed = self._argparsers["print_feed"].parse_args(args)
|
|
tensor_name, tensor_slicing = (
|
|
command_parser.parse_tensor_name_with_slicing(parsed.tensor_name))
|
|
|
|
feed_key = None
|
|
feed_value = None
|
|
for key in self._feed_dict:
|
|
key_name = common.get_graph_element_name(key)
|
|
if key_name == tensor_name:
|
|
feed_key = key_name
|
|
feed_value = self._feed_dict[key]
|
|
break
|
|
|
|
if feed_key is None:
|
|
return cli_shared.error(
|
|
"The feed_dict of the current run does not contain the key %s" %
|
|
tensor_name)
|
|
else:
|
|
return cli_shared.format_tensor(
|
|
feed_value,
|
|
feed_key + " (feed)",
|
|
np_printoptions,
|
|
print_all=parsed.print_all,
|
|
tensor_slicing=tensor_slicing,
|
|
highlight_options=cli_shared.parse_ranges_highlight(parsed.ranges),
|
|
include_numeric_summary=parsed.numeric_summary)
|
|
|
|
def _run_handler(self, args, screen_info=None):
|
|
"""Command handler for "run" command during on-run-start."""
|
|
|
|
del screen_info # Currently unused.
|
|
|
|
parsed = self._argparsers["run"].parse_args(args)
|
|
parsed.node_name_filter = parsed.node_name_filter or None
|
|
parsed.op_type_filter = parsed.op_type_filter or None
|
|
parsed.tensor_dtype_filter = parsed.tensor_dtype_filter or None
|
|
|
|
if parsed.filter_exclude_node_names and not parsed.till_filter_pass:
|
|
raise ValueError(
|
|
"The --filter_exclude_node_names (or -feon) flag is valid only if "
|
|
"the --till_filter_pass (or -f) flag is used.")
|
|
|
|
if parsed.profile:
|
|
raise debugger_cli_common.CommandLineExit(
|
|
exit_token=framework.OnRunStartResponse(
|
|
framework.OnRunStartAction.PROFILE_RUN, []))
|
|
|
|
self._skip_debug = parsed.no_debug
|
|
self._run_through_times = parsed.times
|
|
|
|
if parsed.times > 1 or parsed.no_debug:
|
|
# If requested -t times > 1, the very next run will be a non-debug run.
|
|
action = framework.OnRunStartAction.NON_DEBUG_RUN
|
|
debug_urls = []
|
|
else:
|
|
action = framework.OnRunStartAction.DEBUG_RUN
|
|
debug_urls = self._get_run_debug_urls()
|
|
run_start_response = framework.OnRunStartResponse(
|
|
action,
|
|
debug_urls,
|
|
node_name_regex_allowlist=parsed.node_name_filter,
|
|
op_type_regex_allowlist=parsed.op_type_filter,
|
|
tensor_dtype_regex_allowlist=parsed.tensor_dtype_filter)
|
|
|
|
if parsed.till_filter_pass:
|
|
# For the run-till-filter-pass (run -f) mode, use the DEBUG_RUN
|
|
# option to access the intermediate tensors, and set the corresponding
|
|
# state flag of the class itself to True.
|
|
if parsed.till_filter_pass in self._tensor_filters:
|
|
action = framework.OnRunStartAction.DEBUG_RUN
|
|
self._active_tensor_filter = parsed.till_filter_pass
|
|
self._active_filter_exclude_node_names = (
|
|
parsed.filter_exclude_node_names)
|
|
self._active_tensor_filter_run_start_response = run_start_response
|
|
else:
|
|
# Handle invalid filter name.
|
|
return debugger_cli_common.RichTextLines(
|
|
["ERROR: tensor filter \"%s\" does not exist." %
|
|
parsed.till_filter_pass])
|
|
|
|
# Raise CommandLineExit exception to cause the CLI to exit.
|
|
raise debugger_cli_common.CommandLineExit(exit_token=run_start_response)
|
|
|
|
def _register_this_run_info(self, curses_cli):
|
|
curses_cli.register_command_handler(
|
|
"run",
|
|
self._run_handler,
|
|
self._argparsers["run"].format_help(),
|
|
prefix_aliases=["r"])
|
|
curses_cli.register_command_handler(
|
|
"run_info",
|
|
self._run_info_handler,
|
|
self._argparsers["run_info"].format_help(),
|
|
prefix_aliases=["ri"])
|
|
curses_cli.register_command_handler(
|
|
"print_feed",
|
|
self._print_feed_handler,
|
|
self._argparsers["print_feed"].format_help(),
|
|
prefix_aliases=["pf"])
|
|
|
|
if self._tensor_filters:
|
|
# Register tab completion for the filter names.
|
|
curses_cli.register_tab_comp_context(["run", "r"],
|
|
list(self._tensor_filters.keys()))
|
|
if self._feed_dict and hasattr(self._feed_dict, "keys"):
|
|
# Register tab completion for feed_dict keys.
|
|
feed_keys = [common.get_graph_element_name(key)
|
|
for key in self._feed_dict.keys()]
|
|
curses_cli.register_tab_comp_context(["print_feed", "pf"], feed_keys)
|
|
|
|
def _get_run_debug_urls(self):
|
|
"""Get the debug_urls value for the current run() call.
|
|
|
|
Returns:
|
|
debug_urls: (list of str) Debug URLs for the current run() call.
|
|
Currently, the list consists of only one URL that is a file:// URL.
|
|
"""
|
|
|
|
return ["file://" + self._dump_root]
|
|
|
|
def _update_run_calls_state(self,
|
|
run_call_count,
|
|
fetches,
|
|
feed_dict,
|
|
is_callable_runner=False):
|
|
"""Update the internal state with regard to run() call history.
|
|
|
|
Args:
|
|
run_call_count: (int) Number of run() calls that have occurred.
|
|
fetches: a node/tensor or a list of node/tensor that are the fetches of
|
|
the run() call. This is the same as the fetches argument to the run()
|
|
call.
|
|
feed_dict: None of a dict. This is the feed_dict argument to the run()
|
|
call.
|
|
is_callable_runner: (bool) whether a runner returned by
|
|
Session.make_callable is being run.
|
|
"""
|
|
|
|
self._run_call_count = run_call_count
|
|
self._feed_dict = feed_dict
|
|
self._run_description = cli_shared.get_run_short_description(
|
|
run_call_count,
|
|
fetches,
|
|
feed_dict,
|
|
is_callable_runner=is_callable_runner)
|
|
self._run_through_times -= 1
|
|
|
|
self._run_info = cli_shared.get_run_start_intro(
|
|
run_call_count,
|
|
fetches,
|
|
feed_dict,
|
|
self._tensor_filters,
|
|
is_callable_runner=is_callable_runner)
|