tfdbg: add auto-generated Python API doc to gen_docs_combined.py

Also make some touch-ups on the relevant doc strings.
Change: 143128831
This commit is contained in:
Shanqing Cai 2016-12-28 13:32:22 -08:00 committed by TensorFlower Gardener
parent b61c142e89
commit b04d5686b9
6 changed files with 301 additions and 190 deletions

View File

@ -12,7 +12,53 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Public Python API of TensorFlow Debugger (tfdbg).""" """Public Python API of TensorFlow Debugger (tfdbg).
## Functions for adding debug watches
These functions help you modify `RunOptions` to specify which `Tensor`s are to
be watched when the TensorFlow graph is executed at runtime.
@@add_debug_tensor_watch
@@watch_graph
@@watch_graph_with_blacklists
## Classes for debug-dump data and directories
These classes allow you to load and inspect tensor values dumped from
TensorFlow graphs during runtime.
@@DebugTensorDatum
@@DebugDumpDir
## Functions for loading debug-dump data
@@load_tensor_from_event_file
## Tensor-value predicates
Built-in tensor-filter predicates to support conditional breakpoint between
runs. See `DebugDumpDir.find()` for more details.
@@has_inf_or_nan
## Session wrapper class and `SessionRunHook` implementations
These classes allow you to
* wrap aroundTensorFlow `Session` objects to debug plain TensorFlow models
(see `LocalCLIDebugWrapperSession`), or
* generate `SessionRunHook` objects to debug `tf.contrib.learn` models (see
`LocalCLIDebugHook`).
@@LocalCLIDebugHook
@@LocalCLIDebugWrapperSession
"""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division

View File

@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Data structures and helpers for TensorFlow Debugger (tfdbg).""" """Classes and functions to handle debug-dump data of TensorFlow Debugger."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
@ -32,15 +33,15 @@ from tensorflow.python.platform import gfile
def load_tensor_from_event_file(event_file_path): def load_tensor_from_event_file(event_file_path):
"""Load a tensor from an event file. """Load a tensor from an event file.
Assumes that the event file contains a Event protobuf and the Event protobuf Assumes that the event file contains a `Event` protobuf and the `Event`
contains a tensor. protobuf contains a `Tensor` value.
Args: Args:
event_file_path: Path to the event file. event_file_path: (`str`) path to the event file.
Returns: Returns:
The tensor value loaded from the event file. For uninitialized tensors, The tensor value loaded from the event file, as a `numpy.ndarray`. For
return None. uninitialized tensors, returns None.
""" """
event = event_pb2.Event() event = event_pb2.Event()
@ -105,10 +106,10 @@ def get_output_slot(element_name):
assumed. assumed.
Args: Args:
element_name: (str) name of the graph element in question. element_name: (`str`) name of the graph element in question.
Returns: Returns:
(int) output slot number. (`int`) output slot number.
""" """
return int(element_name.split(":")[-1]) if ":" in element_name else 0 return int(element_name.split(":")[-1]) if ":" in element_name else 0
@ -220,17 +221,17 @@ def has_inf_or_nan(datum, tensor):
"""A predicate for whether a tensor consists of any bad numerical values. """A predicate for whether a tensor consists of any bad numerical values.
This predicate is common enough to merit definition in this module. This predicate is common enough to merit definition in this module.
Bad numerical values include nans and infs. Bad numerical values include `nan`s and `inf`s.
The signature of this function follows the requiremnet of DebugDumpDir's The signature of this function follows the requirement of the method
find() method. `DebugDumpDir.find()`.
Args: Args:
datum: (DebugTensorDatum) Datum metadata. datum: (`DebugTensorDatum`) Datum metadata.
tensor: (numpy.ndarray or None) Value of the tensor. None represents tensor: (`numpy.ndarray` or None) Value of the tensor. None represents
an uninitialized tensor. an uninitialized tensor.
Returns: Returns:
(bool) True if and only if tensor consists of any nan or inf values. (`bool`) True if and only if tensor consists of any nan or inf values.
""" """
_ = datum # Datum metadata is unused in this predicate. _ = datum # Datum metadata is unused in this predicate.
@ -247,32 +248,33 @@ def has_inf_or_nan(datum, tensor):
class DebugTensorDatum(object): class DebugTensorDatum(object):
"""A single tensor dumped by tfdbg. """A single tensor dumped by TensorFlow Debugger (tfdbg).
Contains "metadata" for the dumped tensor, including node name, output slot, Contains metadata about the dumped tensor, including `timestamp`,
debug op and timestamp. `node_name`, `output_slot`, `debug_op`, and path to the dump file
(`file_path`).
This type does not contain the space-expensive tensor (numpy array) itself. This type does not hold the generally space-expensive tensor value (numpy
It just points to the file path from which the tensor can be loaded if array). Instead, it points to the file from which the tensor value can be
needed. loaded (with the `get_tensor` method) if needed.
""" """
def __init__(self, dump_root, debug_dump_rel_path): def __init__(self, dump_root, debug_dump_rel_path):
"""DebugTensorDatum constructor. """`DebugTensorDatum` constructor.
Args: Args:
dump_root: Debug dump root directory. dump_root: (`str`) Debug dump root directory.
debug_dump_rel_path: Path to a debug dump file, relative to the debug debug_dump_rel_path: (`str`) Path to a debug dump file, relative to the
dump root directory. For example, suppose the debug dump root `dump_root`. For example, suppose the debug dump root
directory is "/tmp/tfdbg_1" and the dump file is at directory is `/tmp/tfdbg_1` and the dump file is at
"/tmp/tfdbg_1/ns_1/node_a_0_DebugIdentity_123456789", then `/tmp/tfdbg_1/ns_1/node_a_0_DebugIdentity_123456789`, then
the value of the debug_dump_rel_path should be the value of the debug_dump_rel_path should be
"ns_1/node_a_0_DebugIdenity_1234456789". `ns_1/node_a_0_DebugIdenity_1234456789`.
Raises: Raises:
ValueError: If the base file name of the dump file does not conform to ValueError: If the base file name of the dump file does not conform to
the dump file naming pattern: the dump file naming pattern:
<op_name>_<output_slot>_<debug_op_name>_<timestamp_microsec> `node_name`_`output_slot`_`debug_op`_`timestamp`
""" """
base = os.path.basename(debug_dump_rel_path) base = os.path.basename(debug_dump_rel_path)
@ -307,31 +309,62 @@ class DebugTensorDatum(object):
return self.__str__() return self.__str__()
def get_tensor(self): def get_tensor(self):
"""Get tensor from the dump (Event) file. """Get tensor from the dump (`Event`) file.
Returns: Returns:
The tensor loaded from the dump (Event) file. The tensor loaded from the dump (`Event`) file.
""" """
return load_tensor_from_event_file(self.file_path) return load_tensor_from_event_file(self.file_path)
@property @property
def timestamp(self): def timestamp(self):
"""Timestamp of when this tensor value was dumped.
Returns:
(`int`) The timestamp in microseconds.
"""
return self._timestamp return self._timestamp
@property @property
def debug_op(self): def debug_op(self):
"""Name of the debug op.
Returns:
(`str`) debug op name (e.g., `DebugIdentity`).
"""
return self._debug_op return self._debug_op
@property @property
def node_name(self): def node_name(self):
"""Name of the node from which the tensor value was dumped.
Returns:
(`str`) name of the node watched by the debug op.
"""
return self._node_name return self._node_name
@property @property
def output_slot(self): def output_slot(self):
"""Output slot index from which the tensor value was dumped.
Returns:
(`int`) output slot index watched by the debug op.
"""
return self._output_slot return self._output_slot
@property @property
def tensor_name(self): def tensor_name(self):
"""Name of the tensor watched by the debug op.
Returns:
(`str`) `Tensor` name, in the form of `node_name`:`output_slot`
"""
return _get_tensor_name(self.node_name, self.output_slot) return _get_tensor_name(self.node_name, self.output_slot)
@property @property
@ -339,32 +372,34 @@ class DebugTensorDatum(object):
"""Watch key identities a debug watch on a tensor. """Watch key identities a debug watch on a tensor.
Returns: Returns:
A watch key, in the form of <tensor_name>:<debug_op>. (`str`) A watch key, in the form of `tensor_name`:`debug_op`.
""" """
return _get_tensor_watch_key(self.node_name, self.output_slot, return _get_tensor_watch_key(self.node_name, self.output_slot,
self.debug_op) self.debug_op)
@property @property
def file_path(self): def file_path(self):
"""Path to the file which stores the value of the dumped tensor."""
return self._file_path return self._file_path
class DebugDumpDir(object): class DebugDumpDir(object):
"""Data set from a debug dump directory on filesystem. """Data set from a debug-dump directory on filesystem.
An instance of DebugDumpDir contains all DebugTensorDatum in a tfdbg dump An instance of `DebugDumpDir` contains all `DebugTensorDatum` instances
root directory. This is an immutable object, of which all constitute tensor in a tfdbg dump root directory.
dump files and partition_graphs are loaded during the __init__ call.
""" """
def __init__(self, dump_root, partition_graphs=None, validate=True): def __init__(self, dump_root, partition_graphs=None, validate=True):
"""DebugDumpDir constructor. """`DebugDumpDir` constructor.
Args: Args:
dump_root: Path to the dump root directory. dump_root: (`str`) path to the dump root directory.
partition_graphs: A repeated field of GraphDefs representing the partition_graphs: A repeated field of GraphDefs representing the
partition graphs executed by the TensorFlow runtime. partition graphs executed by the TensorFlow runtime.
validate: Whether the dump files are to be validated against the validate: (`bool`) whether the dump files are to be validated against the
partition graphs. partition graphs.
Raises: Raises:
@ -381,10 +416,10 @@ class DebugDumpDir(object):
self._python_graph = None self._python_graph = None
def _load_dumps(self, dump_root): def _load_dumps(self, dump_root):
"""Load DebugTensorDatum instances from the dump root. """Load `DebugTensorDatum` instances from the dump root.
Populates a list of DebugTensorDatum and sort the list by ascending Populates a list of `DebugTensorDatum` instance and sorts the list by
timestamp. ascending timestamp.
This sorting order reflects the order in which the TensorFlow executor This sorting order reflects the order in which the TensorFlow executor
processed the nodes of the graph. It is (one of many possible) topological processed the nodes of the graph. It is (one of many possible) topological
@ -404,7 +439,7 @@ class DebugDumpDir(object):
graphs may not be available, e.g., when the run errors out. graphs may not be available, e.g., when the run errors out.
Args: Args:
dump_root: (str) Dump root directory. dump_root: (`str`) Dump root directory.
""" """
self._dump_root = dump_root self._dump_root = dump_root
@ -439,11 +474,11 @@ class DebugDumpDir(object):
"""Obtain a DebugTensorDatum from the directory and file name. """Obtain a DebugTensorDatum from the directory and file name.
Args: Args:
dir_name: (str) Name of the directory in which the dump file resides. dir_name: (`str`) Name of the directory in which the dump file resides.
file_name: (str) Base name of the dump file. file_name: (`str`) Base name of the dump file.
Returns: Returns:
(DebugTensorDatum) The DebugTensorDatum loaded from the dump file. (`DebugTensorDatum`) The `DebugTensorDatum` loaded from the dump file.
""" """
# Calculate the relative path of the dump file with respect to the root. # Calculate the relative path of the dump file with respect to the root.
@ -455,7 +490,7 @@ class DebugDumpDir(object):
def _create_tensor_watch_maps(self): def _create_tensor_watch_maps(self):
"""Create maps from tensor watch keys to datum and to timestamps. """Create maps from tensor watch keys to datum and to timestamps.
Create a map from watch key (tensor name + debug op) to DebugTensorDatum Create a map from watch key (tensor name + debug op) to `DebugTensorDatum`
item. Also make a map from watch key to relative timestamp. item. Also make a map from watch key to relative timestamp.
"relative" means (absolute timestamp - t0). "relative" means (absolute timestamp - t0).
""" """
@ -478,7 +513,7 @@ class DebugDumpDir(object):
Unlike the partition graphs, which are protobuf `GraphDef` objects, `Graph` Unlike the partition graphs, which are protobuf `GraphDef` objects, `Graph`
is a Python object and carries additional information such as the traceback is a Python object and carries additional information such as the traceback
of nodes in the graph. of the construction of the nodes in the graph.
Args: Args:
python_graph: (ops.Graph) The Python Graph object. python_graph: (ops.Graph) The Python Graph object.
@ -499,8 +534,9 @@ class DebugDumpDir(object):
"""Absolute timestamp of the first dumped tensor. """Absolute timestamp of the first dumped tensor.
Returns: Returns:
Absolute timestamp of the first dumped tensor. (`int`) absolute timestamp of the first dumped tensor, in microseconds.
""" """
return self._t0 return self._t0
@property @property
@ -508,8 +544,9 @@ class DebugDumpDir(object):
"""Total number of dumped tensors in the dump root directory. """Total number of dumped tensors in the dump root directory.
Returns: Returns:
Total number of dumped tensors in the dump root directory. (`int`) total number of dumped tensors in the dump root directory.
""" """
return len(self._dump_tensor_data) return len(self._dump_tensor_data)
def _load_partition_graphs(self, partition_graphs, validate): def _load_partition_graphs(self, partition_graphs, validate):
@ -524,7 +561,7 @@ class DebugDumpDir(object):
partition_graphs: Partition graphs executed by the TensorFlow runtime, partition_graphs: Partition graphs executed by the TensorFlow runtime,
represented as repeated fields of GraphDef. represented as repeated fields of GraphDef.
If no partition_graph is available, use None. If no partition_graph is available, use None.
validate: (bool) Whether the dump files are to be validated against the validate: (`bool`) Whether the dump files are to be validated against the
partition graphs. partition graphs.
""" """
@ -619,7 +656,7 @@ class DebugDumpDir(object):
"""Prune nodes out of input and recipient maps. """Prune nodes out of input and recipient maps.
Args: Args:
nodes_to_prune: (list of str) Names of the nodes to be pruned. nodes_to_prune: (`list` of `str`) Names of the nodes to be pruned.
""" """
for node in nodes_to_prune: for node in nodes_to_prune:
@ -759,6 +796,7 @@ class DebugDumpDir(object):
Raises: Raises:
LookupError: If no partition graphs have been loaded. LookupError: If no partition graphs have been loaded.
""" """
if self._partition_graphs is None: if self._partition_graphs is None:
raise LookupError("No partition graphs have been loaded.") raise LookupError("No partition graphs have been loaded.")
@ -773,13 +811,14 @@ class DebugDumpDir(object):
Raises: Raises:
LookupError: If no partition graphs have been loaded. LookupError: If no partition graphs have been loaded.
""" """
if self._partition_graphs is None: if self._partition_graphs is None:
raise LookupError("No partition graphs have been loaded.") raise LookupError("No partition graphs have been loaded.")
return [node_name for node_name in self._node_inputs] return [node_name for node_name in self._node_inputs]
def node_attributes(self, node_name): def node_attributes(self, node_name):
"""Get attributes of a node. """Get the attributes of a node.
Args: Args:
node_name: Name of the node in question. node_name: Name of the node in question.
@ -791,6 +830,7 @@ class DebugDumpDir(object):
LookupError: If no partition graphs have been loaded. LookupError: If no partition graphs have been loaded.
ValueError: If no node named node_name exists. ValueError: If no node named node_name exists.
""" """
if self._partition_graphs is None: if self._partition_graphs is None:
raise LookupError("No partition graphs have been loaded.") raise LookupError("No partition graphs have been loaded.")
@ -804,11 +844,11 @@ class DebugDumpDir(object):
Args: Args:
node_name: Name of the node. node_name: Name of the node.
is_control: Whether control inputs, rather than non-control inputs, are is_control: (`bool`) Whether control inputs, rather than non-control
to be returned. inputs, are to be returned.
Returns: Returns:
All non-control inputs to the node, as a list of node names. (`list` of `str`) inputs to the node, as a list of node names.
Raises: Raises:
LookupError: If node inputs and control inputs have not been loaded LookupError: If node inputs and control inputs have not been loaded
@ -837,7 +877,8 @@ class DebugDumpDir(object):
include_control: Include control inputs (True by default). include_control: Include control inputs (True by default).
Returns: Returns:
All transitive inputs to the node, as a list of node names. (`list` of `str`) all transitive inputs to the node, as a list of node
names.
Raises: Raises:
LookupError: If node inputs and control inputs have not been loaded LookupError: If node inputs and control inputs have not been loaded
@ -900,12 +941,12 @@ class DebugDumpDir(object):
"""Get recipient of the given node's output according to partition graphs. """Get recipient of the given node's output according to partition graphs.
Args: Args:
node_name: Name of the node. node_name: (`str`) name of the node.
is_control: Whether control outputs, rather than non-control outputs, is_control: (`bool`) whether control outputs, rather than non-control
are to be returned. outputs, are to be returned.
Returns: Returns:
All non-control inputs to the node, as a list of node names. (`list` of `str`) all inputs to the node, as a list of node names.
Raises: Raises:
LookupError: If node inputs and control inputs have not been loaded LookupError: If node inputs and control inputs have not been loaded
@ -930,7 +971,7 @@ class DebugDumpDir(object):
"""Get the list of devices. """Get the list of devices.
Returns: Returns:
Number of devices. (`list` of `str`) names of the devices.
Raises: Raises:
LookupError: If node inputs and control inputs have not been loaded LookupError: If node inputs and control inputs have not been loaded
@ -946,7 +987,7 @@ class DebugDumpDir(object):
"""Test if a node exists in the partition graphs. """Test if a node exists in the partition graphs.
Args: Args:
node_name: Name of the node to be checked, as a str. node_name: (`str`) name of the node to be checked.
Returns: Returns:
A boolean indicating whether the node exists. A boolean indicating whether the node exists.
@ -965,16 +1006,17 @@ class DebugDumpDir(object):
"""Get the device of a node. """Get the device of a node.
Args: Args:
node_name: Name of the node. node_name: (`str`) name of the node.
Returns: Returns:
Name of the device on which the node is placed, as a str. (`str`) name of the device on which the node is placed.
Raises: Raises:
LookupError: If node inputs and control inputs have not been loaded LookupError: If node inputs and control inputs have not been loaded
from partition graphs yet. from partition graphs yet.
ValueError: If the node does not exist in partition graphs. ValueError: If the node does not exist in partition graphs.
""" """
if self._partition_graphs is None: if self._partition_graphs is None:
raise LookupError( raise LookupError(
"Node devices are not loaded from partition graphs yet.") "Node devices are not loaded from partition graphs yet.")
@ -989,16 +1031,17 @@ class DebugDumpDir(object):
"""Get the op type of given node. """Get the op type of given node.
Args: Args:
node_name: Name of the node. node_name: (`str`) name of the node.
Returns: Returns:
Type of the node's op, as a str. (`str`) op type of the node.
Raises: Raises:
LookupError: If node op types have not been loaded LookupError: If node op types have not been loaded
from partition graphs yet. from partition graphs yet.
ValueError: If the node does not exist in partition graphs. ValueError: If the node does not exist in partition graphs.
""" """
if self._partition_graphs is None: if self._partition_graphs is None:
raise LookupError( raise LookupError(
"Node op types are not loaded from partition graphs yet.") "Node op types are not loaded from partition graphs yet.")
@ -1013,14 +1056,14 @@ class DebugDumpDir(object):
"""Get all tensor watch keys of given node according to partition graphs. """Get all tensor watch keys of given node according to partition graphs.
Args: Args:
node_name: Name of the node. node_name: (`str`) name of the node.
Returns: Returns:
All debug tensor watch keys, as a list of strings. Returns an empty list (`list` of `str`) all debug tensor watch keys. Returns an empty list if
if the node name does not correspond to any debug watch keys. the node name does not correspond to any debug watch keys.
Raises: Raises:
LookupError: If debug watch information has not been loaded from `LookupError`: If debug watch information has not been loaded from
partition graphs yet. partition graphs yet.
""" """
@ -1037,13 +1080,13 @@ class DebugDumpDir(object):
return watch_keys return watch_keys
def watch_key_to_data(self, debug_watch_key): def watch_key_to_data(self, debug_watch_key):
"""Get all DebugTensorDatum instances corresponding to a debug watch key. """Get all `DebugTensorDatum` instances corresponding to a debug watch key.
Args: Args:
debug_watch_key: A debug watch key, as a str. debug_watch_key: (`str`) debug watch key.
Returns: Returns:
A list of DebugTensorDatuminstances that correspond to the debug watch A list of `DebugTensorDatum` instances that correspond to the debug watch
key. If the watch key does not exist, returns an empty list. key. If the watch key does not exist, returns an empty list.
Raises: Raises:
@ -1057,18 +1100,24 @@ class DebugDumpDir(object):
Args: Args:
predicate: A callable that takes two input arguments: predicate: A callable that takes two input arguments:
predicate(debug_tensor_datum, tensor),
where "debug_tensor_datum" is an instance of DebugTensorDatum, which ```python
carries "metadata", such as the name of the node, the tensor's slot def predicate(debug_tensor_datum, tensor):
index on the node, timestamp, debug op name, etc; and "tensor" is # returns a bool
the dumped tensor value as a numpy array. ```
first_n: Return only the first n dumped tensor data (in time order) for
which the predicate is True. To return all such data, let first_n be where `debug_tensor_datum` is an instance of `DebugTensorDatum`, which
<= 0. carries the metadata, such as the `Tensor`'s node name, output slot
timestamp, debug op name, etc.; and `tensor` is the dumped tensor value
as a `numpy.ndarray`.
first_n: (`int`) return only the first n `DebugTensotDatum` instances (in
time order) for which the predicate returns True. To return all the
`DebugTensotDatum` instances, let first_n be <= 0.
Returns: Returns:
A list of all DebugTensorDatum objects in this DebugDumpDir object for A list of all `DebugTensorDatum` objects in this `DebugDumpDir` object
which predicate returns True, sorted in ascending order of the timestamp. for which predicate returns True, sorted in ascending order of the
timestamp.
""" """
matched_data = [] matched_data = []
@ -1085,16 +1134,16 @@ class DebugDumpDir(object):
"""Get the file paths from a debug-dumped tensor. """Get the file paths from a debug-dumped tensor.
Args: Args:
node_name: Name of the node that the tensor is produced by. node_name: (`str`) name of the node that the tensor is produced by.
output_slot: Output slot index of tensor. output_slot: (`int`) output slot index of tensor.
debug_op: Name of the debug op. debug_op: (`str`) name of the debug op.
Returns: Returns:
List of file path(s) loaded. This is a list because each debugged tensor List of file path(s) loaded. This is a list because each debugged tensor
may be dumped multiple times. may be dumped multiple times.
Raises: Raises:
ValueError: If the tensor does not exist in the debub dump data. ValueError: If the tensor does not exist in the debug-dump data.
""" """
watch_key = _get_tensor_watch_key(node_name, output_slot, debug_op) watch_key = _get_tensor_watch_key(node_name, output_slot, debug_op)
@ -1108,18 +1157,18 @@ class DebugDumpDir(object):
"""Get the tensor value from for a debug-dumped tensor. """Get the tensor value from for a debug-dumped tensor.
The tensor may be dumped multiple times in the dump root directory, so a The tensor may be dumped multiple times in the dump root directory, so a
list of tensors (numpy arrays) is returned. list of tensors (`numpy.ndarray`) is returned.
Args: Args:
node_name: Name of the node that the tensor is produced by. node_name: (`str`) name of the node that the tensor is produced by.
output_slot: Output slot index of tensor. output_slot: (`int`) output slot index of tensor.
debug_op: Name of the debug op. debug_op: (`str`) name of the debug op.
Returns: Returns:
List of tensor(s) loaded from the tensor dump file(s). List of tensors (`numpy.ndarray`) loaded from the debug-dump file(s).
Raises: Raises:
ValueError: If the tensor does not exist in the debub dump data. ValueError: If the tensor does not exist in the debug-dump data.
""" """
watch_key = _get_tensor_watch_key(node_name, output_slot, debug_op) watch_key = _get_tensor_watch_key(node_name, output_slot, debug_op)
@ -1132,18 +1181,18 @@ class DebugDumpDir(object):
def get_rel_timestamps(self, node_name, output_slot, debug_op): def get_rel_timestamps(self, node_name, output_slot, debug_op):
"""Get the relative timestamp from for a debug-dumped tensor. """Get the relative timestamp from for a debug-dumped tensor.
Relative timestamp means (absolute timestamp - t0), t0 being the absolute Relative timestamp means (absolute timestamp - `t0`), where `t0` is the
timestamp of the first dumped tensor in the dump root. The tensor may be absolute timestamp of the first dumped tensor in the dump root. The tensor
dumped multiple times in the dump root directory, so a list of relative may be dumped multiple times in the dump root directory, so a list of
timestamp (numpy arrays) is returned. relative timestamps (`numpy.ndarray`) is returned.
Args: Args:
node_name: Name of the node that the tensor is produced by. node_name: (`str`) name of the node that the tensor is produced by.
output_slot: Output slot index of tensor. output_slot: (`int`) output slot index of tensor.
debug_op: Name of the debug op. debug_op: (`str`) name of the debug op.
Returns: Returns:
List of relative timestamps. (list of int) list of relative timestamps.
Raises: Raises:
ValueError: If the tensor does not exist in the debub dump data. ValueError: If the tensor does not exist in the debub dump data.
@ -1160,7 +1209,7 @@ class DebugDumpDir(object):
"""Try to retrieve the Python traceback of node's construction. """Try to retrieve the Python traceback of node's construction.
Args: Args:
element_name: (str) Name of a graph element (node or tensor). element_name: (`str`) Name of a graph element (node or tensor).
Returns: Returns:
(list) The traceback list object as returned by the `extract_trace` (list) The traceback list object as returned by the `extract_trace`

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""TensorFlow Debugger (tfdbg) Utilities.""" """TensorFlow Debugger (tfdbg) Utilities."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
@ -27,18 +28,20 @@ def add_debug_tensor_watch(run_options,
output_slot=0, output_slot=0,
debug_ops="DebugIdentity", debug_ops="DebugIdentity",
debug_urls=None): debug_urls=None):
"""Add debug tensor watch option to RunOptions. """Add watch on a `Tensor` to `RunOptions`.
N.B.: Under certain circumstances, the `Tensor` may not be actually watched
(e.g., if the node of the `Tensor` is constant-folded during runtime).
Args: Args:
run_options: An instance of tensorflow.core.protobuf.config_pb2.RunOptions run_options: An instance of `config_pb2.RunOptions` to be modified.
node_name: Name of the node to watch. node_name: (`str`) name of the node to watch.
output_slot: Output slot index of the tensor from the watched node. output_slot: (`int`) output slot index of the tensor from the watched node.
debug_ops: Name(s) of the debug op(s). Default: "DebugIdentity". debug_ops: (`str` or `list` of `str`) name(s) of the debug op(s). Can be a
Can be a list of strings or a single string. The latter case is `list` of `str` or a single `str`. The latter case is equivalent to a
equivalent to a list of string with only one element. `list` of `str` with only one element.
debug_urls: URLs to send debug signals to: a non-empty list of strings or debug_urls: (`str` or `list` of `str`) URL(s) to send debug values to,
a string, or None. The case of a string is equivalent to a list of e.g., `file:///tmp/tfdbg_dump_1`, `grpc://localhost:12345`.
string with only one element.
""" """
watch_opts = run_options.debug_options.debug_tensor_watch_opts watch_opts = run_options.debug_options.debug_tensor_watch_opts
@ -65,27 +68,31 @@ def watch_graph(run_options,
debug_urls=None, debug_urls=None,
node_name_regex_whitelist=None, node_name_regex_whitelist=None,
op_type_regex_whitelist=None): op_type_regex_whitelist=None):
"""Add debug tensor watch options to RunOptions based on a TensorFlow graph. """Add debug watches to `RunOptions` for a TensorFlow graph.
To watch all tensors on the graph, set both node_name_regex_whitelist To watch all `Tensor`s on the graph, let both `node_name_regex_whitelist`
and op_type_regex_whitelist to None. and `op_type_regex_whitelist` be the default (`None`).
N.B.: Under certain circumstances, not all specified `Tensor`s will be
actually watched (e.g., nodes that are constant-folded during runtime will
not be watched).
Args: Args:
run_options: An instance of tensorflow.core.protobuf.config_pb2.RunOptions run_options: An instance of `config_pb2.RunOptions` to be modified.
graph: An instance of tensorflow.python.framework.ops.Graph graph: An instance of `ops.Graph`.
debug_ops: Name of the debug op to use. Default: "DebugIdentity". debug_ops: (`str` or `list` of `str`) name(s) of the debug op(s) to use.
Can be a list of strings of a single string. The latter case is debug_urls: URLs to send debug values to. Can be a list of strings,
equivalent to a list of a single string. a single string, or None. The case of a single string is equivalent to
debug_urls: Debug urls. Can be a list of strings, a single string, or a list consisting of a single string, e.g., `file:///tmp/tfdbg_dump_1`,
None. The case of a single string is equivalen to a list consisting `grpc://localhost:12345`.
of a single string. node_name_regex_whitelist: Regular-expression whitelist for node_name,
node_name_regex_whitelist: Regular-expression whitelist for node_name. e.g., `"(weight_[0-9]+|bias_.*)"`
This should be a string, e.g., "(weight_[0-9]+|bias_.*)"
op_type_regex_whitelist: Regular-expression whitelist for the op type of op_type_regex_whitelist: Regular-expression whitelist for the op type of
nodes. If both node_name_regex_whitelist and op_type_regex_whitelist nodes, e.g., `"(Variable|Add)"`.
are none, the two filtering operations will occur in an "AND" If both `node_name_regex_whitelist` and `op_type_regex_whitelist`
relation. In other words, a node will be included if and only if it are set, the two filtering operations will occur in a logical `AND`
hits both whitelists. This should be a string, e.g., "(Variable|Add)". relation. In other words, a node will be included if and only if it
hits both whitelists.
""" """
if isinstance(debug_ops, str): if isinstance(debug_ops, str):
@ -130,29 +137,30 @@ def watch_graph_with_blacklists(run_options,
debug_urls=None, debug_urls=None,
node_name_regex_blacklist=None, node_name_regex_blacklist=None,
op_type_regex_blacklist=None): op_type_regex_blacklist=None):
"""Add debug tensor watch options, blacklisting nodes and op types. """Add debug tensor watches, blacklisting nodes and op types.
This is similar to watch_graph(), but the node names and op types can be This is similar to `watch_graph()`, but the node names and op types are
blacklisted, instead of whitelisted. blacklisted, instead of whitelisted.
N.B.: Under certain circumstances, not all specified `Tensor`s will be
actually watched (e.g., nodes that are constant-folded during runtime will
not be watched).
Args: Args:
run_options: An instance of tensorflow.core.protobuf.config_pb2.RunOptions run_options: An instance of `config_pb2.RunOptions` to be modified.
graph: An instance of tensorflow.python.framework.ops.Graph graph: An instance of `ops.Graph`.
debug_ops: Name of the debug op to use. Default: "DebugIdentity". debug_ops: (`str` or `list` of `str`) name(s) of the debug op(s) to use.
Can be a list of strings of a single string. The latter case is debug_urls: URL(s) to send ebug values to, e.g.,
equivalent to a list of a single string. `file:///tmp/tfdbg_dump_1`, `grpc://localhost:12345`.
debug_urls: Debug urls. Can be a list of strings, a single string, or
None. The case of a single string is equivalen to a list consisting
of a single string.
node_name_regex_blacklist: Regular-expression blacklist for node_name. node_name_regex_blacklist: Regular-expression blacklist for node_name.
This should be a string, e.g., "(weight_[0-9]+|bias_.*)" This should be a string, e.g., `"(weight_[0-9]+|bias_.*)"`.
op_type_regex_blacklist: Regular-expression blacklist for the op type of op_type_regex_blacklist: Regular-expression blacklist for the op type of
nodes. If both node_name_regex_blacklist and op_type_regex_blacklist nodes, e.g., `"(Variable|Add)"`.
are none, the two filtering operations will occur in an "OR" If both node_name_regex_blacklist and op_type_regex_blacklist
relation. In other words, a node will be excluded if it hits either of are set, the two filtering operations will occur in a logical `OR`
the two blacklists; a node will be included if and only if it hits relation. In other words, a node will be excluded if it hits either of
none of the blacklists. This should be a string, e.g., the two blacklists; a node will be included if and only if it hits
"(Variable|Add)". neither of the blacklists.
""" """
if isinstance(debug_ops, str): if isinstance(debug_ops, str):

