tfdbg CLI: add eval of arbitrary Python / np expressions

RELNOTES: TensorFlow Debugger (tfdbg) command-line interface: Support evaluation of arbitrary Python and numpy (np) expressions with debug tensor names enclosed in pairs of backtics. E.g., tfdbg> eval 'np.sum(`Softmax:0`, axis=1)'.
PiperOrigin-RevId: 164217384
This commit is contained in:
Shanqing Cai 2017-08-03 19:36:51 -07:00 committed by TensorFlower Gardener
parent 57970f8ff5
commit 44113ce5bb
9 changed files with 558 additions and 3 deletions

View File

@ -204,6 +204,16 @@ addditional features:
* Use the `prev` and `next` commands.
* Click underlined `<--` and `-->` links near the top left corner of the
screen.
* Evaluation of arbitrary expressions (with `numpy` imported as `np`) using
debug tensor names enclosed in pairs of backtics. Use the `-a` flag to
print large-sized results in its entirety. For example:
```none
tfdbg> eval np.argmax(`Softmax:0`)
tfdbg> eval "np.matmul((`output/Identity:0` / `Softmax:0`).T, `Softmax:0`)"
tfdbg> eval -a 'np.sum(`Softmax:0`, axis=1)'
```
* Tab completion of commands and some command arguments.
* To redirect the screen output to a file instead of the screen, end the
command with bash-style redirection. For example, the following command

View File

