tfdbg: initial check-in of stepper CLI

This check-in implements the following commands (and aliases) in the stepper CLI:

* cont (ct, c)
* step (st, s)
* list_sorted_nodes (lt)
* print_tensor (pt)
* inject_value (inject, override)

To activate the stepper CLI, enter command "step" or "s" at the run-start CLI.

Not implemented yet:

* Documentation in g3doc/how_tos/debugger will be submitted in a later CL.
* Commands for inspecting node information and graph structure in the stepper CLI, including ni, li and lo.
Change: 141518037
This commit is contained in:
Shanqing Cai 2016-12-08 20:47:26 -08:00 committed by TensorFlower Gardener
parent 4a29f9f6a0
commit f37c468252
14 changed files with 1405 additions and 145 deletions

View File

@ -48,6 +48,7 @@ py_library(
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":debug_data", ":debug_data",
"//tensorflow/python:data_flow_ops",
], ],
) )
@ -57,6 +58,7 @@ py_library(
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":debug_utils", ":debug_utils",
":stepper",
"//tensorflow/python:session", "//tensorflow/python:session",
], ],
) )
@ -85,7 +87,9 @@ py_library(
srcs = ["cli/cli_shared.py"], srcs = ["cli/cli_shared.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":command_parser",
":debugger_cli_common", ":debugger_cli_common",
":tensor_format",
"//tensorflow/python:framework", "//tensorflow/python:framework",
"//tensorflow/python:variables", "//tensorflow/python:variables",
], ],
@ -96,6 +100,7 @@ py_library(
srcs = ["cli/analyzer_cli.py"], srcs = ["cli/analyzer_cli.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":cli_shared",
":command_parser", ":command_parser",
":debug_data", ":debug_data",
":debugger_cli_common", ":debugger_cli_common",
@ -103,6 +108,19 @@ py_library(
], ],
) )
py_library(
name = "stepper_cli",
srcs = ["cli/stepper_cli.py"],
srcs_version = "PY2AND3",
deps = [
":cli_shared",
":command_parser",
":debugger_cli_common",
":stepper",
":tensor_format",
],
)
py_library( py_library(
name = "curses_ui", name = "curses_ui",
srcs = ["cli/curses_ui.py"], srcs = ["cli/curses_ui.py"],
@ -125,6 +143,7 @@ py_library(
":debug_data", ":debug_data",
":debugger_cli_common", ":debugger_cli_common",
":framework", ":framework",
":stepper_cli",
"//tensorflow/python:session", "//tensorflow/python:session",
], ],
) )
@ -135,6 +154,7 @@ py_library(
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":local_cli_wrapper", ":local_cli_wrapper",
":stepper",
"//tensorflow/python:session", "//tensorflow/python:session",
], ],
) )
@ -223,6 +243,7 @@ py_test(
deps = [ deps = [
":debug_data", ":debug_data",
":framework", ":framework",
":stepper",
"//tensorflow/python:framework", "//tensorflow/python:framework",
"//tensorflow/python:framework_test_lib", "//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
@ -341,6 +362,22 @@ cuda_py_test(
], ],
) )
cuda_py_test(
name = "stepper_cli_test",
size = "small",
srcs = [
"cli/stepper_cli_test.py",
],
additional_deps = [
":stepper",
":stepper_cli",
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:session",
],
)
py_test( py_test(
name = "local_cli_wrapper_test", name = "local_cli_wrapper_test",
size = "small", size = "small",

View File

@ -27,13 +27,12 @@ import argparse
import copy import copy
import re import re
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.debug import debug_data from tensorflow.python.debug import debug_data
from tensorflow.python.debug.cli import cli_shared
from tensorflow.python.debug.cli import command_parser from tensorflow.python.debug.cli import command_parser
from tensorflow.python.debug.cli import debugger_cli_common from tensorflow.python.debug.cli import debugger_cli_common
from tensorflow.python.debug.cli import tensor_format
# String constants for the depth-dependent hanging indent at the beginning # String constants for the depth-dependent hanging indent at the beginning
@ -69,10 +68,6 @@ class DebugAnalyzer(object):
# Argument parsers for command handlers. # Argument parsers for command handlers.
self._arg_parsers = {} self._arg_parsers = {}
# Default threshold number of elements above which ellipses will be used
# when printing the value of the tensor.
self.default_ndarray_display_threshold = 2000
# Parser for list_tensors. # Parser for list_tensors.
ap = argparse.ArgumentParser( ap = argparse.ArgumentParser(
description="List dumped intermediate tensors.", description="List dumped intermediate tensors.",
@ -222,11 +217,6 @@ class DebugAnalyzer(object):
# TODO(cais): Implement list_nodes. # TODO(cais): Implement list_nodes.
def _error(self, msg):
full_msg = "ERROR: " + msg
return debugger_cli_common.RichTextLines(
[full_msg], font_attr_segs={0: [(0, len(full_msg), "red")]})
def add_tensor_filter(self, filter_name, filter_callable): def add_tensor_filter(self, filter_name, filter_callable):
"""Add a tensor filter. """Add a tensor filter.
@ -329,7 +319,7 @@ class DebugAnalyzer(object):
try: try:
filter_callable = self.get_tensor_filter(parsed.tensor_filter) filter_callable = self.get_tensor_filter(parsed.tensor_filter)
except ValueError: except ValueError:
return self._error( return cli_shared.error(
"There is no tensor filter named \"%s\"." % parsed.tensor_filter) "There is no tensor filter named \"%s\"." % parsed.tensor_filter)
data_to_show = self._debug_dump.find(filter_callable) data_to_show = self._debug_dump.find(filter_callable)
@ -392,7 +382,7 @@ class DebugAnalyzer(object):
parsed.node_name) parsed.node_name)
if not self._debug_dump.node_exists(node_name): if not self._debug_dump.node_exists(node_name):
return self._error( return cli_shared.error(
"There is no node named \"%s\" in the partition graphs" % node_name) "There is no node named \"%s\" in the partition graphs" % node_name)
# TODO(cais): Provide UI glossary feature to explain to users what the # TODO(cais): Provide UI glossary feature to explain to users what the
@ -487,24 +477,19 @@ class DebugAnalyzer(object):
np_printoptions = {} np_printoptions = {}
# Determine if any range-highlighting is required. # Determine if any range-highlighting is required.
highlight_options = self._parse_ranges_highlight(parsed.ranges) highlight_options = cli_shared.parse_ranges_highlight(parsed.ranges)
# Determine if there parsed.tensor_name contains any indexing (slicing). tensor_name, tensor_slicing = (
if parsed.tensor_name.count("[") == 1 and parsed.tensor_name.endswith("]"): command_parser.parse_tensor_name_with_slicing(parsed.tensor_name))
tensor_name = parsed.tensor_name[:parsed.tensor_name.index("[")]
tensor_slicing = parsed.tensor_name[parsed.tensor_name.index("["):]
else:
tensor_name = parsed.tensor_name
tensor_slicing = ""
node_name, output_slot = debug_data.parse_node_or_tensor_name(tensor_name) node_name, output_slot = debug_data.parse_node_or_tensor_name(tensor_name)
if output_slot is None: if output_slot is None:
return self._error("\"%s\" is not a valid tensor name" % return cli_shared.error("\"%s\" is not a valid tensor name" %
parsed.tensor_name) parsed.tensor_name)
if (self._debug_dump.loaded_partition_graphs and if (self._debug_dump.loaded_partition_graphs and
not self._debug_dump.node_exists(node_name)): not self._debug_dump.node_exists(node_name)):
return self._error( return cli_shared.error(
"Node \"%s\" does not exist in partition graphs" % node_name) "Node \"%s\" does not exist in partition graphs" % node_name)
watch_keys = self._debug_dump.debug_watch_keys(node_name) watch_keys = self._debug_dump.debug_watch_keys(node_name)
@ -520,12 +505,12 @@ class DebugAnalyzer(object):
if not matching_data: if not matching_data:
# No dump for this tensor. # No dump for this tensor.
return self._error( return cli_shared.error(
"Tensor \"%s\" did not generate any dumps." % parsed.tensor_name) "Tensor \"%s\" did not generate any dumps." % parsed.tensor_name)
elif len(matching_data) == 1: elif len(matching_data) == 1:
# There is only one dump for this tensor. # There is only one dump for this tensor.
if parsed.number <= 0: if parsed.number <= 0:
return self._format_tensor( return cli_shared.format_tensor(
matching_data[0].get_tensor(), matching_data[0].get_tensor(),
matching_data[0].watch_key, matching_data[0].watch_key,
np_printoptions, np_printoptions,
@ -533,7 +518,7 @@ class DebugAnalyzer(object):
tensor_slicing=tensor_slicing, tensor_slicing=tensor_slicing,
highlight_options=highlight_options) highlight_options=highlight_options)
else: else:
return self._error( return cli_shared.error(
"Invalid number (%d) for tensor %s, which generated one dump." % "Invalid number (%d) for tensor %s, which generated one dump." %
(parsed.number, parsed.tensor_name)) (parsed.number, parsed.tensor_name))
else: else:
@ -556,12 +541,12 @@ class DebugAnalyzer(object):
return debugger_cli_common.RichTextLines(lines) return debugger_cli_common.RichTextLines(lines)
elif parsed.number >= len(matching_data): elif parsed.number >= len(matching_data):
return self._error( return cli_shared.error(
"Specified number (%d) exceeds the number of available dumps " "Specified number (%d) exceeds the number of available dumps "
"(%d) for tensor %s" % "(%d) for tensor %s" %
(parsed.number, len(matching_data), parsed.tensor_name)) (parsed.number, len(matching_data), parsed.tensor_name))
else: else:
return self._format_tensor( return cli_shared.format_tensor(
matching_data[parsed.number].get_tensor(), matching_data[parsed.number].get_tensor(),
matching_data[parsed.number].watch_key + " (dump #%d)" % matching_data[parsed.number].watch_key + " (dump #%d)" %
parsed.number, parsed.number,
@ -570,90 +555,6 @@ class DebugAnalyzer(object):
tensor_slicing=tensor_slicing, tensor_slicing=tensor_slicing,
highlight_options=highlight_options) highlight_options=highlight_options)
def _parse_ranges_highlight(self, ranges_string):
"""Process ranges highlight string.
Args:
ranges_string: (str) A string representing a numerical range of a list of
numerical ranges. See the help info of the -r flag of the print_tensor
command for more details.
Returns:
An instance of tensor_format.HighlightOptions, if range_string is a valid
representation of a range or a list of ranges.
"""
ranges = None
def ranges_filter(x):
r = np.zeros(x.shape, dtype=bool)
for rng_start, rng_end in ranges:
r = np.logical_or(r, np.logical_and(x >= rng_start, x <= rng_end))
return r
if ranges_string:
ranges = command_parser.parse_ranges(ranges_string)
return tensor_format.HighlightOptions(
ranges_filter, description=ranges_string)
else:
return None
def _format_tensor(self,
tensor,
watch_key,
np_printoptions,
print_all=False,
tensor_slicing=None,
highlight_options=None):
"""Generate formatted str to represent a tensor or its slices.
Args:
tensor: (numpy ndarray) The tensor value.
watch_key: (str) Tensor debug watch key.
np_printoptions: (dict) Numpy tensor formatting options.
print_all: (bool) Whether the tensor is to be displayed in its entirety,
instead of printing ellipses, even if its number of elements exceeds
the default numpy display threshold.
(Note: Even if this is set to true, the screen output can still be cut
off by the UI frontend if it consist of more lines than the frontend
can handle.)
tensor_slicing: (str or None) Slicing of the tensor, e.g., "[:, 1]". If
None, no slicing will be performed on the tensor.
highlight_options: (tensor_format.HighlightOptions) options to highlight
elements of the tensor. See the doc of tensor_format.format_tensor()
for more details.
Returns:
(str) Formatted str representing the (potentially sliced) tensor.
Raises:
ValueError: If tehsor_slicing is not a valid numpy ndarray slicing str.
"""
if tensor_slicing:
# Validate the indexing.
if not command_parser.validate_slicing_string(tensor_slicing):
raise ValueError("Invalid tensor-slicing string.")
value = eval("tensor" + tensor_slicing) # pylint: disable=eval-used
sliced_name = watch_key + tensor_slicing
else:
value = tensor
sliced_name = watch_key
if print_all:
np_printoptions["threshold"] = value.size
else:
np_printoptions["threshold"] = self.default_ndarray_display_threshold
return tensor_format.format_tensor(
value,
sliced_name,
include_metadata=True,
np_printoptions=np_printoptions,
highlight_options=highlight_options)
def list_outputs(self, args, screen_info=None): def list_outputs(self, args, screen_info=None):
"""Command handler for inputs. """Command handler for inputs.
@ -729,7 +630,7 @@ class DebugAnalyzer(object):
# Check if node exists. # Check if node exists.
if not self._debug_dump.node_exists(node_name): if not self._debug_dump.node_exists(node_name):
return self._error( return cli_shared.error(
"There is no node named \"%s\" in the partition graphs" % node_name) "There is no node named \"%s\" in the partition graphs" % node_name)
if recursive: if recursive:

View File

@ -17,13 +17,116 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np
import six import six
from tensorflow.python.debug.cli import command_parser
from tensorflow.python.debug.cli import debugger_cli_common from tensorflow.python.debug.cli import debugger_cli_common
from tensorflow.python.debug.cli import tensor_format
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
# Default threshold number of elements above which ellipses will be used
# when printing the value of the tensor.
DEFAULT_NDARRAY_DISPLAY_THRESHOLD = 2000
def parse_ranges_highlight(ranges_string):
"""Process ranges highlight string.
Args:
ranges_string: (str) A string representing a numerical range of a list of
numerical ranges. See the help info of the -r flag of the print_tensor
command for more details.
Returns:
An instance of tensor_format.HighlightOptions, if range_string is a valid
representation of a range or a list of ranges.
"""
ranges = None
def ranges_filter(x):
r = np.zeros(x.shape, dtype=bool)
for range_start, range_end in ranges:
r = np.logical_or(r, np.logical_and(x >= range_start, x <= range_end))
return r
if ranges_string:
ranges = command_parser.parse_ranges(ranges_string)
return tensor_format.HighlightOptions(
ranges_filter, description=ranges_string)
else:
return None
def format_tensor(tensor,
tensor_name,
np_printoptions,
print_all=False,
tensor_slicing=None,
highlight_options=None):
"""Generate formatted str to represent a tensor or its slices.
Args:
tensor: (numpy ndarray) The tensor value.
tensor_name: (str) Name of the tensor, e.g., the tensor's debug watch key.
np_printoptions: (dict) Numpy tensor formatting options.
print_all: (bool) Whether the tensor is to be displayed in its entirety,
instead of printing ellipses, even if its number of elements exceeds
the default numpy display threshold.
(Note: Even if this is set to true, the screen output can still be cut
off by the UI frontend if it consist of more lines than the frontend
can handle.)
tensor_slicing: (str or None) Slicing of the tensor, e.g., "[:, 1]". If
None, no slicing will be performed on the tensor.
highlight_options: (tensor_format.HighlightOptions) options to highlight
elements of the tensor. See the doc of tensor_format.format_tensor()
for more details.
Returns:
(str) Formatted str representing the (potentially sliced) tensor.
"""
if tensor_slicing:
# Validate the indexing.
value = command_parser.evaluate_tensor_slice(tensor, tensor_slicing)
sliced_name = tensor_name + tensor_slicing
else:
value = tensor
sliced_name = tensor_name
if print_all:
np_printoptions["threshold"] = value.size
else:
np_printoptions["threshold"] = DEFAULT_NDARRAY_DISPLAY_THRESHOLD
return tensor_format.format_tensor(
value,
sliced_name,
include_metadata=True,
np_printoptions=np_printoptions,
highlight_options=highlight_options)
def error(msg):
"""Generate a RichTextLines output for error.
Args:
msg: (str) The error message.
Returns:
(debugger_cli_common.RichTextLines) A representation of the error message
for screen output.
"""
full_msg = "ERROR: " + msg
return debugger_cli_common.RichTextLines(
[full_msg], font_attr_segs={0: [(0, len(full_msg), "red")]})
def _get_fetch_name(fetch): def _get_fetch_name(fetch):
"""Obtain the name or string representation of a fetch. """Obtain the name or string representation of a fetch.
@ -159,6 +262,12 @@ def get_run_start_intro(run_call_count,
"run -f <filter_name>", "run -f <filter_name>",
"Keep executing run() calls until a dumped tensor passes a given, " "Keep executing run() calls until a dumped tensor passes a given, "
"registered filter (conditional breakpoint mode)")) "registered filter (conditional breakpoint mode)"))
out.extend(
_recommend_command(
"invoke_stepper",
"Use the node-stepper interface, which allows you to interactively "
"step through nodes involved in the graph run() call and "
"inspect/modify their values"))
more_font_attr_segs = {} more_font_attr_segs = {}
more_lines = [" Registered filter(s):"] more_lines = [" Registered filter(s):"]

