153 lines
5.6 KiB
Python
153 lines
5.6 KiB
Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Library for arbitrary expression evaluation based on a debugger data dump."""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import re
|
|
|
|
import numpy as np # pylint: disable=unused-import
|
|
|
|
from tensorflow.python.debug.lib import debug_data
|
|
|
|
_DUMP_TENSOR_PATTERN = re.compile(r"`.*?`")
|
|
_DEVICE_NAME_PREFIX_PATTERN = re.compile(
|
|
r"/job:(\w)+/replica:(\d)+/task:(\d)+/(\w)+:(\d)+:")
|
|
_EXEC_INDEX_SUFFIX_PATTERN = re.compile(r"\[(\d)*\]$")
|
|
|
|
_DEFAULT_DEBUG_OP = "DebugIdentity"
|
|
|
|
|
|
def _parse_debug_tensor_name(debug_tensor_name):
|
|
# pylint: disable=line-too-long
|
|
"""Parse a debug tensor name in a to-be-evaluated expression.
|
|
|
|
Args:
|
|
debug_tensor_name: name of the debug tensor, with or without
|
|
device name as a prefix, with or without debug op, with or
|
|
without '[<exec_index>]' as a suffix.
|
|
E.g., without device name prefix, without debug op suffix:
|
|
"hidden_0/MatMul:0"
|
|
E.g., with device name prefix:
|
|
"/job:worker/replica:0/task:1/gpu:0:hidden_0/MatMul:0"
|
|
E.g., with debug op suffix:
|
|
"hidden_0/MatMul:0:DebugNumericSummary"
|
|
E.g., with device name prefix and debug op suffix:
|
|
"/job:worker/replica:0/task:1/gpu:0:hidden_0/MatMul:0:DebugNumericSummary"
|
|
E.g., with device name prefix, debug op and an exec index:
|
|
"/job:worker/replica:0/task:1/gpu:0:hidden_0/MatMul:0:DebugNumericSummary[1]"
|
|
|
|
Returns:
|
|
device_name: If device name prefix exists, the device name; otherwise,
|
|
`None`.
|
|
node_name: Name of the node.
|
|
output_slot: Output slot index as an `int`.
|
|
debug_op: If the debug op suffix exists, the debug op name; otherwise,
|
|
`None`.
|
|
exec_index: Execution index (applicable to cases in which a debug tensor
|
|
is computed multiple times in a `tf.Session.run` call, e.g., due to
|
|
`tf.while_loop`). If the exec_index suffix does not exist, this value
|
|
defaults to `0`.
|
|
|
|
Raises:
|
|
ValueError: If the input `debug_tensor_name` is malformed.
|
|
"""
|
|
# pylint: enable=line-too-long
|
|
device_prefix_match = re.match(_DEVICE_NAME_PREFIX_PATTERN, debug_tensor_name)
|
|
if device_prefix_match:
|
|
device_name = debug_tensor_name[
|
|
device_prefix_match.start() : device_prefix_match.end() - 1]
|
|
debug_tensor_name = debug_tensor_name[device_prefix_match.end():]
|
|
else:
|
|
device_name = None
|
|
|
|
split_items = debug_tensor_name.split(":")
|
|
if len(split_items) not in (2, 3):
|
|
raise ValueError(
|
|
"The debug tensor name in the to-be-evaluated expression is malformed: "
|
|
"'%s'" % debug_tensor_name)
|
|
# TODO(cais): Provide examples of good debug tensor names in the error
|
|
# message.
|
|
|
|
exec_index_match = re.search(_EXEC_INDEX_SUFFIX_PATTERN, split_items[-1])
|
|
if exec_index_match:
|
|
exec_index = int(split_items[-1][
|
|
exec_index_match.start() + 1 : exec_index_match.end() - 1])
|
|
split_items[-1] = split_items[-1][:exec_index_match.start()]
|
|
else:
|
|
exec_index = 0
|
|
|
|
if len(split_items) == 2:
|
|
node_name = split_items[0]
|
|
output_slot = int(split_items[1])
|
|
debug_op = _DEFAULT_DEBUG_OP
|
|
else:
|
|
split_items = debug_tensor_name.split(":")
|
|
node_name = split_items[0]
|
|
output_slot = int(split_items[1])
|
|
debug_op = split_items[2]
|
|
|
|
return device_name, node_name, output_slot, debug_op, exec_index
|
|
|
|
|
|
class ExpressionEvaluator(object):
|
|
"""Evaluates Python expressions using debug tensor values from a dump."""
|
|
|
|
def __init__(self, dump):
|
|
"""Constructor of ExpressionEvaluator.
|
|
|
|
Args:
|
|
dump: an instance of `DebugDumpDir`.
|
|
"""
|
|
self._dump = dump
|
|
self._cached_tensor_values = {}
|
|
|
|
def evaluate(self, expression):
|
|
"""Parse an expression.
|
|
|
|
Args:
|
|
expression: the expression to be parsed.
|
|
|
|
Returns:
|
|
The result of the evaluation.
|
|
|
|
Raises:
|
|
ValueError: If the value of one or more of the debug tensors in the
|
|
expression are not available.
|
|
"""
|
|
dump_tensors_iter = re.finditer(_DUMP_TENSOR_PATTERN, expression)
|
|
rewritten_expression = expression
|
|
for match in reversed(list(dump_tensors_iter)):
|
|
tensor_name = match.group(0)[1:-1].strip()
|
|
device_name, node_name, output_slot, debug_op, exec_index = (
|
|
_parse_debug_tensor_name(tensor_name))
|
|
if tensor_name not in self._cached_tensor_values:
|
|
try:
|
|
value = self._dump.get_tensors(
|
|
node_name, output_slot, debug_op,
|
|
device_name=device_name)[exec_index]
|
|
except debug_data.WatchKeyDoesNotExistInDebugDumpDirError:
|
|
raise ValueError(
|
|
"Eval failed due to the value of %s:%d:DebugIdentity being "
|
|
"unavailable" % (node_name, output_slot))
|
|
self._cached_tensor_values[tensor_name] = value
|
|
rewritten_expression = (
|
|
rewritten_expression[:match.start(0)] +
|
|
"self._cached_tensor_values['" + tensor_name + "']" +
|
|
rewritten_expression[match.end(0):])
|
|
|
|
return eval(rewritten_expression) # pylint: disable=eval-used
|