[tfdbg] Initial implementation of enable_dumping()

- Only tensor_debug_mode=NO_TENSOR is supported. This tensor_debug_mode simply
  traces *what* tensors are executed within `tf.function`s (graphs), without
  regard to the values of the tensors.
- Other tensor_debug_modes, such as CURT_HEALTH, and SHAPE, will be implemented
  in follow-up CLs, for ease of reviewing.

PiperOrigin-RevId: 275249810
Change-Id: I96802a2b05ef7f6c327e8a0704992f5eae8c5c3d
This commit is contained in:
Shanqing Cai 2019-10-17 07:00:28 -07:00 committed by TensorFlower Gardener
parent 5bb4c25e6a
commit 2542bbea51
4 changed files with 1219 additions and 0 deletions

View File

@ -26,10 +26,12 @@ py_library(
deps = [
":check_numerics_callback",
":debug_data",
":debug_events_reader",
":debug_events_writer",
":debug_gradients",
":debug_graphs",
":debug_utils",
":dumping_callback",
":grpc_debug_server",
":grpc_debug_test_server",
":hooks",
@ -76,12 +78,40 @@ py_library(
],
)
py_library(
name = "dumping_callback",
srcs = ["lib/dumping_callback.py"],
srcs_version = "PY2AND3",
tags = [
"no_windows", # TODO(b/142475891): Enable this test on Windows.
],
deps = [
":debug_events_writer",
":op_callbacks_common",
":source_utils",
"//tensorflow/python:array_ops",
"//tensorflow/python:debug_ops_gen",
"//tensorflow/python:op_callbacks",
"//third_party/py/numpy",
],
)
py_library(
name = "common",
srcs = ["lib/common.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "debug_events_reader",
srcs = ["lib/debug_events_reader.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/core:protos_all_py",
"//tensorflow/python:framework",
],
)
py_library(
name = "debug_events_writer",
srcs = ["lib/debug_events_writer.py"],
@ -662,11 +692,31 @@ cuda_py_test(
],
)
cuda_py_test(
name = "dumping_callback_test",
size = "medium",
srcs = ["lib/dumping_callback_test.py"],
additional_deps = [
":debug_events_reader",
":debug_events_writer",
":dumping_callback",
"//third_party/py/numpy",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:variables",
"//tensorflow/python/keras",
],
shard_count = 4,
xla_enable_strict_auto_jit = False, # Node names are different with autojit
)
cuda_py_test(
name = "debug_v2_ops_test",
size = "medium",
srcs = ["lib/debug_v2_ops_test.py"],
additional_deps = [
":debug_events_reader",
":debug_events_writer",
"//third_party/py/numpy",
"//tensorflow/python:debug_ops_gen",

View File

@ -0,0 +1,84 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Reader class for tfdbg v2 debug events."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import glob
import os
from tensorflow.core.protobuf import debug_event_pb2
from tensorflow.python.lib.io import tf_record
def _check_debug_event_file_exists(file_path):
if not os.path.isfile(file_path):
raise ValueError("DebugEvent data file does not exist: %s" % file_path)
class DebugEventsDir(object):
"""Reader class for a tfdbg v2 DebugEvents directory."""
def __init__(self, dump_root):
if not os.path.isdir(dump_root):
raise ValueError("Specified dump_root is not a directory: %s" % dump_root)
metadata_paths = glob.glob(os.path.join(dump_root, "*.metadata"))
if not metadata_paths:
raise ValueError("Cannot find any metadata file in directory: %s" %
dump_root)
elif len(metadata_paths) > 1:
raise ValueError(
"Unexpected: Found multiple (%d) metadata in directory: %s" %
(len(metadata_paths), dump_root))
self._metadata_path = metadata_paths[0]
prefix = metadata_paths[0][:-len(".metadata")]
self._source_files_path = "%s.source_files" % prefix
self._stack_frames_path = "%s.stack_frames" % prefix
self._graphs_path = "%s.graphs" % prefix
self._execution_path = "%s.execution" % prefix
self._graph_execution_traces_path = ("%s.graph_execution_traces" %
prefix)
def metadata_iterator(self):
for r in tf_record.tf_record_iterator(self._metadata_path):
yield debug_event_pb2.DebugEvent.FromString(r)
def source_files_iterator(self):
_check_debug_event_file_exists(self._source_files_path)
for r in tf_record.tf_record_iterator(self._source_files_path):
yield debug_event_pb2.DebugEvent.FromString(r)
def stack_frames_iterator(self):
_check_debug_event_file_exists(self._stack_frames_path)
for r in tf_record.tf_record_iterator(self._stack_frames_path):
yield debug_event_pb2.DebugEvent.FromString(r)
def graphs_iterator(self):
_check_debug_event_file_exists(self._graphs_path)
for r in tf_record.tf_record_iterator(self._graphs_path):
yield debug_event_pb2.DebugEvent.FromString(r)
def execution_iterator(self):
_check_debug_event_file_exists(self._execution_path)
for r in tf_record.tf_record_iterator(self._execution_path):
yield debug_event_pb2.DebugEvent.FromString(r)
def graph_execution_traces_iterator(self):
_check_debug_event_file_exists(self._graph_execution_traces_path)
for r in tf_record.tf_record_iterator(self._graph_execution_traces_path):
yield debug_event_pb2.DebugEvent.FromString(r)

View File

@ -0,0 +1,363 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dumping op callbacks: Enables dump-based features in tfdbg v2."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import socket
import threading
import uuid
from tensorflow.core.protobuf import debug_event_pb2
from tensorflow.core.protobuf import graph_debug_info_pb2
from tensorflow.python.debug.lib import debug_events_writer
from tensorflow.python.debug.lib import op_callbacks_common
from tensorflow.python.debug.lib import source_utils
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import op_callbacks
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_debug_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
from tensorflow.python.util import tf_stack
TracingConfig = collections.namedtuple(
"TracingConfig", "dump_root tensor_debug_mode circular_buffer_size")
_state = threading.local()
def _get_writer():
"""Get the debug events writer for the currently configured dump root."""
# TODO(cais): Explore caching the object for possible performance gain.
# TODO(cais): Rename cyclic_buffer_size to circular_buffer_size in C++ and
# Python-bindng code.
return debug_events_writer.DebugEventsWriter(
_state.config.dump_root,
cyclic_buffer_size=_state.config.circular_buffer_size)
def _get_id():
"""Get a short unique ID."""
return str(uuid.uuid4())
def _get_context_id(context):
"""Get a unique ID for an op-construction context (e.g., a graph).
If the graph has been encountered before, reuse the same unique ID.
Args:
context: A context to get the unique ID for. Must be hashable. E.g., a Graph
object.
Returns:
A unique ID for the context.
"""
if context not in _state.context_to_id:
_state.context_to_id[context] = _get_id()
return _state.context_to_id[context]
def _write_source_file_content(file_path):
"""Send the content of a source file via debug-events writer.
Args:
file_path: Path to the source file.
Returns:
An int index for the file.
"""
if file_path not in _state.source_file_paths:
lines = None
if source_utils.is_extension_uncompiled_python_source(file_path):
try:
lines, _ = source_utils.load_source(file_path)
except IOError:
# Accept the fact that some source files are not readable. Here we use
# best effort to send the source-file contents.
pass
writer = _get_writer()
writer.WriteSourceFile(debug_event_pb2.SourceFile(
file_path=file_path, host_name=_state.hostname, lines=lines))
_state.source_file_paths.append(file_path)
return _state.source_file_paths.index(file_path)
def _process_stack_frames():
"""Process stack frames.
Send the content of source-files, on a best-effort basis.
Returns:
A list of stack frame IDs.
"""
stack_frames = tf_stack.extract_stack()
stack_frame_ids = []
writer = None
for file_path, lineno, func, _ in stack_frames:
if (file_path, lineno, func) not in _state.stack_frame_to_id:
stack_frame_id = _get_id()
_state.stack_frame_to_id[(file_path, lineno, func)] = stack_frame_id
file_index = _write_source_file_content(file_path)
file_line_col = graph_debug_info_pb2.GraphDebugInfo.FileLineCol(
file_index=file_index, line=lineno, func=func)
stack_frame_with_id = debug_event_pb2.StackFrameWithId(
id=stack_frame_id, file_line_col=file_line_col)
writer = _get_writer()
writer.WriteStackFrameWithId(stack_frame_with_id)
stack_frame_ids.append(_state.stack_frame_to_id[(file_path, lineno, func)])
code_location = debug_event_pb2.CodeLocation(
host_name=_state.hostname, stack_frame_ids=stack_frame_ids)
return code_location
def _instrument_symbolic_tensors(tensors, op_name, tfdbg_context_id):
"""Add debugging instrumentation for symbolic (i.e., non-eager) tensors.
The detailed fashion in which the tensors are instrumented is determined
by the tensor_debug_mode configured for the currently enabled dumping
callback.
Args:
tensors: A tuple of Tensors to instrument. It is assumed that their ordering
corresponds to the ordering of output tensors of an original op. Output
slot indices (0-based) will be generated based on the ordering.
op_name: Name of the op that emits the Tensors.
tfdbg_context_id: A unique ID for the context that the op belongs to (e.g.,
a graph).
Returns:
Non-eager Tensors that override the `tensors` as the output of the op
that originally generated `tensors`. In some cases (e.g., non-V1 graph
mode), this may be `None`, as the instrumentation can simply rely on
automatic control dependencies (see `auto_control_deps.py`) instead of
tensor overriding.
"""
if (_state.config.tensor_debug_mode ==
debug_event_pb2.TensorDebugMode.NO_TENSOR):
is_v1_graph_mode = not ops.executing_eagerly_outside_functions()
instrumented_tensors = [] if is_v1_graph_mode else None
for slot, tensor in enumerate(tensors):
with ops.colocate_with(None, ignore_existing=True):
# Except in V1 graph mode + control flow, debug_identity_v2 trigger auto
# control dependency because it's a stateful op.
debug_tensor = gen_debug_ops.debug_identity_v2(
# Use an empty (shape=[0]) float32 tensor for the NO_TENSOR mode.
constant_op.constant([], dtype=dtypes.float32),
tfdbg_context_id=tfdbg_context_id,
op_name=op_name,
output_slot=slot,
tensor_debug_mode=_state.config.tensor_debug_mode,
debug_urls=["file://%s" % _state.config.dump_root])
if is_v1_graph_mode:
# TODO(cais): Evaluate performance optimization options. For the
# `NO_TENSOR` debug mode, an alternative is to add `debug_tensor` as a
# control dependency of `tensor.op` without an additional identity op.
identity = array_ops.identity(tensor)
identity.op._add_control_input( # pylint: disable=protected-access
debug_tensor.op)
instrumented_tensors.append(identity)
return instrumented_tensors
else:
raise NotImplementedError(
"Symbolic tensor instrumentation is not implemented for debug mode %s" %
_state.config.tensor_debug_mode)
def _dump_eager_tensors(tensors, op_type, input_tensor_ids):
"""Dump the value of eager tensors.
The destination of the dumping is determined by the dump_root of the currently
enabled dumping callback. The tensors may be transformed prior to dumping
(e.g., reduced as summary statistics such as minimum, maximum and arithmetic
mean). The details of this transformation (if any) depends on the
tensor_debug_mode of the currently enabled dumping callback.
Args:
tensors: The EagerTensors whose values are to be dumped, with or without
value transform.
op_type: Type of the op that generates the tensors, as a string.
input_tensor_ids: IDs of the input EagerTensors to the op.
Returns:
A tfdbg Execution protocol buffer.
"""
if (_state.config.tensor_debug_mode ==
debug_event_pb2.TensorDebugMode.NO_TENSOR):
return debug_event_pb2.Execution(
op_type=op_type,
num_outputs=len(tensors),
input_tensor_ids=input_tensor_ids,
output_tensor_ids=[
t._id for t in tensors], # pylint:disable=protected-access
tensor_debug_mode=_state.config.tensor_debug_mode,
code_location=_process_stack_frames())
else:
raise NotImplementedError(
"Tensor instrumentation is not implemented for debug mode %s yet " %
_state.config.tensor_debug_mode)
def _dumping_callback(op_type,
inputs,
attrs,
outputs,
op_name=None,
graph=None):
"""Op callback for tracing a TF program's execution."""
del attrs # Unused
writer = _get_writer()
if graph:
context_id = _get_context_id(graph)
assert op_name is not None
graph_op_creation = debug_event_pb2.GraphOpCreation(
op_type=op_type,
op_name=op_name,
graph_name=graph.name if hasattr(graph, "name") else None,
graph_id=context_id,
input_names=[input_tensor.name for input_tensor in inputs],
num_outputs=len(outputs),
code_location=_process_stack_frames())
writer.WriteGraphOpCreation(graph_op_creation)
if outputs and compat.as_bytes(
op_type) not in op_callbacks_common.OP_CALLBACK_SKIP_OPS:
return _instrument_symbolic_tensors(outputs, op_name, context_id)
else:
input_ids = [t._id for t in inputs] # pylint:disable=protected-access
writer.WriteExecution(_dump_eager_tensors(outputs, op_type, input_ids))
DEFAULT_TENSOR_DEBUG_MODE = "NO_TENSOR"
def enable_dumping(dump_root,
tensor_debug_mode=DEFAULT_TENSOR_DEBUG_MODE,
circular_buffer_size=1000):
"""Enable dumping debugging information from a TensorFlow program.
The debugging information is dumped to a directory on the file system
specified as `dump_root`.
The dumped debugging information can be ingested by debugger UIs.
The files in the dump directory contain the following information:
- TensorFlow Function construction (e.g., compilation of Python functions
decorated with @tf.function), the op types, names (if available), context,
the input and output tensors, and the associated stack traces.
- Execution of TensorFlow operations (ops) and Functions and their stack
traces, op types, names (if available) and contexts. In addition,
depending on the value of the `tensor_debug_mode` argument (see Args
section below), the value(s) of the output tensors or more concise
summaries of the tensor values will be dumped.
- A snapshot of Python source files involved in the execution of the
TensorFlow program.
Once enabled, the dumping can be disabled with the corresponding
`disable_dumping()` method under the same Python namespace.
Calling this method more than once with the same `dump_root` is idempotent.
Calling this method with a different `dump_root` abolishes the
previously-enabled `dump_root`.
Args:
dump_root: The directory path where the dumping information will be written.
tensor_debug_mode: Debug mode for tensor values, as a string.
The currently supported options are:
- "NO_TENSOR": (Default) Only traces the execution of ops' output
tensors, while not dumping the value of the ops' output tensors
or any form of concise summary of them.
circular_buffer_size: Size of the circular buffers for execution events.
These circular buffers are designed to reduce the overhead of debugging
dumping. They hold the most recent debug events concerning eager execution
of ops and `tf.function`s and traces of tensor values computed inside
`tf.function`s. They are written to the file system only when the proper
flushing method is called (see description of return values below).
Expected to be an integer. If <= 0, the circular-buffer behavior will be
disabled, i.e., the execution debug events will be written to the file
writers in the same way as non-execution events such as op creations and
source-file snapshots.
Returns:
A DebugEventsWriter instance used by the dumping callback. The caller
may use its flushing methods, including `FlushNonExecutionFiles()` and
`FlushExecutionFiles()`.
"""
# TODO(cais): Revise the "UIs (currently under construction)" part of the doc
# string above.
# TODO(cais): Add Python code example to the doc string above.
# TODO(cais): Once UIs are ready, expose this method and the associated
# `disable_` method under the `tf.debugging.*` namespace.
if tensor_debug_mode not in debug_event_pb2.TensorDebugMode.keys():
raise ValueError(
"Invalid value in tensor_debug_mode ('%s'). Valid options are: %s" %
(tensor_debug_mode, debug_event_pb2.TensorDebugMode.keys()))
if (hasattr(_state, "config") and
_state.config.circular_buffer_size != circular_buffer_size):
logging.warning(
"There is already a dumping callback configured with a different "
"circular-buffer size (%d). Therefore the newly request "
"circular-buffer size will not be honored.",
_state.config.circular_buffer_size, circular_buffer_size)
if not hasattr(_state, "config") or _state.config.dump_root != dump_root:
_state.config = TracingConfig(
dump_root=dump_root,
tensor_debug_mode=debug_event_pb2.TensorDebugMode.Value(
tensor_debug_mode),
circular_buffer_size=int(circular_buffer_size))
if (_state.config.tensor_debug_mode !=
debug_event_pb2.TensorDebugMode.NO_TENSOR):
raise NotImplementedError(
"tfdbg dumping: support for tensor debug mode %s is not "
"implemented yet" % _state.config.tensor_debug_mode)
_state.hostname = socket.gethostname()
# A list of source-file paths.
_state.source_file_paths = []
# A map from stack frame (FileLineCol) to unique ID.
_state.stack_frame_to_id = dict()
# Mapping op context to unique ID.
_state.context_to_id = dict()
op_callbacks.add_op_callback(_dumping_callback)
logging.info(
"Enabled dumping callback in thread %s "
"(dump root: %s, tensor debug mode: %s)",
threading.current_thread().name, _state.config.dump_root,
tensor_debug_mode)
return _get_writer()
def disable_dumping():
"""Disable the currently-enabled debugging dumping.
If the `enable_dumping()` method under the same Python namespace has been
invoked before, calling this method disables it. If no call to
`enable_dumping()` has been made, calling this method is a no-op.
Calling this method more than once is idempotent.
"""
try:
op_callbacks.remove_op_callback(_dumping_callback)
logging.info("Disabled dumping callback in thread %s (dump root: %s)",
threading.current_thread().name, _state.config.dump_root)
except KeyError:
# Tolerate disabling the dumping callback without enable_dumping() being
# called first.
pass

View File

@ -0,0 +1,722 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Unit tests for tfdbg v2 dumping callback."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import shutil
import socket
import tempfile
import threading
import numpy as np
from tensorflow.core.protobuf import debug_event_pb2
from tensorflow.python.debug.lib import debug_events_reader
from tensorflow.python.debug.lib import dumping_callback
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util
from tensorflow.python.framework import versions
from tensorflow.python.keras import models
from tensorflow.python.keras.applications import mobilenet_v2
from tensorflow.python.keras.layers import core
from tensorflow.python.keras.layers import recurrent_v2
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
def _create_simple_recurrent_keras_model():
"""Create a simple tf.keras model containing a recurrent layer for testing."""
model = models.Sequential()
model.add(recurrent_v2.LSTM(
10,
input_shape=[8, 4],
kernel_initializer="zeros",
recurrent_initializer="zeros"))
model.add(core.Dense(1, kernel_initializer="zeros"))
model.compile(loss="mse", optimizer="sgd")
return model
class TracingCallbackTest(test_util.TensorFlowTestCase):
def setUp(self):
super(TracingCallbackTest, self).setUp()
self.dump_root = tempfile.mkdtemp()
def tearDown(self):
if os.path.isdir(self.dump_root):
shutil.rmtree(self.dump_root, ignore_errors=True)
dumping_callback.disable_dumping()
super(TracingCallbackTest, self).tearDown()
def _readAndCheckMetadataFile(self):
"""Read and check the .metadata debug-events file."""
reader = debug_events_reader.DebugEventsDir(self.dump_root)
metadata_iter = reader.metadata_iterator()
metadata = next(metadata_iter).debug_metadata
self.assertEqual(metadata.tensorflow_version, versions.__version__)
self.assertTrue(metadata.file_version.startswith("debug.Event"))
def _readAndCheckSourceFilesAndStackFrames(self):
"""Read and verify the .source_files & .stack_frames debug-event files.
Returns:
A dict mapping stack frame IDs to stack frames (FileLineCol).
"""
reader = debug_events_reader.DebugEventsDir(self.dump_root)
# Check the content of the .source_files file.
source_files_iter = reader.source_files_iterator()
source_file_paths = []
prev_wall_time = 1
for debug_event in source_files_iter:
self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
prev_wall_time = debug_event.wall_time
source_file = debug_event.source_file
self.assertEqual(source_file.host_name, socket.gethostname())
self.assertTrue(source_file.file_path)
if source_file.lines:
self.assertTrue(os.path.isfile(source_file.file_path))
source_file_paths.append(source_file.file_path)
# Assert the file paths are unique.
self.assertEqual(len(source_file_paths), len(set(source_file_paths)))
# Check the content of the .stack_frames file.
stack_frame_by_id = dict() # A map from ID to stack frame.
stack_frames_iter = reader.stack_frames_iterator()
prev_wall_time = 0
for debug_event in stack_frames_iter:
self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
prev_wall_time = debug_event.wall_time
stack_frame_with_id = debug_event.stack_frame_with_id
stack_frame_id = stack_frame_with_id.id
file_line_col = stack_frame_with_id.file_line_col
self.assertTrue(stack_frame_id)
self.assertNotIn(stack_frame_id, stack_frame_by_id,
"Duplicate stack frame ID: %s" % id)
stack_frame_by_id[stack_frame_id] = (file_line_col.file_index,
file_line_col.line,
file_line_col.func)
self.assertGreaterEqual(file_line_col.file_index, 0)
self.assertLess(file_line_col.file_index, len(source_file_paths))
self.assertTrue(file_line_col.line) # Line numbers are 1-based.
self.assertTrue(file_line_col.func)
# Assert the stack frames are unique.
self.assertEqual(
len(stack_frame_by_id.values()), len(set(stack_frame_by_id.values())))
return stack_frame_by_id
def _readAndCheckGraphsFile(self, stack_frame_by_id):
"""Read and verify the content of the .graphs debug-event file.
Args:
stack_frame_by_id: A dict mapping unique string IDs to stack frames.
It is used by this method to look up stack frames.
Returns:
context_ids: IDs of op creation contexts (e.g., TensorFlow graphs), as a
`list` of `str`s.
op_types: Types of the ops that are created, as a `list` of `str`s with
the same length as `context_ids`.
op_name_to_op_type: A `dict` mapping op name to op type.
"""
reader = debug_events_reader.DebugEventsDir(self.dump_root)
graphs_iter = reader.graphs_iterator()
prev_wall_time = 0
op_types = []
op_name_to_op_type = dict()
context_ids = set()
for debug_event in graphs_iter:
self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
prev_wall_time = debug_event.wall_time
graph_op_creation = debug_event.graph_op_creation
self.assertTrue(graph_op_creation.op_type)
op_types.append(graph_op_creation.op_type)
self.assertTrue(graph_op_creation.op_name)
op_name_to_op_type[graph_op_creation.op_name] = graph_op_creation.op_type
self.assertTrue(graph_op_creation.graph_id)
context_ids.add(graph_op_creation.graph_id)
self.assertTrue(graph_op_creation.code_location)
for stack_frame_id in graph_op_creation.code_location.stack_frame_ids:
self.assertIn(stack_frame_id, stack_frame_by_id)
return context_ids, op_types, op_name_to_op_type
def _readAndCheckExecutionFile(self):
"""Read and verify the content of the .execution debug-event file.
Returns:
executed_op_types: Types of ops that are created, as a `list` of `str`.
input_tensor_ids: Input tensor IDs for each of the ops executed, as a
`list` of `list` of `int`s, with the same length as `executed_op_types`.
output_tensor_ids: Output tensor IDs for each of the ops executed, as a
`list` of `list` of `int`s, with the same length as `executed_op_types`.
tensor_debug_modes: Tensor debug modes used to instrument each of ops
executed.
"""
reader = debug_events_reader.DebugEventsDir(self.dump_root)
execution_iter = reader.execution_iterator()
prev_wall_time = 1
executed_op_types = []
input_tensor_ids = []
output_tensor_ids = []
tensor_debug_modes = []
for debug_event in execution_iter:
self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
prev_wall_time = debug_event.wall_time
execution = debug_event.execution
executed_op_types.append(execution.op_type)
input_tensor_ids.append(execution.input_tensor_ids)
output_tensor_ids.append(execution.output_tensor_ids)
tensor_debug_modes.append(execution.tensor_debug_mode)
# TODO(cais): When tensor debug modes other than NO_TENSOR is supported,
# return tensor_values as well.
return (executed_op_types, input_tensor_ids, output_tensor_ids,
tensor_debug_modes)
def _readAndCheckGraphExecutionTracesFile(self, context_ids):
"""Read & verify the content of the .graph_execution_trace debug-event file.
Args:
context_ids: Op-creation context IDs from _readAndCheckGraphsFile().
Returns:
op_names: Names of the ops that are executed, as a `list` of `str`s.
output_slots: Output slots, as a `list` of `int`s, with the same length as
`op_names`. In other words, for an executed op with N output tensors,
there will be N entries in this `list` and in `op_names`, at
corresponding indices.
tensor_values: Tensor values or their concise summaries, depending on
TensorDebugMode.
"""
reader = debug_events_reader.DebugEventsDir(self.dump_root)
graph_execution_traces_iter = reader.graph_execution_traces_iterator()
op_names = []
output_slots = []
tensor_values = []
for debug_event in graph_execution_traces_iter:
self.assertGreaterEqual(debug_event.wall_time, 0)
graph_execution_trace = debug_event.graph_execution_trace
op_names.append(graph_execution_trace.op_name)
# All the ops in the graph have only one output.
self.assertTrue(graph_execution_trace.tfdbg_context_id)
self.assertIn(graph_execution_trace.tfdbg_context_id, context_ids)
output_slots.append(graph_execution_trace.output_slot)
# Under the default NO_TENSOR tensor-debug mode, the tensor_proto ought to
# be an empty float32 tensor.
tensor_values.append(
tensor_util.MakeNdarray(graph_execution_trace.tensor_proto))
return op_names, output_slots, tensor_values
def testInvalidTensorDebugModeCausesError(self):
with self.assertRaisesRegexp(
ValueError,
r"Invalid value in tensor_debug_mode \(\'NONSENSICAL\'\).*"
r"Valid options.*NO_TENSOR.*"):
dumping_callback.enable_dumping(
self.dump_root, tensor_debug_mode="NONSENSICAL")
def testDisablingTracingCallbackWithoutEnablingFirstIsTolerated(self):
dumping_callback.disable_dumping()
def testPureEagerOpExecution(self):
"""Test catching Infinity in eager op execution: float32."""
writer = dumping_callback.enable_dumping(
self.dump_root, tensor_debug_mode="NO_TENSOR")
x = constant_op.constant(10.0)
zero = constant_op.constant(0.0)
one = constant_op.constant(1.0)
two = constant_op.constant(2.0)
three = constant_op.constant(3.0)
# Use Collatz conjecture as a test case.
while x > one:
if math_ops.equal(x % two, zero):
x = x / two
else:
x = x * three + one
writer.FlushNonExecutionFiles()
self._readAndCheckMetadataFile()
stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
# Before FlushExecutionFiles() is called, the .execution file should be
# empty.
reader = debug_events_reader.DebugEventsDir(self.dump_root)
execution_iter = reader.execution_iterator()
with self.assertRaises(StopIteration):
next(execution_iter)
# After the flushing, the .execution file should hold the appropriate
# contents.
writer.FlushExecutionFiles()
execution_iter = reader.execution_iterator()
prev_wall_time = 1
executed_op_types = []
for debug_event in execution_iter:
self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
prev_wall_time = debug_event.wall_time
execution = debug_event.execution
executed_op_types.append(execution.op_type)
self.assertTrue(execution.input_tensor_ids)
self.assertTrue(execution.output_tensor_ids)
# Due to the default NO_TENSOR tensor debug mode, tensor_protos ought to
# be empty.
self.assertFalse(execution.tensor_protos)
# Verify the code_location field.
self.assertTrue(execution.code_location.stack_frame_ids)
for stack_frame_id in execution.code_location.stack_frame_ids:
self.assertIn(stack_frame_id, stack_frame_by_id)
self.assertEqual(
executed_op_types,
[
"Greater",
"FloorMod",
"Equal",
"RealDiv", # 10 --> 5
"Greater",
"FloorMod",
"Equal",
"Mul",
"AddV2", # 5 --> 16
"Greater",
"FloorMod",
"Equal",
"RealDiv", # 16 --> 8
"Greater",
"FloorMod",
"Equal",
"RealDiv", # 8 --> 4
"Greater",
"FloorMod",
"Equal",
"RealDiv", # 4 --> 2
"Greater",
"FloorMod",
"Equal",
"RealDiv", # 2 --> 1
"Greater"
])
# Due to the pure eager op execution, the .graph file and the
# .graph_execution_traces file ought to be empty.
graphs_iterator = reader.graphs_iterator()
with self.assertRaises(StopIteration):
next(graphs_iterator)
graph_trace_iter = reader.graph_execution_traces_iterator()
with self.assertRaises(StopIteration):
next(graph_trace_iter)
@test_util.run_in_graph_and_eager_modes
def testNestedFunctionExecutionWithoutControlFlow(self):
writer = dumping_callback.enable_dumping(self.dump_root)
@def_function.function
def log_sum(x, y):
return math_ops.log(x + y)
@def_function.function
def sin1p_log_sum(x, y):
return math_ops.sin(1.0 + log_sum(x, y))
x = constant_op.constant(2.0)
y = constant_op.constant(3.0)
self.assertAllClose(sin1p_log_sum(x, y), np.sin(1.0 + np.log(5.0)))
writer.FlushNonExecutionFiles()
writer.FlushExecutionFiles()
if context.executing_eagerly():
# NOTE(b/142486213): Execution of the TF function happens with
# Session.run() in v1 graph mode, so doesn't get logged to the
# .execution file.
executed_op_types, _, _, _ = self._readAndCheckExecutionFile()
self.assertEqual(len(executed_op_types), 1)
self.assertIn("sin1p_log_sum", executed_op_types[0])
stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
(context_ids, op_types,
op_name_to_op_type) = self._readAndCheckGraphsFile(stack_frame_by_id)
self.assertIn("AddV2", op_types)
self.assertIn("Log", op_types)
self.assertIn("Sin", op_types)
(op_names, _,
tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
self.assertEqual(executed_op_types, ["AddV2", "Log", "AddV2", "Sin"])
# Under the default NO_TENSOR tensor-debug mode, the tensor_proto ought to
# be an empty float32 tensor.
for tensor_value in tensor_values:
self.assertEqual(tensor_value.dtype, np.float32)
self.assertEqual(tensor_value.shape, (0,))
@test_util.run_in_graph_and_eager_modes
def testFunctionExecutionWithControlFlow(self):
writer = dumping_callback.enable_dumping(self.dump_root)
@def_function.function
def iterative_doubling(x, times):
i = constant_op.constant(0, dtype=dtypes.int32)
while i < times:
x = x * 2.0
i += 1
return x
x = constant_op.constant(0.5, dtype=dtypes.float32)
times = constant_op.constant(4, dtype=dtypes.int32)
self.assertAllClose(self.evaluate(iterative_doubling(x, times)), 8.0)
writer.FlushNonExecutionFiles()
stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
# Verify the content of the .graphs file.
context_ids, op_types, op_name_to_op_type = (
self._readAndCheckGraphsFile(stack_frame_by_id))
self.assertIn("Less", op_types)
self.assertIn("Mul", op_types)
self.assertIn("AddV2", op_types)
# Before FlushExecutionFiles() is called, the .execution and
# .graph_execution_traces files should be both empty.
reader = debug_events_reader.DebugEventsDir(self.dump_root)
execution_iter = reader.execution_iterator()
graph_execution_traces_iter = reader.graph_execution_traces_iterator()
with self.assertRaises(StopIteration):
next(execution_iter)
with self.assertRaises(StopIteration):
next(graph_execution_traces_iter)
# TODO(cais): Backport execution instrumentation to tf.Session.
writer.FlushExecutionFiles()
# After the flushing, the .execution file should hold the appropriate
# contents.
if context.executing_eagerly():
(executed_op_types, input_tensor_ids, output_tensor_ids,
tensor_debug_modes) = self._readAndCheckExecutionFile()
# NOTE(b/142486213): Execution of the TF function happens with
# Session.run() in v1 graph mode, hence it doesn't get logged to the
# .execution file.
self.assertEqual(len(executed_op_types), 1)
self.assertIn("iterative_doubling", executed_op_types[0])
self.assertEqual(len(input_tensor_ids[0]), 2)
self.assertEqual(len(output_tensor_ids[0]), 1)
self.assertEqual(tensor_debug_modes[0],
debug_event_pb2.TensorDebugMode.NO_TENSOR)
(op_names, output_slots,
tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
# The Less op should have been executed 5 times.
self.assertEqual(executed_op_types.count("Less"), 5)
# The last executed op should be Less.
self.assertEqual(executed_op_types[-1], "Less")
# The Mul op should have been executed 4 times.
self.assertEqual(executed_op_types.count("Mul"), 4)
# The AddV2 op should have been run, but we refrain from asserting on how
# many times it's executed.
self.assertIn("AddV2", executed_op_types)
for output_slot in output_slots:
self.assertEqual(output_slot, 0)
# Under the default NO_TENSOR tensor-debug mode, the tensor_proto ought to
# be an empty float32 tensor.
for tensor_value in tensor_values:
self.assertEqual(tensor_value.dtype, np.float32)
self.assertEqual(tensor_value.shape, (0,))
def testCallingEnableTracingTwiceWithTheSameDumpRootIsIdempotent(self):
dumping_callback.enable_dumping(self.dump_root)
writer = dumping_callback.enable_dumping(self.dump_root)
x = constant_op.constant([10.0, 12.0, 10.0])
for _ in range(2):
array_ops.unique(x)
writer.FlushNonExecutionFiles()
writer.FlushExecutionFiles()
reader = debug_events_reader.DebugEventsDir(self.dump_root)
execution_iter = reader.execution_iterator()
for _ in range(2):
debug_event = next(execution_iter)
self.assertGreater(debug_event.wall_time, 0)
execution = debug_event.execution
self.assertEqual(execution.op_type, "Unique")
self.assertEqual(execution.num_outputs, 2)
self.assertTrue(execution.code_location)
with self.assertRaises(StopIteration):
next(execution_iter)
def testCallingEnableTracingTwiceWithDifferentDumpRootsOverwrites(self):
dumping_callback.enable_dumping(self.dump_root)
new_dump_root = self.dump_root + "_new_dump_root"
writer = dumping_callback.enable_dumping(new_dump_root)
x = constant_op.constant([10.0, 12.0, 10.0])
for _ in range(2):
array_ops.unique(x)
writer.FlushNonExecutionFiles()
writer.FlushExecutionFiles()
reader = debug_events_reader.DebugEventsDir(new_dump_root)
execution_iter = reader.execution_iterator()
for _ in range(2):
debug_event = next(execution_iter)
self.assertGreater(debug_event.wall_time, 0)
execution = debug_event.execution
self.assertEqual(execution.op_type, "Unique")
self.assertEqual(execution.num_outputs, 2)
self.assertTrue(execution.code_location)
with self.assertRaises(StopIteration):
next(execution_iter)
old_dump_root_reader = debug_events_reader.DebugEventsDir(self.dump_root)
execution_iter = old_dump_root_reader.execution_iterator()
# The old dump root shouldn't have been written to.
with self.assertRaises(StopIteration):
next(execution_iter)
def testDisableTracingWorks(self):
writer = dumping_callback.enable_dumping(self.dump_root)
dumping_callback.disable_dumping()
x = constant_op.constant([10.0, 12.0, 10.0])
for _ in range(2):
array_ops.unique(x)
writer.FlushNonExecutionFiles()
writer.FlushExecutionFiles()
reader = debug_events_reader.DebugEventsDir(self.dump_root)
source_files_iter = reader.source_files_iterator()
stack_frames_iter = reader.stack_frames_iterator()
execution_iter = reader.execution_iterator()
# No source-file, stack-frame or execution data should have been dumped.
with self.assertRaises(StopIteration):
next(source_files_iter)
with self.assertRaises(StopIteration):
next(stack_frames_iter)
with self.assertRaises(StopIteration):
next(execution_iter)
def testMultiThreadedExecution(self):
writer = dumping_callback.enable_dumping(self.dump_root)
x = variables.Variable(10.0, dtype=dtypes.float32)
y = variables.Variable(3.0, dtype=dtypes.float32)
@def_function.function
def increase_x():
return x.assign_add(y * 2.0)
increase_x()
num_threads = 3
threads = []
for _ in range(num_threads):
threads.append(threading.Thread(target=increase_x))
for thread in threads:
thread.start()
for thread in threads:
thread.join()
# 10 --> 16 --> 22 --> 28 --> 34.
self.assertAllClose(x.read_value(), 34.0)
writer.FlushNonExecutionFiles()
writer.FlushExecutionFiles()
stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
reader = debug_events_reader.DebugEventsDir(self.dump_root)
execution_iter = reader.execution_iterator()
prev_wall_time = 1
for debug_event in execution_iter:
self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
prev_wall_time = debug_event.wall_time
(context_ids, _,
op_name_to_op_type) = self._readAndCheckGraphsFile(stack_frame_by_id)
(op_names, output_slots,
tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
self.assertEqual(executed_op_types.count("Mul"), 1 + num_threads)
self.assertEqual(
executed_op_types.count("ReadVariableOp"), 2 * (1 + num_threads))
for output_slot in output_slots:
self.assertEqual(output_slot, 0)
for tensor_value in tensor_values:
self.assertEqual(tensor_value.dtype, np.float32)
self.assertEqual(tensor_value.shape, (0,))
@test_util.run_in_graph_and_eager_modes
def testSimpleKerasRecurrentModelPredict(self):
writer = dumping_callback.enable_dumping(self.dump_root)
model = _create_simple_recurrent_keras_model()
xs = np.ones([5, 8, 4])
self.assertAllClose(model.predict(xs), np.zeros([5, 1]))
writer.FlushNonExecutionFiles()
writer.FlushExecutionFiles()
stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
(context_ids, op_types,
op_name_to_op_type) = self._readAndCheckGraphsFile(stack_frame_by_id)
# Simply assert that graph are recorded and refrain from asserting on the
# internal details of the Keras model.
self.assertTrue(context_ids)
self.assertTrue(op_types)
self.assertTrue(op_name_to_op_type)
if context.executing_eagerly():
# NOTE(b/142486213): Execution of the TF function happens with
# Session.run() in v1 graph mode, hence it doesn't get logged to the
# .execution file.
executed_op_types, _, _, _ = self._readAndCheckExecutionFile()
self.assertTrue(executed_op_types)
(op_names, _,
tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
# These are the ops that we can safely assume to have been executed during
# the model prediction.
self.assertIn("MatMul", executed_op_types)
self.assertIn("BiasAdd", executed_op_types)
# On the GPU, CudnnRNN is used in lieu of the default op-by-op
# implementation.
self.assertTrue(
("Sigmoid" in executed_op_types and "Tanh" in executed_op_types or
"CudnnRNN" in executed_op_types))
# Under the default NO_TENSOR tensor-debug mode, the tensor_proto ought to
# be an empty float32 tensor.
for tensor_value in tensor_values:
self.assertEqual(tensor_value.dtype, np.float32)
self.assertEqual(tensor_value.shape, (0,))
@test_util.run_in_graph_and_eager_modes
def testSimpleKerasRecurrentModelFit(self):
writer = dumping_callback.enable_dumping(self.dump_root)
model = _create_simple_recurrent_keras_model()
xs = np.ones([5, 8, 4])
ys = np.ones([5, 1])
history = model.fit(xs, ys, epochs=3, verbose=0)
self.assertAllClose(
history.history["loss"], [1.0, 0.9603999853134155, 0.9223681688308716])
writer.FlushNonExecutionFiles()
writer.FlushExecutionFiles()
stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
(context_ids, op_types,
op_name_to_op_type) = self._readAndCheckGraphsFile(stack_frame_by_id)
# Simply assert that graph are recorded and refrain from asserting on the
# internal details of the Keras model.
self.assertTrue(context_ids)
self.assertTrue(op_types)
self.assertTrue(op_name_to_op_type)
if context.executing_eagerly():
# NOTE(b/142486213): Execution of the TF function happens with
# Session.run() in v1 graph mode, hence it doesn't get logged to the
# .execution file.
executed_op_types, _, _, _ = self._readAndCheckExecutionFile()
self.assertTrue(executed_op_types)
(op_names, _,
tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
# These are the ops that we can safely assume to have been executed during
# the recurrent model's fit() call.
self.assertIn("MatMul", executed_op_types)
self.assertIn("BiasAdd", executed_op_types)
# On the GPU, CudnnRNN is used in lieu of the default op-by-op
# implementation.
self.assertTrue(
("Sigmoid" in executed_op_types and "Tanh" in executed_op_types or
"CudnnRNN" in executed_op_types))
self.assertTrue(
("SigmoidGrad" in executed_op_types and
"TanhGrad" in executed_op_types or
"CudnnRNNBackprop" in executed_op_types))
# Under the default NO_TENSOR tensor-debug mode, the tensor_proto ought to
# be an empty float32 tensor.
for tensor_value in tensor_values:
self.assertEqual(tensor_value.dtype, np.float32)
self.assertEqual(tensor_value.shape, (0,))
@test_util.run_in_graph_and_eager_modes
@test_util.disable_xla("TODO(cais): Investigate timeout.")
def testMobileNetV2Fit(self):
"""Test training Keras MobileNetV2 application works w/ check numerics."""
# Use a large circular-buffer to make sure we capture all the executed ops.
writer = dumping_callback.enable_dumping(self.dump_root,
circular_buffer_size=100000)
model = mobilenet_v2.MobileNetV2(alpha=0.1, weights=None)
xs = np.zeros([2] + list(model.input_shape[1:]))
ys = np.zeros([2] + list(model.output_shape[1:]))
model.compile(optimizer="sgd", loss="categorical_crossentropy")
epochs = 1
history = model.fit(xs, ys, epochs=epochs, verbose=0)
self.assertEqual(len(history.history["loss"]), epochs)
writer.FlushNonExecutionFiles()
writer.FlushExecutionFiles()
stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
(context_ids, op_types,
op_name_to_op_type) = self._readAndCheckGraphsFile(stack_frame_by_id)
# Simply assert that graph are recorded and refrain from asserting on the
# internal details of the Keras model.
self.assertTrue(context_ids)
self.assertTrue(op_types)
self.assertTrue(op_name_to_op_type)
if context.executing_eagerly():
# NOTE(b/142486213): Execution of the TF function happens with
# Session.run() in v1 graph mode, hence it doesn't get logged to the
# .execution file.
executed_op_types, _, _, _ = self._readAndCheckExecutionFile()
self.assertTrue(executed_op_types)
(op_names, _,
tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
# These are the ops that we can safely assume to have been executed during
# the model's fit() call.
self.assertIn("Conv2D", executed_op_types)
self.assertIn("Relu6", executed_op_types)
self.assertIn("Conv2DBackpropFilter", executed_op_types)
self.assertIn("Relu6Grad", executed_op_types)
# Under the default NO_TENSOR tensor-debug mode, the tensor_proto ought to
# be an empty float32 tensor.
for tensor_value in tensor_values:
self.assertEqual(tensor_value.dtype, np.float32)
self.assertEqual(tensor_value.shape, (0,))
if __name__ == "__main__":
ops.enable_eager_execution()
googletest.main()