View File

@ -188,8 +188,8 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
# Verify the listed names of the tensor filters. # Verify the listed names of the tensor filters.
filter_names = set() filter_names = set()
filter_names.add(run_start_intro.lines[20].split(" ")[-1]) filter_names.add(run_start_intro.lines[22].split(" ")[-1])
filter_names.add(run_start_intro.lines[21].split(" ")[-1]) filter_names.add(run_start_intro.lines[23].split(" ")[-1])
self.assertEqual({"filter_a", "filter_b"}, filter_names) self.assertEqual({"filter_a", "filter_b"}, filter_names)

View File

@ -176,3 +176,26 @@ def parse_ranges(range_string):
type(item[0])) type(item[0]))
return ranges return ranges
def evaluate_tensor_slice(tensor, tensor_slicing):
"""Call eval on the slicing of a tensor, with validation.
Args:
tensor: (numpy ndarray) The tensor value.
tensor_slicing: (str or None) Slicing of the tensor, e.g., "[:, 1]". If
None, no slicing will be performed on the tensor.
Returns:
(numpy ndarray) The sliced tensor.
Raises:
ValueError: If tensor_slicing is not a valid numpy ndarray slicing str.
"""
_ = tensor
if not validate_slicing_string(tensor_slicing):
raise ValueError("Invalid tensor-slicing string.")
return eval("tensor" + tensor_slicing) # pylint: disable=eval-used