View File

@ -176,7 +176,7 @@ class OnSessionInitResponse(object):
"""Constructor. """Constructor.
Args: Args:
action: (OnSessionInitAction) Debugger action to take on session init. action: (`OnSessionInitAction`) Debugger action to take on session init.
""" """
_check_type(action, str) _check_type(action, str)
self.action = action self.action = action
@ -191,7 +191,7 @@ class OnRunStartRequest(object):
def __init__(self, fetches, feed_dict, run_options, run_metadata, def __init__(self, fetches, feed_dict, run_options, run_metadata,
run_call_count): run_call_count):
"""Constructor of OnRunStartRequest. """Constructor of `OnRunStartRequest`.
Args: Args:
fetches: Fetch targets of the run() call. fetches: Fetch targets of the run() call.
@ -233,10 +233,10 @@ class OnRunStartResponse(object):
""" """
def __init__(self, action, debug_urls): def __init__(self, action, debug_urls):
"""Constructor of OnRunStartResponse. """Constructor of `OnRunStartResponse`.
Args: Args:
action: (OnRunStartAction) the action actually taken by the wrapped action: (`OnRunStartAction`) the action actually taken by the wrapped
session for the run() call. session for the run() call.
debug_urls: (list of str) debug_urls used in watching the tensors during debug_urls: (list of str) debug_urls used in watching the tensors during
the run() call. the run() call.
@ -260,10 +260,10 @@ class OnRunEndRequest(object):
run_metadata=None, run_metadata=None,
client_graph_def=None, client_graph_def=None,
tf_error=None): tf_error=None):
"""Constructor for OnRunEndRequest. """Constructor for `OnRunEndRequest`.
Args: Args:
performed_action: (OnRunStartAction) Actually-performed action by the performed_action: (`OnRunStartAction`) Actually-performed action by the
debug-wrapper session. debug-wrapper session.
run_metadata: run_metadata output from the run() call (if any). run_metadata: run_metadata output from the run() call (if any).
client_graph_def: (GraphDef) GraphDef from the client side, i.e., from client_graph_def: (GraphDef) GraphDef from the client side, i.e., from
@ -303,13 +303,13 @@ class BaseDebugWrapperSession(session.SessionInterface):
# is available. # is available.
def __init__(self, sess): def __init__(self, sess):
"""Constructor of BaseDebugWrapperSession. """Constructor of `BaseDebugWrapperSession`.
Args: Args:
sess: An (unwrapped) TensorFlow session instance. sess: An (unwrapped) TensorFlow session instance.
Raises: Raises:
ValueError: On invalid OnSessionInitAction value. ValueError: On invalid `OnSessionInitAction` value.
""" """
_check_type(sess, session.BaseSession) _check_type(sess, session.BaseSession)
@ -352,16 +352,16 @@ class BaseDebugWrapperSession(session.SessionInterface):
"""Wrapper around Session.run() that inserts tensor watch options. """Wrapper around Session.run() that inserts tensor watch options.
Args: Args:
fetches: Same as the fetches arg to regular Session.run() fetches: Same as the `fetches` arg to regular `Session.run()`.
feed_dict: Same as the feed_dict arg to regular Session.run() feed_dict: Same as the `feed_dict` arg to regular `Session.run()`.
options: Same as the options arg to regular Session.run() options: Same as the `options` arg to regular `Session.run()`.
run_metadata: Same as the run_metadata to regular Session.run() run_metadata: Same as the `run_metadata` arg to regular `Session.run()`.
Returns: Returns:
Simply forwards the output of the wrapped Session.run() call. Simply forwards the output of the wrapped `Session.run()` call.
Raises: Raises:
ValueError: On invalid OnRunStartAction value. ValueError: On invalid `OnRunStartAction` value.
""" """
self._run_call_count += 1 self._run_call_count += 1
@ -458,11 +458,11 @@ class BaseDebugWrapperSession(session.SessionInterface):
The invocation happens right before the constructor ends. The invocation happens right before the constructor ends.
Args: Args:
request: (OnSessionInitRequest) callback request carrying information request: (`OnSessionInitRequest`) callback request carrying information
such as the session being wrapped. such as the session being wrapped.
Returns: Returns:
An instance of OnSessionInitResponse. An instance of `OnSessionInitResponse`.
""" """
@abc.abstractmethod @abc.abstractmethod
@ -474,12 +474,13 @@ class BaseDebugWrapperSession(session.SessionInterface):
after an increment of run call counter. after an increment of run call counter.
Args: Args:
request: (OnRunStartRequest) callback request object carrying information request: (`OnRunStartRequest`) callback request object carrying
about the run call such as the fetches, feed dict, run options, run information about the run call such as the fetches, feed dict, run
metadata, and how many run() calls to this wrapper session has occurred. options, run metadata, and how many `run()` calls to this wrapper
session have occurred.
Returns: Returns:
An instance of OnRunStartResponse, carrying information to An instance of `OnRunStartResponse`, carrying information to
1) direct the wrapper session to perform a specified action (e.g., run 1) direct the wrapper session to perform a specified action (e.g., run
with or without debug tensor watching, invoking the stepper.) with or without debug tensor watching, invoking the stepper.)
2) debug URLs used to watch the tensors. 2) debug URLs used to watch the tensors.
@ -493,12 +494,12 @@ class BaseDebugWrapperSession(session.SessionInterface):
The invocation happens right before the wrapper exits its run() call. The invocation happens right before the wrapper exits its run() call.
Args: Args:
request: (OnRunEndRequest) callback request object carrying information request: (`OnRunEndRequest`) callback request object carrying information
such as the actual action performed by the session wrapper for the such as the actual action performed by the session wrapper for the
run() call. run() call.
Returns: Returns:
An instance of OnRunStartResponse. An instance of `OnRunStartResponse`.
""" """
def __enter__(self): def __enter__(self):