@ -159,6 +159,16 @@ py_library(
],
)
py_library(
name = "evaluator",
srcs = ["cli/evaluator.py"],
srcs_version = "PY2AND3",
deps = [
":debug_data",
"//third_party/py/numpy",
],
)
py_library(
name = "analyzer_cli",
srcs = ["cli/analyzer_cli.py"],
@ -168,6 +178,7 @@ py_library(
":command_parser",
":debug_data",
":debugger_cli_common",
":evaluator",
":source_utils",
":ui_factory",
"@six_archive//:six",
@ -708,6 +719,22 @@ py_test(
],
)
py_test(
name = "evaluator_test",
size = "small",
srcs = [
"cli/evaluator_test.py",
],
srcs_version = "PY2AND3",
deps = [
":debug_data",
":evaluator",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib",
"//third_party/py/numpy",
],
)
cuda_py_test(
name = "analyzer_cli_test",
size = "small",

View File

@ -32,6 +32,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.debug.cli import cli_shared
from tensorflow.python.debug.cli import command_parser
from tensorflow.python.debug.cli import debugger_cli_common
from tensorflow.python.debug.cli import evaluator
from tensorflow.python.debug.cli import ui_factory
from tensorflow.python.debug.lib import debug_data
from tensorflow.python.debug.lib import source_utils
@ -138,6 +139,7 @@ class DebugAnalyzer(object):
"""
self._debug_dump = debug_dump
self._evaluator = evaluator.ExpressionEvaluator(self._debug_dump)
# Initialize tensor filters state.
self._tensor_filters = {}
@ -333,6 +335,48 @@ class DebugAnalyzer(object):
help="Regular expression filter for node name.")
self._arg_parsers["list_source"] = ap
# Parser for eval.
ap = argparse.ArgumentParser(
description="""Evaluate an arbitrary expression. Can use tensor values
from the current debug dump. The debug tensor names should be enclosed
in pairs of backticks. Expressions with spaces should be enclosed in
a pair of double quotes or a pair of single quotes. By default, numpy
is imported as np and can be used in the expressions. E.g.,
1) eval np.argmax(`Softmax:0`),
2) eval 'np.sum(`Softmax:0`, axis=1)',
3) eval "np.matmul((`output/Identity:0`/`Softmax:0`).T, `Softmax:0`)".
""",
usage=argparse.SUPPRESS)
ap.add_argument(
"expression",
type=str,
help="""Expression to be evaluated.
1) in the simplest case, use <node_name>:<output_slot>, e.g.,
hidden_0/MatMul:0.
2) if the default debug op "DebugIdentity" is to be overridden, use
<node_name>:<output_slot>:<debug_op>, e.g.,
hidden_0/MatMul:0:DebugNumericSummary.
3) if the tensor of the same name exists on more than one device, use
<device_name>:<node_name>:<output_slot>[:<debug_op>], e.g.,
/job:worker/replica:0/task:0/gpu:0:hidden_0/MatMul:0
/job:worker/replica:0/task:2/cpu:0:hidden_0/MatMul:0:DebugNanCount.
4) if the tensor is executed multiple times in a given `Session.run`
call, specify the execution index with a 0-based integer enclose in a
pair of brackets at the end, e.g.,
RNN/tanh:0[0]
/job:worker/replica:0/task:0/gpu:0:RNN/tanh:0[0].""")
ap.add_argument(
"-a",
"--all",
dest="print_all",
action="store_true",
help="Print the tensor in its entirety, i.e., do not use ellipses "
"(may be slow for large results).")
self._arg_parsers["eval"] = ap
# TODO(cais): Implement list_nodes.
def add_tensor_filter(self, filter_name, filter_callable):
@ -966,6 +1010,20 @@ class DebugAnalyzer(object):
return output
def evaluate_expression(self, args, screen_info=None):
parsed = self._arg_parsers["eval"].parse_args(args)
eval_res = self._evaluator.evaluate(parsed.expression)
np_printoptions = cli_shared.numpy_printoptions_from_screen_info(
screen_info)
return cli_shared.format_tensor(
eval_res,
"from eval of expression '%s'" % parsed.expression,
np_printoptions,
print_all=parsed.print_all,
include_numeric_summary=True)
def _reconstruct_print_source_command(self,
parsed,
line_begin,
@ -1500,6 +1558,11 @@ def create_analyzer_ui(debug_dump,
analyzer.list_source,
analyzer.get_help("list_source"),
prefix_aliases=["ls"])
cli.register_command_handler(
"eval",
analyzer.evaluate_expression,
analyzer.get_help("eval"),
prefix_aliases=["ev"])
dumped_tensor_names = []
for datum in debug_dump.dumped_tensor_data:

View File

@ -585,6 +585,11 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
cls._analyzer.list_source,
cls._analyzer.get_help("list_source"),
prefix_aliases=["ls"])
cls._registry.register_command_handler(
"eval",
cls._analyzer.evaluate_expression,
cls._analyzer.get_help("eval"),
prefix_aliases=["ev"])
@classmethod
def tearDownClass(cls):
@ -1134,6 +1139,29 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
], out.lines)
check_main_menu(self, out, list_tensors_enabled=True)
def testEvalExpression(self):
node_name = "simple_mul_add/matmul"
tensor_name = node_name + ":0"
out = self._registry.dispatch_command(
"eval", ["np.matmul(`%s`, `%s`.T)" % (tensor_name, tensor_name)],
screen_info={"cols": 80})
self.assertEqual([
"Tensor \"from eval of expression "
"'np.matmul(`simple_mul_add/matmul:0`, "
"`simple_mul_add/matmul:0`.T)'\":",
" dtype: float64",
" shape: (2, 2)",
"",
"Numeric summary:",
"| - + | total |",
"| 2 2 | 4 |",
"| min max mean std |",
"| -14.0 49.0 6.25 25.7524270701 |",
"",
"array([[ 49., -14.],",
" [-14., 4.]])"], out.lines)
def testAddGetTensorFilterLambda(self):
analyzer = analyzer_cli.DebugAnalyzer(self._debug_dump)
analyzer.add_tensor_filter("foo_filter", lambda x, y: True)

View File

@ -173,7 +173,8 @@ def format_tensor(tensor,
applicable) will be included.
Returns:
(str) Formatted str representing the (potentially sliced) tensor.
An instance of `debugger_cli_common.RichTextLines` representing the
(potentially sliced) tensor.
"""
if tensor_slicing:

View File

@ -24,7 +24,7 @@ import sys
_BRACKETS_PATTERN = re.compile(r"\[[^\]]*\]")
_QUOTES_PATTERN = re.compile(r"\"[^\"]*\"")
_QUOTES_PATTERN = re.compile(r"(\"[^\"]*\"|\'[^\']*\')")
_WHITESPACE_PATTERN = re.compile(r"\s+")
_NUMBER_PATTERN = re.compile(r"[-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?")
@ -92,7 +92,8 @@ def parse_command(command):
argument = command[idx0:start]
# Strip leading and trailing double quote if they are paired.
if argument.startswith("\"") and argument.endswith("\""):
if (argument.startswith("\"") and argument.endswith("\"") or
argument.startswith("'") and argument.endswith("'")):
argument = argument[1:-1]
arguments.append(argument)
idx0 = end