View File

@ -62,6 +62,7 @@ class CursesUI(object):
"green": curses.COLOR_GREEN, "green": curses.COLOR_GREEN,
"yellow": curses.COLOR_YELLOW, "yellow": curses.COLOR_YELLOW,
"blue": curses.COLOR_BLUE, "blue": curses.COLOR_BLUE,
"cyan": curses.COLOR_CYAN,
"magenta": curses.COLOR_MAGENTA, "magenta": curses.COLOR_MAGENTA,
"black": curses.COLOR_BLACK, "black": curses.COLOR_BLACK,
} }
@ -76,7 +77,13 @@ class CursesUI(object):
_ERROR_TOAST_COLOR_PAIR = "red_on_white" _ERROR_TOAST_COLOR_PAIR = "red_on_white"
_STATUS_BAR_COLOR_PAIR = "black_on_white" _STATUS_BAR_COLOR_PAIR = "black_on_white"
def __init__(self): def __init__(self, on_ui_exit=None):
"""Constructor of CursesUI.
Args:
on_ui_exit: (Callable) Callback invoked when the UI exits.
"""
self._screen_init() self._screen_init()
self._screen_refresh_size() self._screen_refresh_size()
# TODO(cais): Error out if the size of the screen is too small. # TODO(cais): Error out if the size of the screen is too small.
@ -126,6 +133,9 @@ class CursesUI(object):
# Register signal handler for SIGINT. # Register signal handler for SIGINT.
signal.signal(signal.SIGINT, self._interrupt_handler) signal.signal(signal.SIGINT, self._interrupt_handler)
# Configurable callbacks.
self._on_ui_exit = on_ui_exit
def _init_layout(self): def _init_layout(self):
"""Initialize the layout of UI components. """Initialize the layout of UI components.
@ -272,6 +282,9 @@ class CursesUI(object):
# CLI main loop. # CLI main loop.
exit_token = self._ui_loop() exit_token = self._ui_loop()
if self._on_ui_exit:
self._on_ui_exit()
self._screen_terminate() self._screen_terminate()
return exit_token return exit_token
@ -1031,7 +1044,8 @@ class CursesUI(object):
# Examine whether the index information is available for the specified line # Examine whether the index information is available for the specified line
# number. # number.
pointer = self._output_pad_row + line_index pointer = self._output_pad_row + line_index
if pointer in self._curr_wrapped_output.annotations: if (pointer in self._curr_wrapped_output.annotations and
"i0" in self._curr_wrapped_output.annotations[pointer]):
indices = self._curr_wrapped_output.annotations[pointer]["i0"] indices = self._curr_wrapped_output.annotations[pointer]["i0"]
array_indices_str = self._format_indices(indices) array_indices_str = self._format_indices(indices)

View File

@ -188,6 +188,10 @@ class RichTextLines(object):
else: else:
self._annotations[key] = other.annotations[key] self._annotations[key] = other.annotations[key]
# TODO(cais): Add method append of the signature:
# def append_line(line, line_font_attr_segs)
# and refactor usage in stepper_cli.py.
def regex_find(orig_screen_output, regex, font_attr): def regex_find(orig_screen_output, regex, font_attr):
"""Perform regex match in rich text lines. """Perform regex match in rich text lines.

View File

