tfdbg CLI: add eval of arbitrary Python / np expressions
RELNOTES: TensorFlow Debugger (tfdbg) command-line interface: Support evaluation of arbitrary Python and numpy (np) expressions with debug tensor names enclosed in pairs of backtics. E.g., tfdbg> eval 'np.sum(`Softmax:0`, axis=1)'. PiperOrigin-RevId: 164217384
This commit is contained in:
parent
57970f8ff5
commit
44113ce5bb
@ -204,6 +204,16 @@ addditional features:
|
||||
* Use the `prev` and `next` commands.
|
||||
* Click underlined `<--` and `-->` links near the top left corner of the
|
||||
screen.
|
||||
* Evaluation of arbitrary expressions (with `numpy` imported as `np`) using
|
||||
debug tensor names enclosed in pairs of backtics. Use the `-a` flag to
|
||||
print large-sized results in its entirety. For example:
|
||||
|
||||
```none
|
||||
tfdbg> eval np.argmax(`Softmax:0`)
|
||||
tfdbg> eval "np.matmul((`output/Identity:0` / `Softmax:0`).T, `Softmax:0`)"
|
||||
tfdbg> eval -a 'np.sum(`Softmax:0`, axis=1)'
|
||||
```
|
||||
|
||||
* Tab completion of commands and some command arguments.
|
||||
* To redirect the screen output to a file instead of the screen, end the
|
||||
command with bash-style redirection. For example, the following command
|
||||
|
@ -159,6 +159,16 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "evaluator",
|
||||
srcs = ["cli/evaluator.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":debug_data",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "analyzer_cli",
|
||||
srcs = ["cli/analyzer_cli.py"],
|
||||
@ -168,6 +178,7 @@ py_library(
|
||||
":command_parser",
|
||||
":debug_data",
|
||||
":debugger_cli_common",
|
||||
":evaluator",
|
||||
":source_utils",
|
||||
":ui_factory",
|
||||
"@six_archive//:six",
|
||||
@ -708,6 +719,22 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "evaluator_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"cli/evaluator_test.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":debug_data",
|
||||
":evaluator",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "analyzer_cli_test",
|
||||
size = "small",
|
||||
|
@ -32,6 +32,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
from tensorflow.python.debug.cli import cli_shared
|
||||
from tensorflow.python.debug.cli import command_parser
|
||||
from tensorflow.python.debug.cli import debugger_cli_common
|
||||
from tensorflow.python.debug.cli import evaluator
|
||||
from tensorflow.python.debug.cli import ui_factory
|
||||
from tensorflow.python.debug.lib import debug_data
|
||||
from tensorflow.python.debug.lib import source_utils
|
||||
@ -138,6 +139,7 @@ class DebugAnalyzer(object):
|
||||
"""
|
||||
|
||||
self._debug_dump = debug_dump
|
||||
self._evaluator = evaluator.ExpressionEvaluator(self._debug_dump)
|
||||
|
||||
# Initialize tensor filters state.
|
||||
self._tensor_filters = {}
|
||||
@ -333,6 +335,48 @@ class DebugAnalyzer(object):
|
||||
help="Regular expression filter for node name.")
|
||||
self._arg_parsers["list_source"] = ap
|
||||
|
||||
# Parser for eval.
|
||||
ap = argparse.ArgumentParser(
|
||||
description="""Evaluate an arbitrary expression. Can use tensor values
|
||||
from the current debug dump. The debug tensor names should be enclosed
|
||||
in pairs of backticks. Expressions with spaces should be enclosed in
|
||||
a pair of double quotes or a pair of single quotes. By default, numpy
|
||||
is imported as np and can be used in the expressions. E.g.,
|
||||
1) eval np.argmax(`Softmax:0`),
|
||||
2) eval 'np.sum(`Softmax:0`, axis=1)',
|
||||
3) eval "np.matmul((`output/Identity:0`/`Softmax:0`).T, `Softmax:0`)".
|
||||
""",
|
||||
usage=argparse.SUPPRESS)
|
||||
ap.add_argument(
|
||||
"expression",
|
||||
type=str,
|
||||
help="""Expression to be evaluated.
|
||||
1) in the simplest case, use <node_name>:<output_slot>, e.g.,
|
||||
hidden_0/MatMul:0.
|
||||
|
||||
2) if the default debug op "DebugIdentity" is to be overridden, use
|
||||
<node_name>:<output_slot>:<debug_op>, e.g.,
|
||||
hidden_0/MatMul:0:DebugNumericSummary.
|
||||
|
||||
3) if the tensor of the same name exists on more than one device, use
|
||||
<device_name>:<node_name>:<output_slot>[:<debug_op>], e.g.,
|
||||
/job:worker/replica:0/task:0/gpu:0:hidden_0/MatMul:0
|
||||
/job:worker/replica:0/task:2/cpu:0:hidden_0/MatMul:0:DebugNanCount.
|
||||
|
||||
4) if the tensor is executed multiple times in a given `Session.run`
|
||||
call, specify the execution index with a 0-based integer enclose in a
|
||||
pair of brackets at the end, e.g.,
|
||||
RNN/tanh:0[0]
|
||||
/job:worker/replica:0/task:0/gpu:0:RNN/tanh:0[0].""")
|
||||
ap.add_argument(
|
||||
"-a",
|
||||
"--all",
|
||||
dest="print_all",
|
||||
action="store_true",
|
||||
help="Print the tensor in its entirety, i.e., do not use ellipses "
|
||||
"(may be slow for large results).")
|
||||
self._arg_parsers["eval"] = ap
|
||||
|
||||
# TODO(cais): Implement list_nodes.
|
||||
|
||||
def add_tensor_filter(self, filter_name, filter_callable):
|
||||
@ -966,6 +1010,20 @@ class DebugAnalyzer(object):
|
||||
|
||||
return output
|
||||
|
||||
def evaluate_expression(self, args, screen_info=None):
|
||||
parsed = self._arg_parsers["eval"].parse_args(args)
|
||||
|
||||
eval_res = self._evaluator.evaluate(parsed.expression)
|
||||
|
||||
np_printoptions = cli_shared.numpy_printoptions_from_screen_info(
|
||||
screen_info)
|
||||
return cli_shared.format_tensor(
|
||||
eval_res,
|
||||
"from eval of expression '%s'" % parsed.expression,
|
||||
np_printoptions,
|
||||
print_all=parsed.print_all,
|
||||
include_numeric_summary=True)
|
||||
|
||||
def _reconstruct_print_source_command(self,
|
||||
parsed,
|
||||
line_begin,
|
||||
@ -1500,6 +1558,11 @@ def create_analyzer_ui(debug_dump,
|
||||
analyzer.list_source,
|
||||
analyzer.get_help("list_source"),
|
||||
prefix_aliases=["ls"])
|
||||
cli.register_command_handler(
|
||||
"eval",
|
||||
analyzer.evaluate_expression,
|
||||
analyzer.get_help("eval"),
|
||||
prefix_aliases=["ev"])
|
||||
|
||||
dumped_tensor_names = []
|
||||
for datum in debug_dump.dumped_tensor_data:
|
||||
|
@ -585,6 +585,11 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
|
||||
cls._analyzer.list_source,
|
||||
cls._analyzer.get_help("list_source"),
|
||||
prefix_aliases=["ls"])
|
||||
cls._registry.register_command_handler(
|
||||
"eval",
|
||||
cls._analyzer.evaluate_expression,
|
||||
cls._analyzer.get_help("eval"),
|
||||
prefix_aliases=["ev"])
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
@ -1134,6 +1139,29 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
|
||||
], out.lines)
|
||||
check_main_menu(self, out, list_tensors_enabled=True)
|
||||
|
||||
def testEvalExpression(self):
|
||||
node_name = "simple_mul_add/matmul"
|
||||
tensor_name = node_name + ":0"
|
||||
out = self._registry.dispatch_command(
|
||||
"eval", ["np.matmul(`%s`, `%s`.T)" % (tensor_name, tensor_name)],
|
||||
screen_info={"cols": 80})
|
||||
|
||||
self.assertEqual([
|
||||
"Tensor \"from eval of expression "
|
||||
"'np.matmul(`simple_mul_add/matmul:0`, "
|
||||
"`simple_mul_add/matmul:0`.T)'\":",
|
||||
" dtype: float64",
|
||||
" shape: (2, 2)",
|
||||
"",
|
||||
"Numeric summary:",
|
||||
"| - + | total |",
|
||||
"| 2 2 | 4 |",
|
||||
"| min max mean std |",
|
||||
"| -14.0 49.0 6.25 25.7524270701 |",
|
||||
"",
|
||||
"array([[ 49., -14.],",
|
||||
" [-14., 4.]])"], out.lines)
|
||||
|
||||
def testAddGetTensorFilterLambda(self):
|
||||
analyzer = analyzer_cli.DebugAnalyzer(self._debug_dump)
|
||||
analyzer.add_tensor_filter("foo_filter", lambda x, y: True)
|
||||
|
@ -173,7 +173,8 @@ def format_tensor(tensor,
|
||||
applicable) will be included.
|
||||
|
||||
Returns:
|
||||
(str) Formatted str representing the (potentially sliced) tensor.
|
||||
An instance of `debugger_cli_common.RichTextLines` representing the
|
||||
(potentially sliced) tensor.
|
||||
"""
|
||||
|
||||
if tensor_slicing:
|
||||
|
@ -24,7 +24,7 @@ import sys
|
||||
|
||||
|
||||
_BRACKETS_PATTERN = re.compile(r"\[[^\]]*\]")
|
||||
_QUOTES_PATTERN = re.compile(r"\"[^\"]*\"")
|
||||
_QUOTES_PATTERN = re.compile(r"(\"[^\"]*\"|\'[^\']*\')")
|
||||
_WHITESPACE_PATTERN = re.compile(r"\s+")
|
||||
|
||||
_NUMBER_PATTERN = re.compile(r"[-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?")
|
||||
@ -92,7 +92,8 @@ def parse_command(command):
|
||||
argument = command[idx0:start]
|
||||
|
||||
# Strip leading and trailing double quote if they are paired.
|
||||
if argument.startswith("\"") and argument.endswith("\""):
|
||||
if (argument.startswith("\"") and argument.endswith("\"") or
|
||||
argument.startswith("'") and argument.endswith("'")):
|
||||
argument = argument[1:-1]
|
||||
arguments.append(argument)
|
||||
idx0 = end
|
||||
|
@ -82,6 +82,11 @@ class ParseCommandTest(test_util.TensorFlowTestCase):
|
||||
command_parser.parse_command(command))
|
||||
# The pair of double quotes should have been stripped.
|
||||
|
||||
command = "inject_value foo 'np.zeros([100, 500])'"
|
||||
self.assertEqual(["inject_value", "foo", "np.zeros([100, 500])"],
|
||||
command_parser.parse_command(command))
|
||||
# The pair of single quotes should have been stripped.
|
||||
|
||||
command = "\"command prefix with spaces\" arg1"
|
||||
self.assertEqual(["command prefix with spaces", "arg1"],
|
||||
command_parser.parse_command(command))
|
||||
|
152
tensorflow/python/debug/cli/evaluator.py
Normal file
152
tensorflow/python/debug/cli/evaluator.py
Normal file
@ -0,0 +1,152 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Library for arbitrary expression evaluation based on a debugger data dump."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import re
|
||||
|
||||
import numpy as np # pylint: disable=unused-import
|
||||
|
||||
from tensorflow.python.debug.lib import debug_data
|
||||
|
||||
_DUMP_TENSOR_PATTERN = re.compile(r"`.*?`")
|
||||
_DEVICE_NAME_PREFIX_PATTERN = re.compile(
|
||||
r"/job:(\w)+/replica:(\d)+/task:(\d)+/(\w)+:(\d)+:")
|
||||
_EXEC_INDEX_SUFFIX_PATTERN = re.compile(r"\[(\d)*\]$")
|
||||
|
||||
_DEFAULT_DEBUG_OP = "DebugIdentity"
|
||||
|
||||
|
||||
def _parse_debug_tensor_name(debug_tensor_name):
|
||||
# pylint: disable=line-too-long
|
||||
"""Parse a debug tensor name in a to-be-evaluated expression.
|
||||
|
||||
Args:
|
||||
debug_tensor_name: name of the debug tensor, with or without
|
||||
device name as a prefix, with or without debug op, with or
|
||||
without '[<exec_index>]' as a suffix.
|
||||
E.g., without device name prefix, without debug op suffix:
|
||||
"hidden_0/MatMul:0"
|
||||
E.g., with device name prefix:
|
||||
"/job:worker/replica:0/task:1/gpu:0:hidden_0/MatMul:0"
|
||||
E.g., with debug op suffix:
|
||||
"hidden_0/MatMul:0:DebugNumericSummary"
|
||||
E.g., with device name prefix and debug op suffix:
|
||||
"/job:worker/replica:0/task:1/gpu:0:hidden_0/MatMul:0:DebugNumericSummary"
|
||||
E.g., with device name prefix, debug op and an exec index:
|
||||
"/job:worker/replica:0/task:1/gpu:0:hidden_0/MatMul:0:DebugNumericSummary[1]"
|
||||
|
||||
Returns:
|
||||
device_name: If device name prefix exists, the device name; otherwise,
|
||||
`None`.
|
||||
node_name: Name of the node.
|
||||
output_slot: Output slot index as an `int`.
|
||||
debug_op: If the debug op suffix exists, the debug op name; otheriwse,
|
||||
`None`.
|
||||
exec_index: Execution index (applicable to cases in which a debug tensor
|
||||
is computed multiple times in a `tf.Session.run` call, e.g., due to
|
||||
`tf.while_loop`). If the exec_index suffix does not exist, this value
|
||||
defaults to `0`.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input `debug_tensor_name` is malformed.
|
||||
"""
|
||||
# pylint: enable=line-too-long
|
||||
device_prefix_match = re.match(_DEVICE_NAME_PREFIX_PATTERN, debug_tensor_name)
|
||||
if device_prefix_match:
|
||||
device_name = debug_tensor_name[
|
||||
device_prefix_match.start() : device_prefix_match.end() - 1]
|
||||
debug_tensor_name = debug_tensor_name[device_prefix_match.end():]
|
||||
else:
|
||||
device_name = None
|
||||
|
||||
split_items = debug_tensor_name.split(":")
|
||||
if len(split_items) not in (2, 3):
|
||||
raise ValueError(
|
||||
"The debug tensor name in the to-be-evaluated expression is malformed: "
|
||||
"'%s'" % debug_tensor_name)
|
||||
# TODO(cais): Provide examples of good debug tensor names in the error
|
||||
# message.
|
||||
|
||||
exec_index_match = re.search(_EXEC_INDEX_SUFFIX_PATTERN, split_items[-1])
|
||||
if exec_index_match:
|
||||
exec_index = int(split_items[-1][
|
||||
exec_index_match.start() + 1 : exec_index_match.end() - 1])
|
||||
split_items[-1] = split_items[-1][:exec_index_match.start()]
|
||||
else:
|
||||
exec_index = 0
|
||||
|
||||
if len(split_items) == 2:
|
||||
node_name = split_items[0]
|
||||
output_slot = int(split_items[1])
|
||||
debug_op = _DEFAULT_DEBUG_OP
|
||||
else:
|
||||
split_items = debug_tensor_name.split(":")
|
||||
node_name = split_items[0]
|
||||
output_slot = int(split_items[1])
|
||||
debug_op = split_items[2]
|
||||
|
||||
return device_name, node_name, output_slot, debug_op, exec_index
|
||||
|
||||
|
||||
class ExpressionEvaluator(object):
|
||||
"""Evaluates Python expressions using debug tensor values from a dump."""
|
||||
|
||||
def __init__(self, dump):
|
||||
"""Constructor of ExpressionEvaluator.
|
||||
|
||||
Args:
|
||||
dump: an instance of `DebugDumpDir`.
|
||||
"""
|
||||
self._dump = dump
|
||||
self._cached_tensor_values = dict()
|
||||
|
||||
def evaluate(self, expression):
|
||||
"""Parse an expression.
|
||||
|
||||
Args:
|
||||
expression: the expression to be parsed.
|
||||
|
||||
Returns:
|
||||
The result of the evaluation.
|
||||
|
||||
Raises:
|
||||
ValueError: If the value of one or more of the debug tensors in the
|
||||
expression are not available.
|
||||
"""
|
||||
dump_tensors_iter = re.finditer(_DUMP_TENSOR_PATTERN, expression)
|
||||
rewritten_expression = expression
|
||||
for match in reversed(list(dump_tensors_iter)):
|
||||
tensor_name = match.group(0)[1:-1].strip()
|
||||
device_name, node_name, output_slot, debug_op, exec_index = (
|
||||
_parse_debug_tensor_name(tensor_name))
|
||||
if tensor_name not in self._cached_tensor_values:
|
||||
try:
|
||||
value = self._dump.get_tensors(
|
||||
node_name, output_slot, debug_op,
|
||||
device_name=device_name)[exec_index]
|
||||
except debug_data.WatchKeyDoesNotExistInDebugDumpDirError:
|
||||
raise ValueError(
|
||||
"Eval failed due to the value of %s:%d:DebugIdentity being "
|
||||
"unavailable" % (node_name, output_slot))
|
||||
self._cached_tensor_values[tensor_name] = value
|
||||
rewritten_expression = (
|
||||
rewritten_expression[:match.start(0)] +
|
||||
"self._cached_tensor_values['" + tensor_name + "']" +
|
||||
rewritten_expression[match.end(0):])
|
||||
|
||||
return eval(rewritten_expression) # pylint: disable=eval-used
|
268
tensorflow/python/debug/cli/evaluator_test.py
Normal file
268
tensorflow/python/debug/cli/evaluator_test.py
Normal file
@ -0,0 +1,268 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for arbitrary expression evaluator."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.debug.cli import evaluator
|
||||
from tensorflow.python.debug.lib import debug_data
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ParseDebugTensorNameTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testParseNamesWithoutPrefixOrSuffix(self):
|
||||
device_name, node_name, output_slot, debug_op, exec_index = (
|
||||
evaluator._parse_debug_tensor_name("foo:1"))
|
||||
self.assertIsNone(device_name)
|
||||
self.assertEqual("foo", node_name)
|
||||
self.assertEqual(1, output_slot)
|
||||
self.assertEqual("DebugIdentity", debug_op)
|
||||
self.assertEqual(0, exec_index)
|
||||
|
||||
device_name, node_name, output_slot, debug_op, exec_index = (
|
||||
evaluator._parse_debug_tensor_name("hidden_0/Weights:0"))
|
||||
self.assertIsNone(device_name)
|
||||
self.assertEqual("hidden_0/Weights", node_name)
|
||||
self.assertEqual(0, output_slot)
|
||||
self.assertEqual("DebugIdentity", debug_op)
|
||||
self.assertEqual(0, exec_index)
|
||||
|
||||
def testParseNamesWithoutPrefixWithDebugOpSuffix(self):
|
||||
device_name, node_name, output_slot, debug_op, exec_index = (
|
||||
evaluator._parse_debug_tensor_name("foo:1:DebugNanCount"))
|
||||
self.assertIsNone(device_name)
|
||||
self.assertEqual("foo", node_name)
|
||||
self.assertEqual(1, output_slot)
|
||||
self.assertEqual("DebugNanCount", debug_op)
|
||||
self.assertEqual(0, exec_index)
|
||||
|
||||
device_name, node_name, output_slot, debug_op, exec_index = (
|
||||
evaluator._parse_debug_tensor_name(
|
||||
"hidden_0/Weights:0:DebugNumericSummary"))
|
||||
self.assertIsNone(device_name)
|
||||
self.assertEqual("hidden_0/Weights", node_name)
|
||||
self.assertEqual(0, output_slot)
|
||||
self.assertEqual("DebugNumericSummary", debug_op)
|
||||
self.assertEqual(0, exec_index)
|
||||
|
||||
def testParseNamesWithDeviceNamePrefixWithoutDebugOpSuffix(self):
|
||||
device_name, node_name, output_slot, debug_op, exec_index = (
|
||||
evaluator._parse_debug_tensor_name(
|
||||
"/job:ps/replica:0/task:2/cpu:0:foo:1"))
|
||||
self.assertEqual("/job:ps/replica:0/task:2/cpu:0", device_name)
|
||||
self.assertEqual("foo", node_name)
|
||||
self.assertEqual(1, output_slot)
|
||||
self.assertEqual("DebugIdentity", debug_op)
|
||||
self.assertEqual(0, exec_index)
|
||||
|
||||
device_name, node_name, output_slot, debug_op, exec_index = (
|
||||
evaluator._parse_debug_tensor_name(
|
||||
"/job:worker/replica:0/task:3/gpu:0:hidden_0/Weights:0"))
|
||||
self.assertEqual("/job:worker/replica:0/task:3/gpu:0", device_name)
|
||||
self.assertEqual("hidden_0/Weights", node_name)
|
||||
self.assertEqual(0, output_slot)
|
||||
self.assertEqual("DebugIdentity", debug_op)
|
||||
self.assertEqual(0, exec_index)
|
||||
|
||||
def testParseNamesWithDeviceNamePrefixWithDebugOpSuffix(self):
|
||||
device_name, node_name, output_slot, debug_op, exec_index = (
|
||||
evaluator._parse_debug_tensor_name(
|
||||
"/job:ps/replica:0/task:2/cpu:0:foo:1:DebugNanCount"))
|
||||
self.assertEqual("/job:ps/replica:0/task:2/cpu:0", device_name)
|
||||
self.assertEqual("foo", node_name)
|
||||
self.assertEqual(1, output_slot)
|
||||
self.assertEqual("DebugNanCount", debug_op)
|
||||
self.assertEqual(0, exec_index)
|
||||
|
||||
device_name, node_name, output_slot, debug_op, exec_index = (
|
||||
evaluator._parse_debug_tensor_name(
|
||||
"/job:worker/replica:0/task:3/gpu:0:"
|
||||
"hidden_0/Weights:0:DebugNumericSummary"))
|
||||
self.assertEqual("/job:worker/replica:0/task:3/gpu:0", device_name)
|
||||
self.assertEqual("hidden_0/Weights", node_name)
|
||||
self.assertEqual(0, output_slot)
|
||||
self.assertEqual("DebugNumericSummary", debug_op)
|
||||
self.assertEqual(0, exec_index)
|
||||
|
||||
def testParseMalformedDebugTensorName(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
r"The debug tensor name in the to-be-evaluated expression is "
|
||||
r"malformed:"):
|
||||
evaluator._parse_debug_tensor_name(
|
||||
"/job:ps/replica:0/task:2/cpu:0:foo:1:DebugNanCount:1337")
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
r"The debug tensor name in the to-be-evaluated expression is "
|
||||
r"malformed:"):
|
||||
evaluator._parse_debug_tensor_name(
|
||||
"/job:ps/replica:0/cpu:0:foo:1:DebugNanCount")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
evaluator._parse_debug_tensor_name(
|
||||
"foo:1:DebugNanCount[]")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
evaluator._parse_debug_tensor_name(
|
||||
"foo:1[DebugNanCount]")
|
||||
|
||||
def testParseNamesWithExecIndex(self):
|
||||
device_name, node_name, output_slot, debug_op, exec_index = (
|
||||
evaluator._parse_debug_tensor_name("foo:1[20]"))
|
||||
self.assertIsNone(device_name)
|
||||
self.assertEqual("foo", node_name)
|
||||
self.assertEqual(1, output_slot)
|
||||
self.assertEqual("DebugIdentity", debug_op)
|
||||
self.assertEqual(20, exec_index)
|
||||
|
||||
device_name, node_name, output_slot, debug_op, exec_index = (
|
||||
evaluator._parse_debug_tensor_name("hidden_0/Weights:0[3]"))
|
||||
self.assertIsNone(device_name)
|
||||
self.assertEqual("hidden_0/Weights", node_name)
|
||||
self.assertEqual(0, output_slot)
|
||||
self.assertEqual("DebugIdentity", debug_op)
|
||||
self.assertEqual(3, exec_index)
|
||||
|
||||
|
||||
class EvaluatorTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testEvaluateSingleTensor(self):
|
||||
dump = test.mock.MagicMock()
|
||||
def fake_get_tensors(node_name, output_slot, debug_op, device_name=None):
|
||||
del node_name, output_slot, debug_op, device_name # Unused.
|
||||
return [np.array([[1.0, 2.0, 3.0]])]
|
||||
|
||||
with test.mock.patch.object(
|
||||
dump, "get_tensors", side_effect=fake_get_tensors, autospec=True):
|
||||
ev = evaluator.ExpressionEvaluator(dump)
|
||||
self.assertEqual(3, ev.evaluate("np.size(`a:0`)"))
|
||||
|
||||
# Whitespace in backticks should be tolerated.
|
||||
self.assertEqual(3, ev.evaluate("np.size(` a:0 `)"))
|
||||
|
||||
def testEvaluateTwoTensors(self):
|
||||
dump = test.mock.MagicMock()
|
||||
def fake_get_tensors(node_name, output_slot, debug_op, device_name=None):
|
||||
del debug_op, device_name # Unused.
|
||||
if node_name == "a" and output_slot == 0:
|
||||
return [np.array([[1.0, -2.0], [0.0, 1.0]])]
|
||||
elif node_name == "b" and output_slot == 0:
|
||||
return [np.array([[-1.0], [1.0]])]
|
||||
|
||||
with test.mock.patch.object(
|
||||
dump, "get_tensors", side_effect=fake_get_tensors, autospec=True):
|
||||
ev = evaluator.ExpressionEvaluator(dump)
|
||||
self.assertAllClose([[-3.0], [1.0]],
|
||||
ev.evaluate("np.matmul(`a:0`, `b:0`)"))
|
||||
self.assertAllClose(
|
||||
[[-4.0], [2.0]], ev.evaluate("np.matmul(`a:0`, `b:0`) + `b:0`"))
|
||||
|
||||
def testEvaluateNoneExistentTensorGeneratesError(self):
|
||||
dump = test.mock.MagicMock()
|
||||
def fake_get_tensors(node_name, output_slot, debug_op, device_name=None):
|
||||
del node_name, output_slot, debug_op, device_name # Unused.
|
||||
raise debug_data.WatchKeyDoesNotExistInDebugDumpDirError()
|
||||
|
||||
with test.mock.patch.object(
|
||||
dump, "get_tensors", side_effect=fake_get_tensors, autospec=True):
|
||||
ev = evaluator.ExpressionEvaluator(dump)
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Eval failed due to the value of .* being unavailable"):
|
||||
ev.evaluate("np.matmul(`a:0`, `b:0`)")
|
||||
|
||||
def testEvaluateWithMultipleDevicesContainingTheSameTensorName(self):
|
||||
dump = test.mock.MagicMock()
|
||||
def fake_get_tensors(node_name, output_slot, debug_op, device_name=None):
|
||||
del output_slot, debug_op # Unused.
|
||||
if node_name == "a" and device_name is None:
|
||||
raise ValueError(
|
||||
"There are multiple (2) devices with nodes named 'a' but "
|
||||
"device_name is not specified")
|
||||
elif (node_name == "a" and
|
||||
device_name == "/job:worker/replica:0/task:0/cpu:0"):
|
||||
return [np.array(10.0)]
|
||||
elif (node_name == "a" and
|
||||
device_name == "/job:worker/replica:0/task:1/cpu:0"):
|
||||
return [np.array(20.0)]
|
||||
|
||||
with test.mock.patch.object(
|
||||
dump, "get_tensors", side_effect=fake_get_tensors, autospec=True):
|
||||
ev = evaluator.ExpressionEvaluator(dump)
|
||||
with self.assertRaisesRegexp(ValueError, r"multiple \(2\) devices"):
|
||||
ev.evaluate("`a:0` + `a:0`")
|
||||
|
||||
self.assertAllClose(
|
||||
30.0,
|
||||
ev.evaluate("`/job:worker/replica:0/task:0/cpu:0:a:0` + "
|
||||
"`/job:worker/replica:0/task:1/cpu:0:a:0`"))
|
||||
|
||||
def testEvaluateWithNonDefaultDebugOp(self):
|
||||
dump = test.mock.MagicMock()
|
||||
def fake_get_tensors(node_name, output_slot, debug_op, device_name=None):
|
||||
del device_name # Unused.
|
||||
if node_name == "a" and output_slot == 0 and debug_op == "DebugIdentity":
|
||||
return [np.array([[-1.0], [1.0]])]
|
||||
elif node_name == "a" and output_slot == 0 and debug_op == "DebugFoo":
|
||||
return [np.array([[-2.0, 2.0]])]
|
||||
|
||||
with test.mock.patch.object(
|
||||
dump, "get_tensors", side_effect=fake_get_tensors, autospec=True):
|
||||
ev = evaluator.ExpressionEvaluator(dump)
|
||||
self.assertAllClose(
|
||||
[[4.0]],
|
||||
ev.evaluate("np.matmul(`a:0:DebugFoo`, `a:0:DebugIdentity`)"))
|
||||
|
||||
def testEvaluateWithMultipleExecIndexes(self):
|
||||
dump = test.mock.MagicMock()
|
||||
def fake_get_tensors(node_name, output_slot, debug_op, device_name=None):
|
||||
del debug_op, device_name # Unused.
|
||||
if node_name == "a" and output_slot == 0:
|
||||
return [np.array([[-1.0], [1.0]]), np.array([[-2.0], [2.0]])]
|
||||
|
||||
with test.mock.patch.object(
|
||||
dump, "get_tensors", side_effect=fake_get_tensors, autospec=True):
|
||||
ev = evaluator.ExpressionEvaluator(dump)
|
||||
self.assertAllClose(
|
||||
[[4.0]], ev.evaluate("np.matmul(`a:0[1]`.T, `a:0[0]`)"))
|
||||
|
||||
def testEvaluateExpressionWithUnmatchedBacktick(self):
|
||||
dump = test.mock.MagicMock()
|
||||
ev = evaluator.ExpressionEvaluator(dump)
|
||||
with self.assertRaises(SyntaxError):
|
||||
ev.evaluate("np.matmul(`a:0`, `b:0`) + `b:0")
|
||||
|
||||
def testEvaluateExpressionWithInvalidDebugTensorName(self):
|
||||
dump = test.mock.MagicMock()
|
||||
ev = evaluator.ExpressionEvaluator(dump)
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, r".* tensor name .* expression .* malformed"):
|
||||
ev.evaluate("np.matmul(`a`, `b`)")
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, r".* tensor name .* expression .* malformed"):
|
||||
ev.evaluate("np.matmul(`a:0:DebugIdentity:0`, `b:1:DebugNanCount:2`)")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
ev.evaluate("np.matmul(`a:0[]`, `b:0[]`)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
Loading…
x
Reference in New Issue
Block a user