tfdbg: add persistent config

* Add two persistent UI configurations backed by a file at ~/.tfdbg_config by default.
  * graph_recursion_depth, which controls the recursive output of li/lo commands.
  * mouse_mode, which controls the mouse state of the CursesUI.
* Add `config` command to set and inspect the persistent configuration. E.g.,
  * config show
  * config set graph_recursion_depth 3
  * config set mouse_mode False

Fixes: #13449
PiperOrigin-RevId: 172270804
This commit is contained in:
Shanqing Cai 2017-10-15 18:15:53 -07:00 committed by TensorFlower Gardener
parent d8b4b00de8
commit 1cf9f7ab2f
14 changed files with 451 additions and 29 deletions

View File

@ -186,6 +186,9 @@ Try the following commands at the `tfdbg>` prompt (referencing the code at
| | `--tensor_dtype_filter <pattern>` | Execute the next `Session.run`, dumping only Tensors with data types (`dtype`s) matching the given regular-expression pattern. | `run --tensor_dtype_filter int.*` |
| | `-p` | Execute the next `Session.run` call in profiling mode. | `run -p` |
| **`ri`** | | **Display information about the run the current run, including fetches and feeds.** | `ri` |
| **`config`** | | **Set or show persistent TFDBG UI configuration.** | |
| | `set` | Set the value of a config item: {`graph_recursion_depth`, `mouse_mode`}. | `config set graph_recursion_depth 3` |
| | `show` | Show current persistent UI configuration. | `config show` |
| **`help`** | | **Print general help information** | `help` |
| | `help <command>` | Print help for given command. | `help lt` |

View File

@ -150,6 +150,13 @@ py_library(
],
)
py_library(
name = "cli_config",
srcs = ["cli/cli_config.py"],
srcs_version = "PY2AND3",
deps = [":debugger_cli_common"],
)
py_library(
name = "command_parser",
srcs = ["cli/command_parser.py"],
@ -197,6 +204,7 @@ py_library(
srcs = ["cli/analyzer_cli.py"],
srcs_version = "PY2AND3",
deps = [
":cli_config",
":cli_shared",
":command_parser",
":debug_graphs",
@ -249,6 +257,7 @@ py_library(
srcs = ["cli/base_ui.py"],
srcs_version = "PY2AND3",
deps = [
":cli_config",
":command_parser",
":debugger_cli_common",
],
@ -583,6 +592,7 @@ py_test(
srcs = ["cli/readline_ui_test.py"],
srcs_version = "PY2AND3",
deps = [
":cli_config",
":debugger_cli_common",
":readline_ui",
":ui_factory",
@ -724,6 +734,19 @@ py_test(
],
)
py_test(
name = "cli_config_test",
size = "small",
srcs = ["cli/cli_config_test.py"],
srcs_version = "PY2AND3",
deps = [
":cli_config",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform",
"//tensorflow/python:platform_test",
],
)
py_test(
name = "command_parser_test",
size = "small",
@ -791,6 +814,7 @@ cuda_py_test(
srcs = ["cli/analyzer_cli_test.py"],
additional_deps = [
":analyzer_cli",
":cli_config",
":command_parser",
":debug_data",
":debug_utils",

View File

@ -29,6 +29,7 @@ import re
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.debug.cli import cli_config
from tensorflow.python.debug.cli import cli_shared
from tensorflow.python.debug.cli import command_parser
from tensorflow.python.debug.cli import debugger_cli_common
@ -140,11 +141,13 @@ class DebugAnalyzer(object):
_GRAPH_STRUCT_OP_TYPE_BLACKLIST = (
"_Send", "_Recv", "_HostSend", "_HostRecv", "_Retval")
def __init__(self, debug_dump):
def __init__(self, debug_dump, config):
"""DebugAnalyzer constructor.
Args:
debug_dump: A DebugDumpDir object.
config: A `cli_config.CLIConfig` object that carries user-facing
configurations.
"""
self._debug_dump = debug_dump
@ -153,6 +156,21 @@ class DebugAnalyzer(object):
# Initialize tensor filters state.
self._tensor_filters = {}
self._build_argument_parsers(config)
config.set_callback("graph_recursion_depth",
self._build_argument_parsers)
# TODO(cais): Implement list_nodes.
def _build_argument_parsers(self, config):
"""Build argument parsers for DebugAnalayzer.
Args:
config: A `cli_config.CLIConfig` object.
Returns:
A dict mapping command handler name to `ArgumentParser` instance.
"""
# Argument parsers for command handlers.
self._arg_parsers = {}
@ -242,7 +260,7 @@ class DebugAnalyzer(object):
"--depth",
dest="depth",
type=int,
default=20,
default=config.get("graph_recursion_depth"),
help="Maximum depth of recursion used when showing the input tree.")
ap.add_argument(
"-r",
@ -273,7 +291,7 @@ class DebugAnalyzer(object):
"--depth",
dest="depth",
type=int,
default=20,
default=config.get("graph_recursion_depth"),
help="Maximum depth of recursion used when showing the output tree.")
ap.add_argument(
"-r",
@ -386,8 +404,6 @@ class DebugAnalyzer(object):
"(may be slow for large results).")
self._arg_parsers["eval"] = ap
# TODO(cais): Implement list_nodes.
def add_tensor_filter(self, filter_name, filter_callable):
"""Add a tensor filter.
@ -1540,7 +1556,8 @@ class DebugAnalyzer(object):
def create_analyzer_ui(debug_dump,
tensor_filters=None,
ui_type="curses",
on_ui_exit=None):
on_ui_exit=None,
config=None):
"""Create an instance of CursesUI based on a DebugDumpDir object.
Args:
@ -1549,19 +1566,22 @@ def create_analyzer_ui(debug_dump,
filter (Callable).
ui_type: (str) requested UI type, e.g., "curses", "readline".
on_ui_exit: (`Callable`) the callback to be called when the UI exits.
config: A `cli_config.CLIConfig` object.
Returns:
(base_ui.BaseUI) A BaseUI subtype object with a set of standard analyzer
commands and tab-completions registered.
"""
if config is None:
config = cli_config.CLIConfig()
analyzer = DebugAnalyzer(debug_dump)
analyzer = DebugAnalyzer(debug_dump, config=config)
if tensor_filters:
for tensor_filter_name in tensor_filters:
analyzer.add_tensor_filter(
tensor_filter_name, tensor_filters[tensor_filter_name])
cli = ui_factory.get_ui(ui_type, on_ui_exit=on_ui_exit)
cli = ui_factory.get_ui(ui_type, on_ui_exit=on_ui_exit, config=config)
cli.register_command_handler(
"list_tensors",
analyzer.list_tensors,

View File

@ -28,6 +28,7 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.debug.cli import analyzer_cli
from tensorflow.python.debug.cli import cli_config
from tensorflow.python.debug.cli import cli_shared
from tensorflow.python.debug.cli import command_parser
from tensorflow.python.debug.cli import debugger_cli_common
@ -45,6 +46,11 @@ from tensorflow.python.platform import test
from tensorflow.python.util import tf_inspect
def _cli_config_from_temp_file():
return cli_config.CLIConfig(
config_file_path=os.path.join(tempfile.mkdtemp(), ".tfdbg_config"))
def no_rewrite_session_config():
rewriter_config = rewriter_config_pb2.RewriterConfig(
disable_model_pruning=True,
@ -512,7 +518,7 @@ def create_analyzer_cli(dump):
and has the common tfdbg commands, e.g., lt, ni, li, lo, registered.
"""
# Construct the analyzer.
analyzer = analyzer_cli.DebugAnalyzer(dump)
analyzer = analyzer_cli.DebugAnalyzer(dump, _cli_config_from_temp_file())
# Construct the handler registry.
registry = debugger_cli_common.CommandHandlerRegistry()
@ -1216,12 +1222,14 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
" [-14., 4.]])"], out.lines)
def testAddGetTensorFilterLambda(self):
analyzer = analyzer_cli.DebugAnalyzer(self._debug_dump)
analyzer = analyzer_cli.DebugAnalyzer(self._debug_dump,
_cli_config_from_temp_file())
analyzer.add_tensor_filter("foo_filter", lambda x, y: True)
self.assertTrue(analyzer.get_tensor_filter("foo_filter")(None, None))
def testAddGetTensorFilterNestedFunction(self):
analyzer = analyzer_cli.DebugAnalyzer(self._debug_dump)
analyzer = analyzer_cli.DebugAnalyzer(self._debug_dump,
_cli_config_from_temp_file())
def foo_filter(unused_arg_0, unused_arg_1):
return True
@ -1230,14 +1238,16 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
self.assertTrue(analyzer.get_tensor_filter("foo_filter")(None, None))
def testAddTensorFilterEmptyName(self):
analyzer = analyzer_cli.DebugAnalyzer(self._debug_dump)
analyzer = analyzer_cli.DebugAnalyzer(self._debug_dump,
_cli_config_from_temp_file())
with self.assertRaisesRegexp(ValueError,
"Input argument filter_name cannot be empty."):
analyzer.add_tensor_filter("", lambda datum, tensor: True)
def testAddTensorFilterNonStrName(self):
analyzer = analyzer_cli.DebugAnalyzer(self._debug_dump)
analyzer = analyzer_cli.DebugAnalyzer(self._debug_dump,
_cli_config_from_temp_file())
with self.assertRaisesRegexp(
TypeError,
@ -1245,7 +1255,8 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
analyzer.add_tensor_filter(1, lambda datum, tensor: True)
def testAddGetTensorFilterNonCallable(self):
analyzer = analyzer_cli.DebugAnalyzer(self._debug_dump)
analyzer = analyzer_cli.DebugAnalyzer(self._debug_dump,
_cli_config_from_temp_file())
with self.assertRaisesRegexp(
TypeError, "Input argument filter_callable is expected to be callable, "
@ -1253,7 +1264,8 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
analyzer.add_tensor_filter("foo_filter", "bar")
def testGetNonexistentTensorFilter(self):
analyzer = analyzer_cli.DebugAnalyzer(self._debug_dump)
analyzer = analyzer_cli.DebugAnalyzer(self._debug_dump,
_cli_config_from_temp_file())
analyzer.add_tensor_filter("foo_filter", lambda datum, tensor: True)
with self.assertRaisesRegexp(ValueError,

View File

@ -17,6 +17,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
from tensorflow.python.debug.cli import cli_config
from tensorflow.python.debug.cli import command_parser
from tensorflow.python.debug.cli import debugger_cli_common
@ -29,11 +32,13 @@ class BaseUI(object):
ERROR_MESSAGE_PREFIX = "ERROR: "
INFO_MESSAGE_PREFIX = "INFO: "
def __init__(self, on_ui_exit=None):
def __init__(self, on_ui_exit=None, config=None):
"""Constructor of the base class.
Args:
on_ui_exit: (`Callable`) the callback to be called when the UI exits.
config: An instance of `cli_config.CLIConfig()` carrying user-facing
configurations.
"""
self._on_ui_exit = on_ui_exit
@ -50,6 +55,20 @@ class BaseUI(object):
[debugger_cli_common.CommandHandlerRegistry.HELP_COMMAND] +
debugger_cli_common.CommandHandlerRegistry.HELP_COMMAND_ALIASES)
self._config = config or cli_config.CLIConfig()
self._config_argparser = argparse.ArgumentParser(
description="config command", usage=argparse.SUPPRESS)
subparsers = self._config_argparser.add_subparsers()
set_parser = subparsers.add_parser("set")
set_parser.add_argument("property_name", type=str)
set_parser.add_argument("property_value", type=str)
set_parser = subparsers.add_parser("show")
self.register_command_handler(
"config",
self._config_command_handler,
self._config_argparser.format_help(),
prefix_aliases=["cfg"])
def set_help_intro(self, help_intro):
"""Set an introductory message to the help output of the command registry.
@ -176,3 +195,21 @@ class BaseUI(object):
except_last_word = " ".join(items[:-1]) + " "
return context, prefix, except_last_word
@property
def config(self):
"""Obtain the CLIConfig of this `BaseUI` instance."""
return self._config
def _config_command_handler(self, args, screen_info=None):
"""Command handler for the "config" command."""
del screen_info # Currently unused.
parsed = self._config_argparser.parse_args(args)
if hasattr(parsed, "property_name") and hasattr(parsed, "property_value"):
# set.
self._config.set(parsed.property_name, parsed.property_value)
return self._config.summarize(highlight=parsed.property_name)
else:
# show.
return self._config.summarize()

View File

@ -0,0 +1,160 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Configurations for TensorFlow Debugger (TFDBG) command-line interfaces."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import json
import os
from tensorflow.python.debug.cli import debugger_cli_common
from tensorflow.python.platform import gfile
RL = debugger_cli_common.RichLine
class CLIConfig(object):
"""Client-facing configurations for TFDBG command-line interfaces."""
_CONFIG_FILE_NAME = ".tfdbg_config"
_DEFAULT_CONFIG = [
("graph_recursion_depth", 20),
("mouse_mode", True),
]
def __init__(self, config_file_path=None):
self._config_file_path = (config_file_path or
self._default_config_file_path())
self._config = collections.OrderedDict(self._DEFAULT_CONFIG)
if gfile.Exists(self._config_file_path):
config = self._load_from_file()
for key, value in config.items():
self._config[key] = value
self._save_to_file()
self._set_callbacks = dict()
def get(self, property_name):
if property_name not in self._config:
raise KeyError("%s is not a valid property name." % property_name)
return self._config[property_name]
def set(self, property_name, property_val):
"""Set the value of a property.
Supports limitd property value types: `bool`, `int` and `str`.
Args:
property_name: Name of the property.
property_val: Value of the property. If the property has `bool` type and
this argument has `str` type, the `str` value will be parsed as a `bool`
Raises:
ValueError: if a `str` property_value fails to be parsed as a `bool`.
KeyError: if `property_name` is an invalid property name.
"""
if property_name not in self._config:
raise KeyError("%s is not a valid property name." % property_name)
orig_val = self._config[property_name]
if isinstance(orig_val, bool):
if isinstance(property_val, str):
if property_val.lower() in ("1", "true", "t", "yes", "y", "on"):
property_val = True
elif property_val.lower() in ("0", "false", "f", "no", "n", "off"):
property_val = False
else:
raise ValueError(
"Invalid string value for bool type: %s" % property_val)
else:
property_val = bool(property_val)
elif isinstance(orig_val, int):
property_val = int(property_val)
elif isinstance(orig_val, str):
property_val = str(property_val)
else:
raise TypeError("Unsupported property type: %s" % type(orig_val))
self._config[property_name] = property_val
self._save_to_file()
# Invoke set-callback.
if property_name in self._set_callbacks:
self._set_callbacks[property_name](self._config)
def set_callback(self, property_name, callback):
"""Set a set-callback for given property.
Args:
property_name: Name of the property.
callback: The callback as a `callable` of signature:
def cbk(config):
where config is the config after it is set to the new value.
The callback is invoked each time the set() method is called with the
matching property_name.
Raises:
KeyError: If property_name does not exist.
TypeError: If `callback` is not callable.
"""
if property_name not in self._config:
raise KeyError("%s is not a valid property name." % property_name)
if not callable(callback):
raise TypeError("The callback object provided is not callable.")
self._set_callbacks[property_name] = callback
def _default_config_file_path(self):
return os.path.join(os.path.expanduser("~"), self._CONFIG_FILE_NAME)
def _save_to_file(self):
try:
with gfile.Open(self._config_file_path, "w") as config_file:
json.dump(self._config, config_file)
except IOError:
pass
def summarize(self, highlight=None):
"""Get a text summary of the config.
Args:
highlight: A property name to highlight in the output.
Returns:
A `RichTextLines` output.
"""
lines = [RL("Command-line configuration:", "bold"), RL("")]
for name, val in self._config.items():
highlight_attr = "bold" if name == highlight else None
line = RL(" ")
line += RL(name, ["underline", highlight_attr])
line += RL(": ")
line += RL(str(val), font_attr=highlight_attr)
lines.append(line)
return debugger_cli_common.rich_text_lines_from_rich_line_list(lines)
def _load_from_file(self):
try:
with gfile.Open(self._config_file_path, "r") as config_file:
config_dict = json.load(config_file)
config = collections.OrderedDict()
for key in sorted(config_dict.keys()):
config[key] = config_dict[key]
return config
except (IOError, ValueError):
# The reading of the config file may fail due to IO issues or file
# corruption. We do not want tfdbg to error out just because of that.
return dict()

View File

@ -0,0 +1,137 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for cli_config."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os
import shutil
import tempfile
from tensorflow.python.debug.cli import cli_config
from tensorflow.python.framework import test_util
from tensorflow.python.platform import gfile
from tensorflow.python.platform import googletest
class CLIConfigTest(test_util.TensorFlowTestCase):
def setUp(self):
self._tmp_dir = tempfile.mkdtemp()
self._tmp_config_path = os.path.join(self._tmp_dir, ".tfdbg_config")
self.assertFalse(gfile.Exists(self._tmp_config_path))
super(CLIConfigTest, self).setUp()
def tearDown(self):
shutil.rmtree(self._tmp_dir)
super(CLIConfigTest, self).tearDown()
def testConstructCLIConfigWithoutFile(self):
config = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
self.assertEqual(20, config.get("graph_recursion_depth"))
self.assertEqual(True, config.get("mouse_mode"))
with self.assertRaises(KeyError):
config.get("property_that_should_not_exist")
self.assertTrue(gfile.Exists(self._tmp_config_path))
def testCLIConfigForwardCompatibilityTest(self):
config = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
with open(self._tmp_config_path, "rt") as f:
config_json = json.load(f)
# Remove a field to simulate forward compatibility test.
del config_json["graph_recursion_depth"]
with open(self._tmp_config_path, "wt") as f:
json.dump(config_json, f)
config = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
self.assertEqual(20, config.get("graph_recursion_depth"))
def testModifyConfigValue(self):
config = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
config.set("graph_recursion_depth", 9)
config.set("mouse_mode", False)
self.assertEqual(9, config.get("graph_recursion_depth"))
self.assertEqual(False, config.get("mouse_mode"))
def testModifyConfigValueWithTypeCasting(self):
config = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
config.set("graph_recursion_depth", "18")
config.set("mouse_mode", "false")
self.assertEqual(18, config.get("graph_recursion_depth"))
self.assertEqual(False, config.get("mouse_mode"))
def testModifyConfigValueWithTypeCastingFailure(self):
config = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
with self.assertRaises(ValueError):
config.set("mouse_mode", "maybe")
def testLoadFromModifiedConfigFile(self):
config = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
config.set("graph_recursion_depth", 9)
config.set("mouse_mode", False)
config2 = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
self.assertEqual(9, config2.get("graph_recursion_depth"))
self.assertEqual(False, config2.get("mouse_mode"))
def testSummarizeFromConfig(self):
config = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
output = config.summarize()
self.assertEqual(
["Command-line configuration:",
"",
" graph_recursion_depth: %d" % config.get("graph_recursion_depth"),
" mouse_mode: %s" % config.get("mouse_mode")], output.lines)
def testSummarizeFromConfigWithHighlight(self):
config = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
output = config.summarize(highlight="mouse_mode")
self.assertEqual(
["Command-line configuration:",
"",
" graph_recursion_depth: %d" % config.get("graph_recursion_depth"),
" mouse_mode: %s" % config.get("mouse_mode")], output.lines)
self.assertEqual((2, 12, ["underline", "bold"]),
output.font_attr_segs[3][0])
self.assertEqual((14, 18, "bold"), output.font_attr_segs[3][1])
def testSetCallback(self):
config = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
test_value = {"graph_recursion_depth": -1}
def callback(config):
test_value["graph_recursion_depth"] = config.get("graph_recursion_depth")
config.set_callback("graph_recursion_depth", callback)
config.set("graph_recursion_depth", config.get("graph_recursion_depth") - 1)
self.assertEqual(test_value["graph_recursion_depth"],
config.get("graph_recursion_depth"))
def testSetCallbackInvalidPropertyName(self):
config = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
with self.assertRaises(KeyError):
config.set_callback("nonexistent_property_name", print)
def testSetCallbackNotCallable(self):
config = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
with self.assertRaises(TypeError):
config.set_callback("graph_recursion_depth", 1)
if __name__ == "__main__":
googletest.main()

View File

@ -273,14 +273,16 @@ class CursesUI(base_ui.BaseUI):
_single_instance_lock = threading.Lock()
def __init__(self, on_ui_exit=None):
def __init__(self, on_ui_exit=None, config=None):
"""Constructor of CursesUI.
Args:
on_ui_exit: (Callable) Callback invoked when the UI exits.
config: An instance of `cli_config.CLIConfig()` carrying user-facing
configurations.
"""
base_ui.BaseUI.__init__(self, on_ui_exit=on_ui_exit)
base_ui.BaseUI.__init__(self, on_ui_exit=on_ui_exit, config=config)
self._screen_init()
self._screen_refresh_size()
@ -445,8 +447,11 @@ class CursesUI(base_ui.BaseUI):
curses.cbreak()
self._stdscr.keypad(1)
self._mouse_enabled = enable_mouse_on_start
self._mouse_enabled = self.config.get("mouse_mode")
self._screen_set_mousemask()
self.config.set_callback(
"mouse_mode",
lambda cfg: self._set_mouse_enabled(cfg.get("mouse_mode")))
self._screen_create_command_window()

View File

@ -704,8 +704,8 @@ class CursesTest(test_util.TensorFlowTestCase):
# The manually registered command, along with the automatically registered
# exit commands should appear in the candidates.
self.assertEqual(
[["a", "babble", "exit", "h", "help", "m", "mouse", "quit"]],
ui.candidates_lists)
[["a", "babble", "cfg", "config", "exit", "h", "help", "m", "mouse",
"quit"]], ui.candidates_lists)
# The two candidates have no common prefix. So no command should have been
# issued.

View File

@ -768,7 +768,8 @@ class ProfileAnalyzer(object):
def create_profiler_ui(graph,
run_metadata,
ui_type="curses",
on_ui_exit=None):
on_ui_exit=None,
config=None):
"""Create an instance of CursesUI based on a `tf.Graph` and `RunMetadata`.
Args:
@ -776,11 +777,13 @@ def create_profiler_ui(graph,
run_metadata: A `RunMetadata` protobuf object.
ui_type: (str) requested UI type, e.g., "curses", "readline".
on_ui_exit: (`Callable`) the callback to be called when the UI exits.
config: An instance of `cli_config.CLIConfig`.
Returns:
(base_ui.BaseUI) A BaseUI subtype object with a set of standard analyzer
commands and tab-completions registered.
"""
del config # Currently unused.
analyzer = ProfileAnalyzer(graph, run_metadata)

View File

@ -26,8 +26,8 @@ from tensorflow.python.debug.cli import debugger_cli_common
class ReadlineUI(base_ui.BaseUI):
"""Readline-based Command-line UI."""
def __init__(self, on_ui_exit=None):
base_ui.BaseUI.__init__(self, on_ui_exit=on_ui_exit)
def __init__(self, on_ui_exit=None, config=None):
base_ui.BaseUI.__init__(self, on_ui_exit=on_ui_exit, config=config)
self._init_input()
def _init_input(self):

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import argparse
import tempfile
from tensorflow.python.debug.cli import cli_config
from tensorflow.python.debug.cli import debugger_cli_common
from tensorflow.python.debug.cli import readline_ui
from tensorflow.python.debug.cli import ui_factory
@ -32,7 +33,9 @@ class MockReadlineUI(readline_ui.ReadlineUI):
"""Test subclass of ReadlineUI that bypasses terminal manipulations."""
def __init__(self, on_ui_exit=None, command_sequence=None):
readline_ui.ReadlineUI.__init__(self, on_ui_exit=on_ui_exit)
readline_ui.ReadlineUI.__init__(
self, on_ui_exit=on_ui_exit,
config=cli_config.CLIConfig(config_file_path=tempfile.mktemp()))
self._command_sequence = command_sequence
self._command_counter = 0
@ -161,6 +164,18 @@ class CursesTest(test_util.TensorFlowTestCase):
with gfile.Open(output_path, "r") as f:
self.assertEqual("bar\nbar\n", f.read())
def testConfigSetAndShow(self):
"""Run UI with an initial command specified."""
ui = MockReadlineUI(command_sequence=[
"config set graph_recursion_depth 5", "config show", "exit"])
ui.run_ui()
outputs = ui.observers["screen_outputs"]
self.assertEqual(
["Command-line configuration:",
"",
" graph_recursion_depth: 5"], outputs[1].lines[:3])
if __name__ == "__main__":
googletest.main()

View File

@ -23,7 +23,10 @@ import copy
SUPPORTED_UI_TYPES = ["curses", "readline"]
def get_ui(ui_type, on_ui_exit=None, available_ui_types=None):
def get_ui(ui_type,
on_ui_exit=None,
available_ui_types=None,
config=None):
"""Create a `base_ui.BaseUI` subtype.
This factory method attempts to fallback to other available ui_types on
@ -36,6 +39,8 @@ def get_ui(ui_type, on_ui_exit=None, available_ui_types=None):
on_ui_exit: (`Callable`) the callback to be called when the UI exits.
available_ui_types: (`None` or `list` of `str`) Manually-set available
ui_types.
config: An instance of `cli_config.CLIConfig()` carrying user-facing
configurations.
Returns:
A `base_ui.BaseUI` subtype object.
@ -53,10 +58,10 @@ def get_ui(ui_type, on_ui_exit=None, available_ui_types=None):
# pylint: disable=g-import-not-at-top
if not ui_type or ui_type == "curses":
from tensorflow.python.debug.cli import curses_ui
return curses_ui.CursesUI(on_ui_exit=on_ui_exit)
return curses_ui.CursesUI(on_ui_exit=on_ui_exit, config=config)
elif ui_type == "readline":
from tensorflow.python.debug.cli import readline_ui
return readline_ui.ReadlineUI(on_ui_exit=on_ui_exit)
return readline_ui.ReadlineUI(on_ui_exit=on_ui_exit, config=config)
# pylint: enable=g-import-not-at-top
except ImportError:
available_ui_types.remove(ui_type)

View File

@ -414,7 +414,8 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
def _prep_profile_cli_for_run_end(self, py_graph, run_metadata):
self._init_command = "lp"
self._run_cli = profile_analyzer_cli.create_profiler_ui(
py_graph, run_metadata, ui_type=self._ui_type)
py_graph, run_metadata, ui_type=self._ui_type,
config=self._run_cli.config)
self._title = "run-end (profiler mode): " + self._run_description
def _launch_cli(self):