@ -0,0 +1,593 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""CLI Backend for the Node Stepper Part of the Debugger."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import numpy as np # pylint: disable=unused-import
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.debug import stepper
from tensorflow.python.debug.cli import cli_shared
from tensorflow.python.debug.cli import command_parser
from tensorflow.python.debug.cli import debugger_cli_common
from tensorflow.python.debug.cli import tensor_format
class NodeStepperCLI(object):
"""Command-line-interface backend of Node Stepper."""
# Possible states of an element in the transitive closure of the stepper's
# fetch(es).
# State where the element is already continued-to and a TensorHandle is
# available for the tensor.
STATE_CONT = "H"
# State where an intermediate dump of the tensor is available.
STATE_INTERMEDIATE = "I"
# State where the element is already overridden.
STATE_OVERRIDDEN = "O"
# State where the element is a placeholder (and hence cannot be continued to)
STATE_IS_PLACEHOLDER = "P"
# State where a variable's value has been updated during the lifetime of
# this NodeStepperCLI instance.
STATE_DIRTY_VARIABLE = "D"
NEXT_NODE_POINTER_STR = "-->"
_MESSAGE_TEMPLATES = {
"NOT_IN_CLOSURE":
"%s is not in the transitive closure of this stepper instance.",
"MULTIPLE_TENSORS":
"Node %s has more than one output tensor. "
"Please use full tensor name.",
}
def __init__(self, node_stepper):
self._node_stepper = node_stepper
# Command parsers for the stepper.
self.arg_parsers = {}
# Parser for "list_sorted_nodes".
ap = argparse.ArgumentParser(
description="List the state of the sorted transitive closure of the "
"stepper.",
usage=argparse.SUPPRESS)
ap.add_argument(
"-l",
"--lower_bound",
dest="lower_bound",
type=int,
default=-1,
help="Lower-bound index (0-based)")
ap.add_argument(
"-u",
"--upper_bound",
dest="upper_bound",
type=int,
default=-1,
help="Upper-bound index (0-based)")
self.arg_parsers["list_sorted_nodes"] = ap
# Parser for "cont".
ap = argparse.ArgumentParser(
description="Continue to a tensor or op.", usage=argparse.SUPPRESS)
ap.add_argument(
"target_name",
type=str,
help="Name of the Tensor or Op to continue to.")
ap.add_argument(
"-r",
"--restore_variable_values",
dest="restore_variable_values",
action="store_true",
help="Restore all variables in the transitive closure of the cont "
"target to their initial values (i.e., values when this stepper "
"instance was created.")
self.arg_parsers["cont"] = ap
# Parser for "step".
ap = argparse.ArgumentParser(
description="Step to the next tensor or op in the sorted transitive "
"closure of the stepper's fetch(es).",
usage=argparse.SUPPRESS)
ap.add_argument(
"-t",
"--num_times",
dest="num_times",
type=int,
default=1,
help="Number of times to step (>=1)")
self.arg_parsers["step"] = ap
# Parser for "print_tensor".
ap = argparse.ArgumentParser(
description="Print the value of a tensor, from cached TensorHandle or "
"client-provided overrides.",
usage=argparse.SUPPRESS)
ap.add_argument(
"tensor_name",
type=str,
help="Name of the tensor, followed by any slicing indices, "
"e.g., hidden1/Wx_plus_b/MatMul:0, "
"hidden1/Wx_plus_b/MatMul:0[1, :]")
ap.add_argument(
"-r",
"--ranges",
dest="ranges",
type=str,
default="",
help="Numerical ranges to highlight tensor elements in. "
"Examples: -r 0,1e-8, -r [-0.1,0.1], "
"-r \"[[-inf, -0.1], [0.1, inf]]\"")
ap.add_argument(
"-a",
"--all",
dest="print_all",
action="store_true",
help="Print the tensor in its entirety, i.e., do not use ellipses.")
self.arg_parsers["print_tensor"] = ap
# Parser for inject_value.
ap = argparse.ArgumentParser(
description="Inject (override) the value of a Tensor.",
usage=argparse.SUPPRESS)
ap.add_argument(
"tensor_name",
type=str,
help="Name of the Tensor of which the value is to be overridden.")
ap.add_argument(
"tensor_value_str",
type=str,
help="A string representing the value of the tensor, without any "
"whitespaces, e.g., np.zeros([10,100])")
self.arg_parsers["inject_value"] = ap
self._initialize_state()
def _initialize_state(self):
"""Initialize the state of this stepper CLI."""
# Get the elements in the sorted transitive closure, as a list of str.
self._sorted_nodes = self._node_stepper.sorted_nodes()
self._closure_elements = self._node_stepper.closure_elements()
self._placeholders = self._node_stepper.placeholders()
self._completed_nodes = set()
self._calculate_next()
def _calculate_next(self):
"""Calculate the next target for "step" action based on current state."""
override_names = self._node_stepper.override_names()
next_i = -1
for i in xrange(len(self._sorted_nodes)):
if (i > next_i and (self._sorted_nodes[i] in self._completed_nodes) or
(self._sorted_nodes[i] in override_names)):
next_i = i
next_i += 1
self._next = next_i
def list_sorted_nodes(self, args, screen_info=None):
"""List the sorted transitive closure of the stepper's fetches."""
# TODO(cais): Use pattern such as del args, del screen_info python/debug.
_ = args
_ = screen_info
parsed = self.arg_parsers["list_sorted_nodes"].parse_args(args)
if parsed.lower_bound != -1 and parsed.upper_bound != -1:
index_range = [
max(0, parsed.lower_bound),
min(len(self._sorted_nodes), parsed.upper_bound)
]
verbose = False
else:
index_range = [0, len(self._sorted_nodes)]
verbose = True
handle_node_names = self._node_stepper.handle_node_names()
override_names = self._node_stepper.override_names()
dirty_variable_names = [
dirty_variable.split(":")[0]
for dirty_variable in self._node_stepper.dirty_variables()
]
lines = []
font_attr_segs = {}
if verbose:
lines.extend(
["Topologically-sorted transitive input(s) and fetch(es):", ""])
line_counter = len(lines)
for i, element_name in enumerate(self._sorted_nodes):
if i < index_range[0] or i >= index_range[1]:
continue
font_attr_segs[line_counter] = []
# TODO(cais): Use fixed-width text to show node index.
node_prefix = "(%d / %d)" % (i + 1, len(self._sorted_nodes))
if i == self._next:
node_prefix = " " + self.NEXT_NODE_POINTER_STR + node_prefix
font_attr_segs[line_counter].append((0, 3, "bold"))
else:
node_prefix = " " + node_prefix
node_prefix += " ["
labels, label_font_attr_segs = self._get_status_labels(
element_name,
handle_node_names,
override_names,
dirty_variable_names,
len(node_prefix))
node_prefix += labels
font_attr_segs[line_counter].extend(label_font_attr_segs)
lines.append(node_prefix + "] " + element_name)
line_counter += 1
output = debugger_cli_common.RichTextLines(
lines, font_attr_segs=font_attr_segs)
if verbose:
output.extend(self._node_status_label_legend())
return output
def _get_status_labels(self,
element_name,
handle_node_names,
override_names,
dirty_variable_names,
offset):
"""Get a string of status labels for a graph element.
A status label indicates that a node has a certain state in this
node-stepper CLI invocation. For example, 1) that the node has been
continued-to and a handle to its output tensor is available to the node
stepper; 2) the node is a Variable and its value has been altered, e.g.,
by continuing to a variable-updating node, since the beginning of this
node-stepper invocation (i.e., "dirty variable").
Args:
element_name: (str) name of the graph element.
handle_node_names: (list of str) Names of the nodes of which the output
tensors' handles are available.
override_names: (list of str) Names of the tensors of which the values
are overridden.
dirty_variable_names: (list of str) Names of the dirty variables.
offset: (int) Initial offset of the font attribute segments.
Returns:
(str) The string made of status labels that currently apply to the graph
element.
(list of tuples) The font attribute segments, with offset applied.
"""
stat_string = ""
font_attr_segs = []
position = offset
node_name = element_name.split(":")[0]
if node_name in self._placeholders:
stat_string += "P"
font_attr_segs.append((position, position + 1, "cyan"))
else:
stat_string += " "
position += 1
if self._node_stepper.is_feedable(str(element_name)):
stat_string += " "
else:
stat_string += "U"
font_attr_segs.append((position, position + 1, "red"))
position += 1
if element_name in handle_node_names:
stat_string += "H"
font_attr_segs.append((position, position + 1, "green"))
else:
stat_string += " "
position += 1
slots = self._node_stepper.output_slots_in_closure(element_name)
has_override = False
for slot in slots:
if element_name + ":%d" % slot in override_names:
has_override = True
break
if has_override:
stat_string += "O"
font_attr_segs.append((position, position + 1, "yellow"))
else:
stat_string += " "
position += 1
if element_name in dirty_variable_names:
stat_string += self.STATE_DIRTY_VARIABLE
font_attr_segs.append((position, position + 1, "magenta"))
else:
stat_string += " "
position += 1
return stat_string, font_attr_segs
def _node_status_label_legend(self):
"""Get legend for node-status labels.
Returns:
(debugger_cli_common.RichTextLines) Legend text.
"""
lines = []
font_attr_segs = {}
line_counter = 0
lines.append("")
lines.append("Legend:")
line_counter += 2
lines.append(" P - Placeholder")
font_attr_segs[line_counter] = [(2, 3, "cyan")]
line_counter += 1
lines.append(" U - Unfeedable")
font_attr_segs[line_counter] = [(2, 3, "red")]
line_counter += 1
lines.append(
" H - Already continued-to; Tensor handle available from output "
"slot(s)"
)
font_attr_segs[line_counter] = [(2, 3, "green")]
line_counter += 1
lines.append(" O - Has overriding (injected) tensor value")
font_attr_segs[line_counter] = [(2, 3, "yellow")]
line_counter += 1
lines.append(
" D - Dirty variable: Variable already updated this node stepper.")
font_attr_segs[line_counter] = [(2, 3, "magenta")]
line_counter += 1
return debugger_cli_common.RichTextLines(
lines, font_attr_segs=font_attr_segs)
def cont(self, args, screen_info=None):
"""Continue-to action on the graph."""
_ = screen_info
parsed = self.arg_parsers["cont"].parse_args(args)
# Determine which node is being continued to, so the _next pointer can be
# set properly.
node_name = parsed.target_name.split(":")[0]
if node_name not in self._sorted_nodes:
return cli_shared.error(self._MESSAGE_TEMPLATES["NOT_IN_CLOSURE"] %
parsed.target_name)
self._next = self._sorted_nodes.index(node_name)
cont_result = self._node_stepper.cont(
parsed.target_name,
restore_variable_values=parsed.restore_variable_values)
self._completed_nodes.add(parsed.target_name.split(":")[0])
feed_types = self._node_stepper.last_feed_types()
lines = ["Continued to %s:" % parsed.target_name, ""]
font_attr_segs = {}
lines.append("Stepper used feeds:")
line_counter = len(lines)
if feed_types:
for feed_name in feed_types:
feed_info_line = " %s : %s" % (feed_name, feed_types[feed_name])
lines.append(feed_info_line)
if feed_types[feed_name] == stepper.NodeStepper.FEED_TYPE_HANDLE:
font_attr_segs[line_counter] = [
(len(feed_name) + 2, len(feed_info_line), "green")
]
elif feed_types[feed_name] == stepper.NodeStepper.FEED_TYPE_OVERRIDE:
font_attr_segs[line_counter] = [
(len(feed_name) + 2, len(feed_info_line), "yellow")
]
line_counter += 1
else:
lines.append(" (No feeds)")
lines.append("")
screen_output = debugger_cli_common.RichTextLines(
lines, font_attr_segs=font_attr_segs)
tensor_output = tensor_format.format_tensor(
cont_result, parsed.target_name,
include_metadata=True)
screen_output.extend(tensor_output)
# Generate windowed view of the sorted transitive closure on which the
# stepping is occurring.
lower_bound = max(0, self._next - 2)
upper_bound = min(len(self._sorted_nodes), self._next + 3)
final_output = self.list_sorted_nodes(
["-l", str(lower_bound), "-u", str(upper_bound)])
final_output.extend(debugger_cli_common.RichTextLines([""]))
final_output.extend(screen_output)
# Re-calculate the target of the next "step" action.
self._calculate_next()
return final_output
def step(self, args, screen_info=None):
"""Step once.
Args:
args: (list of str) command-line arguments for the "step" command.
screen_info: Information about screen.
Returns:
(RichTextLines) Screen output for the result of the stepping action.
"""
parsed = self.arg_parsers["step"].parse_args(args)
if parsed.num_times < 0:
return debugger_cli_common.RichTextLines(
"ERROR: Invalid number of times to step: %d" % parsed.num_times)
for _ in xrange(parsed.num_times):
if self._next >= len(self._sorted_nodes):
return debugger_cli_common.RichTextLines(
"ERROR: Cannot step any further because the end of the sorted "
"transitive closure has been reached.")
else:
screen_output = self.cont([self._sorted_nodes[self._next]], screen_info)
return screen_output
def print_tensor(self, args, screen_info=None):
"""Print the value of a tensor that the stepper has access to."""
parsed = self.arg_parsers["print_tensor"].parse_args(args)
if screen_info and "cols" in screen_info:
np_printoptions = {"linewidth": screen_info["cols"]}
else:
np_printoptions = {}
# Determine if any range-highlighting is required.
highlight_options = cli_shared.parse_ranges_highlight(parsed.ranges)
tensor_name, tensor_slicing = (
command_parser.parse_tensor_name_with_slicing(parsed.tensor_name))
tensor_names = self._resolve_tensor_names(tensor_name)
if not tensor_names:
return cli_shared.error(
self._MESSAGE_TEMPLATES["NOT_IN_CLOSURE"] % tensor_name)
elif len(tensor_names) > 1:
return cli_shared.error(
self._MESSAGE_TEMPLATES["MULTIPLE_TENSORS"] % tensor_name)
else:
tensor_name = tensor_names[0]
try:
tensor_value = self._node_stepper.get_tensor_value(tensor_name)
except ValueError as e:
return debugger_cli_common.RichTextLines([str(e)])
return cli_shared.format_tensor(
tensor_value,
tensor_name,
np_printoptions,
print_all=parsed.print_all,
tensor_slicing=tensor_slicing,
highlight_options=highlight_options)
def inject_value(self, args, screen_info=None):
"""Inject value to a given tensor.
Args:
args: (list of str) command-line arguments for the "step" command.
screen_info: Information about screen.
Returns:
(RichTextLines) Screen output for the result of the stepping action.
"""
_ = screen_info # Currently unused.
if screen_info and "cols" in screen_info:
np_printoptions = {"linewidth": screen_info["cols"]}
else:
np_printoptions = {}
parsed = self.arg_parsers["inject_value"].parse_args(args)
tensor_names = self._resolve_tensor_names(parsed.tensor_name)
if not tensor_names:
return cli_shared.error(
self._MESSAGE_TEMPLATES["NOT_IN_CLOSURE"] % parsed.tensor_name)
elif len(tensor_names) > 1:
return cli_shared.error(
self._MESSAGE_TEMPLATES["MULTIPLE_TENSORS"] % parsed.tensor_name)
else:
tensor_name = tensor_names[0]
tensor_value = eval(parsed.tensor_value_str) # pylint: disable=eval-used
try:
self._node_stepper.override_tensor(tensor_name, tensor_value)
lines = [
"Injected value \"%s\"" % parsed.tensor_value_str,
" to tensor \"%s\":" % tensor_name, ""
]
tensor_lines = tensor_format.format_tensor(
tensor_value,
tensor_name,
include_metadata=True,
np_printoptions=np_printoptions).lines
lines.extend(tensor_lines)
except ValueError:
lines = [
"ERROR: Failed to inject value to tensor %s" % parsed.tensor_name
]
return debugger_cli_common.RichTextLines(lines)
# TODO(cais): Implement list_inputs
# TODO(cais): Implement list_outputs
# TODO(cais): Implement node_info
def _resolve_tensor_names(self, element_name):
"""Resolve tensor name from graph element name.
Args:
element_name: (str) Name of the graph element to resolve.
Returns:
(list) Name of the tensor(s). If element_name is the name of a tensor in
the transitive closure, return [element_name]. If element_name is the
name of a node in the transitive closure, return the list of output
tensors from the node that are in the transitive closure. Otherwise,
return empty list.
"""
if element_name in self._closure_elements and ":" in element_name:
return [element_name]
if (element_name in self._sorted_nodes or
(element_name in self._closure_elements and ":" not in element_name)):
slots = self._node_stepper.output_slots_in_closure(element_name)
return [(element_name + ":%d" % slot) for slot in slots]
else:
return []

