tfdbg: initial check-in of stepper CLI
This check-in implements the following commands (and aliases) in the stepper CLI: * cont (ct, c) * step (st, s) * list_sorted_nodes (lt) * print_tensor (pt) * inject_value (inject, override) To activate the stepper CLI, enter command "step" or "s" at the run-start CLI. Not implemented yet: * Documentation in g3doc/how_tos/debugger will be submitted in a later CL. * Commands for inspecting node information and graph structure in the stepper CLI, including ni, li and lo. Change: 141518037
This commit is contained in:
parent
4a29f9f6a0
commit
f37c468252
@ -48,6 +48,7 @@ py_library(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
":debug_data",
|
":debug_data",
|
||||||
|
"//tensorflow/python:data_flow_ops",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -57,6 +58,7 @@ py_library(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
":debug_utils",
|
":debug_utils",
|
||||||
|
":stepper",
|
||||||
"//tensorflow/python:session",
|
"//tensorflow/python:session",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -85,7 +87,9 @@ py_library(
|
|||||||
srcs = ["cli/cli_shared.py"],
|
srcs = ["cli/cli_shared.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
|
":command_parser",
|
||||||
":debugger_cli_common",
|
":debugger_cli_common",
|
||||||
|
":tensor_format",
|
||||||
"//tensorflow/python:framework",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
],
|
],
|
||||||
@ -96,6 +100,7 @@ py_library(
|
|||||||
srcs = ["cli/analyzer_cli.py"],
|
srcs = ["cli/analyzer_cli.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
|
":cli_shared",
|
||||||
":command_parser",
|
":command_parser",
|
||||||
":debug_data",
|
":debug_data",
|
||||||
":debugger_cli_common",
|
":debugger_cli_common",
|
||||||
@ -103,6 +108,19 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "stepper_cli",
|
||||||
|
srcs = ["cli/stepper_cli.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":cli_shared",
|
||||||
|
":command_parser",
|
||||||
|
":debugger_cli_common",
|
||||||
|
":stepper",
|
||||||
|
":tensor_format",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "curses_ui",
|
name = "curses_ui",
|
||||||
srcs = ["cli/curses_ui.py"],
|
srcs = ["cli/curses_ui.py"],
|
||||||
@ -125,6 +143,7 @@ py_library(
|
|||||||
":debug_data",
|
":debug_data",
|
||||||
":debugger_cli_common",
|
":debugger_cli_common",
|
||||||
":framework",
|
":framework",
|
||||||
|
":stepper_cli",
|
||||||
"//tensorflow/python:session",
|
"//tensorflow/python:session",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -135,6 +154,7 @@ py_library(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
":local_cli_wrapper",
|
":local_cli_wrapper",
|
||||||
|
":stepper",
|
||||||
"//tensorflow/python:session",
|
"//tensorflow/python:session",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -223,6 +243,7 @@ py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":debug_data",
|
":debug_data",
|
||||||
":framework",
|
":framework",
|
||||||
|
":stepper",
|
||||||
"//tensorflow/python:framework",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
@ -341,6 +362,22 @@ cuda_py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cuda_py_test(
|
||||||
|
name = "stepper_cli_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = [
|
||||||
|
"cli/stepper_cli_test.py",
|
||||||
|
],
|
||||||
|
additional_deps = [
|
||||||
|
":stepper",
|
||||||
|
":stepper_cli",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
|
"//tensorflow/python:framework",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"//tensorflow/python:session",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "local_cli_wrapper_test",
|
name = "local_cli_wrapper_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
@ -27,13 +27,12 @@ import argparse
|
|||||||
import copy
|
import copy
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
|
|
||||||
from tensorflow.python.debug import debug_data
|
from tensorflow.python.debug import debug_data
|
||||||
|
from tensorflow.python.debug.cli import cli_shared
|
||||||
from tensorflow.python.debug.cli import command_parser
|
from tensorflow.python.debug.cli import command_parser
|
||||||
from tensorflow.python.debug.cli import debugger_cli_common
|
from tensorflow.python.debug.cli import debugger_cli_common
|
||||||
from tensorflow.python.debug.cli import tensor_format
|
|
||||||
|
|
||||||
|
|
||||||
# String constants for the depth-dependent hanging indent at the beginning
|
# String constants for the depth-dependent hanging indent at the beginning
|
||||||
@ -69,10 +68,6 @@ class DebugAnalyzer(object):
|
|||||||
# Argument parsers for command handlers.
|
# Argument parsers for command handlers.
|
||||||
self._arg_parsers = {}
|
self._arg_parsers = {}
|
||||||
|
|
||||||
# Default threshold number of elements above which ellipses will be used
|
|
||||||
# when printing the value of the tensor.
|
|
||||||
self.default_ndarray_display_threshold = 2000
|
|
||||||
|
|
||||||
# Parser for list_tensors.
|
# Parser for list_tensors.
|
||||||
ap = argparse.ArgumentParser(
|
ap = argparse.ArgumentParser(
|
||||||
description="List dumped intermediate tensors.",
|
description="List dumped intermediate tensors.",
|
||||||
@ -222,11 +217,6 @@ class DebugAnalyzer(object):
|
|||||||
|
|
||||||
# TODO(cais): Implement list_nodes.
|
# TODO(cais): Implement list_nodes.
|
||||||
|
|
||||||
def _error(self, msg):
|
|
||||||
full_msg = "ERROR: " + msg
|
|
||||||
return debugger_cli_common.RichTextLines(
|
|
||||||
[full_msg], font_attr_segs={0: [(0, len(full_msg), "red")]})
|
|
||||||
|
|
||||||
def add_tensor_filter(self, filter_name, filter_callable):
|
def add_tensor_filter(self, filter_name, filter_callable):
|
||||||
"""Add a tensor filter.
|
"""Add a tensor filter.
|
||||||
|
|
||||||
@ -329,7 +319,7 @@ class DebugAnalyzer(object):
|
|||||||
try:
|
try:
|
||||||
filter_callable = self.get_tensor_filter(parsed.tensor_filter)
|
filter_callable = self.get_tensor_filter(parsed.tensor_filter)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return self._error(
|
return cli_shared.error(
|
||||||
"There is no tensor filter named \"%s\"." % parsed.tensor_filter)
|
"There is no tensor filter named \"%s\"." % parsed.tensor_filter)
|
||||||
|
|
||||||
data_to_show = self._debug_dump.find(filter_callable)
|
data_to_show = self._debug_dump.find(filter_callable)
|
||||||
@ -392,7 +382,7 @@ class DebugAnalyzer(object):
|
|||||||
parsed.node_name)
|
parsed.node_name)
|
||||||
|
|
||||||
if not self._debug_dump.node_exists(node_name):
|
if not self._debug_dump.node_exists(node_name):
|
||||||
return self._error(
|
return cli_shared.error(
|
||||||
"There is no node named \"%s\" in the partition graphs" % node_name)
|
"There is no node named \"%s\" in the partition graphs" % node_name)
|
||||||
|
|
||||||
# TODO(cais): Provide UI glossary feature to explain to users what the
|
# TODO(cais): Provide UI glossary feature to explain to users what the
|
||||||
@ -487,24 +477,19 @@ class DebugAnalyzer(object):
|
|||||||
np_printoptions = {}
|
np_printoptions = {}
|
||||||
|
|
||||||
# Determine if any range-highlighting is required.
|
# Determine if any range-highlighting is required.
|
||||||
highlight_options = self._parse_ranges_highlight(parsed.ranges)
|
highlight_options = cli_shared.parse_ranges_highlight(parsed.ranges)
|
||||||
|
|
||||||
# Determine if there parsed.tensor_name contains any indexing (slicing).
|
tensor_name, tensor_slicing = (
|
||||||
if parsed.tensor_name.count("[") == 1 and parsed.tensor_name.endswith("]"):
|
command_parser.parse_tensor_name_with_slicing(parsed.tensor_name))
|
||||||
tensor_name = parsed.tensor_name[:parsed.tensor_name.index("[")]
|
|
||||||
tensor_slicing = parsed.tensor_name[parsed.tensor_name.index("["):]
|
|
||||||
else:
|
|
||||||
tensor_name = parsed.tensor_name
|
|
||||||
tensor_slicing = ""
|
|
||||||
|
|
||||||
node_name, output_slot = debug_data.parse_node_or_tensor_name(tensor_name)
|
node_name, output_slot = debug_data.parse_node_or_tensor_name(tensor_name)
|
||||||
if output_slot is None:
|
if output_slot is None:
|
||||||
return self._error("\"%s\" is not a valid tensor name" %
|
return cli_shared.error("\"%s\" is not a valid tensor name" %
|
||||||
parsed.tensor_name)
|
parsed.tensor_name)
|
||||||
|
|
||||||
if (self._debug_dump.loaded_partition_graphs and
|
if (self._debug_dump.loaded_partition_graphs and
|
||||||
not self._debug_dump.node_exists(node_name)):
|
not self._debug_dump.node_exists(node_name)):
|
||||||
return self._error(
|
return cli_shared.error(
|
||||||
"Node \"%s\" does not exist in partition graphs" % node_name)
|
"Node \"%s\" does not exist in partition graphs" % node_name)
|
||||||
|
|
||||||
watch_keys = self._debug_dump.debug_watch_keys(node_name)
|
watch_keys = self._debug_dump.debug_watch_keys(node_name)
|
||||||
@ -520,12 +505,12 @@ class DebugAnalyzer(object):
|
|||||||
|
|
||||||
if not matching_data:
|
if not matching_data:
|
||||||
# No dump for this tensor.
|
# No dump for this tensor.
|
||||||
return self._error(
|
return cli_shared.error(
|
||||||
"Tensor \"%s\" did not generate any dumps." % parsed.tensor_name)
|
"Tensor \"%s\" did not generate any dumps." % parsed.tensor_name)
|
||||||
elif len(matching_data) == 1:
|
elif len(matching_data) == 1:
|
||||||
# There is only one dump for this tensor.
|
# There is only one dump for this tensor.
|
||||||
if parsed.number <= 0:
|
if parsed.number <= 0:
|
||||||
return self._format_tensor(
|
return cli_shared.format_tensor(
|
||||||
matching_data[0].get_tensor(),
|
matching_data[0].get_tensor(),
|
||||||
matching_data[0].watch_key,
|
matching_data[0].watch_key,
|
||||||
np_printoptions,
|
np_printoptions,
|
||||||
@ -533,7 +518,7 @@ class DebugAnalyzer(object):
|
|||||||
tensor_slicing=tensor_slicing,
|
tensor_slicing=tensor_slicing,
|
||||||
highlight_options=highlight_options)
|
highlight_options=highlight_options)
|
||||||
else:
|
else:
|
||||||
return self._error(
|
return cli_shared.error(
|
||||||
"Invalid number (%d) for tensor %s, which generated one dump." %
|
"Invalid number (%d) for tensor %s, which generated one dump." %
|
||||||
(parsed.number, parsed.tensor_name))
|
(parsed.number, parsed.tensor_name))
|
||||||
else:
|
else:
|
||||||
@ -556,12 +541,12 @@ class DebugAnalyzer(object):
|
|||||||
|
|
||||||
return debugger_cli_common.RichTextLines(lines)
|
return debugger_cli_common.RichTextLines(lines)
|
||||||
elif parsed.number >= len(matching_data):
|
elif parsed.number >= len(matching_data):
|
||||||
return self._error(
|
return cli_shared.error(
|
||||||
"Specified number (%d) exceeds the number of available dumps "
|
"Specified number (%d) exceeds the number of available dumps "
|
||||||
"(%d) for tensor %s" %
|
"(%d) for tensor %s" %
|
||||||
(parsed.number, len(matching_data), parsed.tensor_name))
|
(parsed.number, len(matching_data), parsed.tensor_name))
|
||||||
else:
|
else:
|
||||||
return self._format_tensor(
|
return cli_shared.format_tensor(
|
||||||
matching_data[parsed.number].get_tensor(),
|
matching_data[parsed.number].get_tensor(),
|
||||||
matching_data[parsed.number].watch_key + " (dump #%d)" %
|
matching_data[parsed.number].watch_key + " (dump #%d)" %
|
||||||
parsed.number,
|
parsed.number,
|
||||||
@ -570,90 +555,6 @@ class DebugAnalyzer(object):
|
|||||||
tensor_slicing=tensor_slicing,
|
tensor_slicing=tensor_slicing,
|
||||||
highlight_options=highlight_options)
|
highlight_options=highlight_options)
|
||||||
|
|
||||||
def _parse_ranges_highlight(self, ranges_string):
|
|
||||||
"""Process ranges highlight string.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ranges_string: (str) A string representing a numerical range of a list of
|
|
||||||
numerical ranges. See the help info of the -r flag of the print_tensor
|
|
||||||
command for more details.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An instance of tensor_format.HighlightOptions, if range_string is a valid
|
|
||||||
representation of a range or a list of ranges.
|
|
||||||
"""
|
|
||||||
|
|
||||||
ranges = None
|
|
||||||
|
|
||||||
def ranges_filter(x):
|
|
||||||
r = np.zeros(x.shape, dtype=bool)
|
|
||||||
for rng_start, rng_end in ranges:
|
|
||||||
r = np.logical_or(r, np.logical_and(x >= rng_start, x <= rng_end))
|
|
||||||
|
|
||||||
return r
|
|
||||||
|
|
||||||
if ranges_string:
|
|
||||||
ranges = command_parser.parse_ranges(ranges_string)
|
|
||||||
return tensor_format.HighlightOptions(
|
|
||||||
ranges_filter, description=ranges_string)
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _format_tensor(self,
|
|
||||||
tensor,
|
|
||||||
watch_key,
|
|
||||||
np_printoptions,
|
|
||||||
print_all=False,
|
|
||||||
tensor_slicing=None,
|
|
||||||
highlight_options=None):
|
|
||||||
"""Generate formatted str to represent a tensor or its slices.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tensor: (numpy ndarray) The tensor value.
|
|
||||||
watch_key: (str) Tensor debug watch key.
|
|
||||||
np_printoptions: (dict) Numpy tensor formatting options.
|
|
||||||
print_all: (bool) Whether the tensor is to be displayed in its entirety,
|
|
||||||
instead of printing ellipses, even if its number of elements exceeds
|
|
||||||
the default numpy display threshold.
|
|
||||||
(Note: Even if this is set to true, the screen output can still be cut
|
|
||||||
off by the UI frontend if it consist of more lines than the frontend
|
|
||||||
can handle.)
|
|
||||||
tensor_slicing: (str or None) Slicing of the tensor, e.g., "[:, 1]". If
|
|
||||||
None, no slicing will be performed on the tensor.
|
|
||||||
highlight_options: (tensor_format.HighlightOptions) options to highlight
|
|
||||||
elements of the tensor. See the doc of tensor_format.format_tensor()
|
|
||||||
for more details.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(str) Formatted str representing the (potentially sliced) tensor.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If tehsor_slicing is not a valid numpy ndarray slicing str.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if tensor_slicing:
|
|
||||||
# Validate the indexing.
|
|
||||||
if not command_parser.validate_slicing_string(tensor_slicing):
|
|
||||||
raise ValueError("Invalid tensor-slicing string.")
|
|
||||||
|
|
||||||
value = eval("tensor" + tensor_slicing) # pylint: disable=eval-used
|
|
||||||
sliced_name = watch_key + tensor_slicing
|
|
||||||
else:
|
|
||||||
value = tensor
|
|
||||||
sliced_name = watch_key
|
|
||||||
|
|
||||||
if print_all:
|
|
||||||
np_printoptions["threshold"] = value.size
|
|
||||||
else:
|
|
||||||
np_printoptions["threshold"] = self.default_ndarray_display_threshold
|
|
||||||
|
|
||||||
return tensor_format.format_tensor(
|
|
||||||
value,
|
|
||||||
sliced_name,
|
|
||||||
include_metadata=True,
|
|
||||||
np_printoptions=np_printoptions,
|
|
||||||
highlight_options=highlight_options)
|
|
||||||
|
|
||||||
def list_outputs(self, args, screen_info=None):
|
def list_outputs(self, args, screen_info=None):
|
||||||
"""Command handler for inputs.
|
"""Command handler for inputs.
|
||||||
|
|
||||||
@ -729,7 +630,7 @@ class DebugAnalyzer(object):
|
|||||||
|
|
||||||
# Check if node exists.
|
# Check if node exists.
|
||||||
if not self._debug_dump.node_exists(node_name):
|
if not self._debug_dump.node_exists(node_name):
|
||||||
return self._error(
|
return cli_shared.error(
|
||||||
"There is no node named \"%s\" in the partition graphs" % node_name)
|
"There is no node named \"%s\" in the partition graphs" % node_name)
|
||||||
|
|
||||||
if recursive:
|
if recursive:
|
||||||
|
@ -17,13 +17,116 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
from tensorflow.python.debug.cli import command_parser
|
||||||
from tensorflow.python.debug.cli import debugger_cli_common
|
from tensorflow.python.debug.cli import debugger_cli_common
|
||||||
|
from tensorflow.python.debug.cli import tensor_format
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
|
|
||||||
|
|
||||||
|
# Default threshold number of elements above which ellipses will be used
|
||||||
|
# when printing the value of the tensor.
|
||||||
|
DEFAULT_NDARRAY_DISPLAY_THRESHOLD = 2000
|
||||||
|
|
||||||
|
|
||||||
|
def parse_ranges_highlight(ranges_string):
|
||||||
|
"""Process ranges highlight string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ranges_string: (str) A string representing a numerical range of a list of
|
||||||
|
numerical ranges. See the help info of the -r flag of the print_tensor
|
||||||
|
command for more details.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An instance of tensor_format.HighlightOptions, if range_string is a valid
|
||||||
|
representation of a range or a list of ranges.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ranges = None
|
||||||
|
|
||||||
|
def ranges_filter(x):
|
||||||
|
r = np.zeros(x.shape, dtype=bool)
|
||||||
|
for range_start, range_end in ranges:
|
||||||
|
r = np.logical_or(r, np.logical_and(x >= range_start, x <= range_end))
|
||||||
|
|
||||||
|
return r
|
||||||
|
|
||||||
|
if ranges_string:
|
||||||
|
ranges = command_parser.parse_ranges(ranges_string)
|
||||||
|
return tensor_format.HighlightOptions(
|
||||||
|
ranges_filter, description=ranges_string)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def format_tensor(tensor,
|
||||||
|
tensor_name,
|
||||||
|
np_printoptions,
|
||||||
|
print_all=False,
|
||||||
|
tensor_slicing=None,
|
||||||
|
highlight_options=None):
|
||||||
|
"""Generate formatted str to represent a tensor or its slices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: (numpy ndarray) The tensor value.
|
||||||
|
tensor_name: (str) Name of the tensor, e.g., the tensor's debug watch key.
|
||||||
|
np_printoptions: (dict) Numpy tensor formatting options.
|
||||||
|
print_all: (bool) Whether the tensor is to be displayed in its entirety,
|
||||||
|
instead of printing ellipses, even if its number of elements exceeds
|
||||||
|
the default numpy display threshold.
|
||||||
|
(Note: Even if this is set to true, the screen output can still be cut
|
||||||
|
off by the UI frontend if it consist of more lines than the frontend
|
||||||
|
can handle.)
|
||||||
|
tensor_slicing: (str or None) Slicing of the tensor, e.g., "[:, 1]". If
|
||||||
|
None, no slicing will be performed on the tensor.
|
||||||
|
highlight_options: (tensor_format.HighlightOptions) options to highlight
|
||||||
|
elements of the tensor. See the doc of tensor_format.format_tensor()
|
||||||
|
for more details.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(str) Formatted str representing the (potentially sliced) tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if tensor_slicing:
|
||||||
|
# Validate the indexing.
|
||||||
|
value = command_parser.evaluate_tensor_slice(tensor, tensor_slicing)
|
||||||
|
sliced_name = tensor_name + tensor_slicing
|
||||||
|
else:
|
||||||
|
value = tensor
|
||||||
|
sliced_name = tensor_name
|
||||||
|
|
||||||
|
if print_all:
|
||||||
|
np_printoptions["threshold"] = value.size
|
||||||
|
else:
|
||||||
|
np_printoptions["threshold"] = DEFAULT_NDARRAY_DISPLAY_THRESHOLD
|
||||||
|
|
||||||
|
return tensor_format.format_tensor(
|
||||||
|
value,
|
||||||
|
sliced_name,
|
||||||
|
include_metadata=True,
|
||||||
|
np_printoptions=np_printoptions,
|
||||||
|
highlight_options=highlight_options)
|
||||||
|
|
||||||
|
|
||||||
|
def error(msg):
|
||||||
|
"""Generate a RichTextLines output for error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
msg: (str) The error message.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(debugger_cli_common.RichTextLines) A representation of the error message
|
||||||
|
for screen output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
full_msg = "ERROR: " + msg
|
||||||
|
return debugger_cli_common.RichTextLines(
|
||||||
|
[full_msg], font_attr_segs={0: [(0, len(full_msg), "red")]})
|
||||||
|
|
||||||
|
|
||||||
def _get_fetch_name(fetch):
|
def _get_fetch_name(fetch):
|
||||||
"""Obtain the name or string representation of a fetch.
|
"""Obtain the name or string representation of a fetch.
|
||||||
|
|
||||||
@ -159,6 +262,12 @@ def get_run_start_intro(run_call_count,
|
|||||||
"run -f <filter_name>",
|
"run -f <filter_name>",
|
||||||
"Keep executing run() calls until a dumped tensor passes a given, "
|
"Keep executing run() calls until a dumped tensor passes a given, "
|
||||||
"registered filter (conditional breakpoint mode)"))
|
"registered filter (conditional breakpoint mode)"))
|
||||||
|
out.extend(
|
||||||
|
_recommend_command(
|
||||||
|
"invoke_stepper",
|
||||||
|
"Use the node-stepper interface, which allows you to interactively "
|
||||||
|
"step through nodes involved in the graph run() call and "
|
||||||
|
"inspect/modify their values"))
|
||||||
|
|
||||||
more_font_attr_segs = {}
|
more_font_attr_segs = {}
|
||||||
more_lines = [" Registered filter(s):"]
|
more_lines = [" Registered filter(s):"]
|
||||||
|
@ -188,8 +188,8 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
# Verify the listed names of the tensor filters.
|
# Verify the listed names of the tensor filters.
|
||||||
filter_names = set()
|
filter_names = set()
|
||||||
filter_names.add(run_start_intro.lines[20].split(" ")[-1])
|
filter_names.add(run_start_intro.lines[22].split(" ")[-1])
|
||||||
filter_names.add(run_start_intro.lines[21].split(" ")[-1])
|
filter_names.add(run_start_intro.lines[23].split(" ")[-1])
|
||||||
|
|
||||||
self.assertEqual({"filter_a", "filter_b"}, filter_names)
|
self.assertEqual({"filter_a", "filter_b"}, filter_names)
|
||||||
|
|
||||||
|
@ -176,3 +176,26 @@ def parse_ranges(range_string):
|
|||||||
type(item[0]))
|
type(item[0]))
|
||||||
|
|
||||||
return ranges
|
return ranges
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_tensor_slice(tensor, tensor_slicing):
|
||||||
|
"""Call eval on the slicing of a tensor, with validation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: (numpy ndarray) The tensor value.
|
||||||
|
tensor_slicing: (str or None) Slicing of the tensor, e.g., "[:, 1]". If
|
||||||
|
None, no slicing will be performed on the tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(numpy ndarray) The sliced tensor.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If tensor_slicing is not a valid numpy ndarray slicing str.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_ = tensor
|
||||||
|
|
||||||
|
if not validate_slicing_string(tensor_slicing):
|
||||||
|
raise ValueError("Invalid tensor-slicing string.")
|
||||||
|
|
||||||
|
return eval("tensor" + tensor_slicing) # pylint: disable=eval-used
|
||||||
|
@ -62,6 +62,7 @@ class CursesUI(object):
|
|||||||
"green": curses.COLOR_GREEN,
|
"green": curses.COLOR_GREEN,
|
||||||
"yellow": curses.COLOR_YELLOW,
|
"yellow": curses.COLOR_YELLOW,
|
||||||
"blue": curses.COLOR_BLUE,
|
"blue": curses.COLOR_BLUE,
|
||||||
|
"cyan": curses.COLOR_CYAN,
|
||||||
"magenta": curses.COLOR_MAGENTA,
|
"magenta": curses.COLOR_MAGENTA,
|
||||||
"black": curses.COLOR_BLACK,
|
"black": curses.COLOR_BLACK,
|
||||||
}
|
}
|
||||||
@ -76,7 +77,13 @@ class CursesUI(object):
|
|||||||
_ERROR_TOAST_COLOR_PAIR = "red_on_white"
|
_ERROR_TOAST_COLOR_PAIR = "red_on_white"
|
||||||
_STATUS_BAR_COLOR_PAIR = "black_on_white"
|
_STATUS_BAR_COLOR_PAIR = "black_on_white"
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, on_ui_exit=None):
|
||||||
|
"""Constructor of CursesUI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
on_ui_exit: (Callable) Callback invoked when the UI exits.
|
||||||
|
"""
|
||||||
|
|
||||||
self._screen_init()
|
self._screen_init()
|
||||||
self._screen_refresh_size()
|
self._screen_refresh_size()
|
||||||
# TODO(cais): Error out if the size of the screen is too small.
|
# TODO(cais): Error out if the size of the screen is too small.
|
||||||
@ -126,6 +133,9 @@ class CursesUI(object):
|
|||||||
# Register signal handler for SIGINT.
|
# Register signal handler for SIGINT.
|
||||||
signal.signal(signal.SIGINT, self._interrupt_handler)
|
signal.signal(signal.SIGINT, self._interrupt_handler)
|
||||||
|
|
||||||
|
# Configurable callbacks.
|
||||||
|
self._on_ui_exit = on_ui_exit
|
||||||
|
|
||||||
def _init_layout(self):
|
def _init_layout(self):
|
||||||
"""Initialize the layout of UI components.
|
"""Initialize the layout of UI components.
|
||||||
|
|
||||||
@ -272,6 +282,9 @@ class CursesUI(object):
|
|||||||
# CLI main loop.
|
# CLI main loop.
|
||||||
exit_token = self._ui_loop()
|
exit_token = self._ui_loop()
|
||||||
|
|
||||||
|
if self._on_ui_exit:
|
||||||
|
self._on_ui_exit()
|
||||||
|
|
||||||
self._screen_terminate()
|
self._screen_terminate()
|
||||||
|
|
||||||
return exit_token
|
return exit_token
|
||||||
@ -1031,7 +1044,8 @@ class CursesUI(object):
|
|||||||
# Examine whether the index information is available for the specified line
|
# Examine whether the index information is available for the specified line
|
||||||
# number.
|
# number.
|
||||||
pointer = self._output_pad_row + line_index
|
pointer = self._output_pad_row + line_index
|
||||||
if pointer in self._curr_wrapped_output.annotations:
|
if (pointer in self._curr_wrapped_output.annotations and
|
||||||
|
"i0" in self._curr_wrapped_output.annotations[pointer]):
|
||||||
indices = self._curr_wrapped_output.annotations[pointer]["i0"]
|
indices = self._curr_wrapped_output.annotations[pointer]["i0"]
|
||||||
|
|
||||||
array_indices_str = self._format_indices(indices)
|
array_indices_str = self._format_indices(indices)
|
||||||
|
@ -188,6 +188,10 @@ class RichTextLines(object):
|
|||||||
else:
|
else:
|
||||||
self._annotations[key] = other.annotations[key]
|
self._annotations[key] = other.annotations[key]
|
||||||
|
|
||||||
|
# TODO(cais): Add method append of the signature:
|
||||||
|
# def append_line(line, line_font_attr_segs)
|
||||||
|
# and refactor usage in stepper_cli.py.
|
||||||
|
|
||||||
|
|
||||||
def regex_find(orig_screen_output, regex, font_attr):
|
def regex_find(orig_screen_output, regex, font_attr):
|
||||||
"""Perform regex match in rich text lines.
|
"""Perform regex match in rich text lines.
|
||||||
|
593
tensorflow/python/debug/cli/stepper_cli.py
Normal file
593
tensorflow/python/debug/cli/stepper_cli.py
Normal file
@ -0,0 +1,593 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""CLI Backend for the Node Stepper Part of the Debugger."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import numpy as np # pylint: disable=unused-import
|
||||||
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
|
|
||||||
|
from tensorflow.python.debug import stepper
|
||||||
|
from tensorflow.python.debug.cli import cli_shared
|
||||||
|
from tensorflow.python.debug.cli import command_parser
|
||||||
|
from tensorflow.python.debug.cli import debugger_cli_common
|
||||||
|
from tensorflow.python.debug.cli import tensor_format
|
||||||
|
|
||||||
|
|
||||||
|
class NodeStepperCLI(object):
|
||||||
|
"""Command-line-interface backend of Node Stepper."""
|
||||||
|
|
||||||
|
# Possible states of an element in the transitive closure of the stepper's
|
||||||
|
# fetch(es).
|
||||||
|
# State where the element is already continued-to and a TensorHandle is
|
||||||
|
# available for the tensor.
|
||||||
|
STATE_CONT = "H"
|
||||||
|
|
||||||
|
# State where an intermediate dump of the tensor is available.
|
||||||
|
STATE_INTERMEDIATE = "I"
|
||||||
|
|
||||||
|
# State where the element is already overridden.
|
||||||
|
STATE_OVERRIDDEN = "O"
|
||||||
|
|
||||||
|
# State where the element is a placeholder (and hence cannot be continued to)
|
||||||
|
STATE_IS_PLACEHOLDER = "P"
|
||||||
|
|
||||||
|
# State where a variable's value has been updated during the lifetime of
|
||||||
|
# this NodeStepperCLI instance.
|
||||||
|
STATE_DIRTY_VARIABLE = "D"
|
||||||
|
|
||||||
|
NEXT_NODE_POINTER_STR = "-->"
|
||||||
|
|
||||||
|
_MESSAGE_TEMPLATES = {
|
||||||
|
"NOT_IN_CLOSURE":
|
||||||
|
"%s is not in the transitive closure of this stepper instance.",
|
||||||
|
"MULTIPLE_TENSORS":
|
||||||
|
"Node %s has more than one output tensor. "
|
||||||
|
"Please use full tensor name.",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, node_stepper):
|
||||||
|
self._node_stepper = node_stepper
|
||||||
|
|
||||||
|
# Command parsers for the stepper.
|
||||||
|
self.arg_parsers = {}
|
||||||
|
|
||||||
|
# Parser for "list_sorted_nodes".
|
||||||
|
ap = argparse.ArgumentParser(
|
||||||
|
description="List the state of the sorted transitive closure of the "
|
||||||
|
"stepper.",
|
||||||
|
usage=argparse.SUPPRESS)
|
||||||
|
ap.add_argument(
|
||||||
|
"-l",
|
||||||
|
"--lower_bound",
|
||||||
|
dest="lower_bound",
|
||||||
|
type=int,
|
||||||
|
default=-1,
|
||||||
|
help="Lower-bound index (0-based)")
|
||||||
|
ap.add_argument(
|
||||||
|
"-u",
|
||||||
|
"--upper_bound",
|
||||||
|
dest="upper_bound",
|
||||||
|
type=int,
|
||||||
|
default=-1,
|
||||||
|
help="Upper-bound index (0-based)")
|
||||||
|
self.arg_parsers["list_sorted_nodes"] = ap
|
||||||
|
|
||||||
|
# Parser for "cont".
|
||||||
|
ap = argparse.ArgumentParser(
|
||||||
|
description="Continue to a tensor or op.", usage=argparse.SUPPRESS)
|
||||||
|
ap.add_argument(
|
||||||
|
"target_name",
|
||||||
|
type=str,
|
||||||
|
help="Name of the Tensor or Op to continue to.")
|
||||||
|
ap.add_argument(
|
||||||
|
"-r",
|
||||||
|
"--restore_variable_values",
|
||||||
|
dest="restore_variable_values",
|
||||||
|
action="store_true",
|
||||||
|
help="Restore all variables in the transitive closure of the cont "
|
||||||
|
"target to their initial values (i.e., values when this stepper "
|
||||||
|
"instance was created.")
|
||||||
|
self.arg_parsers["cont"] = ap
|
||||||
|
|
||||||
|
# Parser for "step".
|
||||||
|
ap = argparse.ArgumentParser(
|
||||||
|
description="Step to the next tensor or op in the sorted transitive "
|
||||||
|
"closure of the stepper's fetch(es).",
|
||||||
|
usage=argparse.SUPPRESS)
|
||||||
|
ap.add_argument(
|
||||||
|
"-t",
|
||||||
|
"--num_times",
|
||||||
|
dest="num_times",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of times to step (>=1)")
|
||||||
|
self.arg_parsers["step"] = ap
|
||||||
|
|
||||||
|
# Parser for "print_tensor".
|
||||||
|
ap = argparse.ArgumentParser(
|
||||||
|
description="Print the value of a tensor, from cached TensorHandle or "
|
||||||
|
"client-provided overrides.",
|
||||||
|
usage=argparse.SUPPRESS)
|
||||||
|
ap.add_argument(
|
||||||
|
"tensor_name",
|
||||||
|
type=str,
|
||||||
|
help="Name of the tensor, followed by any slicing indices, "
|
||||||
|
"e.g., hidden1/Wx_plus_b/MatMul:0, "
|
||||||
|
"hidden1/Wx_plus_b/MatMul:0[1, :]")
|
||||||
|
ap.add_argument(
|
||||||
|
"-r",
|
||||||
|
"--ranges",
|
||||||
|
dest="ranges",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="Numerical ranges to highlight tensor elements in. "
|
||||||
|
"Examples: -r 0,1e-8, -r [-0.1,0.1], "
|
||||||
|
"-r \"[[-inf, -0.1], [0.1, inf]]\"")
|
||||||
|
ap.add_argument(
|
||||||
|
"-a",
|
||||||
|
"--all",
|
||||||
|
dest="print_all",
|
||||||
|
action="store_true",
|
||||||
|
help="Print the tensor in its entirety, i.e., do not use ellipses.")
|
||||||
|
self.arg_parsers["print_tensor"] = ap
|
||||||
|
|
||||||
|
# Parser for inject_value.
|
||||||
|
ap = argparse.ArgumentParser(
|
||||||
|
description="Inject (override) the value of a Tensor.",
|
||||||
|
usage=argparse.SUPPRESS)
|
||||||
|
ap.add_argument(
|
||||||
|
"tensor_name",
|
||||||
|
type=str,
|
||||||
|
help="Name of the Tensor of which the value is to be overridden.")
|
||||||
|
ap.add_argument(
|
||||||
|
"tensor_value_str",
|
||||||
|
type=str,
|
||||||
|
help="A string representing the value of the tensor, without any "
|
||||||
|
"whitespaces, e.g., np.zeros([10,100])")
|
||||||
|
self.arg_parsers["inject_value"] = ap
|
||||||
|
|
||||||
|
self._initialize_state()
|
||||||
|
|
||||||
|
def _initialize_state(self):
|
||||||
|
"""Initialize the state of this stepper CLI."""
|
||||||
|
|
||||||
|
# Get the elements in the sorted transitive closure, as a list of str.
|
||||||
|
self._sorted_nodes = self._node_stepper.sorted_nodes()
|
||||||
|
self._closure_elements = self._node_stepper.closure_elements()
|
||||||
|
self._placeholders = self._node_stepper.placeholders()
|
||||||
|
self._completed_nodes = set()
|
||||||
|
|
||||||
|
self._calculate_next()
|
||||||
|
|
||||||
|
def _calculate_next(self):
|
||||||
|
"""Calculate the next target for "step" action based on current state."""
|
||||||
|
|
||||||
|
override_names = self._node_stepper.override_names()
|
||||||
|
|
||||||
|
next_i = -1
|
||||||
|
for i in xrange(len(self._sorted_nodes)):
|
||||||
|
if (i > next_i and (self._sorted_nodes[i] in self._completed_nodes) or
|
||||||
|
(self._sorted_nodes[i] in override_names)):
|
||||||
|
next_i = i
|
||||||
|
|
||||||
|
next_i += 1
|
||||||
|
self._next = next_i
|
||||||
|
|
||||||
|
def list_sorted_nodes(self, args, screen_info=None):
|
||||||
|
"""List the sorted transitive closure of the stepper's fetches."""
|
||||||
|
|
||||||
|
# TODO(cais): Use pattern such as del args, del screen_info python/debug.
|
||||||
|
_ = args
|
||||||
|
_ = screen_info
|
||||||
|
|
||||||
|
parsed = self.arg_parsers["list_sorted_nodes"].parse_args(args)
|
||||||
|
|
||||||
|
if parsed.lower_bound != -1 and parsed.upper_bound != -1:
|
||||||
|
index_range = [
|
||||||
|
max(0, parsed.lower_bound),
|
||||||
|
min(len(self._sorted_nodes), parsed.upper_bound)
|
||||||
|
]
|
||||||
|
verbose = False
|
||||||
|
else:
|
||||||
|
index_range = [0, len(self._sorted_nodes)]
|
||||||
|
verbose = True
|
||||||
|
|
||||||
|
handle_node_names = self._node_stepper.handle_node_names()
|
||||||
|
override_names = self._node_stepper.override_names()
|
||||||
|
dirty_variable_names = [
|
||||||
|
dirty_variable.split(":")[0]
|
||||||
|
for dirty_variable in self._node_stepper.dirty_variables()
|
||||||
|
]
|
||||||
|
|
||||||
|
lines = []
|
||||||
|
font_attr_segs = {}
|
||||||
|
if verbose:
|
||||||
|
lines.extend(
|
||||||
|
["Topologically-sorted transitive input(s) and fetch(es):", ""])
|
||||||
|
|
||||||
|
line_counter = len(lines)
|
||||||
|
for i, element_name in enumerate(self._sorted_nodes):
|
||||||
|
if i < index_range[0] or i >= index_range[1]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
font_attr_segs[line_counter] = []
|
||||||
|
|
||||||
|
# TODO(cais): Use fixed-width text to show node index.
|
||||||
|
node_prefix = "(%d / %d)" % (i + 1, len(self._sorted_nodes))
|
||||||
|
if i == self._next:
|
||||||
|
node_prefix = " " + self.NEXT_NODE_POINTER_STR + node_prefix
|
||||||
|
font_attr_segs[line_counter].append((0, 3, "bold"))
|
||||||
|
else:
|
||||||
|
node_prefix = " " + node_prefix
|
||||||
|
|
||||||
|
node_prefix += " ["
|
||||||
|
labels, label_font_attr_segs = self._get_status_labels(
|
||||||
|
element_name,
|
||||||
|
handle_node_names,
|
||||||
|
override_names,
|
||||||
|
dirty_variable_names,
|
||||||
|
len(node_prefix))
|
||||||
|
node_prefix += labels
|
||||||
|
font_attr_segs[line_counter].extend(label_font_attr_segs)
|
||||||
|
|
||||||
|
lines.append(node_prefix + "] " + element_name)
|
||||||
|
line_counter += 1
|
||||||
|
|
||||||
|
output = debugger_cli_common.RichTextLines(
|
||||||
|
lines, font_attr_segs=font_attr_segs)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
output.extend(self._node_status_label_legend())
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def _get_status_labels(self,
|
||||||
|
element_name,
|
||||||
|
handle_node_names,
|
||||||
|
override_names,
|
||||||
|
dirty_variable_names,
|
||||||
|
offset):
|
||||||
|
"""Get a string of status labels for a graph element.
|
||||||
|
|
||||||
|
A status label indicates that a node has a certain state in this
|
||||||
|
node-stepper CLI invocation. For example, 1) that the node has been
|
||||||
|
continued-to and a handle to its output tensor is available to the node
|
||||||
|
stepper; 2) the node is a Variable and its value has been altered, e.g.,
|
||||||
|
by continuing to a variable-updating node, since the beginning of this
|
||||||
|
node-stepper invocation (i.e., "dirty variable").
|
||||||
|
|
||||||
|
Args:
|
||||||
|
element_name: (str) name of the graph element.
|
||||||
|
handle_node_names: (list of str) Names of the nodes of which the output
|
||||||
|
tensors' handles are available.
|
||||||
|
override_names: (list of str) Names of the tensors of which the values
|
||||||
|
are overridden.
|
||||||
|
dirty_variable_names: (list of str) Names of the dirty variables.
|
||||||
|
offset: (int) Initial offset of the font attribute segments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(str) The string made of status labels that currently apply to the graph
|
||||||
|
element.
|
||||||
|
(list of tuples) The font attribute segments, with offset applied.
|
||||||
|
"""
|
||||||
|
|
||||||
|
stat_string = ""
|
||||||
|
font_attr_segs = []
|
||||||
|
position = offset
|
||||||
|
|
||||||
|
node_name = element_name.split(":")[0]
|
||||||
|
if node_name in self._placeholders:
|
||||||
|
stat_string += "P"
|
||||||
|
font_attr_segs.append((position, position + 1, "cyan"))
|
||||||
|
else:
|
||||||
|
stat_string += " "
|
||||||
|
position += 1
|
||||||
|
|
||||||
|
if self._node_stepper.is_feedable(str(element_name)):
|
||||||
|
stat_string += " "
|
||||||
|
else:
|
||||||
|
stat_string += "U"
|
||||||
|
font_attr_segs.append((position, position + 1, "red"))
|
||||||
|
position += 1
|
||||||
|
|
||||||
|
if element_name in handle_node_names:
|
||||||
|
stat_string += "H"
|
||||||
|
font_attr_segs.append((position, position + 1, "green"))
|
||||||
|
else:
|
||||||
|
stat_string += " "
|
||||||
|
position += 1
|
||||||
|
|
||||||
|
slots = self._node_stepper.output_slots_in_closure(element_name)
|
||||||
|
has_override = False
|
||||||
|
for slot in slots:
|
||||||
|
if element_name + ":%d" % slot in override_names:
|
||||||
|
has_override = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if has_override:
|
||||||
|
stat_string += "O"
|
||||||
|
font_attr_segs.append((position, position + 1, "yellow"))
|
||||||
|
else:
|
||||||
|
stat_string += " "
|
||||||
|
position += 1
|
||||||
|
|
||||||
|
if element_name in dirty_variable_names:
|
||||||
|
stat_string += self.STATE_DIRTY_VARIABLE
|
||||||
|
font_attr_segs.append((position, position + 1, "magenta"))
|
||||||
|
else:
|
||||||
|
stat_string += " "
|
||||||
|
position += 1
|
||||||
|
|
||||||
|
return stat_string, font_attr_segs
|
||||||
|
|
||||||
|
def _node_status_label_legend(self):
|
||||||
|
"""Get legend for node-status labels.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(debugger_cli_common.RichTextLines) Legend text.
|
||||||
|
"""
|
||||||
|
|
||||||
|
lines = []
|
||||||
|
font_attr_segs = {}
|
||||||
|
|
||||||
|
line_counter = 0
|
||||||
|
lines.append("")
|
||||||
|
lines.append("Legend:")
|
||||||
|
line_counter += 2
|
||||||
|
|
||||||
|
lines.append(" P - Placeholder")
|
||||||
|
font_attr_segs[line_counter] = [(2, 3, "cyan")]
|
||||||
|
line_counter += 1
|
||||||
|
|
||||||
|
lines.append(" U - Unfeedable")
|
||||||
|
font_attr_segs[line_counter] = [(2, 3, "red")]
|
||||||
|
line_counter += 1
|
||||||
|
|
||||||
|
lines.append(
|
||||||
|
" H - Already continued-to; Tensor handle available from output "
|
||||||
|
"slot(s)"
|
||||||
|
)
|
||||||
|
font_attr_segs[line_counter] = [(2, 3, "green")]
|
||||||
|
line_counter += 1
|
||||||
|
|
||||||
|
lines.append(" O - Has overriding (injected) tensor value")
|
||||||
|
font_attr_segs[line_counter] = [(2, 3, "yellow")]
|
||||||
|
line_counter += 1
|
||||||
|
|
||||||
|
lines.append(
|
||||||
|
" D - Dirty variable: Variable already updated this node stepper.")
|
||||||
|
font_attr_segs[line_counter] = [(2, 3, "magenta")]
|
||||||
|
line_counter += 1
|
||||||
|
|
||||||
|
return debugger_cli_common.RichTextLines(
|
||||||
|
lines, font_attr_segs=font_attr_segs)
|
||||||
|
|
||||||
|
def cont(self, args, screen_info=None):
|
||||||
|
"""Continue-to action on the graph."""
|
||||||
|
|
||||||
|
_ = screen_info
|
||||||
|
|
||||||
|
parsed = self.arg_parsers["cont"].parse_args(args)
|
||||||
|
|
||||||
|
# Determine which node is being continued to, so the _next pointer can be
|
||||||
|
# set properly.
|
||||||
|
node_name = parsed.target_name.split(":")[0]
|
||||||
|
if node_name not in self._sorted_nodes:
|
||||||
|
return cli_shared.error(self._MESSAGE_TEMPLATES["NOT_IN_CLOSURE"] %
|
||||||
|
parsed.target_name)
|
||||||
|
self._next = self._sorted_nodes.index(node_name)
|
||||||
|
|
||||||
|
cont_result = self._node_stepper.cont(
|
||||||
|
parsed.target_name,
|
||||||
|
restore_variable_values=parsed.restore_variable_values)
|
||||||
|
self._completed_nodes.add(parsed.target_name.split(":")[0])
|
||||||
|
|
||||||
|
feed_types = self._node_stepper.last_feed_types()
|
||||||
|
|
||||||
|
lines = ["Continued to %s:" % parsed.target_name, ""]
|
||||||
|
font_attr_segs = {}
|
||||||
|
lines.append("Stepper used feeds:")
|
||||||
|
line_counter = len(lines)
|
||||||
|
|
||||||
|
if feed_types:
|
||||||
|
for feed_name in feed_types:
|
||||||
|
feed_info_line = " %s : %s" % (feed_name, feed_types[feed_name])
|
||||||
|
lines.append(feed_info_line)
|
||||||
|
if feed_types[feed_name] == stepper.NodeStepper.FEED_TYPE_HANDLE:
|
||||||
|
font_attr_segs[line_counter] = [
|
||||||
|
(len(feed_name) + 2, len(feed_info_line), "green")
|
||||||
|
]
|
||||||
|
elif feed_types[feed_name] == stepper.NodeStepper.FEED_TYPE_OVERRIDE:
|
||||||
|
font_attr_segs[line_counter] = [
|
||||||
|
(len(feed_name) + 2, len(feed_info_line), "yellow")
|
||||||
|
]
|
||||||
|
line_counter += 1
|
||||||
|
else:
|
||||||
|
lines.append(" (No feeds)")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
screen_output = debugger_cli_common.RichTextLines(
|
||||||
|
lines, font_attr_segs=font_attr_segs)
|
||||||
|
|
||||||
|
tensor_output = tensor_format.format_tensor(
|
||||||
|
cont_result, parsed.target_name,
|
||||||
|
include_metadata=True)
|
||||||
|
screen_output.extend(tensor_output)
|
||||||
|
|
||||||
|
# Generate windowed view of the sorted transitive closure on which the
|
||||||
|
# stepping is occurring.
|
||||||
|
lower_bound = max(0, self._next - 2)
|
||||||
|
upper_bound = min(len(self._sorted_nodes), self._next + 3)
|
||||||
|
|
||||||
|
final_output = self.list_sorted_nodes(
|
||||||
|
["-l", str(lower_bound), "-u", str(upper_bound)])
|
||||||
|
final_output.extend(debugger_cli_common.RichTextLines([""]))
|
||||||
|
final_output.extend(screen_output)
|
||||||
|
|
||||||
|
# Re-calculate the target of the next "step" action.
|
||||||
|
self._calculate_next()
|
||||||
|
|
||||||
|
return final_output
|
||||||
|
|
||||||
|
def step(self, args, screen_info=None):
|
||||||
|
"""Step once.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: (list of str) command-line arguments for the "step" command.
|
||||||
|
screen_info: Information about screen.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(RichTextLines) Screen output for the result of the stepping action.
|
||||||
|
"""
|
||||||
|
|
||||||
|
parsed = self.arg_parsers["step"].parse_args(args)
|
||||||
|
|
||||||
|
if parsed.num_times < 0:
|
||||||
|
return debugger_cli_common.RichTextLines(
|
||||||
|
"ERROR: Invalid number of times to step: %d" % parsed.num_times)
|
||||||
|
|
||||||
|
for _ in xrange(parsed.num_times):
|
||||||
|
if self._next >= len(self._sorted_nodes):
|
||||||
|
return debugger_cli_common.RichTextLines(
|
||||||
|
"ERROR: Cannot step any further because the end of the sorted "
|
||||||
|
"transitive closure has been reached.")
|
||||||
|
else:
|
||||||
|
screen_output = self.cont([self._sorted_nodes[self._next]], screen_info)
|
||||||
|
|
||||||
|
return screen_output
|
||||||
|
|
||||||
|
def print_tensor(self, args, screen_info=None):
|
||||||
|
"""Print the value of a tensor that the stepper has access to."""
|
||||||
|
|
||||||
|
parsed = self.arg_parsers["print_tensor"].parse_args(args)
|
||||||
|
|
||||||
|
if screen_info and "cols" in screen_info:
|
||||||
|
np_printoptions = {"linewidth": screen_info["cols"]}
|
||||||
|
else:
|
||||||
|
np_printoptions = {}
|
||||||
|
|
||||||
|
# Determine if any range-highlighting is required.
|
||||||
|
highlight_options = cli_shared.parse_ranges_highlight(parsed.ranges)
|
||||||
|
|
||||||
|
tensor_name, tensor_slicing = (
|
||||||
|
command_parser.parse_tensor_name_with_slicing(parsed.tensor_name))
|
||||||
|
|
||||||
|
tensor_names = self._resolve_tensor_names(tensor_name)
|
||||||
|
if not tensor_names:
|
||||||
|
return cli_shared.error(
|
||||||
|
self._MESSAGE_TEMPLATES["NOT_IN_CLOSURE"] % tensor_name)
|
||||||
|
elif len(tensor_names) > 1:
|
||||||
|
return cli_shared.error(
|
||||||
|
self._MESSAGE_TEMPLATES["MULTIPLE_TENSORS"] % tensor_name)
|
||||||
|
else:
|
||||||
|
tensor_name = tensor_names[0]
|
||||||
|
|
||||||
|
try:
|
||||||
|
tensor_value = self._node_stepper.get_tensor_value(tensor_name)
|
||||||
|
except ValueError as e:
|
||||||
|
return debugger_cli_common.RichTextLines([str(e)])
|
||||||
|
|
||||||
|
return cli_shared.format_tensor(
|
||||||
|
tensor_value,
|
||||||
|
tensor_name,
|
||||||
|
np_printoptions,
|
||||||
|
print_all=parsed.print_all,
|
||||||
|
tensor_slicing=tensor_slicing,
|
||||||
|
highlight_options=highlight_options)
|
||||||
|
|
||||||
|
def inject_value(self, args, screen_info=None):
|
||||||
|
"""Inject value to a given tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: (list of str) command-line arguments for the "step" command.
|
||||||
|
screen_info: Information about screen.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(RichTextLines) Screen output for the result of the stepping action.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_ = screen_info # Currently unused.
|
||||||
|
|
||||||
|
if screen_info and "cols" in screen_info:
|
||||||
|
np_printoptions = {"linewidth": screen_info["cols"]}
|
||||||
|
else:
|
||||||
|
np_printoptions = {}
|
||||||
|
|
||||||
|
parsed = self.arg_parsers["inject_value"].parse_args(args)
|
||||||
|
|
||||||
|
tensor_names = self._resolve_tensor_names(parsed.tensor_name)
|
||||||
|
if not tensor_names:
|
||||||
|
return cli_shared.error(
|
||||||
|
self._MESSAGE_TEMPLATES["NOT_IN_CLOSURE"] % parsed.tensor_name)
|
||||||
|
elif len(tensor_names) > 1:
|
||||||
|
return cli_shared.error(
|
||||||
|
self._MESSAGE_TEMPLATES["MULTIPLE_TENSORS"] % parsed.tensor_name)
|
||||||
|
else:
|
||||||
|
tensor_name = tensor_names[0]
|
||||||
|
|
||||||
|
tensor_value = eval(parsed.tensor_value_str) # pylint: disable=eval-used
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._node_stepper.override_tensor(tensor_name, tensor_value)
|
||||||
|
lines = [
|
||||||
|
"Injected value \"%s\"" % parsed.tensor_value_str,
|
||||||
|
" to tensor \"%s\":" % tensor_name, ""
|
||||||
|
]
|
||||||
|
|
||||||
|
tensor_lines = tensor_format.format_tensor(
|
||||||
|
tensor_value,
|
||||||
|
tensor_name,
|
||||||
|
include_metadata=True,
|
||||||
|
np_printoptions=np_printoptions).lines
|
||||||
|
lines.extend(tensor_lines)
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
lines = [
|
||||||
|
"ERROR: Failed to inject value to tensor %s" % parsed.tensor_name
|
||||||
|
]
|
||||||
|
|
||||||
|
return debugger_cli_common.RichTextLines(lines)
|
||||||
|
|
||||||
|
# TODO(cais): Implement list_inputs
|
||||||
|
# TODO(cais): Implement list_outputs
|
||||||
|
# TODO(cais): Implement node_info
|
||||||
|
|
||||||
|
def _resolve_tensor_names(self, element_name):
|
||||||
|
"""Resolve tensor name from graph element name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
element_name: (str) Name of the graph element to resolve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(list) Name of the tensor(s). If element_name is the name of a tensor in
|
||||||
|
the transitive closure, return [element_name]. If element_name is the
|
||||||
|
name of a node in the transitive closure, return the list of output
|
||||||
|
tensors from the node that are in the transitive closure. Otherwise,
|
||||||
|
return empty list.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if element_name in self._closure_elements and ":" in element_name:
|
||||||
|
return [element_name]
|
||||||
|
if (element_name in self._sorted_nodes or
|
||||||
|
(element_name in self._closure_elements and ":" not in element_name)):
|
||||||
|
slots = self._node_stepper.output_slots_in_closure(element_name)
|
||||||
|
return [(element_name + ":%d" % slot) for slot in slots]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
444
tensorflow/python/debug/cli/stepper_cli_test.py
Normal file
444
tensorflow/python/debug/cli/stepper_cli_test.py
Normal file
@ -0,0 +1,444 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests of the Stepper CLI Backend."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow.python.debug import stepper
|
||||||
|
from tensorflow.python.debug.cli import stepper_cli
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.platform import googletest
|
||||||
|
|
||||||
|
# Regex pattern for a node line in the stepper CLI output.
|
||||||
|
NODE_LINE_PATTERN = re.compile(r".*\(.*\).*\[.*\].*")
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_sorted_nodes_list(lines):
|
||||||
|
"""Parsed a list of lines to extract the node list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lines: (list of str) Lines from which the node list and associated
|
||||||
|
information will be extracted.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(list of str) The list of node names.
|
||||||
|
(list of str) The list of status labels.
|
||||||
|
(int) 0-based index among the nodes for the node pointed by the next-node
|
||||||
|
pointer. If no such node exists, -1.
|
||||||
|
"""
|
||||||
|
|
||||||
|
node_names = []
|
||||||
|
status_labels = []
|
||||||
|
node_pointer = -1
|
||||||
|
|
||||||
|
node_line_counter = 0
|
||||||
|
for line in lines:
|
||||||
|
if NODE_LINE_PATTERN.match(line):
|
||||||
|
node_names.append(line.split(" ")[-1])
|
||||||
|
|
||||||
|
idx_left_bracket = line.index("[")
|
||||||
|
idx_right_bracket = line.index("]")
|
||||||
|
status_labels.append(line[idx_left_bracket + 1:idx_right_bracket])
|
||||||
|
if line.strip().startswith(
|
||||||
|
stepper_cli.NodeStepperCLI.NEXT_NODE_POINTER_STR):
|
||||||
|
node_pointer = node_line_counter
|
||||||
|
|
||||||
|
node_line_counter += 1
|
||||||
|
|
||||||
|
return node_names, status_labels, node_pointer
|
||||||
|
|
||||||
|
|
||||||
|
def _parsed_used_feeds(lines):
|
||||||
|
feed_types = {}
|
||||||
|
|
||||||
|
begin_line = -1
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
if line.startswith("Stepper used feeds:"):
|
||||||
|
begin_line = i + 1
|
||||||
|
break
|
||||||
|
|
||||||
|
if begin_line == -1:
|
||||||
|
return feed_types
|
||||||
|
|
||||||
|
for line in lines[begin_line:]:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
return feed_types
|
||||||
|
else:
|
||||||
|
feed_name = line.split(" : ")[0].strip()
|
||||||
|
feed_type = line.split(" : ")[1].strip()
|
||||||
|
feed_types[feed_name] = feed_type
|
||||||
|
|
||||||
|
|
||||||
|
class NodeStepperSimpleGraphTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.a = tf.Variable(10.0, name="a")
|
||||||
|
self.b = tf.Variable(20.0, name="b")
|
||||||
|
|
||||||
|
self.c = tf.add(self.a, self.b, name="c") # Should be 30.0.
|
||||||
|
self.d = tf.sub(self.a, self.c, name="d") # Should be -20.0.
|
||||||
|
self.e = tf.mul(self.c, self.d, name="e") # Should be -600.0.
|
||||||
|
|
||||||
|
self.ph = tf.placeholder(tf.float32, shape=(2, 2), name="ph")
|
||||||
|
self.f = tf.mul(self.e, self.ph, name="f")
|
||||||
|
|
||||||
|
self.opt = tf.train.GradientDescentOptimizer(0.1).minimize(
|
||||||
|
self.e, name="opt")
|
||||||
|
|
||||||
|
self.sess = tf.Session()
|
||||||
|
|
||||||
|
self.sess.run(self.a.initializer)
|
||||||
|
self.sess.run(self.b.initializer)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
tf.reset_default_graph()
|
||||||
|
|
||||||
|
def _assert_nodes_topologically_sorted_with_target_e(self, node_names):
|
||||||
|
"""Check the topologically sorted order of the node names."""
|
||||||
|
|
||||||
|
self.assertGreaterEqual(len(node_names), 7)
|
||||||
|
self.assertLess(node_names.index("a"), node_names.index("a/read"))
|
||||||
|
self.assertLess(node_names.index("b"), node_names.index("b/read"))
|
||||||
|
self.assertLess(node_names.index("a/read"), node_names.index("c"))
|
||||||
|
self.assertLess(node_names.index("b/read"), node_names.index("c"))
|
||||||
|
self.assertLess(node_names.index("a/read"), node_names.index("d"))
|
||||||
|
self.assertLess(node_names.index("c"), node_names.index("d"))
|
||||||
|
self.assertLess(node_names.index("c"), node_names.index("e"))
|
||||||
|
self.assertLess(node_names.index("d"), node_names.index("e"))
|
||||||
|
|
||||||
|
def _assert_nodes_topologically_sorted_with_target_f(self, node_names):
|
||||||
|
self._assert_nodes_topologically_sorted_with_target_e(node_names)
|
||||||
|
|
||||||
|
self.assertGreaterEqual(len(node_names), 9)
|
||||||
|
self.assertLess(node_names.index("ph"), node_names.index("f"))
|
||||||
|
self.assertLess(node_names.index("e"), node_names.index("f"))
|
||||||
|
|
||||||
|
def testListingSortedNodesPresentsTransitveClosure(self):
|
||||||
|
cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))
|
||||||
|
|
||||||
|
output = cli.list_sorted_nodes([])
|
||||||
|
node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
|
||||||
|
output.lines)
|
||||||
|
|
||||||
|
self._assert_nodes_topologically_sorted_with_target_e(node_names)
|
||||||
|
self.assertEqual(len(node_names), len(stat_labels))
|
||||||
|
for stat_label in stat_labels:
|
||||||
|
self.assertEqual(" ", stat_label)
|
||||||
|
self.assertEqual(0, node_pointer)
|
||||||
|
|
||||||
|
def testListingSortedNodesLabelsPlaceholders(self):
|
||||||
|
cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.f))
|
||||||
|
|
||||||
|
output = cli.list_sorted_nodes([])
|
||||||
|
node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
|
||||||
|
output.lines)
|
||||||
|
|
||||||
|
self._assert_nodes_topologically_sorted_with_target_f(node_names)
|
||||||
|
|
||||||
|
index_ph = node_names.index("ph")
|
||||||
|
self.assertEqual(len(node_names), len(stat_labels))
|
||||||
|
for i in xrange(len(stat_labels)):
|
||||||
|
if index_ph == i:
|
||||||
|
self.assertIn(stepper_cli.NodeStepperCLI.STATE_IS_PLACEHOLDER,
|
||||||
|
stat_labels[i])
|
||||||
|
else:
|
||||||
|
self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_IS_PLACEHOLDER,
|
||||||
|
stat_labels[i])
|
||||||
|
|
||||||
|
self.assertEqual(0, node_pointer)
|
||||||
|
|
||||||
|
def testContToNonexistentNodeShouldError(self):
|
||||||
|
cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.f))
|
||||||
|
|
||||||
|
output = cli.cont(["foobar"])
|
||||||
|
self.assertEqual(
|
||||||
|
["ERROR: foobar is not in the transitive closure of this stepper "
|
||||||
|
"instance."], output.lines)
|
||||||
|
|
||||||
|
def testContToNodeOutsideTransitiveClosureShouldError(self):
|
||||||
|
cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))
|
||||||
|
|
||||||
|
output = cli.cont(["f"])
|
||||||
|
self.assertEqual(
|
||||||
|
["ERROR: f is not in the transitive closure of this stepper "
|
||||||
|
"instance."], output.lines)
|
||||||
|
|
||||||
|
def testContToValidNodeShouldUpdateStatus(self):
|
||||||
|
cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))
|
||||||
|
|
||||||
|
output = cli.list_sorted_nodes([])
|
||||||
|
node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
|
||||||
|
output.lines)
|
||||||
|
|
||||||
|
index_c = node_names.index("c")
|
||||||
|
self.assertEqual(" ", stat_labels[index_c])
|
||||||
|
self.assertEqual(0, node_pointer)
|
||||||
|
|
||||||
|
output = cli.cont("c")
|
||||||
|
node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
|
||||||
|
output.lines)
|
||||||
|
|
||||||
|
self.assertGreaterEqual(len(node_names), 3)
|
||||||
|
self.assertIn("c", node_names)
|
||||||
|
index_c = node_names.index("c")
|
||||||
|
self.assertEqual(index_c, node_pointer)
|
||||||
|
self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_c])
|
||||||
|
|
||||||
|
output = cli.cont("d")
|
||||||
|
node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
|
||||||
|
output.lines)
|
||||||
|
|
||||||
|
used_feed_types = _parsed_used_feeds(output.lines)
|
||||||
|
self.assertEqual({"c:0": "handle"}, used_feed_types)
|
||||||
|
|
||||||
|
self.assertGreaterEqual(len(node_names), 3)
|
||||||
|
self.assertIn("d", node_names)
|
||||||
|
index_d = node_names.index("d")
|
||||||
|
self.assertEqual(index_d, node_pointer)
|
||||||
|
self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_d])
|
||||||
|
|
||||||
|
def testSteppingOneStepAtATimeShouldUpdateStatus(self):
|
||||||
|
cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))
|
||||||
|
|
||||||
|
output = cli.list_sorted_nodes([])
|
||||||
|
orig_node_names, _, node_pointer = _parse_sorted_nodes_list(output.lines)
|
||||||
|
self.assertEqual(0, node_pointer)
|
||||||
|
|
||||||
|
for i in xrange(len(orig_node_names)):
|
||||||
|
output = cli.step([])
|
||||||
|
node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
|
||||||
|
output.lines)
|
||||||
|
|
||||||
|
next_node_name = node_names[node_pointer]
|
||||||
|
self.assertEqual(orig_node_names[i], next_node_name)
|
||||||
|
|
||||||
|
self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT,
|
||||||
|
stat_labels[node_pointer])
|
||||||
|
|
||||||
|
# The order in which the nodes are listed should not change as the
|
||||||
|
# stepping happens.
|
||||||
|
output = cli.list_sorted_nodes([])
|
||||||
|
node_names, _, node_pointer = _parse_sorted_nodes_list(output.lines)
|
||||||
|
self.assertEqual(orig_node_names, node_names)
|
||||||
|
|
||||||
|
if i < len(orig_node_names) - 1:
|
||||||
|
self.assertEqual(i + 1, node_pointer)
|
||||||
|
else:
|
||||||
|
# Stepped over the limit. Pointer should be at -1.
|
||||||
|
self.assertEqual(-1, node_pointer)
|
||||||
|
|
||||||
|
# Attempt to step once more after the end has been reached should error out.
|
||||||
|
output = cli.step([])
|
||||||
|
self.assertEqual([
|
||||||
|
"ERROR: Cannot step any further because the end of the sorted "
|
||||||
|
"transitive closure has been reached."
|
||||||
|
], output.lines)
|
||||||
|
|
||||||
|
def testSteppingMultipleStepsUpdatesStatus(self):
|
||||||
|
cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))
|
||||||
|
|
||||||
|
output = cli.list_sorted_nodes([])
|
||||||
|
orig_node_names, _, _ = _parse_sorted_nodes_list(output.lines)
|
||||||
|
|
||||||
|
output = cli.step(["-t", "3"])
|
||||||
|
node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
|
||||||
|
output.lines)
|
||||||
|
|
||||||
|
self.assertEqual(orig_node_names[2], node_names[node_pointer])
|
||||||
|
|
||||||
|
for i in xrange(node_pointer):
|
||||||
|
self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[i])
|
||||||
|
|
||||||
|
for i in xrange(node_pointer + 1, len(stat_labels)):
|
||||||
|
self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[i])
|
||||||
|
|
||||||
|
def testContToNodeWithoutOutputTensorInClosureShowsNoHandleCached(self):
|
||||||
|
node_stepper = stepper.NodeStepper(self.sess, self.opt)
|
||||||
|
|
||||||
|
sorted_nodes = node_stepper.sorted_nodes()
|
||||||
|
closure_elements = node_stepper.closure_elements()
|
||||||
|
|
||||||
|
# Find a node which is in the list of sorted nodes, but whose output tensor
|
||||||
|
# is not in the transitive closure.
|
||||||
|
no_output_node = None
|
||||||
|
for node in sorted_nodes:
|
||||||
|
if (node + ":0" not in closure_elements and
|
||||||
|
node + ":1" not in closure_elements):
|
||||||
|
no_output_node = node
|
||||||
|
break
|
||||||
|
|
||||||
|
self.assertIsNotNone(no_output_node)
|
||||||
|
|
||||||
|
cli = stepper_cli.NodeStepperCLI(node_stepper)
|
||||||
|
output = cli.cont([no_output_node])
|
||||||
|
node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
|
||||||
|
output.lines)
|
||||||
|
|
||||||
|
self.assertEqual(no_output_node, node_names[node_pointer])
|
||||||
|
self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_CONT,
|
||||||
|
stat_labels[node_pointer])
|
||||||
|
|
||||||
|
def testContToUpdateNodeLeadsToDirtyVariableLabel(self):
|
||||||
|
cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.opt))
|
||||||
|
output = cli.cont(["opt/update_b/ApplyGradientDescent"])
|
||||||
|
|
||||||
|
output = cli.list_sorted_nodes([])
|
||||||
|
node_names, stat_labels, _ = _parse_sorted_nodes_list(
|
||||||
|
output.lines)
|
||||||
|
self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
|
||||||
|
stat_labels[node_names.index("b")])
|
||||||
|
self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
|
||||||
|
stat_labels[node_names.index("a")])
|
||||||
|
|
||||||
|
def testContWithRestoreVariablesOptionShouldRestoreVariableValue(self):
|
||||||
|
cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.opt))
|
||||||
|
output = cli.cont(["opt/update_a/ApplyGradientDescent"])
|
||||||
|
|
||||||
|
# After cont() call on .../update_a/..., Variable a should have been marked
|
||||||
|
# as dirty, whereas b should not have.
|
||||||
|
output = cli.list_sorted_nodes([])
|
||||||
|
node_names, stat_labels, _ = _parse_sorted_nodes_list(
|
||||||
|
output.lines)
|
||||||
|
self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
|
||||||
|
stat_labels[node_names.index("a")])
|
||||||
|
self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
|
||||||
|
stat_labels[node_names.index("b")])
|
||||||
|
|
||||||
|
output = cli.cont(["opt/update_b/ApplyGradientDescent", "-r"])
|
||||||
|
|
||||||
|
# After cont() call on .../update_b/... with the -r flag, Variable b should
|
||||||
|
# have been marked as dirty, whereas Variable a should not be because it
|
||||||
|
# should have been restored.
|
||||||
|
output = cli.list_sorted_nodes([])
|
||||||
|
node_names, stat_labels, _ = _parse_sorted_nodes_list(
|
||||||
|
output.lines)
|
||||||
|
self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
|
||||||
|
stat_labels[node_names.index("b")])
|
||||||
|
self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
|
||||||
|
stat_labels[node_names.index("a")])
|
||||||
|
|
||||||
|
def testPrintTensorShouldWorkWithTensorName(self):
|
||||||
|
cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))
|
||||||
|
|
||||||
|
cli.cont("d")
|
||||||
|
output = cli.print_tensor(["d:0"])
|
||||||
|
|
||||||
|
self.assertEqual("Tensor \"d:0\":", output.lines[0])
|
||||||
|
self.assertEqual("-20.0", output.lines[-1])
|
||||||
|
|
||||||
|
def testPrintTensorShouldWorkWithNodeNameWithOutputTensor(self):
|
||||||
|
cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))
|
||||||
|
|
||||||
|
cli.cont("d")
|
||||||
|
output = cli.print_tensor(["d"])
|
||||||
|
|
||||||
|
self.assertEqual("Tensor \"d:0\":", output.lines[0])
|
||||||
|
self.assertEqual("-20.0", output.lines[-1])
|
||||||
|
|
||||||
|
def testPrintTensorShouldWorkSlicingString(self):
|
||||||
|
ph_value = np.array([[1.0, 0.0], [0.0, 2.0]])
|
||||||
|
cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(
|
||||||
|
self.sess,
|
||||||
|
self.f,
|
||||||
|
feed_dict={self.ph: ph_value}))
|
||||||
|
|
||||||
|
output = cli.print_tensor(["ph:0[:, 1]"])
|
||||||
|
self.assertEqual("Tensor \"ph:0[:, 1]\":", output.lines[0])
|
||||||
|
self.assertEqual(repr(ph_value[:, 1]), output.lines[-1])
|
||||||
|
|
||||||
|
output = cli.print_tensor(["ph[:, 1]"])
|
||||||
|
self.assertEqual("Tensor \"ph:0[:, 1]\":", output.lines[0])
|
||||||
|
self.assertEqual(repr(ph_value[:, 1]), output.lines[-1])
|
||||||
|
|
||||||
|
def testPrintTensorWithNonexistentTensorShouldError(self):
|
||||||
|
cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))
|
||||||
|
|
||||||
|
output = cli.print_tensor(["foobar"])
|
||||||
|
self.assertEqual(
|
||||||
|
["ERROR: foobar is not in the transitive closure of this stepper "
|
||||||
|
"instance."], output.lines)
|
||||||
|
|
||||||
|
def testPrintTensorWithNoHandleShouldError(self):
|
||||||
|
cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))
|
||||||
|
|
||||||
|
output = cli.print_tensor("e")
|
||||||
|
self.assertEqual(
|
||||||
|
["This stepper instance does not have access to the value of tensor "
|
||||||
|
"\"e:0\""], output.lines)
|
||||||
|
|
||||||
|
def testInjectTensorValueByTensorNameShouldBeReflected(self):
|
||||||
|
node_stepper = stepper.NodeStepper(self.sess, self.e)
|
||||||
|
cli = stepper_cli.NodeStepperCLI(node_stepper)
|
||||||
|
|
||||||
|
output = cli.cont(["d"])
|
||||||
|
node_names, _, node_pointer = _parse_sorted_nodes_list(output.lines)
|
||||||
|
self.assertEqual("d", node_names[node_pointer])
|
||||||
|
|
||||||
|
output = cli.list_sorted_nodes([])
|
||||||
|
node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
|
||||||
|
output.lines)
|
||||||
|
|
||||||
|
index_d = node_names.index("d")
|
||||||
|
self.assertIn(
|
||||||
|
stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_d])
|
||||||
|
self.assertNotIn(
|
||||||
|
stepper_cli.NodeStepperCLI.STATE_OVERRIDDEN, stat_labels[index_d])
|
||||||
|
|
||||||
|
self.assertAllClose(-20.0, node_stepper.get_tensor_value("d:0"))
|
||||||
|
|
||||||
|
output = cli.inject_value(["d:0", "20.0"])
|
||||||
|
|
||||||
|
# Verify that the override is available.
|
||||||
|
self.assertEqual(["d:0"], node_stepper.override_names())
|
||||||
|
|
||||||
|
# Verify that the list of sorted nodes reflects the existence of the value
|
||||||
|
# override (i.e., injection).
|
||||||
|
output = cli.list_sorted_nodes([])
|
||||||
|
node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
|
||||||
|
output.lines)
|
||||||
|
|
||||||
|
index_d = node_names.index("d")
|
||||||
|
self.assertNotIn(
|
||||||
|
stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_d])
|
||||||
|
self.assertIn(
|
||||||
|
stepper_cli.NodeStepperCLI.STATE_OVERRIDDEN, stat_labels[index_d])
|
||||||
|
|
||||||
|
def testInjectTensorValueByNodeNameShouldBeReflected(self):
|
||||||
|
node_stepper = stepper.NodeStepper(self.sess, self.e)
|
||||||
|
cli = stepper_cli.NodeStepperCLI(node_stepper)
|
||||||
|
|
||||||
|
cli.inject_value(["d", "20.0"])
|
||||||
|
self.assertEqual(["d:0"], node_stepper.override_names())
|
||||||
|
|
||||||
|
def testInjectToNonexistentTensorShouldError(self):
|
||||||
|
node_stepper = stepper.NodeStepper(self.sess, self.e)
|
||||||
|
cli = stepper_cli.NodeStepperCLI(node_stepper)
|
||||||
|
|
||||||
|
output = cli.inject_value(["foobar:0", "20.0"])
|
||||||
|
self.assertEqual(
|
||||||
|
["ERROR: foobar:0 is not in the transitive closure of this stepper "
|
||||||
|
"instance."], output.lines)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
googletest.main()
|
@ -19,6 +19,8 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
|
import six
|
||||||
|
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.python.debug import debug_data
|
from tensorflow.python.debug import debug_data
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -174,7 +176,7 @@ class NodeStepper(object):
|
|||||||
fetch_names = []
|
fetch_names = []
|
||||||
fetch_list = []
|
fetch_list = []
|
||||||
for fetch in flattened_fetches:
|
for fetch in flattened_fetches:
|
||||||
if isinstance(fetch, str):
|
if isinstance(fetch, six.string_types):
|
||||||
fetch_names.append(fetch)
|
fetch_names.append(fetch)
|
||||||
fetch_list.append(self._sess.graph.as_graph_element(fetch))
|
fetch_list.append(self._sess.graph.as_graph_element(fetch))
|
||||||
else:
|
else:
|
||||||
@ -345,7 +347,7 @@ class NodeStepper(object):
|
|||||||
(bool) whether the graph element is feedable.
|
(bool) whether the graph element is feedable.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not isinstance(name, str):
|
if not isinstance(name, six.string_types):
|
||||||
raise TypeError("Expected type str; got type %s" % type(name))
|
raise TypeError("Expected type str; got type %s" % type(name))
|
||||||
|
|
||||||
elem = self._sess.graph.as_graph_element(name)
|
elem = self._sess.graph.as_graph_element(name)
|
||||||
@ -363,7 +365,7 @@ class NodeStepper(object):
|
|||||||
tree to the fetched graph element of this stepper instance.
|
tree to the fetched graph element of this stepper instance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not isinstance(tensor_name, str):
|
if not isinstance(tensor_name, six.string_types):
|
||||||
raise TypeError("Expected type str; got type %s" % type(tensor_name))
|
raise TypeError("Expected type str; got type %s" % type(tensor_name))
|
||||||
|
|
||||||
node_name = self._get_node_name(tensor_name)
|
node_name = self._get_node_name(tensor_name)
|
||||||
@ -442,7 +444,7 @@ class NodeStepper(object):
|
|||||||
# The feeds to be used in the Session.run() call.
|
# The feeds to be used in the Session.run() call.
|
||||||
feeds = {}
|
feeds = {}
|
||||||
|
|
||||||
if isinstance(target, str):
|
if isinstance(target, six.string_types):
|
||||||
# Fetch target is a string. Assume it is the name of the Tensor or Op and
|
# Fetch target is a string. Assume it is the name of the Tensor or Op and
|
||||||
# will attempt to find it in the Session's graph.
|
# will attempt to find it in the Session's graph.
|
||||||
target_name = target
|
target_name = target
|
||||||
@ -703,16 +705,23 @@ class NodeStepper(object):
|
|||||||
The same return value as self.cont() as called on the final fetch.
|
The same return value as self.cont() as called on the final fetch.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Restore variable to their previous values.
|
self.restore_variable_values()
|
||||||
for var_name in self._cached_variable_values:
|
return self._sess.run(self._fetches, feed_dict=self._client_feed_dict)
|
||||||
|
|
||||||
|
def restore_variable_values(self):
|
||||||
|
"""Restore variables to the initial values.
|
||||||
|
|
||||||
|
"Initial value" refers to the value when this NodeStepper instance was
|
||||||
|
first constructed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
for var_name in self._dirty_variables:
|
||||||
self._sess.run(self._variable_initializers[var_name],
|
self._sess.run(self._variable_initializers[var_name],
|
||||||
feed_dict={
|
feed_dict={
|
||||||
self._variable_initial_values[var_name]:
|
self._variable_initial_values[var_name]:
|
||||||
self._cached_variable_values[var_name]
|
self._cached_variable_values[var_name]
|
||||||
})
|
})
|
||||||
|
|
||||||
return self._sess.run(self._fetches, feed_dict=self._client_feed_dict)
|
|
||||||
|
|
||||||
def handle_names(self):
|
def handle_names(self):
|
||||||
"""Return names of the TensorHandles that the debugger is holding.
|
"""Return names of the TensorHandles that the debugger is holding.
|
||||||
|
|
||||||
@ -800,7 +809,11 @@ class NodeStepper(object):
|
|||||||
or through a TensorHandle.
|
or through a TensorHandle.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if tensor_name in self._override_tensors:
|
if self.is_placeholder(tensor_name):
|
||||||
|
if ":" not in tensor_name:
|
||||||
|
tensor_name += ":0"
|
||||||
|
return self._client_feed_dict[tensor_name]
|
||||||
|
elif tensor_name in self._override_tensors:
|
||||||
return self._override_tensors[tensor_name]
|
return self._override_tensors[tensor_name]
|
||||||
elif tensor_name in self._tensor_handles:
|
elif tensor_name in self._tensor_handles:
|
||||||
return self._tensor_handles[tensor_name].eval()
|
return self._tensor_handles[tensor_name].eval()
|
||||||
|
@ -412,6 +412,23 @@ class StepperTestWithPlaceHolders(test_util.TensorFlowTestCase):
|
|||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
tf.reset_default_graph()
|
tf.reset_default_graph()
|
||||||
|
|
||||||
|
def testGetTensorValueWorksOnPlaceholder(self):
|
||||||
|
stepper = NodeStepper(
|
||||||
|
self.sess,
|
||||||
|
self.y,
|
||||||
|
feed_dict={
|
||||||
|
self.ph0: [[1.0, 2.0], [-3.0, 5.0]],
|
||||||
|
self.ph1: [[-1.0], [0.5]]
|
||||||
|
})
|
||||||
|
|
||||||
|
self.assertAllClose(
|
||||||
|
[[1.0, 2.0], [-3.0, 5.0]], stepper.get_tensor_value("ph0"))
|
||||||
|
self.assertAllClose(
|
||||||
|
[[1.0, 2.0], [-3.0, 5.0]], stepper.get_tensor_value("ph0:0"))
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
KeyError, r"The name 'ph0:1' refers to a Tensor which does not exist"):
|
||||||
|
stepper.get_tensor_value("ph0:1")
|
||||||
|
|
||||||
def testIsPlaceholdersShouldGiveCorrectAnswers(self):
|
def testIsPlaceholdersShouldGiveCorrectAnswers(self):
|
||||||
stepper = NodeStepper(self.sess, self.y)
|
stepper = NodeStepper(self.sess, self.y)
|
||||||
|
|
||||||
@ -694,6 +711,18 @@ class StepperBackwardRunTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertAllClose(1.84, self.sess.run(self.b))
|
self.assertAllClose(1.84, self.sess.run(self.b))
|
||||||
self.assertAllClose(4.0, self.sess.run(self.c))
|
self.assertAllClose(4.0, self.sess.run(self.c))
|
||||||
|
|
||||||
|
def testRestoreVariableValues(self):
|
||||||
|
"""Test restore_variable_values() restores the old values of variables."""
|
||||||
|
|
||||||
|
stepper = NodeStepper(self.sess, "optim")
|
||||||
|
|
||||||
|
stepper.cont("optim/update_b/ApplyGradientDescent",
|
||||||
|
restore_variable_values=True)
|
||||||
|
self.assertAllClose(1.84, self.sess.run(self.b))
|
||||||
|
|
||||||
|
stepper.restore_variable_values()
|
||||||
|
self.assertAllClose(2.0, self.sess.run(self.b))
|
||||||
|
|
||||||
def testFinalize(self):
|
def testFinalize(self):
|
||||||
"""Test finalize() to restore variables and run the original fetch."""
|
"""Test finalize() to restore variables and run the original fetch."""
|
||||||
|
|
||||||
|
@ -116,6 +116,7 @@ import abc
|
|||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.debug import debug_utils
|
from tensorflow.python.debug import debug_utils
|
||||||
|
from tensorflow.python.debug import stepper
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
|
|
||||||
|
|
||||||
@ -155,9 +156,9 @@ class OnSessionInitRequest(object):
|
|||||||
class OnSessionInitAction(object):
|
class OnSessionInitAction(object):
|
||||||
"""Enum-like values for possible action to take on session init."""
|
"""Enum-like values for possible action to take on session init."""
|
||||||
|
|
||||||
# Proceed, without special actions, in the wrapper session initializaton. What
|
# Proceed, without special actions, in the wrapper session initialization.
|
||||||
# action the wrapper session performs next is determined by the caller of the
|
# What action the wrapper session performs next is determined by the caller
|
||||||
# wrapper session. E.g., it can call run().
|
# of the wrapper session. E.g., it can call run().
|
||||||
PROCEED = "proceed"
|
PROCEED = "proceed"
|
||||||
|
|
||||||
# Instead of letting the caller of the wrapper session determine what actions
|
# Instead of letting the caller of the wrapper session determine what actions
|
||||||
@ -397,7 +398,13 @@ class BaseDebugWrapperSession(session.SessionInterface):
|
|||||||
client_graph_def=self._sess.graph.as_graph_def(),
|
client_graph_def=self._sess.graph.as_graph_def(),
|
||||||
tf_error=tf_error)
|
tf_error=tf_error)
|
||||||
|
|
||||||
elif run_start_resp.action == OnRunStartAction.NON_DEBUG_RUN:
|
elif (run_start_resp.action == OnRunStartAction.NON_DEBUG_RUN or
|
||||||
|
run_start_resp.action == OnRunStartAction.INVOKE_STEPPER):
|
||||||
|
if run_start_resp.action == OnRunStartAction.INVOKE_STEPPER:
|
||||||
|
retvals = self.invoke_node_stepper(
|
||||||
|
stepper.NodeStepper(self._sess, fetches, feed_dict),
|
||||||
|
restore_variable_values_on_exit=True)
|
||||||
|
|
||||||
# Invoke run() method of the wrapped session.
|
# Invoke run() method of the wrapped session.
|
||||||
retvals = self._sess.run(
|
retvals = self._sess.run(
|
||||||
fetches,
|
fetches,
|
||||||
@ -407,10 +414,6 @@ class BaseDebugWrapperSession(session.SessionInterface):
|
|||||||
|
|
||||||
# Prepare arg for the on-run-end callback.
|
# Prepare arg for the on-run-end callback.
|
||||||
run_end_req = OnRunEndRequest(run_start_resp.action)
|
run_end_req = OnRunEndRequest(run_start_resp.action)
|
||||||
elif run_start_resp.action == OnRunStartAction.INVOKE_STEPPER:
|
|
||||||
# TODO(cais): Implement stepper loop.
|
|
||||||
raise NotImplementedError(
|
|
||||||
"OnRunStartAction INVOKE_STEPPER has not been implemented.")
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Invalid OnRunStartAction value: %s" % run_start_resp.action)
|
"Invalid OnRunStartAction value: %s" % run_start_resp.action)
|
||||||
@ -461,7 +464,6 @@ class BaseDebugWrapperSession(session.SessionInterface):
|
|||||||
Returns:
|
Returns:
|
||||||
An instance of OnSessionInitResponse.
|
An instance of OnSessionInitResponse.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def on_run_start(self, request):
|
def on_run_start(self, request):
|
||||||
@ -482,7 +484,6 @@ class BaseDebugWrapperSession(session.SessionInterface):
|
|||||||
with or without debug tensor watching, invoking the stepper.)
|
with or without debug tensor watching, invoking the stepper.)
|
||||||
2) debug URLs used to watch the tensors.
|
2) debug URLs used to watch the tensors.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def on_run_end(self, request):
|
def on_run_end(self, request):
|
||||||
@ -499,7 +500,6 @@ class BaseDebugWrapperSession(session.SessionInterface):
|
|||||||
Returns:
|
Returns:
|
||||||
An instance of OnRunStartResponse.
|
An instance of OnRunStartResponse.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return self._sess.__enter__()
|
return self._sess.__enter__()
|
||||||
@ -512,3 +512,21 @@ class BaseDebugWrapperSession(session.SessionInterface):
|
|||||||
|
|
||||||
# TODO(cais): Add _node_name_regex_whitelist and
|
# TODO(cais): Add _node_name_regex_whitelist and
|
||||||
# _node_op_type_regex_whitelist.
|
# _node_op_type_regex_whitelist.
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def invoke_node_stepper(self,
|
||||||
|
node_stepper,
|
||||||
|
restore_variable_values_on_exit=True):
|
||||||
|
"""Callback invoked when the client intends to step through graph nodes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_stepper: (stepper.NodeStepper) An instance of NodeStepper to be used
|
||||||
|
in this stepping session.
|
||||||
|
restore_variable_values_on_exit: (bool) Whether any variables whose values
|
||||||
|
have been altered during this node-stepper invocation should be restored
|
||||||
|
to their old values when this invocation ends.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The same return values as the `Session.run()` call on the same fetches as
|
||||||
|
the NodeStepper.
|
||||||
|
"""
|
||||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.python.debug import debug_utils
|
from tensorflow.python.debug import debug_utils
|
||||||
|
from tensorflow.python.debug import stepper
|
||||||
from tensorflow.python.debug.wrappers import framework
|
from tensorflow.python.debug.wrappers import framework
|
||||||
from tensorflow.python.debug.wrappers import local_cli_wrapper
|
from tensorflow.python.debug.wrappers import local_cli_wrapper
|
||||||
from tensorflow.python.training import session_run_hook
|
from tensorflow.python.training import session_run_hook
|
||||||
@ -64,19 +65,25 @@ class LocalCLIDebugHook(session_run_hook.SessionRunHook,
|
|||||||
self._decorate_options_for_debug(run_args.options,
|
self._decorate_options_for_debug(run_args.options,
|
||||||
run_context.session.graph)
|
run_context.session.graph)
|
||||||
elif self._performed_action == framework.OnRunStartAction.INVOKE_STEPPER:
|
elif self._performed_action == framework.OnRunStartAction.INVOKE_STEPPER:
|
||||||
raise NotImplementedError(
|
# The _finalized property must be set to False so that the NodeStepper
|
||||||
"OnRunStartAction INVOKE_STEPPER has not been implemented.")
|
# can insert ops for retrieving TensorHandles.
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
run_context.session.graph._finalized = False
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
self.invoke_node_stepper(
|
||||||
|
stepper.NodeStepper(run_context.session, run_context.original_args.
|
||||||
|
fetches, run_context.original_args.feed_dict),
|
||||||
|
restore_variable_values_on_exit=True)
|
||||||
|
|
||||||
return run_args
|
return run_args
|
||||||
|
|
||||||
def after_run(self, run_context, run_values):
|
def after_run(self, run_context, run_values):
|
||||||
# Adapt run_context and run_values to OnRunEndRequest and invoke superclass
|
# Adapt run_context and run_values to OnRunEndRequest and invoke superclass
|
||||||
# on_run_end()
|
# on_run_end()
|
||||||
if self._performed_action == framework.OnRunStartAction.DEBUG_RUN:
|
on_run_end_request = framework.OnRunEndRequest(self._performed_action,
|
||||||
on_run_end_request = framework.OnRunEndRequest(self._performed_action,
|
run_values.run_metadata)
|
||||||
run_values.run_metadata)
|
self.on_run_end(on_run_end_request)
|
||||||
|
|
||||||
self.on_run_end(on_run_end_request)
|
|
||||||
|
|
||||||
def _decorate_options_for_debug(self, options, graph):
|
def _decorate_options_for_debug(self, options, graph):
|
||||||
"""Modify RunOptions.debug_options.debug_tensor_watch_opts for debugging.
|
"""Modify RunOptions.debug_options.debug_tensor_watch_opts for debugging.
|
||||||
|
@ -29,6 +29,7 @@ from tensorflow.python.debug.cli import analyzer_cli
|
|||||||
from tensorflow.python.debug.cli import cli_shared
|
from tensorflow.python.debug.cli import cli_shared
|
||||||
from tensorflow.python.debug.cli import curses_ui
|
from tensorflow.python.debug.cli import curses_ui
|
||||||
from tensorflow.python.debug.cli import debugger_cli_common
|
from tensorflow.python.debug.cli import debugger_cli_common
|
||||||
|
from tensorflow.python.debug.cli import stepper_cli
|
||||||
from tensorflow.python.debug.wrappers import framework
|
from tensorflow.python.debug.wrappers import framework
|
||||||
|
|
||||||
|
|
||||||
@ -497,3 +498,70 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
|
|||||||
fetches,
|
fetches,
|
||||||
feed_dict,
|
feed_dict,
|
||||||
self._tensor_filters)
|
self._tensor_filters)
|
||||||
|
|
||||||
|
def invoke_node_stepper(self,
|
||||||
|
node_stepper,
|
||||||
|
restore_variable_values_on_exit=True):
|
||||||
|
"""Overrides method in base class to implement interactive node stepper.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_stepper: (stepper.NodeStepper) The underlying NodeStepper API object.
|
||||||
|
restore_variable_values_on_exit: (bool) Whether any variables whose values
|
||||||
|
have been altered during this node-stepper invocation should be restored
|
||||||
|
to their old values when this invocation ends.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The same return values as the `Session.run()` call on the same fetches as
|
||||||
|
the NodeStepper.
|
||||||
|
"""
|
||||||
|
|
||||||
|
stepper = stepper_cli.NodeStepperCLI(node_stepper)
|
||||||
|
|
||||||
|
# On exiting the node-stepper CLI, the finalize method of the node_stepper
|
||||||
|
# object will be called, ensuring that the state of the graph will be the
|
||||||
|
# same as if the stepping did not happen.
|
||||||
|
# TODO(cais): Perhaps some users will want the effect of the interactive
|
||||||
|
# stepping and value injection to persist. When that happens, make the call
|
||||||
|
# to finalize optional.
|
||||||
|
stepper_ui = curses_ui.CursesUI(
|
||||||
|
on_ui_exit=(node_stepper.restore_variable_values
|
||||||
|
if restore_variable_values_on_exit else None))
|
||||||
|
|
||||||
|
stepper_ui.register_command_handler(
|
||||||
|
"list_sorted_nodes",
|
||||||
|
stepper.list_sorted_nodes,
|
||||||
|
stepper.arg_parsers["list_sorted_nodes"].format_help(),
|
||||||
|
prefix_aliases=["lt", "lsn"])
|
||||||
|
stepper_ui.register_command_handler(
|
||||||
|
"cont",
|
||||||
|
stepper.cont,
|
||||||
|
stepper.arg_parsers["cont"].format_help(),
|
||||||
|
prefix_aliases=["ct", "c"])
|
||||||
|
stepper_ui.register_command_handler(
|
||||||
|
"step",
|
||||||
|
stepper.step,
|
||||||
|
stepper.arg_parsers["step"].format_help(),
|
||||||
|
prefix_aliases=["st", "s"])
|
||||||
|
stepper_ui.register_command_handler(
|
||||||
|
"print_tensor",
|
||||||
|
stepper.print_tensor,
|
||||||
|
stepper.arg_parsers["print_tensor"].format_help(),
|
||||||
|
prefix_aliases=["pt"])
|
||||||
|
stepper_ui.register_command_handler(
|
||||||
|
"inject_value",
|
||||||
|
stepper.inject_value,
|
||||||
|
stepper.arg_parsers["inject_value"].format_help(),
|
||||||
|
prefix_aliases=["inject", "override_value", "override"])
|
||||||
|
|
||||||
|
# Register tab completion candidates.
|
||||||
|
stepper_ui.register_tab_comp_context([
|
||||||
|
"cont", "ct", "c", "pt", "inject_value", "inject", "override_value",
|
||||||
|
"override"
|
||||||
|
], [str(elem) for elem in node_stepper.sorted_nodes()])
|
||||||
|
# TODO(cais): Tie up register_tab_comp_context to a single alias to shorten
|
||||||
|
# calls like this.
|
||||||
|
|
||||||
|
return stepper_ui.run_ui(
|
||||||
|
init_command="lt",
|
||||||
|
title="Node Stepper: " + self._run_description,
|
||||||
|
title_color="blue_on_white")
|
||||||
|
Loading…
Reference in New Issue
Block a user