tfdbg: add persistent config
* Add two persistent UI configurations backed by a file at ~/.tfdbg_config by default. * graph_recursion_depth, which controls the recursive output of li/lo commands. * mouse_mode, which controls the mouse state of the CursesUI. * Add `config` command to set and inspect the persistent configuration. E.g., * config show * config set graph_recursion_depth 3 * config set mouse_mode False Fixes: #13449 PiperOrigin-RevId: 172270804
This commit is contained in:
parent
d8b4b00de8
commit
1cf9f7ab2f
@ -186,6 +186,9 @@ Try the following commands at the `tfdbg>` prompt (referencing the code at
|
||||
| | `--tensor_dtype_filter <pattern>` | Execute the next `Session.run`, dumping only Tensors with data types (`dtype`s) matching the given regular-expression pattern. | `run --tensor_dtype_filter int.*` |
|
||||
| | `-p` | Execute the next `Session.run` call in profiling mode. | `run -p` |
|
||||
| **`ri`** | | **Display information about the run the current run, including fetches and feeds.** | `ri` |
|
||||
| **`config`** | | **Set or show persistent TFDBG UI configuration.** | |
|
||||
| | `set` | Set the value of a config item: {`graph_recursion_depth`, `mouse_mode`}. | `config set graph_recursion_depth 3` |
|
||||
| | `show` | Show current persistent UI configuration. | `config show` |
|
||||
| **`help`** | | **Print general help information** | `help` |
|
||||
| | `help <command>` | Print help for given command. | `help lt` |
|
||||
|
||||
|
@ -150,6 +150,13 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "cli_config",
|
||||
srcs = ["cli/cli_config.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [":debugger_cli_common"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "command_parser",
|
||||
srcs = ["cli/command_parser.py"],
|
||||
@ -197,6 +204,7 @@ py_library(
|
||||
srcs = ["cli/analyzer_cli.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":cli_config",
|
||||
":cli_shared",
|
||||
":command_parser",
|
||||
":debug_graphs",
|
||||
@ -249,6 +257,7 @@ py_library(
|
||||
srcs = ["cli/base_ui.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":cli_config",
|
||||
":command_parser",
|
||||
":debugger_cli_common",
|
||||
],
|
||||
@ -583,6 +592,7 @@ py_test(
|
||||
srcs = ["cli/readline_ui_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":cli_config",
|
||||
":debugger_cli_common",
|
||||
":readline_ui",
|
||||
":ui_factory",
|
||||
@ -724,6 +734,19 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "cli_config_test",
|
||||
size = "small",
|
||||
srcs = ["cli/cli_config_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":cli_config",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "command_parser_test",
|
||||
size = "small",
|
||||
@ -791,6 +814,7 @@ cuda_py_test(
|
||||
srcs = ["cli/analyzer_cli_test.py"],
|
||||
additional_deps = [
|
||||
":analyzer_cli",
|
||||
":cli_config",
|
||||
":command_parser",
|
||||
":debug_data",
|
||||
":debug_utils",
|
||||
|
@ -29,6 +29,7 @@ import re
|
||||
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.python.debug.cli import cli_config
|
||||
from tensorflow.python.debug.cli import cli_shared
|
||||
from tensorflow.python.debug.cli import command_parser
|
||||
from tensorflow.python.debug.cli import debugger_cli_common
|
||||
@ -140,11 +141,13 @@ class DebugAnalyzer(object):
|
||||
_GRAPH_STRUCT_OP_TYPE_BLACKLIST = (
|
||||
"_Send", "_Recv", "_HostSend", "_HostRecv", "_Retval")
|
||||
|
||||
def __init__(self, debug_dump):
|
||||
def __init__(self, debug_dump, config):
|
||||
"""DebugAnalyzer constructor.
|
||||
|
||||
Args:
|
||||
debug_dump: A DebugDumpDir object.
|
||||
config: A `cli_config.CLIConfig` object that carries user-facing
|
||||
configurations.
|
||||
"""
|
||||
|
||||
self._debug_dump = debug_dump
|
||||
@ -153,6 +156,21 @@ class DebugAnalyzer(object):
|
||||
# Initialize tensor filters state.
|
||||
self._tensor_filters = {}
|
||||
|
||||
self._build_argument_parsers(config)
|
||||
config.set_callback("graph_recursion_depth",
|
||||
self._build_argument_parsers)
|
||||
|
||||
# TODO(cais): Implement list_nodes.
|
||||
|
||||
def _build_argument_parsers(self, config):
|
||||
"""Build argument parsers for DebugAnalayzer.
|
||||
|
||||
Args:
|
||||
config: A `cli_config.CLIConfig` object.
|
||||
|
||||
Returns:
|
||||
A dict mapping command handler name to `ArgumentParser` instance.
|
||||
"""
|
||||
# Argument parsers for command handlers.
|
||||
self._arg_parsers = {}
|
||||
|
||||
@ -242,7 +260,7 @@ class DebugAnalyzer(object):
|
||||
"--depth",
|
||||
dest="depth",
|
||||
type=int,
|
||||
default=20,
|
||||
default=config.get("graph_recursion_depth"),
|
||||
help="Maximum depth of recursion used when showing the input tree.")
|
||||
ap.add_argument(
|
||||
"-r",
|
||||
@ -273,7 +291,7 @@ class DebugAnalyzer(object):
|
||||
"--depth",
|
||||
dest="depth",
|
||||
type=int,
|
||||
default=20,
|
||||
default=config.get("graph_recursion_depth"),
|
||||
help="Maximum depth of recursion used when showing the output tree.")
|
||||
ap.add_argument(
|
||||
"-r",
|
||||
@ -386,8 +404,6 @@ class DebugAnalyzer(object):
|
||||
"(may be slow for large results).")
|
||||
self._arg_parsers["eval"] = ap
|
||||
|
||||
# TODO(cais): Implement list_nodes.
|
||||
|
||||
def add_tensor_filter(self, filter_name, filter_callable):
|
||||
"""Add a tensor filter.
|
||||
|
||||
@ -1540,7 +1556,8 @@ class DebugAnalyzer(object):
|
||||
def create_analyzer_ui(debug_dump,
|
||||
tensor_filters=None,
|
||||
ui_type="curses",
|
||||
on_ui_exit=None):
|
||||
on_ui_exit=None,
|
||||
config=None):
|
||||
"""Create an instance of CursesUI based on a DebugDumpDir object.
|
||||
|
||||
Args:
|
||||
@ -1549,19 +1566,22 @@ def create_analyzer_ui(debug_dump,
|
||||
filter (Callable).
|
||||
ui_type: (str) requested UI type, e.g., "curses", "readline".
|
||||
on_ui_exit: (`Callable`) the callback to be called when the UI exits.
|
||||
config: A `cli_config.CLIConfig` object.
|
||||
|
||||
Returns:
|
||||
(base_ui.BaseUI) A BaseUI subtype object with a set of standard analyzer
|
||||
commands and tab-completions registered.
|
||||
"""
|
||||
if config is None:
|
||||
config = cli_config.CLIConfig()
|
||||
|
||||
analyzer = DebugAnalyzer(debug_dump)
|
||||
analyzer = DebugAnalyzer(debug_dump, config=config)
|
||||
if tensor_filters:
|
||||
for tensor_filter_name in tensor_filters:
|
||||
analyzer.add_tensor_filter(
|
||||
tensor_filter_name, tensor_filters[tensor_filter_name])
|
||||
|
||||
cli = ui_factory.get_ui(ui_type, on_ui_exit=on_ui_exit)
|
||||
cli = ui_factory.get_ui(ui_type, on_ui_exit=on_ui_exit, config=config)
|
||||
cli.register_command_handler(
|
||||
"list_tensors",
|
||||
analyzer.list_tensors,
|
||||
|
@ -28,6 +28,7 @@ from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.debug.cli import analyzer_cli
|
||||
from tensorflow.python.debug.cli import cli_config
|
||||
from tensorflow.python.debug.cli import cli_shared
|
||||
from tensorflow.python.debug.cli import command_parser
|
||||
from tensorflow.python.debug.cli import debugger_cli_common
|
||||
@ -45,6 +46,11 @@ from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
def _cli_config_from_temp_file():
|
||||
return cli_config.CLIConfig(
|
||||
config_file_path=os.path.join(tempfile.mkdtemp(), ".tfdbg_config"))
|
||||
|
||||
|
||||
def no_rewrite_session_config():
|
||||
rewriter_config = rewriter_config_pb2.RewriterConfig(
|
||||
disable_model_pruning=True,
|
||||
@ -512,7 +518,7 @@ def create_analyzer_cli(dump):
|
||||
and has the common tfdbg commands, e.g., lt, ni, li, lo, registered.
|
||||
"""
|
||||
# Construct the analyzer.
|
||||
analyzer = analyzer_cli.DebugAnalyzer(dump)
|
||||
analyzer = analyzer_cli.DebugAnalyzer(dump, _cli_config_from_temp_file())
|
||||
|
||||
# Construct the handler registry.
|
||||
registry = debugger_cli_common.CommandHandlerRegistry()
|
||||
@ -1216,12 +1222,14 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
|
||||
" [-14., 4.]])"], out.lines)
|
||||
|
||||
def testAddGetTensorFilterLambda(self):
|
||||
analyzer = analyzer_cli.DebugAnalyzer(self._debug_dump)
|
||||
analyzer = analyzer_cli.DebugAnalyzer(self._debug_dump,
|
||||
_cli_config_from_temp_file())
|
||||
analyzer.add_tensor_filter("foo_filter", lambda x, y: True)
|
||||
self.assertTrue(analyzer.get_tensor_filter("foo_filter")(None, None))
|
||||
|
||||
def testAddGetTensorFilterNestedFunction(self):
|
||||
analyzer = analyzer_cli.DebugAnalyzer(self._debug_dump)
|
||||
analyzer = analyzer_cli.DebugAnalyzer(self._debug_dump,
|
||||
_cli_config_from_temp_file())
|
||||
|
||||
def foo_filter(unused_arg_0, unused_arg_1):
|
||||
return True
|
||||
@ -1230,14 +1238,16 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
|
||||
self.assertTrue(analyzer.get_tensor_filter("foo_filter")(None, None))
|
||||
|
||||
def testAddTensorFilterEmptyName(self):
|
||||
analyzer = analyzer_cli.DebugAnalyzer(self._debug_dump)
|
||||
analyzer = analyzer_cli.DebugAnalyzer(self._debug_dump,
|
||||
_cli_config_from_temp_file())
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"Input argument filter_name cannot be empty."):
|
||||
analyzer.add_tensor_filter("", lambda datum, tensor: True)
|
||||
|
||||
def testAddTensorFilterNonStrName(self):
|
||||
analyzer = analyzer_cli.DebugAnalyzer(self._debug_dump)
|
||||
analyzer = analyzer_cli.DebugAnalyzer(self._debug_dump,
|
||||
_cli_config_from_temp_file())
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError,
|
||||
@ -1245,7 +1255,8 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
|
||||
analyzer.add_tensor_filter(1, lambda datum, tensor: True)
|
||||
|
||||
def testAddGetTensorFilterNonCallable(self):
|
||||
analyzer = analyzer_cli.DebugAnalyzer(self._debug_dump)
|
||||
analyzer = analyzer_cli.DebugAnalyzer(self._debug_dump,
|
||||
_cli_config_from_temp_file())
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError, "Input argument filter_callable is expected to be callable, "
|
||||
@ -1253,7 +1264,8 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
|
||||
analyzer.add_tensor_filter("foo_filter", "bar")
|
||||
|
||||
def testGetNonexistentTensorFilter(self):
|
||||
analyzer = analyzer_cli.DebugAnalyzer(self._debug_dump)
|
||||
analyzer = analyzer_cli.DebugAnalyzer(self._debug_dump,
|
||||
_cli_config_from_temp_file())
|
||||
|
||||
analyzer.add_tensor_filter("foo_filter", lambda datum, tensor: True)
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
|
@ -17,6 +17,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
|
||||
from tensorflow.python.debug.cli import cli_config
|
||||
from tensorflow.python.debug.cli import command_parser
|
||||
from tensorflow.python.debug.cli import debugger_cli_common
|
||||
|
||||
@ -29,11 +32,13 @@ class BaseUI(object):
|
||||
ERROR_MESSAGE_PREFIX = "ERROR: "
|
||||
INFO_MESSAGE_PREFIX = "INFO: "
|
||||
|
||||
def __init__(self, on_ui_exit=None):
|
||||
def __init__(self, on_ui_exit=None, config=None):
|
||||
"""Constructor of the base class.
|
||||
|
||||
Args:
|
||||
on_ui_exit: (`Callable`) the callback to be called when the UI exits.
|
||||
config: An instance of `cli_config.CLIConfig()` carrying user-facing
|
||||
configurations.
|
||||
"""
|
||||
|
||||
self._on_ui_exit = on_ui_exit
|
||||
@ -50,6 +55,20 @@ class BaseUI(object):
|
||||
[debugger_cli_common.CommandHandlerRegistry.HELP_COMMAND] +
|
||||
debugger_cli_common.CommandHandlerRegistry.HELP_COMMAND_ALIASES)
|
||||
|
||||
self._config = config or cli_config.CLIConfig()
|
||||
self._config_argparser = argparse.ArgumentParser(
|
||||
description="config command", usage=argparse.SUPPRESS)
|
||||
subparsers = self._config_argparser.add_subparsers()
|
||||
set_parser = subparsers.add_parser("set")
|
||||
set_parser.add_argument("property_name", type=str)
|
||||
set_parser.add_argument("property_value", type=str)
|
||||
set_parser = subparsers.add_parser("show")
|
||||
self.register_command_handler(
|
||||
"config",
|
||||
self._config_command_handler,
|
||||
self._config_argparser.format_help(),
|
||||
prefix_aliases=["cfg"])
|
||||
|
||||
def set_help_intro(self, help_intro):
|
||||
"""Set an introductory message to the help output of the command registry.
|
||||
|
||||
@ -176,3 +195,21 @@ class BaseUI(object):
|
||||
except_last_word = " ".join(items[:-1]) + " "
|
||||
|
||||
return context, prefix, except_last_word
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
"""Obtain the CLIConfig of this `BaseUI` instance."""
|
||||
return self._config
|
||||
|
||||
def _config_command_handler(self, args, screen_info=None):
|
||||
"""Command handler for the "config" command."""
|
||||
del screen_info # Currently unused.
|
||||
|
||||
parsed = self._config_argparser.parse_args(args)
|
||||
if hasattr(parsed, "property_name") and hasattr(parsed, "property_value"):
|
||||
# set.
|
||||
self._config.set(parsed.property_name, parsed.property_value)
|
||||
return self._config.summarize(highlight=parsed.property_name)
|
||||
else:
|
||||
# show.
|
||||
return self._config.summarize()
|
||||
|
160
tensorflow/python/debug/cli/cli_config.py
Normal file
160
tensorflow/python/debug/cli/cli_config.py
Normal file
@ -0,0 +1,160 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Configurations for TensorFlow Debugger (TFDBG) command-line interfaces."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import json
|
||||
import os
|
||||
|
||||
from tensorflow.python.debug.cli import debugger_cli_common
|
||||
from tensorflow.python.platform import gfile
|
||||
|
||||
RL = debugger_cli_common.RichLine
|
||||
|
||||
|
||||
class CLIConfig(object):
|
||||
"""Client-facing configurations for TFDBG command-line interfaces."""
|
||||
|
||||
_CONFIG_FILE_NAME = ".tfdbg_config"
|
||||
|
||||
_DEFAULT_CONFIG = [
|
||||
("graph_recursion_depth", 20),
|
||||
("mouse_mode", True),
|
||||
]
|
||||
|
||||
def __init__(self, config_file_path=None):
|
||||
self._config_file_path = (config_file_path or
|
||||
self._default_config_file_path())
|
||||
self._config = collections.OrderedDict(self._DEFAULT_CONFIG)
|
||||
if gfile.Exists(self._config_file_path):
|
||||
config = self._load_from_file()
|
||||
for key, value in config.items():
|
||||
self._config[key] = value
|
||||
self._save_to_file()
|
||||
|
||||
self._set_callbacks = dict()
|
||||
|
||||
def get(self, property_name):
|
||||
if property_name not in self._config:
|
||||
raise KeyError("%s is not a valid property name." % property_name)
|
||||
return self._config[property_name]
|
||||
|
||||
def set(self, property_name, property_val):
|
||||
"""Set the value of a property.
|
||||
|
||||
Supports limitd property value types: `bool`, `int` and `str`.
|
||||
|
||||
Args:
|
||||
property_name: Name of the property.
|
||||
property_val: Value of the property. If the property has `bool` type and
|
||||
this argument has `str` type, the `str` value will be parsed as a `bool`
|
||||
|
||||
Raises:
|
||||
ValueError: if a `str` property_value fails to be parsed as a `bool`.
|
||||
KeyError: if `property_name` is an invalid property name.
|
||||
"""
|
||||
if property_name not in self._config:
|
||||
raise KeyError("%s is not a valid property name." % property_name)
|
||||
|
||||
orig_val = self._config[property_name]
|
||||
if isinstance(orig_val, bool):
|
||||
if isinstance(property_val, str):
|
||||
if property_val.lower() in ("1", "true", "t", "yes", "y", "on"):
|
||||
property_val = True
|
||||
elif property_val.lower() in ("0", "false", "f", "no", "n", "off"):
|
||||
property_val = False
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid string value for bool type: %s" % property_val)
|
||||
else:
|
||||
property_val = bool(property_val)
|
||||
elif isinstance(orig_val, int):
|
||||
property_val = int(property_val)
|
||||
elif isinstance(orig_val, str):
|
||||
property_val = str(property_val)
|
||||
else:
|
||||
raise TypeError("Unsupported property type: %s" % type(orig_val))
|
||||
self._config[property_name] = property_val
|
||||
self._save_to_file()
|
||||
|
||||
# Invoke set-callback.
|
||||
if property_name in self._set_callbacks:
|
||||
self._set_callbacks[property_name](self._config)
|
||||
|
||||
def set_callback(self, property_name, callback):
|
||||
"""Set a set-callback for given property.
|
||||
|
||||
Args:
|
||||
property_name: Name of the property.
|
||||
callback: The callback as a `callable` of signature:
|
||||
def cbk(config):
|
||||
where config is the config after it is set to the new value.
|
||||
The callback is invoked each time the set() method is called with the
|
||||
matching property_name.
|
||||
|
||||
Raises:
|
||||
KeyError: If property_name does not exist.
|
||||
TypeError: If `callback` is not callable.
|
||||
"""
|
||||
if property_name not in self._config:
|
||||
raise KeyError("%s is not a valid property name." % property_name)
|
||||
if not callable(callback):
|
||||
raise TypeError("The callback object provided is not callable.")
|
||||
self._set_callbacks[property_name] = callback
|
||||
|
||||
def _default_config_file_path(self):
|
||||
return os.path.join(os.path.expanduser("~"), self._CONFIG_FILE_NAME)
|
||||
|
||||
def _save_to_file(self):
|
||||
try:
|
||||
with gfile.Open(self._config_file_path, "w") as config_file:
|
||||
json.dump(self._config, config_file)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
def summarize(self, highlight=None):
|
||||
"""Get a text summary of the config.
|
||||
|
||||
Args:
|
||||
highlight: A property name to highlight in the output.
|
||||
|
||||
Returns:
|
||||
A `RichTextLines` output.
|
||||
"""
|
||||
lines = [RL("Command-line configuration:", "bold"), RL("")]
|
||||
for name, val in self._config.items():
|
||||
highlight_attr = "bold" if name == highlight else None
|
||||
line = RL(" ")
|
||||
line += RL(name, ["underline", highlight_attr])
|
||||
line += RL(": ")
|
||||
line += RL(str(val), font_attr=highlight_attr)
|
||||
lines.append(line)
|
||||
return debugger_cli_common.rich_text_lines_from_rich_line_list(lines)
|
||||
|
||||
def _load_from_file(self):
|
||||
try:
|
||||
with gfile.Open(self._config_file_path, "r") as config_file:
|
||||
config_dict = json.load(config_file)
|
||||
config = collections.OrderedDict()
|
||||
for key in sorted(config_dict.keys()):
|
||||
config[key] = config_dict[key]
|
||||
return config
|
||||
except (IOError, ValueError):
|
||||
# The reading of the config file may fail due to IO issues or file
|
||||
# corruption. We do not want tfdbg to error out just because of that.
|
||||
return dict()
|
137
tensorflow/python/debug/cli/cli_config_test.py
Normal file
137
tensorflow/python/debug/cli/cli_config_test.py
Normal file
@ -0,0 +1,137 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for cli_config."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
from tensorflow.python.debug.cli import cli_config
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class CLIConfigTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._tmp_dir = tempfile.mkdtemp()
|
||||
self._tmp_config_path = os.path.join(self._tmp_dir, ".tfdbg_config")
|
||||
self.assertFalse(gfile.Exists(self._tmp_config_path))
|
||||
super(CLIConfigTest, self).setUp()
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self._tmp_dir)
|
||||
super(CLIConfigTest, self).tearDown()
|
||||
|
||||
def testConstructCLIConfigWithoutFile(self):
|
||||
config = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
|
||||
self.assertEqual(20, config.get("graph_recursion_depth"))
|
||||
self.assertEqual(True, config.get("mouse_mode"))
|
||||
with self.assertRaises(KeyError):
|
||||
config.get("property_that_should_not_exist")
|
||||
self.assertTrue(gfile.Exists(self._tmp_config_path))
|
||||
|
||||
def testCLIConfigForwardCompatibilityTest(self):
|
||||
config = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
|
||||
with open(self._tmp_config_path, "rt") as f:
|
||||
config_json = json.load(f)
|
||||
# Remove a field to simulate forward compatibility test.
|
||||
del config_json["graph_recursion_depth"]
|
||||
with open(self._tmp_config_path, "wt") as f:
|
||||
json.dump(config_json, f)
|
||||
|
||||
config = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
|
||||
self.assertEqual(20, config.get("graph_recursion_depth"))
|
||||
|
||||
def testModifyConfigValue(self):
|
||||
config = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
|
||||
config.set("graph_recursion_depth", 9)
|
||||
config.set("mouse_mode", False)
|
||||
self.assertEqual(9, config.get("graph_recursion_depth"))
|
||||
self.assertEqual(False, config.get("mouse_mode"))
|
||||
|
||||
def testModifyConfigValueWithTypeCasting(self):
|
||||
config = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
|
||||
config.set("graph_recursion_depth", "18")
|
||||
config.set("mouse_mode", "false")
|
||||
self.assertEqual(18, config.get("graph_recursion_depth"))
|
||||
self.assertEqual(False, config.get("mouse_mode"))
|
||||
|
||||
def testModifyConfigValueWithTypeCastingFailure(self):
|
||||
config = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
|
||||
with self.assertRaises(ValueError):
|
||||
config.set("mouse_mode", "maybe")
|
||||
|
||||
def testLoadFromModifiedConfigFile(self):
|
||||
config = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
|
||||
config.set("graph_recursion_depth", 9)
|
||||
config.set("mouse_mode", False)
|
||||
config2 = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
|
||||
self.assertEqual(9, config2.get("graph_recursion_depth"))
|
||||
self.assertEqual(False, config2.get("mouse_mode"))
|
||||
|
||||
def testSummarizeFromConfig(self):
|
||||
config = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
|
||||
output = config.summarize()
|
||||
self.assertEqual(
|
||||
["Command-line configuration:",
|
||||
"",
|
||||
" graph_recursion_depth: %d" % config.get("graph_recursion_depth"),
|
||||
" mouse_mode: %s" % config.get("mouse_mode")], output.lines)
|
||||
|
||||
def testSummarizeFromConfigWithHighlight(self):
|
||||
config = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
|
||||
output = config.summarize(highlight="mouse_mode")
|
||||
self.assertEqual(
|
||||
["Command-line configuration:",
|
||||
"",
|
||||
" graph_recursion_depth: %d" % config.get("graph_recursion_depth"),
|
||||
" mouse_mode: %s" % config.get("mouse_mode")], output.lines)
|
||||
self.assertEqual((2, 12, ["underline", "bold"]),
|
||||
output.font_attr_segs[3][0])
|
||||
self.assertEqual((14, 18, "bold"), output.font_attr_segs[3][1])
|
||||
|
||||
def testSetCallback(self):
|
||||
config = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
|
||||
|
||||
test_value = {"graph_recursion_depth": -1}
|
||||
def callback(config):
|
||||
test_value["graph_recursion_depth"] = config.get("graph_recursion_depth")
|
||||
config.set_callback("graph_recursion_depth", callback)
|
||||
|
||||
config.set("graph_recursion_depth", config.get("graph_recursion_depth") - 1)
|
||||
self.assertEqual(test_value["graph_recursion_depth"],
|
||||
config.get("graph_recursion_depth"))
|
||||
|
||||
def testSetCallbackInvalidPropertyName(self):
|
||||
config = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
|
||||
|
||||
with self.assertRaises(KeyError):
|
||||
config.set_callback("nonexistent_property_name", print)
|
||||
|
||||
def testSetCallbackNotCallable(self):
|
||||
config = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
config.set_callback("graph_recursion_depth", 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
@ -273,14 +273,16 @@ class CursesUI(base_ui.BaseUI):
|
||||
|
||||
_single_instance_lock = threading.Lock()
|
||||
|
||||
def __init__(self, on_ui_exit=None):
|
||||
def __init__(self, on_ui_exit=None, config=None):
|
||||
"""Constructor of CursesUI.
|
||||
|
||||
Args:
|
||||
on_ui_exit: (Callable) Callback invoked when the UI exits.
|
||||
config: An instance of `cli_config.CLIConfig()` carrying user-facing
|
||||
configurations.
|
||||
"""
|
||||
|
||||
base_ui.BaseUI.__init__(self, on_ui_exit=on_ui_exit)
|
||||
base_ui.BaseUI.__init__(self, on_ui_exit=on_ui_exit, config=config)
|
||||
|
||||
self._screen_init()
|
||||
self._screen_refresh_size()
|
||||
@ -445,8 +447,11 @@ class CursesUI(base_ui.BaseUI):
|
||||
curses.cbreak()
|
||||
self._stdscr.keypad(1)
|
||||
|
||||
self._mouse_enabled = enable_mouse_on_start
|
||||
self._mouse_enabled = self.config.get("mouse_mode")
|
||||
self._screen_set_mousemask()
|
||||
self.config.set_callback(
|
||||
"mouse_mode",
|
||||
lambda cfg: self._set_mouse_enabled(cfg.get("mouse_mode")))
|
||||
|
||||
self._screen_create_command_window()
|
||||
|
||||
|
@ -704,8 +704,8 @@ class CursesTest(test_util.TensorFlowTestCase):
|
||||
# The manually registered command, along with the automatically registered
|
||||
# exit commands should appear in the candidates.
|
||||
self.assertEqual(
|
||||
[["a", "babble", "exit", "h", "help", "m", "mouse", "quit"]],
|
||||
ui.candidates_lists)
|
||||
[["a", "babble", "cfg", "config", "exit", "h", "help", "m", "mouse",
|
||||
"quit"]], ui.candidates_lists)
|
||||
|
||||
# The two candidates have no common prefix. So no command should have been
|
||||
# issued.
|
||||
|
@ -768,7 +768,8 @@ class ProfileAnalyzer(object):
|
||||
def create_profiler_ui(graph,
|
||||
run_metadata,
|
||||
ui_type="curses",
|
||||
on_ui_exit=None):
|
||||
on_ui_exit=None,
|
||||
config=None):
|
||||
"""Create an instance of CursesUI based on a `tf.Graph` and `RunMetadata`.
|
||||
|
||||
Args:
|
||||
@ -776,11 +777,13 @@ def create_profiler_ui(graph,
|
||||
run_metadata: A `RunMetadata` protobuf object.
|
||||
ui_type: (str) requested UI type, e.g., "curses", "readline".
|
||||
on_ui_exit: (`Callable`) the callback to be called when the UI exits.
|
||||
config: An instance of `cli_config.CLIConfig`.
|
||||
|
||||
Returns:
|
||||
(base_ui.BaseUI) A BaseUI subtype object with a set of standard analyzer
|
||||
commands and tab-completions registered.
|
||||
"""
|
||||
del config # Currently unused.
|
||||
|
||||
analyzer = ProfileAnalyzer(graph, run_metadata)
|
||||
|
||||
|
@ -26,8 +26,8 @@ from tensorflow.python.debug.cli import debugger_cli_common
|
||||
class ReadlineUI(base_ui.BaseUI):
|
||||
"""Readline-based Command-line UI."""
|
||||
|
||||
def __init__(self, on_ui_exit=None):
|
||||
base_ui.BaseUI.__init__(self, on_ui_exit=on_ui_exit)
|
||||
def __init__(self, on_ui_exit=None, config=None):
|
||||
base_ui.BaseUI.__init__(self, on_ui_exit=on_ui_exit, config=config)
|
||||
self._init_input()
|
||||
|
||||
def _init_input(self):
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
import argparse
|
||||
import tempfile
|
||||
|
||||
from tensorflow.python.debug.cli import cli_config
|
||||
from tensorflow.python.debug.cli import debugger_cli_common
|
||||
from tensorflow.python.debug.cli import readline_ui
|
||||
from tensorflow.python.debug.cli import ui_factory
|
||||
@ -32,7 +33,9 @@ class MockReadlineUI(readline_ui.ReadlineUI):
|
||||
"""Test subclass of ReadlineUI that bypasses terminal manipulations."""
|
||||
|
||||
def __init__(self, on_ui_exit=None, command_sequence=None):
|
||||
readline_ui.ReadlineUI.__init__(self, on_ui_exit=on_ui_exit)
|
||||
readline_ui.ReadlineUI.__init__(
|
||||
self, on_ui_exit=on_ui_exit,
|
||||
config=cli_config.CLIConfig(config_file_path=tempfile.mktemp()))
|
||||
|
||||
self._command_sequence = command_sequence
|
||||
self._command_counter = 0
|
||||
@ -161,6 +164,18 @@ class CursesTest(test_util.TensorFlowTestCase):
|
||||
with gfile.Open(output_path, "r") as f:
|
||||
self.assertEqual("bar\nbar\n", f.read())
|
||||
|
||||
def testConfigSetAndShow(self):
|
||||
"""Run UI with an initial command specified."""
|
||||
|
||||
ui = MockReadlineUI(command_sequence=[
|
||||
"config set graph_recursion_depth 5", "config show", "exit"])
|
||||
ui.run_ui()
|
||||
outputs = ui.observers["screen_outputs"]
|
||||
self.assertEqual(
|
||||
["Command-line configuration:",
|
||||
"",
|
||||
" graph_recursion_depth: 5"], outputs[1].lines[:3])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
@ -23,7 +23,10 @@ import copy
|
||||
SUPPORTED_UI_TYPES = ["curses", "readline"]
|
||||
|
||||
|
||||
def get_ui(ui_type, on_ui_exit=None, available_ui_types=None):
|
||||
def get_ui(ui_type,
|
||||
on_ui_exit=None,
|
||||
available_ui_types=None,
|
||||
config=None):
|
||||
"""Create a `base_ui.BaseUI` subtype.
|
||||
|
||||
This factory method attempts to fallback to other available ui_types on
|
||||
@ -36,6 +39,8 @@ def get_ui(ui_type, on_ui_exit=None, available_ui_types=None):
|
||||
on_ui_exit: (`Callable`) the callback to be called when the UI exits.
|
||||
available_ui_types: (`None` or `list` of `str`) Manually-set available
|
||||
ui_types.
|
||||
config: An instance of `cli_config.CLIConfig()` carrying user-facing
|
||||
configurations.
|
||||
|
||||
Returns:
|
||||
A `base_ui.BaseUI` subtype object.
|
||||
@ -53,10 +58,10 @@ def get_ui(ui_type, on_ui_exit=None, available_ui_types=None):
|
||||
# pylint: disable=g-import-not-at-top
|
||||
if not ui_type or ui_type == "curses":
|
||||
from tensorflow.python.debug.cli import curses_ui
|
||||
return curses_ui.CursesUI(on_ui_exit=on_ui_exit)
|
||||
return curses_ui.CursesUI(on_ui_exit=on_ui_exit, config=config)
|
||||
elif ui_type == "readline":
|
||||
from tensorflow.python.debug.cli import readline_ui
|
||||
return readline_ui.ReadlineUI(on_ui_exit=on_ui_exit)
|
||||
return readline_ui.ReadlineUI(on_ui_exit=on_ui_exit, config=config)
|
||||
# pylint: enable=g-import-not-at-top
|
||||
except ImportError:
|
||||
available_ui_types.remove(ui_type)
|
||||
|
@ -414,7 +414,8 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
|
||||
def _prep_profile_cli_for_run_end(self, py_graph, run_metadata):
|
||||
self._init_command = "lp"
|
||||
self._run_cli = profile_analyzer_cli.create_profiler_ui(
|
||||
py_graph, run_metadata, ui_type=self._ui_type)
|
||||
py_graph, run_metadata, ui_type=self._ui_type,
|
||||
config=self._run_cli.config)
|
||||
self._title = "run-end (profiler mode): " + self._run_description
|
||||
|
||||
def _launch_cli(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user