View File

@ -0,0 +1,444 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests of the Stepper CLI Backend."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.python.debug import stepper
from tensorflow.python.debug.cli import stepper_cli
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
# Regex pattern for a node line in the stepper CLI output.
NODE_LINE_PATTERN = re.compile(r".*\(.*\).*\[.*\].*")
def _parse_sorted_nodes_list(lines):
"""Parsed a list of lines to extract the node list.
Args:
lines: (list of str) Lines from which the node list and associated
information will be extracted.
Returns:
(list of str) The list of node names.
(list of str) The list of status labels.
(int) 0-based index among the nodes for the node pointed by the next-node
pointer. If no such node exists, -1.
"""
node_names = []
status_labels = []
node_pointer = -1
node_line_counter = 0
for line in lines:
if NODE_LINE_PATTERN.match(line):
node_names.append(line.split(" ")[-1])
idx_left_bracket = line.index("[")
idx_right_bracket = line.index("]")
status_labels.append(line[idx_left_bracket + 1:idx_right_bracket])
if line.strip().startswith(
stepper_cli.NodeStepperCLI.NEXT_NODE_POINTER_STR):
node_pointer = node_line_counter
node_line_counter += 1
return node_names, status_labels, node_pointer
def _parsed_used_feeds(lines):
feed_types = {}
begin_line = -1
for i, line in enumerate(lines):
if line.startswith("Stepper used feeds:"):
begin_line = i + 1
break
if begin_line == -1:
return feed_types
for line in lines[begin_line:]:
line = line.strip()
if not line:
return feed_types
else:
feed_name = line.split(" : ")[0].strip()
feed_type = line.split(" : ")[1].strip()
feed_types[feed_name] = feed_type
class NodeStepperSimpleGraphTest(test_util.TensorFlowTestCase):
def setUp(self):
self.a = tf.Variable(10.0, name="a")
self.b = tf.Variable(20.0, name="b")
self.c = tf.add(self.a, self.b, name="c") # Should be 30.0.
self.d = tf.sub(self.a, self.c, name="d") # Should be -20.0.
self.e = tf.mul(self.c, self.d, name="e") # Should be -600.0.
self.ph = tf.placeholder(tf.float32, shape=(2, 2), name="ph")
self.f = tf.mul(self.e, self.ph, name="f")
self.opt = tf.train.GradientDescentOptimizer(0.1).minimize(
self.e, name="opt")
self.sess = tf.Session()
self.sess.run(self.a.initializer)
self.sess.run(self.b.initializer)
def tearDown(self):
tf.reset_default_graph()
def _assert_nodes_topologically_sorted_with_target_e(self, node_names):
"""Check the topologically sorted order of the node names."""
self.assertGreaterEqual(len(node_names), 7)
self.assertLess(node_names.index("a"), node_names.index("a/read"))
self.assertLess(node_names.index("b"), node_names.index("b/read"))
self.assertLess(node_names.index("a/read"), node_names.index("c"))
self.assertLess(node_names.index("b/read"), node_names.index("c"))
self.assertLess(node_names.index("a/read"), node_names.index("d"))
self.assertLess(node_names.index("c"), node_names.index("d"))
self.assertLess(node_names.index("c"), node_names.index("e"))
self.assertLess(node_names.index("d"), node_names.index("e"))
def _assert_nodes_topologically_sorted_with_target_f(self, node_names):
self._assert_nodes_topologically_sorted_with_target_e(node_names)
self.assertGreaterEqual(len(node_names), 9)
self.assertLess(node_names.index("ph"), node_names.index("f"))
self.assertLess(node_names.index("e"), node_names.index("f"))
def testListingSortedNodesPresentsTransitveClosure(self):
cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))
output = cli.list_sorted_nodes([])
node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
output.lines)
self._assert_nodes_topologically_sorted_with_target_e(node_names)
self.assertEqual(len(node_names), len(stat_labels))
for stat_label in stat_labels:
self.assertEqual(" ", stat_label)
self.assertEqual(0, node_pointer)
def testListingSortedNodesLabelsPlaceholders(self):
cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.f))
output = cli.list_sorted_nodes([])
node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
output.lines)
self._assert_nodes_topologically_sorted_with_target_f(node_names)
index_ph = node_names.index("ph")
self.assertEqual(len(node_names), len(stat_labels))
for i in xrange(len(stat_labels)):
if index_ph == i:
self.assertIn(stepper_cli.NodeStepperCLI.STATE_IS_PLACEHOLDER,
stat_labels[i])
else:
self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_IS_PLACEHOLDER,
stat_labels[i])
self.assertEqual(0, node_pointer)
def testContToNonexistentNodeShouldError(self):
cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.f))
output = cli.cont(["foobar"])
self.assertEqual(
["ERROR: foobar is not in the transitive closure of this stepper "
"instance."], output.lines)
def testContToNodeOutsideTransitiveClosureShouldError(self):
cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))
output = cli.cont(["f"])
self.assertEqual(
["ERROR: f is not in the transitive closure of this stepper "
"instance."], output.lines)
def testContToValidNodeShouldUpdateStatus(self):
cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))
output = cli.list_sorted_nodes([])
node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
output.lines)
index_c = node_names.index("c")
self.assertEqual(" ", stat_labels[index_c])
self.assertEqual(0, node_pointer)
output = cli.cont("c")
node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
output.lines)
self.assertGreaterEqual(len(node_names), 3)
self.assertIn("c", node_names)
index_c = node_names.index("c")
self.assertEqual(index_c, node_pointer)
self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_c])
output = cli.cont("d")
node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
output.lines)
used_feed_types = _parsed_used_feeds(output.lines)
self.assertEqual({"c:0": "handle"}, used_feed_types)
self.assertGreaterEqual(len(node_names), 3)
self.assertIn("d", node_names)
index_d = node_names.index("d")
self.assertEqual(index_d, node_pointer)
self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_d])
def testSteppingOneStepAtATimeShouldUpdateStatus(self):
cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))
output = cli.list_sorted_nodes([])
orig_node_names, _, node_pointer = _parse_sorted_nodes_list(output.lines)
self.assertEqual(0, node_pointer)
for i in xrange(len(orig_node_names)):
output = cli.step([])
node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
output.lines)
next_node_name = node_names[node_pointer]
self.assertEqual(orig_node_names[i], next_node_name)
self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT,
stat_labels[node_pointer])
# The order in which the nodes are listed should not change as the
# stepping happens.
output = cli.list_sorted_nodes([])
node_names, _, node_pointer = _parse_sorted_nodes_list(output.lines)
self.assertEqual(orig_node_names, node_names)
if i < len(orig_node_names) - 1:
self.assertEqual(i + 1, node_pointer)
else:
# Stepped over the limit. Pointer should be at -1.
self.assertEqual(-1, node_pointer)
# Attempt to step once more after the end has been reached should error out.
output = cli.step([])
self.assertEqual([
"ERROR: Cannot step any further because the end of the sorted "
"transitive closure has been reached."
], output.lines)
def testSteppingMultipleStepsUpdatesStatus(self):
cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))
output = cli.list_sorted_nodes([])
orig_node_names, _, _ = _parse_sorted_nodes_list(output.lines)
output = cli.step(["-t", "3"])
node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
output.lines)
self.assertEqual(orig_node_names[2], node_names[node_pointer])
for i in xrange(node_pointer):
self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[i])
for i in xrange(node_pointer + 1, len(stat_labels)):
self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[i])
def testContToNodeWithoutOutputTensorInClosureShowsNoHandleCached(self):
node_stepper = stepper.NodeStepper(self.sess, self.opt)
sorted_nodes = node_stepper.sorted_nodes()
closure_elements = node_stepper.closure_elements()
# Find a node which is in the list of sorted nodes, but whose output tensor
# is not in the transitive closure.
no_output_node = None
for node in sorted_nodes:
if (node + ":0" not in closure_elements and
node + ":1" not in closure_elements):
no_output_node = node
break
self.assertIsNotNone(no_output_node)
cli = stepper_cli.NodeStepperCLI(node_stepper)
output = cli.cont([no_output_node])
node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
output.lines)
self.assertEqual(no_output_node, node_names[node_pointer])
self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_CONT,
stat_labels[node_pointer])
def testContToUpdateNodeLeadsToDirtyVariableLabel(self):
cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.opt))
output = cli.cont(["opt/update_b/ApplyGradientDescent"])
output = cli.list_sorted_nodes([])
node_names, stat_labels, _ = _parse_sorted_nodes_list(
output.lines)
self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
stat_labels[node_names.index("b")])
self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
stat_labels[node_names.index("a")])
def testContWithRestoreVariablesOptionShouldRestoreVariableValue(self):
cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.opt))
output = cli.cont(["opt/update_a/ApplyGradientDescent"])
# After cont() call on .../update_a/..., Variable a should have been marked
# as dirty, whereas b should not have.
output = cli.list_sorted_nodes([])
node_names, stat_labels, _ = _parse_sorted_nodes_list(
output.lines)
self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
stat_labels[node_names.index("a")])
self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
stat_labels[node_names.index("b")])
output = cli.cont(["opt/update_b/ApplyGradientDescent", "-r"])
# After cont() call on .../update_b/... with the -r flag, Variable b should
# have been marked as dirty, whereas Variable a should not be because it
# should have been restored.
output = cli.list_sorted_nodes([])
node_names, stat_labels, _ = _parse_sorted_nodes_list(
output.lines)
self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
stat_labels[node_names.index("b")])
self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
stat_labels[node_names.index("a")])
def testPrintTensorShouldWorkWithTensorName(self):
cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))
cli.cont("d")
output = cli.print_tensor(["d:0"])
self.assertEqual("Tensor \"d:0\":", output.lines[0])
self.assertEqual("-20.0", output.lines[-1])
def testPrintTensorShouldWorkWithNodeNameWithOutputTensor(self):
cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))
cli.cont("d")
output = cli.print_tensor(["d"])
self.assertEqual("Tensor \"d:0\":", output.lines[0])
self.assertEqual("-20.0", output.lines[-1])
def testPrintTensorShouldWorkSlicingString(self):
ph_value = np.array([[1.0, 0.0], [0.0, 2.0]])
cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(
self.sess,
self.f,
feed_dict={self.ph: ph_value}))
output = cli.print_tensor(["ph:0[:, 1]"])
self.assertEqual("Tensor \"ph:0[:, 1]\":", output.lines[0])
self.assertEqual(repr(ph_value[:, 1]), output.lines[-1])
output = cli.print_tensor(["ph[:, 1]"])
self.assertEqual("Tensor \"ph:0[:, 1]\":", output.lines[0])
self.assertEqual(repr(ph_value[:, 1]), output.lines[-1])
def testPrintTensorWithNonexistentTensorShouldError(self):
cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))
output = cli.print_tensor(["foobar"])
self.assertEqual(
["ERROR: foobar is not in the transitive closure of this stepper "
"instance."], output.lines)
def testPrintTensorWithNoHandleShouldError(self):
cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))
output = cli.print_tensor("e")
self.assertEqual(
["This stepper instance does not have access to the value of tensor "
"\"e:0\""], output.lines)
def testInjectTensorValueByTensorNameShouldBeReflected(self):
node_stepper = stepper.NodeStepper(self.sess, self.e)
cli = stepper_cli.NodeStepperCLI(node_stepper)
output = cli.cont(["d"])
node_names, _, node_pointer = _parse_sorted_nodes_list(output.lines)
self.assertEqual("d", node_names[node_pointer])
output = cli.list_sorted_nodes([])
node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
output.lines)
index_d = node_names.index("d")
self.assertIn(
stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_d])
self.assertNotIn(
stepper_cli.NodeStepperCLI.STATE_OVERRIDDEN, stat_labels[index_d])
self.assertAllClose(-20.0, node_stepper.get_tensor_value("d:0"))
output = cli.inject_value(["d:0", "20.0"])
# Verify that the override is available.
self.assertEqual(["d:0"], node_stepper.override_names())
# Verify that the list of sorted nodes reflects the existence of the value
# override (i.e., injection).
output = cli.list_sorted_nodes([])
node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
output.lines)
index_d = node_names.index("d")
self.assertNotIn(
stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_d])
self.assertIn(
stepper_cli.NodeStepperCLI.STATE_OVERRIDDEN, stat_labels[index_d])
def testInjectTensorValueByNodeNameShouldBeReflected(self):
node_stepper = stepper.NodeStepper(self.sess, self.e)
cli = stepper_cli.NodeStepperCLI(node_stepper)
cli.inject_value(["d", "20.0"])
self.assertEqual(["d:0"], node_stepper.override_names())
def testInjectToNonexistentTensorShouldError(self):
node_stepper = stepper.NodeStepper(self.sess, self.e)
cli = stepper_cli.NodeStepperCLI(node_stepper)
output = cli.inject_value(["foobar:0", "20.0"])
self.assertEqual(
["ERROR: foobar:0 is not in the transitive closure of this stepper "
"instance."], output.lines)
if __name__ == "__main__":
googletest.main()