View File

@ -82,6 +82,11 @@ class ParseCommandTest(test_util.TensorFlowTestCase):
command_parser.parse_command(command))
# The pair of double quotes should have been stripped.
command = "inject_value foo 'np.zeros([100, 500])'"
self.assertEqual(["inject_value", "foo", "np.zeros([100, 500])"],
command_parser.parse_command(command))
# The pair of single quotes should have been stripped.
command = "\"command prefix with spaces\" arg1"
self.assertEqual(["command prefix with spaces", "arg1"],
command_parser.parse_command(command))

View File

@ -0,0 +1,152 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Library for arbitrary expression evaluation based on a debugger data dump."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
import numpy as np # pylint: disable=unused-import
from tensorflow.python.debug.lib import debug_data
_DUMP_TENSOR_PATTERN = re.compile(r"`.*?`")
_DEVICE_NAME_PREFIX_PATTERN = re.compile(
r"/job:(\w)+/replica:(\d)+/task:(\d)+/(\w)+:(\d)+:")
_EXEC_INDEX_SUFFIX_PATTERN = re.compile(r"\[(\d)*\]$")
_DEFAULT_DEBUG_OP = "DebugIdentity"
def _parse_debug_tensor_name(debug_tensor_name):
# pylint: disable=line-too-long
"""Parse a debug tensor name in a to-be-evaluated expression.
Args:
debug_tensor_name: name of the debug tensor, with or without
device name as a prefix, with or without debug op, with or
without '[<exec_index>]' as a suffix.
E.g., without device name prefix, without debug op suffix:
"hidden_0/MatMul:0"
E.g., with device name prefix:
"/job:worker/replica:0/task:1/gpu:0:hidden_0/MatMul:0"
E.g., with debug op suffix:
"hidden_0/MatMul:0:DebugNumericSummary"
E.g., with device name prefix and debug op suffix:
"/job:worker/replica:0/task:1/gpu:0:hidden_0/MatMul:0:DebugNumericSummary"
E.g., with device name prefix, debug op and an exec index:
"/job:worker/replica:0/task:1/gpu:0:hidden_0/MatMul:0:DebugNumericSummary[1]"
Returns:
device_name: If device name prefix exists, the device name; otherwise,
`None`.
node_name: Name of the node.
output_slot: Output slot index as an `int`.
debug_op: If the debug op suffix exists, the debug op name; otheriwse,
`None`.
exec_index: Execution index (applicable to cases in which a debug tensor
is computed multiple times in a `tf.Session.run` call, e.g., due to
`tf.while_loop`). If the exec_index suffix does not exist, this value
defaults to `0`.
Raises:
ValueError: If the input `debug_tensor_name` is malformed.
"""
# pylint: enable=line-too-long
device_prefix_match = re.match(_DEVICE_NAME_PREFIX_PATTERN, debug_tensor_name)
if device_prefix_match:
device_name = debug_tensor_name[
device_prefix_match.start() : device_prefix_match.end() - 1]
debug_tensor_name = debug_tensor_name[device_prefix_match.end():]
else:
device_name = None
split_items = debug_tensor_name.split(":")
if len(split_items) not in (2, 3):
raise ValueError(
"The debug tensor name in the to-be-evaluated expression is malformed: "
"'%s'" % debug_tensor_name)
# TODO(cais): Provide examples of good debug tensor names in the error
# message.
exec_index_match = re.search(_EXEC_INDEX_SUFFIX_PATTERN, split_items[-1])
if exec_index_match:
exec_index = int(split_items[-1][
exec_index_match.start() + 1 : exec_index_match.end() - 1])
split_items[-1] = split_items[-1][:exec_index_match.start()]
else:
exec_index = 0
if len(split_items) == 2:
node_name = split_items[0]
output_slot = int(split_items[1])
debug_op = _DEFAULT_DEBUG_OP
else:
split_items = debug_tensor_name.split(":")
node_name = split_items[0]
output_slot = int(split_items[1])
debug_op = split_items[2]
return device_name, node_name, output_slot, debug_op, exec_index
class ExpressionEvaluator(object):
"""Evaluates Python expressions using debug tensor values from a dump."""
def __init__(self, dump):
"""Constructor of ExpressionEvaluator.
Args:
dump: an instance of `DebugDumpDir`.
"""
self._dump = dump
self._cached_tensor_values = dict()
def evaluate(self, expression):
"""Parse an expression.
Args:
expression: the expression to be parsed.
Returns:
The result of the evaluation.
Raises:
ValueError: If the value of one or more of the debug tensors in the
expression are not available.
"""
dump_tensors_iter = re.finditer(_DUMP_TENSOR_PATTERN, expression)
rewritten_expression = expression
for match in reversed(list(dump_tensors_iter)):
tensor_name = match.group(0)[1:-1].strip()
device_name, node_name, output_slot, debug_op, exec_index = (
_parse_debug_tensor_name(tensor_name))
if tensor_name not in self._cached_tensor_values:
try:
value = self._dump.get_tensors(
node_name, output_slot, debug_op,
device_name=device_name)[exec_index]
except debug_data.WatchKeyDoesNotExistInDebugDumpDirError:
raise ValueError(
"Eval failed due to the value of %s:%d:DebugIdentity being "
"unavailable" % (node_name, output_slot))
self._cached_tensor_values[tensor_name] = value
rewritten_expression = (
rewritten_expression[:match.start(0)] +
"self._cached_tensor_values['" + tensor_name + "']" +
rewritten_expression[match.end(0):])
return eval(rewritten_expression) # pylint: disable=eval-used