View File

@ -37,18 +37,23 @@ _DUMP_ROOT_PREFIX = "tfdbg_"
class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
"""Concrete subclass of BaseDebugWrapperSession implementing a local CLI.""" """Concrete subclass of BaseDebugWrapperSession implementing a local CLI.
This class has all the methods that a `session.Session` object has, in order
to support debugging with minimal code changes. Invoking its `run()` method
will launch the command-line interface (CLI) of tfdbg.
"""
def __init__(self, sess, dump_root=None, log_usage=True): def __init__(self, sess, dump_root=None, log_usage=True):
"""Constructor of LocalCLIDebugWrapperSession. """Constructor of LocalCLIDebugWrapperSession.
Args: Args:
sess: (BaseSession subtypes) The TensorFlow Session object being wrapped. sess: The TensorFlow `Session` object being wrapped.
dump_root: (str) Optional path to the dump root directory. Must be either dump_root: (`str`) optional path to the dump root directory. Must be a
a directory that does not exist or an empty directory. If the directory directory that does not exist or an empty directory. If the directory
does not exist, it will be created by the debugger core during debug does not exist, it will be created by the debugger core during debug
run() calls and removed afterwards. `run()` calls and removed afterwards.
log_usage: (bool) Whether the usage of this class is to be logged. log_usage: (`bool`) whether the usage of this class is to be logged.
Raises: Raises:
ValueError: If dump_root is an existing and non-empty directory or if ValueError: If dump_root is an existing and non-empty directory or if
@ -137,14 +142,10 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
def add_tensor_filter(self, filter_name, tensor_filter): def add_tensor_filter(self, filter_name, tensor_filter):
"""Add a tensor filter. """Add a tensor filter.
The signature of this command is identical to that of
debug_data.DebugDumpDir.add_tensor_filter(). This method is a thin wrapper
around that method.
Args: Args:
filter_name: (str) Name of the filter. filter_name: (`str`) name of the filter.
tensor_filter: (callable) The filter callable. See the doc string of tensor_filter: (`callable`) the filter callable. See the doc string of
debug_data.DebugDumpDir.add_tensor_filter() for more details. `DebugDumpDir.find()` for more details about its signature.
""" """
self._tensor_filters[filter_name] = tensor_filter self._tensor_filters[filter_name] = tensor_filter
@ -153,7 +154,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
"""Overrides on-session-init callback. """Overrides on-session-init callback.
Args: Args:
request: An instance of OnSessionInitRequest. request: An instance of `OnSessionInitRequest`.
Returns: Returns:
An instance of OnSessionInitResponse. An instance of OnSessionInitResponse.
@ -166,13 +167,13 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
"""Overrides on-run-start callback. """Overrides on-run-start callback.
Invoke the CLI to let user choose what action to take: Invoke the CLI to let user choose what action to take:
run / run --no_debug / step. `run` / `invoke_stepper`.
Args: Args:
request: An instance of OnSessionInitRequest. request: An instance of `OnSessionInitRequest`.
Returns: Returns:
An instance of OnSessionInitResponse. An instance of `OnSessionInitResponse`.
Raises: Raises:
RuntimeError: If user chooses to prematurely exit the debugger. RuntimeError: If user chooses to prematurely exit the debugger.
@ -483,10 +484,11 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
"""Overrides method in base class to implement interactive node stepper. """Overrides method in base class to implement interactive node stepper.
Args: Args:
node_stepper: (stepper.NodeStepper) The underlying NodeStepper API object. node_stepper: (`stepper.NodeStepper`) The underlying NodeStepper API
restore_variable_values_on_exit: (bool) Whether any variables whose values object.
have been altered during this node-stepper invocation should be restored restore_variable_values_on_exit: (`bool`) Whether any variables whose
to their old values when this invocation ends. values have been altered during this node-stepper invocation should be
restored to their old values when this invocation ends.
Returns: Returns:
The same return values as the `Session.run()` call on the same fetches as The same return values as the `Session.run()` call on the same fetches as