View File

@ -19,6 +19,8 @@ from __future__ import print_function
import copy import copy
import six
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python.debug import debug_data from tensorflow.python.debug import debug_data
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
@ -174,7 +176,7 @@ class NodeStepper(object):
fetch_names = [] fetch_names = []
fetch_list = [] fetch_list = []
for fetch in flattened_fetches: for fetch in flattened_fetches:
if isinstance(fetch, str): if isinstance(fetch, six.string_types):
fetch_names.append(fetch) fetch_names.append(fetch)
fetch_list.append(self._sess.graph.as_graph_element(fetch)) fetch_list.append(self._sess.graph.as_graph_element(fetch))
else: else:
@ -345,7 +347,7 @@ class NodeStepper(object):
(bool) whether the graph element is feedable. (bool) whether the graph element is feedable.
""" """
if not isinstance(name, str): if not isinstance(name, six.string_types):
raise TypeError("Expected type str; got type %s" % type(name)) raise TypeError("Expected type str; got type %s" % type(name))
elem = self._sess.graph.as_graph_element(name) elem = self._sess.graph.as_graph_element(name)
@ -363,7 +365,7 @@ class NodeStepper(object):
tree to the fetched graph element of this stepper instance. tree to the fetched graph element of this stepper instance.
""" """
if not isinstance(tensor_name, str): if not isinstance(tensor_name, six.string_types):
raise TypeError("Expected type str; got type %s" % type(tensor_name)) raise TypeError("Expected type str; got type %s" % type(tensor_name))
node_name = self._get_node_name(tensor_name) node_name = self._get_node_name(tensor_name)
@ -442,7 +444,7 @@ class NodeStepper(object):
# The feeds to be used in the Session.run() call. # The feeds to be used in the Session.run() call.
feeds = {} feeds = {}
if isinstance(target, str): if isinstance(target, six.string_types):
# Fetch target is a string. Assume it is the name of the Tensor or Op and # Fetch target is a string. Assume it is the name of the Tensor or Op and
# will attempt to find it in the Session's graph. # will attempt to find it in the Session's graph.
target_name = target target_name = target
@ -703,16 +705,23 @@ class NodeStepper(object):
The same return value as self.cont() as called on the final fetch. The same return value as self.cont() as called on the final fetch.
""" """
# Restore variable to their previous values. self.restore_variable_values()
for var_name in self._cached_variable_values: return self._sess.run(self._fetches, feed_dict=self._client_feed_dict)
def restore_variable_values(self):
"""Restore variables to the initial values.
"Initial value" refers to the value when this NodeStepper instance was
first constructed.
"""
for var_name in self._dirty_variables:
self._sess.run(self._variable_initializers[var_name], self._sess.run(self._variable_initializers[var_name],
feed_dict={ feed_dict={
self._variable_initial_values[var_name]: self._variable_initial_values[var_name]:
self._cached_variable_values[var_name] self._cached_variable_values[var_name]
}) })
return self._sess.run(self._fetches, feed_dict=self._client_feed_dict)
def handle_names(self): def handle_names(self):
"""Return names of the TensorHandles that the debugger is holding. """Return names of the TensorHandles that the debugger is holding.
@ -800,7 +809,11 @@ class NodeStepper(object):
or through a TensorHandle. or through a TensorHandle.
""" """
if tensor_name in self._override_tensors: if self.is_placeholder(tensor_name):
if ":" not in tensor_name:
tensor_name += ":0"
return self._client_feed_dict[tensor_name]
elif tensor_name in self._override_tensors:
return self._override_tensors[tensor_name] return self._override_tensors[tensor_name]
elif tensor_name in self._tensor_handles: elif tensor_name in self._tensor_handles:
return self._tensor_handles[tensor_name].eval() return self._tensor_handles[tensor_name].eval()