View File

@ -0,0 +1,268 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for arbitrary expression evaluator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.debug.cli import evaluator
from tensorflow.python.debug.lib import debug_data
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
class ParseDebugTensorNameTest(test_util.TensorFlowTestCase):
def testParseNamesWithoutPrefixOrSuffix(self):
device_name, node_name, output_slot, debug_op, exec_index = (
evaluator._parse_debug_tensor_name("foo:1"))
self.assertIsNone(device_name)
self.assertEqual("foo", node_name)
self.assertEqual(1, output_slot)
self.assertEqual("DebugIdentity", debug_op)
self.assertEqual(0, exec_index)
device_name, node_name, output_slot, debug_op, exec_index = (
evaluator._parse_debug_tensor_name("hidden_0/Weights:0"))
self.assertIsNone(device_name)
self.assertEqual("hidden_0/Weights", node_name)
self.assertEqual(0, output_slot)
self.assertEqual("DebugIdentity", debug_op)
self.assertEqual(0, exec_index)
def testParseNamesWithoutPrefixWithDebugOpSuffix(self):
device_name, node_name, output_slot, debug_op, exec_index = (
evaluator._parse_debug_tensor_name("foo:1:DebugNanCount"))
self.assertIsNone(device_name)
self.assertEqual("foo", node_name)
self.assertEqual(1, output_slot)
self.assertEqual("DebugNanCount", debug_op)
self.assertEqual(0, exec_index)
device_name, node_name, output_slot, debug_op, exec_index = (
evaluator._parse_debug_tensor_name(
"hidden_0/Weights:0:DebugNumericSummary"))
self.assertIsNone(device_name)
self.assertEqual("hidden_0/Weights", node_name)
self.assertEqual(0, output_slot)
self.assertEqual("DebugNumericSummary", debug_op)
self.assertEqual(0, exec_index)
def testParseNamesWithDeviceNamePrefixWithoutDebugOpSuffix(self):
device_name, node_name, output_slot, debug_op, exec_index = (
evaluator._parse_debug_tensor_name(
"/job:ps/replica:0/task:2/cpu:0:foo:1"))
self.assertEqual("/job:ps/replica:0/task:2/cpu:0", device_name)
self.assertEqual("foo", node_name)
self.assertEqual(1, output_slot)
self.assertEqual("DebugIdentity", debug_op)
self.assertEqual(0, exec_index)
device_name, node_name, output_slot, debug_op, exec_index = (
evaluator._parse_debug_tensor_name(
"/job:worker/replica:0/task:3/gpu:0:hidden_0/Weights:0"))
self.assertEqual("/job:worker/replica:0/task:3/gpu:0", device_name)
self.assertEqual("hidden_0/Weights", node_name)
self.assertEqual(0, output_slot)
self.assertEqual("DebugIdentity", debug_op)
self.assertEqual(0, exec_index)
def testParseNamesWithDeviceNamePrefixWithDebugOpSuffix(self):
device_name, node_name, output_slot, debug_op, exec_index = (
evaluator._parse_debug_tensor_name(
"/job:ps/replica:0/task:2/cpu:0:foo:1:DebugNanCount"))
self.assertEqual("/job:ps/replica:0/task:2/cpu:0", device_name)
self.assertEqual("foo", node_name)
self.assertEqual(1, output_slot)
self.assertEqual("DebugNanCount", debug_op)
self.assertEqual(0, exec_index)
device_name, node_name, output_slot, debug_op, exec_index = (
evaluator._parse_debug_tensor_name(
"/job:worker/replica:0/task:3/gpu:0:"
"hidden_0/Weights:0:DebugNumericSummary"))
self.assertEqual("/job:worker/replica:0/task:3/gpu:0", device_name)
self.assertEqual("hidden_0/Weights", node_name)
self.assertEqual(0, output_slot)
self.assertEqual("DebugNumericSummary", debug_op)
self.assertEqual(0, exec_index)
def testParseMalformedDebugTensorName(self):
with self.assertRaisesRegexp(
ValueError,
r"The debug tensor name in the to-be-evaluated expression is "
r"malformed:"):
evaluator._parse_debug_tensor_name(
"/job:ps/replica:0/task:2/cpu:0:foo:1:DebugNanCount:1337")
with self.assertRaisesRegexp(
ValueError,
r"The debug tensor name in the to-be-evaluated expression is "
r"malformed:"):
evaluator._parse_debug_tensor_name(
"/job:ps/replica:0/cpu:0:foo:1:DebugNanCount")
with self.assertRaises(ValueError):
evaluator._parse_debug_tensor_name(
"foo:1:DebugNanCount[]")
with self.assertRaises(ValueError):
evaluator._parse_debug_tensor_name(
"foo:1[DebugNanCount]")
def testParseNamesWithExecIndex(self):
device_name, node_name, output_slot, debug_op, exec_index = (
evaluator._parse_debug_tensor_name("foo:1[20]"))
self.assertIsNone(device_name)
self.assertEqual("foo", node_name)
self.assertEqual(1, output_slot)
self.assertEqual("DebugIdentity", debug_op)
self.assertEqual(20, exec_index)
device_name, node_name, output_slot, debug_op, exec_index = (
evaluator._parse_debug_tensor_name("hidden_0/Weights:0[3]"))
self.assertIsNone(device_name)
self.assertEqual("hidden_0/Weights", node_name)
self.assertEqual(0, output_slot)
self.assertEqual("DebugIdentity", debug_op)
self.assertEqual(3, exec_index)
class EvaluatorTest(test_util.TensorFlowTestCase):
def testEvaluateSingleTensor(self):
dump = test.mock.MagicMock()
def fake_get_tensors(node_name, output_slot, debug_op, device_name=None):
del node_name, output_slot, debug_op, device_name # Unused.
return [np.array([[1.0, 2.0, 3.0]])]
with test.mock.patch.object(
dump, "get_tensors", side_effect=fake_get_tensors, autospec=True):
ev = evaluator.ExpressionEvaluator(dump)
self.assertEqual(3, ev.evaluate("np.size(`a:0`)"))
# Whitespace in backticks should be tolerated.
self.assertEqual(3, ev.evaluate("np.size(` a:0 `)"))
def testEvaluateTwoTensors(self):
dump = test.mock.MagicMock()
def fake_get_tensors(node_name, output_slot, debug_op, device_name=None):
del debug_op, device_name # Unused.
if node_name == "a" and output_slot == 0:
return [np.array([[1.0, -2.0], [0.0, 1.0]])]
elif node_name == "b" and output_slot == 0:
return [np.array([[-1.0], [1.0]])]
with test.mock.patch.object(
dump, "get_tensors", side_effect=fake_get_tensors, autospec=True):
ev = evaluator.ExpressionEvaluator(dump)
self.assertAllClose([[-3.0], [1.0]],
ev.evaluate("np.matmul(`a:0`, `b:0`)"))
self.assertAllClose(
[[-4.0], [2.0]], ev.evaluate("np.matmul(`a:0`, `b:0`) + `b:0`"))
def testEvaluateNoneExistentTensorGeneratesError(self):
dump = test.mock.MagicMock()
def fake_get_tensors(node_name, output_slot, debug_op, device_name=None):
del node_name, output_slot, debug_op, device_name # Unused.
raise debug_data.WatchKeyDoesNotExistInDebugDumpDirError()
with test.mock.patch.object(
dump, "get_tensors", side_effect=fake_get_tensors, autospec=True):
ev = evaluator.ExpressionEvaluator(dump)
with self.assertRaisesRegexp(
ValueError, "Eval failed due to the value of .* being unavailable"):
ev.evaluate("np.matmul(`a:0`, `b:0`)")
def testEvaluateWithMultipleDevicesContainingTheSameTensorName(self):
dump = test.mock.MagicMock()
def fake_get_tensors(node_name, output_slot, debug_op, device_name=None):
del output_slot, debug_op # Unused.
if node_name == "a" and device_name is None:
raise ValueError(
"There are multiple (2) devices with nodes named 'a' but "
"device_name is not specified")
elif (node_name == "a" and
device_name == "/job:worker/replica:0/task:0/cpu:0"):
return [np.array(10.0)]
elif (node_name == "a" and
device_name == "/job:worker/replica:0/task:1/cpu:0"):
return [np.array(20.0)]
with test.mock.patch.object(
dump, "get_tensors", side_effect=fake_get_tensors, autospec=True):
ev = evaluator.ExpressionEvaluator(dump)
with self.assertRaisesRegexp(ValueError, r"multiple \(2\) devices"):
ev.evaluate("`a:0` + `a:0`")
self.assertAllClose(
30.0,
ev.evaluate("`/job:worker/replica:0/task:0/cpu:0:a:0` + "
"`/job:worker/replica:0/task:1/cpu:0:a:0`"))
def testEvaluateWithNonDefaultDebugOp(self):
dump = test.mock.MagicMock()
def fake_get_tensors(node_name, output_slot, debug_op, device_name=None):
del device_name # Unused.
if node_name == "a" and output_slot == 0 and debug_op == "DebugIdentity":
return [np.array([[-1.0], [1.0]])]
elif node_name == "a" and output_slot == 0 and debug_op == "DebugFoo":
return [np.array([[-2.0, 2.0]])]
with test.mock.patch.object(
dump, "get_tensors", side_effect=fake_get_tensors, autospec=True):
ev = evaluator.ExpressionEvaluator(dump)
self.assertAllClose(
[[4.0]],
ev.evaluate("np.matmul(`a:0:DebugFoo`, `a:0:DebugIdentity`)"))
def testEvaluateWithMultipleExecIndexes(self):
dump = test.mock.MagicMock()
def fake_get_tensors(node_name, output_slot, debug_op, device_name=None):
del debug_op, device_name # Unused.
if node_name == "a" and output_slot == 0:
return [np.array([[-1.0], [1.0]]), np.array([[-2.0], [2.0]])]
with test.mock.patch.object(
dump, "get_tensors", side_effect=fake_get_tensors, autospec=True):
ev = evaluator.ExpressionEvaluator(dump)
self.assertAllClose(
[[4.0]], ev.evaluate("np.matmul(`a:0[1]`.T, `a:0[0]`)"))
def testEvaluateExpressionWithUnmatchedBacktick(self):
dump = test.mock.MagicMock()
ev = evaluator.ExpressionEvaluator(dump)
with self.assertRaises(SyntaxError):
ev.evaluate("np.matmul(`a:0`, `b:0`) + `b:0")
def testEvaluateExpressionWithInvalidDebugTensorName(self):
dump = test.mock.MagicMock()
ev = evaluator.ExpressionEvaluator(dump)
with self.assertRaisesRegexp(
ValueError, r".* tensor name .* expression .* malformed"):
ev.evaluate("np.matmul(`a`, `b`)")
with self.assertRaisesRegexp(
ValueError, r".* tensor name .* expression .* malformed"):
ev.evaluate("np.matmul(`a:0:DebugIdentity:0`, `b:1:DebugNanCount:2`)")
with self.assertRaises(ValueError):
ev.evaluate("np.matmul(`a:0[]`, `b:0[]`)")
if __name__ == "__main__":
test.main()