View File

@ -25,6 +25,7 @@ import sys
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib import ffmpeg from tensorflow.contrib import ffmpeg
from tensorflow.python import debug as tf_debug
from tensorflow.python.client import client_lib from tensorflow.python.client import client_lib
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import docs from tensorflow.python.framework import docs
@ -79,6 +80,7 @@ def module_names():
"tf.contrib.solvers", "tf.contrib.solvers",
"tf.contrib.training", "tf.contrib.training",
"tf.contrib.util", "tf.contrib.util",
"tf_debug",
] ]
@ -89,6 +91,8 @@ def find_module(base_module, name):
# to size concerns. # to size concerns.
elif name == "tf.contrib.ffmpeg": elif name == "tf.contrib.ffmpeg":
return ffmpeg return ffmpeg
elif name == "tf_debug":
return tf_debug
elif name.startswith("tf."): elif name.startswith("tf."):
subname = name[3:] subname = name[3:]
subnames = subname.split(".") subnames = subname.split(".")
@ -240,6 +244,7 @@ def all_libraries(module_to_name, members, documented):
library("contrib.util", "Utilities (contrib)", tf.contrib.util), library("contrib.util", "Utilities (contrib)", tf.contrib.util),
library("contrib.copy_graph", "Copying Graph Elements (contrib)", library("contrib.copy_graph", "Copying Graph Elements (contrib)",
tf.contrib.copy_graph), tf.contrib.copy_graph),
library("tf_debug", "TensorFlow Debugger", tf_debug),
]) ])
_hidden_symbols = ["Event", "LogMessage", "Summary", "SessionLog", "xrange", _hidden_symbols = ["Event", "LogMessage", "Summary", "SessionLog", "xrange",