View File

@ -412,6 +412,23 @@ class StepperTestWithPlaceHolders(test_util.TensorFlowTestCase):
def tearDown(self): def tearDown(self):
tf.reset_default_graph() tf.reset_default_graph()
def testGetTensorValueWorksOnPlaceholder(self):
stepper = NodeStepper(
self.sess,
self.y,
feed_dict={
self.ph0: [[1.0, 2.0], [-3.0, 5.0]],
self.ph1: [[-1.0], [0.5]]
})
self.assertAllClose(
[[1.0, 2.0], [-3.0, 5.0]], stepper.get_tensor_value("ph0"))
self.assertAllClose(
[[1.0, 2.0], [-3.0, 5.0]], stepper.get_tensor_value("ph0:0"))
with self.assertRaisesRegexp(
KeyError, r"The name 'ph0:1' refers to a Tensor which does not exist"):
stepper.get_tensor_value("ph0:1")
def testIsPlaceholdersShouldGiveCorrectAnswers(self): def testIsPlaceholdersShouldGiveCorrectAnswers(self):
stepper = NodeStepper(self.sess, self.y) stepper = NodeStepper(self.sess, self.y)
@ -694,6 +711,18 @@ class StepperBackwardRunTest(test_util.TensorFlowTestCase):
self.assertAllClose(1.84, self.sess.run(self.b)) self.assertAllClose(1.84, self.sess.run(self.b))
self.assertAllClose(4.0, self.sess.run(self.c)) self.assertAllClose(4.0, self.sess.run(self.c))
def testRestoreVariableValues(self):
"""Test restore_variable_values() restores the old values of variables."""
stepper = NodeStepper(self.sess, "optim")
stepper.cont("optim/update_b/ApplyGradientDescent",
restore_variable_values=True)
self.assertAllClose(1.84, self.sess.run(self.b))
stepper.restore_variable_values()
self.assertAllClose(2.0, self.sess.run(self.b))
def testFinalize(self): def testFinalize(self):
"""Test finalize() to restore variables and run the original fetch.""" """Test finalize() to restore variables and run the original fetch."""

View File

@ -116,6 +116,7 @@ import abc
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.debug import debug_utils from tensorflow.python.debug import debug_utils
from tensorflow.python.debug import stepper
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
@ -155,9 +156,9 @@ class OnSessionInitRequest(object):
class OnSessionInitAction(object): class OnSessionInitAction(object):
"""Enum-like values for possible action to take on session init.""" """Enum-like values for possible action to take on session init."""
# Proceed, without special actions, in the wrapper session initializaton. What # Proceed, without special actions, in the wrapper session initialization.
# action the wrapper session performs next is determined by the caller of the # What action the wrapper session performs next is determined by the caller
# wrapper session. E.g., it can call run(). # of the wrapper session. E.g., it can call run().
PROCEED = "proceed" PROCEED = "proceed"
# Instead of letting the caller of the wrapper session determine what actions # Instead of letting the caller of the wrapper session determine what actions
@ -397,7 +398,13 @@ class BaseDebugWrapperSession(session.SessionInterface):
client_graph_def=self._sess.graph.as_graph_def(), client_graph_def=self._sess.graph.as_graph_def(),
tf_error=tf_error) tf_error=tf_error)
elif run_start_resp.action == OnRunStartAction.NON_DEBUG_RUN: elif (run_start_resp.action == OnRunStartAction.NON_DEBUG_RUN or
run_start_resp.action == OnRunStartAction.INVOKE_STEPPER):
if run_start_resp.action == OnRunStartAction.INVOKE_STEPPER:
retvals = self.invoke_node_stepper(
stepper.NodeStepper(self._sess, fetches, feed_dict),
restore_variable_values_on_exit=True)
# Invoke run() method of the wrapped session. # Invoke run() method of the wrapped session.
retvals = self._sess.run( retvals = self._sess.run(
fetches, fetches,
@ -407,10 +414,6 @@ class BaseDebugWrapperSession(session.SessionInterface):
# Prepare arg for the on-run-end callback. # Prepare arg for the on-run-end callback.
run_end_req = OnRunEndRequest(run_start_resp.action) run_end_req = OnRunEndRequest(run_start_resp.action)
elif run_start_resp.action == OnRunStartAction.INVOKE_STEPPER:
# TODO(cais): Implement stepper loop.
raise NotImplementedError(
"OnRunStartAction INVOKE_STEPPER has not been implemented.")
else: else:
raise ValueError( raise ValueError(
"Invalid OnRunStartAction value: %s" % run_start_resp.action) "Invalid OnRunStartAction value: %s" % run_start_resp.action)
@ -461,7 +464,6 @@ class BaseDebugWrapperSession(session.SessionInterface):
Returns: Returns:
An instance of OnSessionInitResponse. An instance of OnSessionInitResponse.
""" """
pass
@abc.abstractmethod @abc.abstractmethod
def on_run_start(self, request): def on_run_start(self, request):
@ -482,7 +484,6 @@ class BaseDebugWrapperSession(session.SessionInterface):
with or without debug tensor watching, invoking the stepper.) with or without debug tensor watching, invoking the stepper.)
2) debug URLs used to watch the tensors. 2) debug URLs used to watch the tensors.
""" """
pass
@abc.abstractmethod @abc.abstractmethod
def on_run_end(self, request): def on_run_end(self, request):
@ -499,7 +500,6 @@ class BaseDebugWrapperSession(session.SessionInterface):
Returns: Returns:
An instance of OnRunStartResponse. An instance of OnRunStartResponse.
""" """
pass
def __enter__(self): def __enter__(self):
return self._sess.__enter__() return self._sess.__enter__()
@ -512,3 +512,21 @@ class BaseDebugWrapperSession(session.SessionInterface):
# TODO(cais): Add _node_name_regex_whitelist and # TODO(cais): Add _node_name_regex_whitelist and
# _node_op_type_regex_whitelist. # _node_op_type_regex_whitelist.
@abc.abstractmethod
def invoke_node_stepper(self,
node_stepper,
restore_variable_values_on_exit=True):
"""Callback invoked when the client intends to step through graph nodes.
Args:
node_stepper: (stepper.NodeStepper) An instance of NodeStepper to be used
in this stepping session.
restore_variable_values_on_exit: (bool) Whether any variables whose values
have been altered during this node-stepper invocation should be restored
to their old values when this invocation ends.
Returns:
The same return values as the `Session.run()` call on the same fetches as
the NodeStepper.
"""

View File

@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python.debug import debug_utils from tensorflow.python.debug import debug_utils
from tensorflow.python.debug import stepper
from tensorflow.python.debug.wrappers import framework from tensorflow.python.debug.wrappers import framework
from tensorflow.python.debug.wrappers import local_cli_wrapper from tensorflow.python.debug.wrappers import local_cli_wrapper
from tensorflow.python.training import session_run_hook from tensorflow.python.training import session_run_hook
@ -64,19 +65,25 @@ class LocalCLIDebugHook(session_run_hook.SessionRunHook,
self._decorate_options_for_debug(run_args.options, self._decorate_options_for_debug(run_args.options,
run_context.session.graph) run_context.session.graph)
elif self._performed_action == framework.OnRunStartAction.INVOKE_STEPPER: elif self._performed_action == framework.OnRunStartAction.INVOKE_STEPPER:
raise NotImplementedError( # The _finalized property must be set to False so that the NodeStepper
"OnRunStartAction INVOKE_STEPPER has not been implemented.") # can insert ops for retrieving TensorHandles.
# pylint: disable=protected-access
run_context.session.graph._finalized = False
# pylint: enable=protected-access
self.invoke_node_stepper(
stepper.NodeStepper(run_context.session, run_context.original_args.
fetches, run_context.original_args.feed_dict),
restore_variable_values_on_exit=True)
return run_args return run_args
def after_run(self, run_context, run_values): def after_run(self, run_context, run_values):
# Adapt run_context and run_values to OnRunEndRequest and invoke superclass # Adapt run_context and run_values to OnRunEndRequest and invoke superclass
# on_run_end() # on_run_end()
if self._performed_action == framework.OnRunStartAction.DEBUG_RUN: on_run_end_request = framework.OnRunEndRequest(self._performed_action,
on_run_end_request = framework.OnRunEndRequest(self._performed_action, run_values.run_metadata)
run_values.run_metadata) self.on_run_end(on_run_end_request)
self.on_run_end(on_run_end_request)
def _decorate_options_for_debug(self, options, graph): def _decorate_options_for_debug(self, options, graph):
"""Modify RunOptions.debug_options.debug_tensor_watch_opts for debugging. """Modify RunOptions.debug_options.debug_tensor_watch_opts for debugging.

View File

@ -29,6 +29,7 @@ from tensorflow.python.debug.cli import analyzer_cli
from tensorflow.python.debug.cli import cli_shared from tensorflow.python.debug.cli import cli_shared
from tensorflow.python.debug.cli import curses_ui from tensorflow.python.debug.cli import curses_ui
from tensorflow.python.debug.cli import debugger_cli_common from tensorflow.python.debug.cli import debugger_cli_common
from tensorflow.python.debug.cli import stepper_cli
from tensorflow.python.debug.wrappers import framework from tensorflow.python.debug.wrappers import framework
@ -497,3 +498,70 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
fetches, fetches,
feed_dict, feed_dict,
self._tensor_filters) self._tensor_filters)
def invoke_node_stepper(self,
node_stepper,
restore_variable_values_on_exit=True):
"""Overrides method in base class to implement interactive node stepper.
Args:
node_stepper: (stepper.NodeStepper) The underlying NodeStepper API object.
restore_variable_values_on_exit: (bool) Whether any variables whose values
have been altered during this node-stepper invocation should be restored
to their old values when this invocation ends.
Returns:
The same return values as the `Session.run()` call on the same fetches as
the NodeStepper.
"""
stepper = stepper_cli.NodeStepperCLI(node_stepper)
# On exiting the node-stepper CLI, the finalize method of the node_stepper
# object will be called, ensuring that the state of the graph will be the
# same as if the stepping did not happen.
# TODO(cais): Perhaps some users will want the effect of the interactive
# stepping and value injection to persist. When that happens, make the call
# to finalize optional.
stepper_ui = curses_ui.CursesUI(
on_ui_exit=(node_stepper.restore_variable_values
if restore_variable_values_on_exit else None))
stepper_ui.register_command_handler(
"list_sorted_nodes",
stepper.list_sorted_nodes,
stepper.arg_parsers["list_sorted_nodes"].format_help(),
prefix_aliases=["lt", "lsn"])
stepper_ui.register_command_handler(
"cont",
stepper.cont,
stepper.arg_parsers["cont"].format_help(),
prefix_aliases=["ct", "c"])
stepper_ui.register_command_handler(
"step",
stepper.step,
stepper.arg_parsers["step"].format_help(),
prefix_aliases=["st", "s"])
stepper_ui.register_command_handler(
"print_tensor",
stepper.print_tensor,
stepper.arg_parsers["print_tensor"].format_help(),
prefix_aliases=["pt"])
stepper_ui.register_command_handler(
"inject_value",
stepper.inject_value,
stepper.arg_parsers["inject_value"].format_help(),
prefix_aliases=["inject", "override_value", "override"])
# Register tab completion candidates.
stepper_ui.register_tab_comp_context([
"cont", "ct", "c", "pt", "inject_value", "inject", "override_value",
"override"
], [str(elem) for elem in node_stepper.sorted_nodes()])
# TODO(cais): Tie up register_tab_comp_context to a single alias to shorten
# calls like this.
return stepper_ui.run_ui(
init_command="lt",
title="Node Stepper: " + self._run_description,
title_color="blue_on_white")