From e724d9e162b92c8e039dfbaf976762a725681584 Mon Sep 17 00:00:00 2001
From: Shanqing Cai <cais@google.com>
Date: Mon, 23 Dec 2019 16:04:06 -0800
Subject: [PATCH] [tfdbg] Implement random file read & DebugDataReader;
 Simplify & improve tests.

- Support offset-based random read access in DebugEventsReader
- Support yielding offsets from the iterators of DebugEventsReader to enable subsequent random-access reading
- Check the tensor ID in the debug tensor values under the CURT_HEALTH, CONCISE_HEALTH and SHAPE modes: Tackling multiple TODO items.
- Use new DebugDataReader in tests to simplify code.

Per design for scalable reading of large tfdbg v2 datasets:
- Create light-weight digest classes: ExecutionDigest and GraphExecutionTraceDigest
  - Loaded by DebugDataReader.executions() and .graph_execution_traces() with
    kwarg digest=True.
- Corresponding detailed data classes: Execution and GraphExecutionTrace.
- Other data classes:
  - DebuggedGraph
  - GraphOpCreationDigest

PiperOrigin-RevId: 286955104
Change-Id: I750fc085fd75a7df11637413389b68dd0a6733c6
---
 tensorflow/core/protobuf/debug_event.proto    |    1 +
 tensorflow/python/debug/BUILD                 |    1 +
 .../python/debug/lib/debug_events_reader.py   |  883 +++++++++++-
 .../debug/lib/debug_events_writer_test.py     |   52 +-
 .../python/debug/lib/debug_v2_ops_test.py     |   10 +-
 .../python/debug/lib/dumping_callback_test.py | 1225 ++++++++---------
 .../debug/lib/dumping_callback_test_lib.py    |   12 +-
 7 files changed, 1522 insertions(+), 662 deletions(-)

diff --git a/tensorflow/core/protobuf/debug_event.proto b/tensorflow/core/protobuf/debug_event.proto
index 8f9680f38d9..ebbb93ee049 100644
--- a/tensorflow/core/protobuf/debug_event.proto
+++ b/tensorflow/core/protobuf/debug_event.proto
@@ -162,6 +162,7 @@ message GraphOpCreation {
   string graph_name = 3;
 
   // Unique ID of the graph (generated by debugger).
+  // This is the ID of the immediately-enclosing graph.
   string graph_id = 4;
 
   // Name of the device that the op is assigned to (if available).
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index 4ee84e512bc..58bebfa6bd9 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -120,6 +120,7 @@ py_library(
     deps = [
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python:framework",
+        "@six_archive//:six",
     ],
 )
 
diff --git a/tensorflow/python/debug/lib/debug_events_reader.py b/tensorflow/python/debug/lib/debug_events_reader.py
index c6142c6e309..1594b5f27f8 100644
--- a/tensorflow/python/debug/lib/debug_events_reader.py
+++ b/tensorflow/python/debug/lib/debug_events_reader.py
@@ -18,17 +18,24 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import collections
 import glob
 import os
 import threading
 
-from six.moves import map
+import six
 
 from tensorflow.core.protobuf import debug_event_pb2
-from tensorflow.python.lib.io import tf_record
+from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import tensor_util
 from tensorflow.python.util import compat
 
 
+DebugEventWithOffset = collections.namedtuple(
+    "DebugEventWithOffset", "debug_event offset")
+
+
 class DebugEventsReader(object):
   """Reader class for a tfdbg v2 DebugEvents directory."""
 
@@ -56,6 +63,8 @@ class DebugEventsReader(object):
     self._readers = dict()  # A map from file path to reader.
     self._readers_lock = threading.Lock()
 
+    self._offsets = dict()
+
   def __enter__(self):
     return self
 
@@ -64,15 +73,48 @@ class DebugEventsReader(object):
     self.close()
 
   def _generic_iterator(self, file_path):
-    """A helper method that makes an iterator given a debug-events file path."""
+    """A helper method that makes an iterator given a debug-events file path.
+
+    Repeated calls to this method create iterators that remember the last
+    successful reading position (offset) for each given `file_path`. So the
+    iterators are meant for incremental reading of the file.
+
+    Args:
+      file_path: Path to the file to create the iterator for.
+
+    Yields:
+      A tuple of (offset, debug_event_proto) on each `next()` call.
+    """
     # The following code uses the double-checked locking pattern to optimize
     # the common case (where the reader is already initialized).
     if file_path not in self._readers:  # 1st check, without lock.
       with self._readers_lock:
         if file_path not in self._readers:  # 2nd check, with lock.
-          self._readers[file_path] = tf_record.tf_record_iterator(file_path)
+          with errors.raise_exception_on_not_ok_status() as status:
+            # TODO(b/136474806): Use tf_record.tf_record_iterator() once it
+            # supports offset.
+            self._readers[file_path] = pywrap_tensorflow.PyRecordReader_New(
+                compat.as_bytes(file_path), 0, b"", status)
+    reader = self._readers[file_path]
+    while True:
+      offset = reader.offset()
+      try:
+        reader.GetNext()
+      except (errors.DataLossError, errors.OutOfRangeError):
+        # We ignore partial read exceptions, because a record may be truncated.
+        # PyRecordReader holds the offset prior to the failed read, so retrying
+        # will succeed.
+        break
+      yield DebugEventWithOffset(
+          debug_event=debug_event_pb2.DebugEvent.FromString(reader.record()),
+          offset=offset)
 
-    return map(debug_event_pb2.DebugEvent.FromString, self._readers[file_path])
+  def _create_offset_reader(self, file_path, offset):
+    with errors.raise_exception_on_not_ok_status() as status:
+      # TODO(b/136474806): Use tf_record.tf_record_iterator() once it
+      # supports ofset.
+      return pywrap_tensorflow.PyRecordReader_New(
+          file_path, offset, b"", status)
 
   def metadata_iterator(self):
     return self._generic_iterator(self._metadata_path)
@@ -86,12 +128,839 @@ class DebugEventsReader(object):
   def graphs_iterator(self):
     return self._generic_iterator(self._graphs_path)
 
+  def read_graphs_event(self, offset):
+    """Read a DebugEvent proto at a given offset from the .graphs file.
+
+    Args:
+      offset: Offset to read the DebugEvent proto from.
+
+    Returns:
+      A DebugEventProto.
+
+    Raises:
+      `errors.DataLossError` if offset is at a wrong location.
+      `errors.OutOfRangeError` if offset is out of range of the file.
+    """
+    # TODO(cais): After switching to new Python wrapper of tfrecord reader,
+    # use seeking instead of repeated file opening. Same below.
+    reader = self._create_offset_reader(self._graphs_path, offset)
+    reader.GetNext()
+    debug_event = debug_event_pb2.DebugEvent.FromString(reader.record())
+    reader.Close()
+    return debug_event
+
   def execution_iterator(self):
     return self._generic_iterator(self._execution_path)
 
+  def read_execution_debug_event(self, offset):
+    """Read a DebugEvent proto at a given offset from the .execution file.
+
+    Args:
+      offset: Offset to read the DebugEvent proto from.
+
+    Returns:
+      A DebugEventProto.
+
+    Raises:
+      `errors.DataLossError` if offset is at a wrong location.
+      `errors.OutOfRangeError` if offset is out of range of the file.
+    """
+    reader = self._create_offset_reader(self._execution_path, offset)
+    reader.GetNext()
+    debug_event = debug_event_pb2.DebugEvent.FromString(reader.record())
+    reader.Close()
+    return debug_event
+
   def graph_execution_traces_iterator(self):
     return self._generic_iterator(self._graph_execution_traces_path)
 
+  def read_graph_execution_traces_event(self, offset):
+    """Read DebugEvent at given offset from .graph_execution_traces file.
+
+    Args:
+      offset: Offset to read the DebugEvent proto from.
+
+    Returns:
+      A DebugEventProto.
+
+    Raises:
+      `errors.DataLossError` if offset is at a wrong location.
+      `errors.OutOfRangeError` if offset is out of range of the file.
+    """
+    reader = self._create_offset_reader(
+        self._graph_execution_traces_path, offset)
+    reader.GetNext()
+    debug_event = debug_event_pb2.DebugEvent.FromString(reader.record())
+    reader.Close()
+    return debug_event
+
   def close(self):
-    with self._readers_lock:
-      self._readers.clear()
+    for reader in self._readers.values():
+      reader.Close()
+
+
+class BaseDigest(object):
+  """Base class for digest.
+
+  Properties:
+    wall_time: A timestamp for the digest (unit: s).
+    offset: A offset number in the corresponding file that can be used for
+      fast random read access.
+  """
+
+  def __init__(self, wall_time, offset):
+    self._wall_time = wall_time
+    self._offset = offset
+
+  @property
+  def wall_time(self):
+    return self._wall_time
+
+  @property
+  def offset(self):
+    return self._offset
+
+
+class ExecutionDigest(BaseDigest):
+  """Light-weight digest summarizing top-level execution event.
+
+  Use `DebugDataReader.read_execution(execution_digest)` to load the more
+  detailed data object concerning the execution event (`Execution`).
+
+  Properties:
+    op_type: Type name of the executed op. In the case of the eager execution of
+      an individual op, it is the name of the op (e.g., "MatMul").
+      In the case of the execution of a tf.function (FuncGraph), this is the
+      internally-generated name of the function (e.g.,
+      "__inference_my_func_123").
+  """
+
+  def __init__(self,
+               wall_time,
+               offset,
+               op_type):
+    super(ExecutionDigest, self).__init__(wall_time, offset)
+    self._op_type = op_type
+
+  @property
+  def op_type(self):
+    return self._op_type
+
+  # TODO(cais): Implement to_json().
+
+
+class Execution(ExecutionDigest):
+  """Detailed data relating to a top-level execution event.
+
+  The execution is of an individual op or a tf.function, which may have any
+  number of output tensors.
+
+  Properties (beyond the base class `ExecutionDigest`):
+    stack_frame_ids: Reference IDs for stack frames, ordered from bottommost to
+      topmost. Use `DebugDataReader.read_execution_stack_trace()` to load the
+      detailed stack frames (filepath, lineno and function name).
+    tensor_debug_mode: TensorDebugMode enum value, as an `int`.
+    graph_id: ID of the executed FuncGraph (applicable only the execution of a
+      tf.function). `None` for the eager execution of an individual op.
+    input_tensor_ids: IDs of the input (eager) tensor(s) for this execution, if
+      any.
+    output_tensor_ids: IDs of the output (eager) tensor(s) from this execution,
+      if any.
+    debug_tensor_values: Values of the debug tensor(s), applicable only to
+      non-FULL_TENSOR tensor debug mode. A tuple of list of numbers. Each
+      element of the tuple corresponds to an output tensor of the execution.
+      See documentation of the various TensorDebugModes for the semantics of the
+      numbers.
+  """
+
+  def __init__(self,
+               execution_digest,
+               stack_frame_ids,
+               tensor_debug_mode,
+               graph_id=None,
+               input_tensor_ids=None,
+               output_tensor_ids=None,
+               debug_tensor_values=None):
+    super(Execution, self).__init__(
+        execution_digest.wall_time,
+        execution_digest.offset,
+        execution_digest.op_type)
+    self._stack_frame_ids = stack_frame_ids
+    self._tensor_debug_mode = tensor_debug_mode
+    self._graph_id = graph_id
+    self._input_tensor_ids = input_tensor_ids
+    self._output_tensor_ids = output_tensor_ids
+    self._debug_tensor_values = debug_tensor_values
+
+  @property
+  def stack_frame_ids(self):
+    return self._stack_frame_ids
+
+  @property
+  def tensor_debug_mode(self):
+    return self._tensor_debug_mode
+
+  @property
+  def graph_id(self):
+    return self._graph_id
+
+  @property
+  def input_tensor_ids(self):
+    return self._input_tensor_ids
+
+  @property
+  def num_outputs(self):
+    return len(self._output_tensor_ids)
+
+  @property
+  def output_tensor_ids(self):
+    return self._output_tensor_ids
+
+  @property
+  def debug_tensor_values(self):
+    return self._debug_tensor_values
+
+  # TODO(cais): Implement to_json().
+
+
+class DebuggedGraph(object):
+  """Data object representing debugging information about a tf.Graph.
+
+  Includes `FuncGraph`s.
+
+  Properties:
+    name: Name of the graph (if any). May be `None` for non-function graphs.
+    graph_id: Debugger-generated ID for the graph.
+    inner_graph_ids: A list of the debugger-generated IDs for the graphs
+      enclosed by this graph.
+    outer_graph_id: If this graph is nested within an outer graph, ID of the
+      outer graph. If this is an outermost graph, `None`.
+  """
+
+  def __init__(self,
+               name,
+               graph_id,
+               outer_graph_id=None):
+    self._name = name
+    self._graph_id = graph_id
+    self._outer_graph_id = outer_graph_id
+    self._inner_graph_ids = []
+    # A dictionary from op name to GraphOpCreationDigest.
+    self._op_by_name = dict()
+
+  def add_inner_graph_id(self, inner_graph_id):
+    """Add the debugger-generated ID of a graph nested within this graph.
+
+    Args:
+      inner_graph_id: The debugger-generated ID of the nested inner graph.
+    """
+    assert isinstance(inner_graph_id, six.string_types)
+    self._inner_graph_ids.append(inner_graph_id)
+
+  def add_op(self, graph_op_creation_digest):
+    """Add an op creation data object.
+
+    Args:
+      graph_op_creation_digest: A GraphOpCreationDigest data object describing
+        the creation of an op inside this graph.
+    """
+    assert graph_op_creation_digest.op_name not in self._op_by_name
+    self._op_by_name[
+        graph_op_creation_digest.op_name] = graph_op_creation_digest
+
+  @property
+  def name(self):
+    return self._name
+
+  @property
+  def graph_id(self):
+    return self._graph_id
+
+  @property
+  def outer_graph_id(self):
+    return self._outer_graph_id
+
+  @property
+  def inner_graph_ids(self):
+    return self._inner_graph_ids
+
+  def get_op_type(self, op_name):
+    return self._op_by_name[op_name].op_type
+
+  def get_tensor_id(self, op_name, output_slot):
+    """Get the ID of a symbolic tensor in this graph."""
+    return self._op_by_name[op_name].output_tensor_ids[output_slot]
+
+  # TODO(cais): Implement to_json().
+
+
+class GraphOpCreationDigest(BaseDigest):
+  """Data object describing the creation of an op inside a graph.
+
+  For size efficiency, this digest object does not contain any stack frames or
+  any references to them. To obtain the stack frames, use
+  `DataReader.read_graph_op_creation_stack_trace()`.
+
+  Properties (beyond the base class):
+    graph_id: Debugger-generated ID of the immediately-enclosing graph.
+    op_type: Type name of the op (e.g., "MatMul").
+    op_name: Name of the op (e.g., "dense_1/MatMul").
+    output_tensor_ids: Debugger-generated IDs for the output(s) of the op.
+    input_names: Names of the input tensors to the op.
+    device_name: The name of the device that the op is placed on (if available).
+  """
+
+  def __init__(self,
+               wall_time,
+               offset,
+               graph_id,
+               op_type,
+               op_name,
+               output_tensor_ids,
+               input_names=None,
+               device_name=None):
+    super(GraphOpCreationDigest, self).__init__(wall_time, offset)
+    self._graph_id = graph_id
+    self._op_type = op_type
+    self._op_name = op_name
+    self._output_tensor_ids = output_tensor_ids
+    self._input_names = input_names
+    self._device_name = device_name
+
+  @property
+  def graph_id(self):
+    return self._graph_id
+
+  @property
+  def op_type(self):
+    return self._op_type
+
+  @property
+  def op_name(self):
+    return self._op_name
+
+  @property
+  def output_tensor_ids(self):
+    return self._output_tensor_ids
+
+  @property
+  def num_outputs(self):
+    return len(self._output_tensor_ids)
+
+  @property
+  def input_names(self):
+    return self._input_names
+
+  @property
+  def device_name(self):
+    return self._device_name
+
+  # TODO(cais): Implement to_json().
+
+
+class GraphExecutionTraceDigest(BaseDigest):
+  """Light-weight summary of a intra-graph tensor execution event.
+
+  Use `DebugDataReader.read_graph_execution_trace()` on this object to read more
+  detailed data (`GraphExecutionTrace`).
+
+  Properties (beyond the base class):
+    op_type: Type name of the executed op (e.g., "Conv2D").
+    op_name: Name of the op (e.g., "conv_2d_3/Conv2D").
+    output_slot: Output slot index of the tensor.
+  """
+
+  def __init__(self,
+               wall_time,
+               offset,
+               op_type,
+               op_name,
+               output_slot):
+    super(GraphExecutionTraceDigest, self).__init__(wall_time, offset)
+    self._op_type = op_type
+    self._op_name = op_name
+    self._output_slot = output_slot
+
+  @property
+  def op_type(self):
+    return self._op_type
+
+  @property
+  def op_name(self):
+    return self._op_name
+
+  @property
+  def output_slot(self):
+    return self._output_slot
+
+  # TODO(cais): Implement to_json().
+
+
+class GraphExecutionTrace(GraphExecutionTraceDigest):
+  """Detailed data object describing an intra-graph tensor execution.
+
+  Attributes (in addition to GraphExecutionTraceDigest):
+    graph_ids: The debugger-generated IDs of the graphs that enclose the
+      executed op (tensor), ordered from the outermost to the innermost.
+    graph_id: The debugger-generated ID of the innermost (immediately-enclosing)
+      graph.
+    tensor_debug_mode: TensorDebugMode enum value.
+    debug_tensor_value: Debug tensor values (only for non-FULL_TENSOR
+      tensor_debug_mode). A list of numbers. See the documentation of the
+      TensorDebugModes for the semantics of the numbers.
+    device_name: Device on which the tensor resides (if available)
+  """
+
+  def __init__(self,
+               graph_execution_trace_digest,
+               graph_ids,
+               tensor_debug_mode,
+               debug_tensor_value=None,
+               device_name=None):
+    super(GraphExecutionTrace, self).__init__(
+        graph_execution_trace_digest.wall_time,
+        graph_execution_trace_digest.offset,
+        graph_execution_trace_digest.op_type,
+        graph_execution_trace_digest.op_name,
+        graph_execution_trace_digest.output_slot)
+    self._graph_ids = graph_ids
+    self._tensor_debug_mode = tensor_debug_mode
+    self._debug_tensor_value = debug_tensor_value
+    self._device_name = device_name
+
+  @property
+  def graph_ids(self):
+    return self._graph_ids
+
+  @property
+  def graph_id(self):
+    return self._graph_ids[-1]
+
+  @property
+  def tensor_debug_mode(self):
+    return self._tensor_debug_mode
+
+  @property
+  def debug_tensor_value(self):
+    return self._debug_tensor_value
+
+  @property
+  def device_name(self):
+    return self._device_name
+
+  # TODO(cais): Implement to_json().
+
+
+def _parse_tensor_value(tensor_proto, return_list=False):
+  """Helper method for reading a tensor value from a tensor proto.
+
+  The rationale for the distinction between `True` and `False value of
+  `return_list` is as follows:
+  - `return_list=True` is used for TensorDebugMode values other than
+    FULL_TENSOR, e.g., CONCISE_HEALTH, SHAPE and FULL_HEATLH. Under
+    those modes, the value is guaranteed (by contract) to be a 1D float64
+    tensor.
+  - `return_list=False` is used for the FULL_HEALTH TensorDebugMode
+    specifically. Instead, we use `numpy.ndarray` to maximally preserve
+    the shape, dtype and value information regarding the underlying tensor
+    value. Under that mode, we don't use a python list to represent the
+    tensor value because that can lead to loss of information (e.g., both
+    float16 and float32 dtypes get mapped to Python floats).
+
+  Args:
+    tensor_proto: The TensorProto instance from which the tensor value will be
+      loaded.
+    return_list: Whether the return value will be a nested Python list that
+      comes out from `numpy.ndarray.tolist()`.
+
+  Returns:
+    If parsing is successful, the tensor value as a `numpy.ndarray` or the
+      nested Python list converted from it.
+    If parsing fails, `None`.
+  """
+  try:
+    ndarray = tensor_util.MakeNdarray(tensor_proto)
+    return ndarray.tolist() if return_list else ndarray
+  except TypeError:
+    # Depending on tensor_debug_mode, certain dtype of tensors don't
+    # have logged debug tensor values.
+    return None
+
+
+class DebugDataReader(object):
+  """A reader that reads structured debugging data in the tfdbg v2 format.
+
+  The set of data read by an object of this class concerns the execution history
+  of a tfdbg2-instrumented TensorFlow program.
+
+  Note:
+    - An object of this class incrementally reads data from files that belong to
+      the tfdbg v2 DebugEvent file set. Calling `update()` triggers the reading
+      from the last-successful reading positions in the files.
+    - This object can be used as a context manager. Its `__exit__()` call
+      closes the file readers cleanly.
+  """
+
+  def __init__(self, dump_root):
+    self._reader = DebugEventsReader(dump_root)
+    # TODO(cais): Implement pagination for memory constraints.
+    self._execution_digests = []
+
+    # A list of (host_name, file_path) tuples.
+    self._host_name_file_paths = []
+    # A dict mapping id to (host_name, file_path, lineno, func) tuple.
+    self._stack_frame_by_id = dict()
+    # Stores unprocessed stack frame IDs. This is necessary to handle the
+    # case in which reading of the .stack_frames file gets ahead of the reading
+    # of the .source_files file.
+    self._unprocessed_stack_frames = dict()
+    # A dict mapping id to DebuggedGraph objects.
+    self._graph_by_id = dict()
+    self._graph_op_digests = []
+    # TODO(cais): Implement pagination for memory constraints.
+    self._graph_execution_trace_digests = []
+
+    # The following timestamps keep track where we've reached in each
+    # file of the DebugEvent source file, so that we don't run into race
+    # conditions with the writer.
+    self._source_files_timestamp = 0
+    # Temporary object used to hold DebugEvent protos with stack_frames
+    # field that has been read beyond max_wall_time.
+    # self._last_successful_stack_frames_offset = -1  # TODO(cais): Fix.
+
+  # TODO(cais): Read metadata.
+  def _load_source_files(self):
+    """Incrementally read the .source_files DebugEvent file."""
+    source_files_iter = self._reader.source_files_iterator()
+    for debug_event, _ in source_files_iter:
+      source_file = debug_event.source_file
+      self._host_name_file_paths.append(
+          (source_file.host_name, source_file.file_path))
+      self._source_file_timestamp = debug_event.wall_time
+
+  def _load_stack_frames(self):
+    """Incrementally read the .stack_frames file.
+
+    This must be called after _load_source_files().
+    It assumes that the following contract is honored by the writer of the tfdbg
+    v2 data file set:
+      - Before a stack frame is written to the .stack_frames file, the
+        corresponding source file information must have been written to the
+        .source_files file first.
+    """
+    stack_frames_iter = self._reader.stack_frames_iterator()
+    for debug_event, _ in stack_frames_iter:
+      stack_frame_with_id = debug_event.stack_frame_with_id
+      file_line_col = stack_frame_with_id.file_line_col
+      self._unprocessed_stack_frames[stack_frame_with_id.id] = file_line_col
+    # We do the processing in a separate stage, because the reading in the
+    # .source_files file may sometimes get ahead of the .source_files file.
+    unprocessed_stack_frame_ids = tuple(self._unprocessed_stack_frames.keys())
+    for stack_frame_id in unprocessed_stack_frame_ids:
+      file_line_col = self._unprocessed_stack_frames[stack_frame_id]
+      if len(self._host_name_file_paths) > file_line_col.file_index:
+        self._stack_frame_by_id[stack_frame_id] = (
+            self._host_name_file_paths[file_line_col.file_index][0],
+            self._host_name_file_paths[file_line_col.file_index][1],
+            file_line_col.line,
+            file_line_col.func)
+      del self._unprocessed_stack_frames[stack_frame_id]
+
+  def _load_graphs(self):
+    """Incrementally read the .graphs file.
+
+    Compiles the DebuggedGraph and GraphOpCreation data.
+    """
+    graphs_iter = self._reader.graphs_iterator()
+    for debug_event, offset in graphs_iter:
+      if debug_event.graph_op_creation.ByteSize():
+        op_creation_proto = debug_event.graph_op_creation
+        op_digest = GraphOpCreationDigest(
+            debug_event.wall_time,
+            offset,
+            op_creation_proto.graph_id,
+            op_creation_proto.op_type,
+            op_creation_proto.op_name,
+            tuple(op_creation_proto.output_tensor_ids),
+            input_names=tuple(op_creation_proto.input_names))
+        self._graph_op_digests.append(op_digest)
+        self._graph_by_id[op_creation_proto.graph_id].add_op(op_digest)
+      elif debug_event.debugged_graph.ByteSize():
+        graph_proto = debug_event.debugged_graph
+        graph = DebuggedGraph(
+            graph_proto.graph_name or None,
+            graph_proto.graph_id,
+            outer_graph_id=graph_proto.outer_context_id or None)
+        self._graph_by_id[graph_proto.graph_id] = graph
+        if graph_proto.outer_context_id:
+          self._graph_by_id[
+              graph_proto.outer_context_id].add_inner_graph_id(graph.graph_id)
+
+  def _load_graph_execution_traces(self):
+    """Incrementally load the .graph_execution_traces file."""
+    traces_iter = self._reader.graph_execution_traces_iterator()
+    for debug_event, offset in traces_iter:
+      trace_proto = debug_event.graph_execution_trace
+      op_name = trace_proto.op_name
+      op_type = self._lookup_op_type(trace_proto.tfdbg_context_id, op_name)
+      digest = GraphExecutionTraceDigest(
+          debug_event.wall_time,
+          offset,
+          op_type,
+          op_name,
+          trace_proto.output_slot)
+      self._graph_execution_trace_digests.append(digest)
+
+  def _lookup_op_type(self, graph_id, op_name):
+    """Lookup the type of an op by name and the immediately enclosing graph.
+
+    Args:
+      graph_id: Debugger-generated ID of the immediately-enclosing graph.
+      op_name: Name of the op.
+
+    Returns:
+      Op type as a str.
+    """
+    return self._graph_by_id[graph_id].get_op_type(op_name)
+
+  def _load_execution(self):
+    """Incrementally read the .execution file."""
+    execution_iter = self._reader.execution_iterator()
+    for debug_event, offset in execution_iter:
+      self._execution_digests.append(ExecutionDigest(
+          debug_event.wall_time,
+          offset,
+          debug_event.execution.op_type))
+
+  def update(self):
+    """Perform incremental read of the file set."""
+    self._load_source_files()
+    self._load_stack_frames()
+    self._load_graphs()
+    self._load_graph_execution_traces()
+    self._load_execution()
+
+  def outermost_graphs(self):
+    """Get the number of outer most graphs read so far."""
+    return [graph for graph in self._graph_by_id.values()
+            if not graph.outer_graph_id]
+
+  def graph_by_id(self, graph_id):
+    """Get a DebuggedGraph object by its ID."""
+    return self._graph_by_id[graph_id]
+
+  def graph_op_digests(self, op_type=None):
+    """Get the list of the digests for graph-op creation so far.
+
+    Args:
+      op_type: Optional op type to filter the creation events with.
+
+    Returns:
+      A list of `GraphOpCreationDigest` objects.
+    """
+    if op_type is not None:
+      return [digest for digest in self._graph_op_digests
+              if digest.op_type == op_type]
+    else:
+      return self._graph_op_digests
+
+  def graph_execution_traces(self, digest=False):
+    """Get all the intra-graph execution tensor traces read so far.
+
+    TODO(cais): Support begin and end to enable partial loading.
+
+    Args:
+      digest: Whether the results will be returned in the more light-weight
+        digest form.
+
+    Returns:
+      If `digest`: a `list` of `GraphExecutionTraceDigest` objects.
+      Else: a `list` of `GraphExecutionTrace` objects.
+    """
+    if digest:
+      return self._graph_execution_trace_digests
+    else:
+      return [self.read_graph_execution_trace(digest)
+              for digest in self._graph_execution_trace_digests]
+
+  def num_graph_execution_traces(self):
+    """Get the number of graph execution traces read so far."""
+    return len(self._graph_execution_trace_digests)
+
+  def executions(self, digest=False):
+    """Get `Execution`s or `ExecutionDigest`s this reader has read so far.
+
+    # TODO(cais): Support begin index and end index to support partial loading.
+
+    Args:
+      digest: Whether the results are returned in a digest form, i.e.,
+        `ExecutionDigest` format, instead of the more detailed `Execution`
+        format.
+
+    Returns:
+      If `digest`: a `list` of `ExecutionDigest` objects.
+      Else: a `list` of `Execution` objects.
+    """
+    if digest:
+      return self._execution_digests
+    else:
+      # TODO(cais): Optimizer performance removing repeated file open/close.
+      return [self.read_execution(digest) for digest in self._execution_digests]
+
+  def num_executions(self):
+    """Get the number of execution events read so far."""
+    return len(self._execution_digests)
+
+  def read_execution(self, execution_digest):
+    """Read a detailed Execution object."""
+    debug_event = self._reader.read_execution_debug_event(
+        execution_digest.offset)
+    execution_proto = debug_event.execution
+
+    debug_tensor_values = None
+    if (execution_proto.tensor_debug_mode ==
+        debug_event_pb2.TensorDebugMode.FULL_TENSOR):
+      pass  # TODO(cais): Build tensor store.
+    elif (execution_proto.tensor_debug_mode !=
+          debug_event_pb2.TensorDebugMode.NO_TENSOR):
+      debug_tensor_values = []
+      for tensor_proto in execution_proto.tensor_protos:
+        # TODO(cais): Refactor into a helper method.
+        debug_tensor_values.append(
+            _parse_tensor_value(tensor_proto, return_list=True))
+    return Execution(
+        execution_digest,
+        tuple(execution_proto.code_location.stack_frame_ids),
+        execution_proto.tensor_debug_mode,
+        graph_id=execution_proto.graph_id,
+        input_tensor_ids=tuple(execution_proto.input_tensor_ids),
+        output_tensor_ids=tuple(execution_proto.output_tensor_ids),
+        debug_tensor_values=tuple(
+            debug_tensor_values) if debug_tensor_values else None)
+
+  def read_graph_execution_trace(self, graph_execution_trace_digest):
+    """Read the detailed graph execution trace.
+
+    Args:
+      graph_execution_trace_digest: A `GraphExecutionTraceDigest` object.
+
+    Returns:
+      The corresponding `GraphExecutionTrace` object.
+    """
+    debug_event = self._reader.read_graph_execution_traces_event(
+        graph_execution_trace_digest.offset)
+    trace_proto = debug_event.graph_execution_trace
+
+    graph_ids = [trace_proto.tfdbg_context_id]
+    # Exhaust the outer contexts (graphs).
+    while True:
+      graph = self.graph_by_id(graph_ids[0])
+      if graph.outer_graph_id:
+        graph_ids.insert(0, graph.outer_graph_id)
+      else:
+        break
+
+    debug_tensor_value = None
+    if (trace_proto.tensor_debug_mode ==
+        debug_event_pb2.TensorDebugMode.FULL_TENSOR):
+      pass  # TODO(cais): Build tensor store.
+    else:
+      debug_tensor_value = _parse_tensor_value(
+          trace_proto.tensor_proto, return_list=True)
+    return GraphExecutionTrace(
+        graph_execution_trace_digest,
+        graph_ids=graph_ids,
+        tensor_debug_mode=trace_proto.tensor_debug_mode,
+        debug_tensor_value=debug_tensor_value,
+        device_name=trace_proto.device_name or None)
+
+  def read_execution_stack_trace(self, execution):
+    """Read the stack trace of a given Execution object.
+
+    Args:
+      execution: The Execution object of interest.
+
+    Returns:
+      A tuple consisting of:
+        1. The host name.
+        2. The stack trace, as a list of (file_path, lineno, func) tuples.
+    """
+    host_name = self._stack_frame_by_id[execution.stack_frame_ids[0]][0]
+    return (host_name, [
+        self._stack_frame_by_id[frame_id][1:]
+        for frame_id in execution.stack_frame_ids])
+
+  def read_graph_op_creation_stack_trace(self, graph_op_creation_digest):
+    """Read the stack trace of a given graph op creation object.
+
+    Args:
+      graph_op_creation_digest: The GraphOpCreationDigest object of interest.
+
+    Returns:
+      A tuple consisting of:
+        1. The host name.
+        2. The stack trace, as a list of (file_path, lineno, func) tuples.
+    """
+    debug_event = self._reader.read_graphs_event(
+        graph_op_creation_digest.offset)
+    graph_op_creation = debug_event.graph_op_creation
+    host_name = graph_op_creation.code_location.host_name
+    return host_name, [
+        self._stack_frame_by_id[frame_id][1:]
+        for frame_id in graph_op_creation.code_location.stack_frame_ids]
+
+  # TODO(cais): Add graph_execution_digests() with an ExecutionDigest
+  #   as a kwarg, to establish the association between top-level and intra-graph
+  #   execution events.
+
+  def execution_to_tensor_values(self, execution):
+    """Read the full tensor values from an Execution or ExecutionDigest.
+
+    Args:
+      execution: An `ExecutionDigest` or `ExeuctionDigest` object.
+
+    Returns:
+      A list of numpy arrays representing the output tensor values of the
+        execution event.
+    """
+    debug_event = self._reader.read_execution_debug_event(execution.offset)
+    return [_parse_tensor_value(tensor_proto)
+            for tensor_proto in debug_event.execution.tensor_protos]
+
+  def graph_execution_trace_to_tensor_value(self, trace):
+    """Read full tensor values from an Execution or ExecutionDigest.
+
+    Args:
+      trace: An `GraphExecutionTraceDigest` or `GraphExecutionTrace` object.
+
+    Returns:
+      A numpy array representing the output tensor value of the intra-graph
+        tensor execution event.
+    """
+    debug_event = self._reader.read_graph_execution_traces_event(trace.offset)
+    return _parse_tensor_value(debug_event.graph_execution_trace.tensor_proto)
+
+  def symbolic_tensor_id(self, graph_id, op_name, output_slot):
+    """Get the ID of a symbolic tensor.
+
+    Args:
+      graph_id: The ID of the immediately-enclosing graph.
+      op_name: Name of the op.
+      output_slot: Output slot as an int.
+
+    Returns:
+      The ID of the symbolic tensor as an int.
+    """
+    return self._graph_by_id[graph_id].get_tensor_id(op_name, output_slot)
+
+  def graph_execution_trace_to_tensor_id(self, trace):
+    """Get symbolic tensor ID from a GraphExecutoinTraceDigest object."""
+    return self.symbolic_tensor_id(
+        trace.graph_id, trace.op_name, trace.output_slot)
+
+  def __enter__(self):
+    return self
+
+  def __exit__(self, exception_type, exception_value, traceback):
+    del exception_type, exception_value, traceback  # Unused
+    self._reader.close()
diff --git a/tensorflow/python/debug/lib/debug_events_writer_test.py b/tensorflow/python/debug/lib/debug_events_writer_test.py
index f6e973befed..b62fc9b3f9f 100644
--- a/tensorflow/python/debug/lib/debug_events_writer_test.py
+++ b/tensorflow/python/debug/lib/debug_events_writer_test.py
@@ -76,20 +76,20 @@ class DebugEventsWriterTest(dumping_callback_test_lib.DumpingCallbackTestBase):
     writer.FlushNonExecutionFiles()
 
     with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
-      actuals = list(reader.source_files_iterator())
+      actuals = list(item.debug_event.source_file
+                     for item in reader.source_files_iterator())
       self.assertLen(actuals, num_protos)
       for i in range(num_protos):
-        self.assertEqual(actuals[i].source_file.file_path,
-                         "/home/tf2user/main.py")
-        self.assertEqual(actuals[i].source_file.host_name, "machine.cluster")
-        self.assertEqual(actuals[i].source_file.lines, ["print(%d)" % i])
+        self.assertEqual(actuals[i].file_path, "/home/tf2user/main.py")
+        self.assertEqual(actuals[i].host_name, "machine.cluster")
+        self.assertEqual(actuals[i].lines, ["print(%d)" % i])
 
-      actuals = list(reader.stack_frames_iterator())
+      actuals = list(item.debug_event.stack_frame_with_id
+                     for item in reader.stack_frames_iterator())
       self.assertLen(actuals, num_protos)
       for i in range(num_protos):
-        self.assertEqual(actuals[i].stack_frame_with_id.id, "stack_%d" % i)
-        self.assertEqual(
-            actuals[i].stack_frame_with_id.file_line_col.file_index, i * 10)
+        self.assertEqual(actuals[i].id, "stack_%d" % i)
+        self.assertEqual(actuals[i].file_line_col.file_index, i * 10)
 
   def testWriteGraphOpCreationAndDebuggedGraphs(self):
     writer = debug_events_writer.DebugEventsWriter(self.dump_root)
@@ -106,7 +106,7 @@ class DebugEventsWriterTest(dumping_callback_test_lib.DumpingCallbackTestBase):
     writer.FlushNonExecutionFiles()
 
     reader = debug_events_reader.DebugEventsReader(self.dump_root)
-    actuals = list(reader.graphs_iterator())
+    actuals = list(item.debug_event for item in reader.graphs_iterator())
     self.assertLen(actuals, num_op_creations + 1)
     for i in range(num_op_creations):
       self.assertEqual(actuals[i].graph_op_creation.op_type, "Conv2D")
@@ -172,24 +172,24 @@ class DebugEventsWriterTest(dumping_callback_test_lib.DumpingCallbackTestBase):
     # Verify the content of the .source_files file.
     with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
       source_files_iter = reader.source_files_iterator()
-      actuals = list(source_files_iter)
-      file_paths = sorted([actual.source_file.file_path for actual in actuals])
+      actuals = list(item.debug_event.source_file for item in source_files_iter)
+      file_paths = sorted([actual.file_path for actual in actuals])
       self.assertEqual(file_paths, [
           "/home/tf2user/file_0.py", "/home/tf2user/file_1.py",
           "/home/tf2user/file_2.py"
       ])
 
     # Verify the content of the .stack_frames file.
-    actuals = list(reader.stack_frames_iterator())
-    stack_frame_ids = sorted(
-        [actual.stack_frame_with_id.id for actual in actuals])
+    actuals = list(item.debug_event.stack_frame_with_id
+                   for item in reader.stack_frames_iterator())
+    stack_frame_ids = sorted([actual.id for actual in actuals])
     self.assertEqual(stack_frame_ids,
                      ["stack_frame_0", "stack_frame_1", "stack_frame_2"])
 
     # Verify the content of the .graphs file.
-    actuals = list(reader.graphs_iterator())
-    graph_op_names = sorted(
-        [actual.graph_op_creation.op_name for actual in actuals])
+    actuals = list(item.debug_event.graph_op_creation
+                   for item in reader.graphs_iterator())
+    graph_op_names = sorted([actual.op_name for actual in actuals])
     self.assertEqual(graph_op_names, ["Op0", "Op1", "Op2"])
 
   def testWriteExecutionEventsWithCircularBuffer(self):
@@ -242,11 +242,12 @@ class DebugEventsWriterTest(dumping_callback_test_lib.DumpingCallbackTestBase):
       self.assertEqual(len(actuals), 0)
 
       writer.FlushExecutionFiles()
-      actuals = list(reader.graph_execution_traces_iterator())
+      actuals = list(item.debug_event.graph_execution_trace
+                     for item in reader.graph_execution_traces_iterator())
       self.assertLen(actuals, debug_events_writer.DEFAULT_CIRCULAR_BUFFER_SIZE)
       for i in range(debug_events_writer.DEFAULT_CIRCULAR_BUFFER_SIZE):
         self.assertEqual(
-            actuals[i].graph_execution_trace.op_name,
+            actuals[i].op_name,
             "Op%d" % (i + debug_events_writer.DEFAULT_CIRCULAR_BUFFER_SIZE))
 
   def testWriteGraphExecutionTraceEventsWithoutCircularBufferBehavior(self):
@@ -260,10 +261,11 @@ class DebugEventsWriterTest(dumping_callback_test_lib.DumpingCallbackTestBase):
     writer.FlushExecutionFiles()
 
     with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
-      actuals = list(reader.graph_execution_traces_iterator())
+      actuals = list(item.debug_event.graph_execution_trace
+                     for item in reader.graph_execution_traces_iterator())
     self.assertLen(actuals, num_execution_events)
     for i in range(num_execution_events):
-      self.assertEqual(actuals[i].graph_execution_trace.op_name, "Op%d" % i)
+      self.assertEqual(actuals[i].op_name, "Op%d" % i)
 
   def testConcurrentWritesToExecutionFiles(self):
     circular_buffer_size = 5
@@ -308,9 +310,9 @@ class DebugEventsWriterTest(dumping_callback_test_lib.DumpingCallbackTestBase):
 
     # Verify the content of the .execution file.
     with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
-      actuals = list(reader.graph_execution_traces_iterator())
-      op_names = sorted(
-          [actual.graph_execution_trace.op_name for actual in actuals])
+      actuals = list(item.debug_event.graph_execution_trace
+                     for item in reader.graph_execution_traces_iterator())
+      op_names = sorted([actual.op_name for actual in actuals])
       self.assertLen(op_names, circular_buffer_size)
       self.assertLen(op_names, len(set(op_names)))
 
diff --git a/tensorflow/python/debug/lib/debug_v2_ops_test.py b/tensorflow/python/debug/lib/debug_v2_ops_test.py
index c665da7132d..d6f0d4310a2 100644
--- a/tensorflow/python/debug/lib/debug_v2_ops_test.py
+++ b/tensorflow/python/debug/lib/debug_v2_ops_test.py
@@ -88,7 +88,7 @@ class DebugIdentityV2OpTest(dumping_callback_test_lib.DumpingCallbackTestBase):
       metadata_iter = reader.metadata_iterator()
       # Check that the .metadata DebugEvents data file has been created, even
       # before FlushExecutionFiles() is called.
-      debug_event = next(metadata_iter)
+      debug_event = next(metadata_iter).debug_event
       self.assertGreater(debug_event.wall_time, 0)
       self.assertTrue(debug_event.debug_metadata.tensorflow_version)
       self.assertTrue(
@@ -107,7 +107,7 @@ class DebugIdentityV2OpTest(dumping_callback_test_lib.DumpingCallbackTestBase):
       # The circular buffer has a size of 4. So only the data from the
       # last two iterations should have been written to self.dump_root.
       for _ in range(2):
-        debug_event = next(graph_trace_iter)
+        debug_event = next(graph_trace_iter).debug_event
         self.assertGreater(debug_event.wall_time, 0)
         trace = debug_event.graph_execution_trace
         self.assertEqual(trace.tfdbg_context_id, "deadbeaf")
@@ -118,7 +118,7 @@ class DebugIdentityV2OpTest(dumping_callback_test_lib.DumpingCallbackTestBase):
         tensor_value = tensor_util.MakeNdarray(trace.tensor_proto)
         self.assertAllClose(tensor_value, [9.0, 16.0])
 
-        debug_event = next(graph_trace_iter)
+        debug_event = next(graph_trace_iter).debug_event
         self.assertGreater(debug_event.wall_time, 0)
         trace = debug_event.graph_execution_trace
         self.assertEqual(trace.tfdbg_context_id, "beafdead")
@@ -165,7 +165,7 @@ class DebugIdentityV2OpTest(dumping_callback_test_lib.DumpingCallbackTestBase):
         x_values = []
         timestamp = 0
         while True:
-          debug_event = next(graph_trace_iter)
+          debug_event = next(graph_trace_iter).debug_event
           self.assertGreater(debug_event.wall_time, timestamp)
           timestamp = debug_event.wall_time
           trace = debug_event.graph_execution_trace
@@ -210,7 +210,7 @@ class DebugIdentityV2OpTest(dumping_callback_test_lib.DumpingCallbackTestBase):
       with debug_events_reader.DebugEventsReader(debug_root) as reader:
         graph_trace_iter = reader.graph_execution_traces_iterator()
 
-        debug_event = next(graph_trace_iter)
+        debug_event = next(graph_trace_iter).debug_event
         trace = debug_event.graph_execution_trace
         self.assertEqual(trace.tfdbg_context_id, "deadbeaf")
         self.assertEqual(trace.op_name, "")
diff --git a/tensorflow/python/debug/lib/dumping_callback_test.py b/tensorflow/python/debug/lib/dumping_callback_test.py
index b7e90f3179c..061cb001639 100644
--- a/tensorflow/python/debug/lib/dumping_callback_test.py
+++ b/tensorflow/python/debug/lib/dumping_callback_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
 import collections
 import os
 import shutil
+import socket
 import tempfile
 import threading
 
@@ -36,7 +37,6 @@ from tensorflow.python.eager import def_function
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_util
 from tensorflow.python.framework import test_util
 from tensorflow.python.keras import models
 from tensorflow.python.keras.applications import mobilenet_v2
@@ -61,6 +61,10 @@ def _create_simple_recurrent_keras_model(input_shape):
   return model
 
 
+_host_name = socket.gethostname()
+_current_file_full_path = os.path.abspath(__file__)
+
+
 class TracingCallbackTest(
     dumping_callback_test_lib.DumpingCallbackTestBase, parameterized.TestCase):
 
@@ -74,6 +78,19 @@ class TracingCallbackTest(
     dumping_callback.disable_dump_debug_info()
     super(TracingCallbackTest, self).tearDown()
 
+  def _verifyStackFrames(self, stack_frames):
+    """Verify the correctness of the stack frames.
+
+    Currently, it simply asserts that the current file is found in the stack
+    frames.
+    TODO(cais): Perhaps implement a stricter check later.
+
+    Args:
+      stack_frames: The stack frames to verify.
+    """
+    self.assertTrue([
+        frame for frame in stack_frames if frame[0] == _current_file_full_path])
+
   def testInvalidTensorDebugModeCausesError(self):
     with self.assertRaisesRegexp(
         ValueError,
@@ -111,73 +128,66 @@ class TracingCallbackTest(
 
     writer.FlushNonExecutionFiles()
     self._readAndCheckMetadataFile()
-    stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
 
-    # Before FlushExecutionFiles() is called, the .execution file should be
-    # empty.
-    with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
-      execution_iter = reader.execution_iterator()
-      with self.assertRaises(StopIteration):
-        next(execution_iter)
+    with debug_events_reader.DebugDataReader(self.dump_root) as reader:
+      reader.update()
+      # Before FlushExecutionFiles() is called, the .execution file should be
+      # empty.
+      self.assertFalse(reader.executions())
 
       # After the flushing, the .execution file should hold the appropriate
       # contents.
       writer.FlushExecutionFiles()
-      execution_iter = reader.execution_iterator()
+      reader.update()
+      executions = reader.executions()
       prev_wall_time = 1
       executed_op_types = []
       tensor_values = collections.defaultdict(lambda: [])
-      for debug_event in execution_iter:
-        self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
-        prev_wall_time = debug_event.wall_time
-        execution = debug_event.execution
+      for execution in executions:
+        self.assertGreaterEqual(execution.wall_time, prev_wall_time)
+        prev_wall_time = execution.wall_time
         executed_op_types.append(execution.op_type)
         # No graph IDs should have been logged for eager op executions.
         self.assertFalse(execution.graph_id)
         self.assertTrue(execution.input_tensor_ids)
         self.assertTrue(execution.output_tensor_ids)
+        self.assertEqual(
+            debug_event_pb2.TensorDebugMode.keys()[execution.tensor_debug_mode],
+            tensor_debug_mode)
         if tensor_debug_mode == "NO_TENSOR":
           # Due to the NO_TENSOR tensor debug mode, tensor_protos ought to
           # be empty.
-          self.assertFalse(execution.tensor_protos)
+          self.assertFalse(execution.debug_tensor_values)
         elif tensor_debug_mode == "CURT_HEALTH":
-          self.assertLen(execution.tensor_protos, 1)
+          self.assertLen(execution.debug_tensor_values, 1)
           if execution.op_type in ("AddV2", "Mul", "RealDiv"):
             # 1st element: -1 is the unset tensor_id for eager op execution.
             # 2nd element: 0 means there is no inf or nan.
-            self.assertAllClose(
-                tensor_util.MakeNdarray(execution.tensor_protos[0]),
-                [-1.0, 0.0])
+            self.assertAllClose(execution.debug_tensor_values, [[-1.0, 0.0]])
         elif tensor_debug_mode == "CONCISE_HEALTH":
-          self.assertLen(execution.tensor_protos, 1)
           if execution.op_type in ("AddV2", "Mul", "RealDiv"):
             # 1st element: -1 is the unset tensor_id for eager op execution.
             # 2nd element: each scalar tensor has 1 element.
             # Remaining elements: no -inf, inf or nan in these
             self.assertAllClose(
-                tensor_util.MakeNdarray(execution.tensor_protos[0]),
-                [-1, 1, 0, 0, 0])
+                execution.debug_tensor_values, [[-1, 1, 0, 0, 0]])
         elif tensor_debug_mode == "SHAPE":
-          self.assertLen(execution.tensor_protos, 1)
           if execution.op_type in ("AddV2", "Mul", "RealDiv"):
             # 1st element: -1 is the unset tensor_id for eager op execution.
             # 2nd element: dtype enum value (float32).
             # 3rd element: rank (scalar).
             # 4th element: element count (4).
             # Remaining elements: shape at fixed length (6).
-            self.assertAllClose(
-                tensor_util.MakeNdarray(execution.tensor_protos[0]),
-                [-1, 1, 0, 1, 0, 0, 0, 0, 0, 0])
+            self.assertAllClose(execution.debug_tensor_values,
+                                [[-1, 1, 0, 1, 0, 0, 0, 0, 0, 0]])
         elif tensor_debug_mode == "FULL_TENSOR":
-          # Under the FULL_TENSOR mode, the value of the tensor should be
-          # available through `tensor_protos`.
-          tensor_value = float(
-              tensor_util.MakeNdarray(execution.tensor_protos[0]))
-          tensor_values[execution.op_type].append(tensor_value)
-        # Verify the code_location field.
-        self.assertTrue(execution.code_location.stack_frame_ids)
-        for stack_frame_id in execution.code_location.stack_frame_ids:
-          self.assertIn(stack_frame_id, stack_frame_by_id)
+          tensor_values[execution.op_type].append(
+              reader.execution_to_tensor_values(execution)[0])
+
+        host_name, stack_frames = reader.read_execution_stack_trace(execution)
+        self.assertEqual(host_name, _host_name)
+        self._verifyStackFrames(stack_frames)
+
       if tensor_debug_mode == "FULL_TENSOR":
         self.assertAllClose(tensor_values["Greater"], [1, 1, 1, 1, 1, 1, 0])
         self.assertAllClose(tensor_values["RealDiv"], [5, 8, 4, 2, 1])
@@ -217,12 +227,8 @@ class TracingCallbackTest(
 
       # Due to the pure eager op execution, the .graph file and the
       # .graph_execution_traces file ought to be empty.
-      graphs_iterator = reader.graphs_iterator()
-      with self.assertRaises(StopIteration):
-        next(graphs_iterator)
-      graph_trace_iter = reader.graph_execution_traces_iterator()
-      with self.assertRaises(StopIteration):
-        next(graph_trace_iter)
+      self.assertFalse(reader.outermost_graphs())
+      self.assertEqual(reader.num_graph_execution_traces(), 0)
 
   @parameterized.named_parameters(
       ("CurtHealth", "CURT_HEALTH"),
@@ -242,60 +248,48 @@ class TracingCallbackTest(
     y = np.array([2, -1, 0, 0, 1, 1, 1, 3], dtype=np.float16)
     # (x + y) / (x - y) = [0.2, -inf, nan, nan, inf, inf, inf, -5].
     self.evaluate(func(x, y))
-
     writer.FlushNonExecutionFiles()
     writer.FlushExecutionFiles()
 
-    stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
-    (context_ids,
-     _, op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id)
-
-    (op_names, _, _,
-     tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
-    executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
-    self.assertCountEqual(executed_op_types, ["AddV2", "Sub", "RealDiv"])
-
-    if tensor_debug_mode == "CURT_HEALTH":
-      for op_type, tensor_value in zip(executed_op_types, tensor_values):
-        self.assertLen(tensor_value, 2)
-        # 1st element: tensor_id, should be >= 0.
-        # TODO(cais): Assert on detailed value once Function-graph association
-        # is in place.
-        self.assertGreaterEqual(tensor_value[0], 0)
-        # 2nd element: 0 means there is no inf or nan.
-        if op_type == "RealDiv":
-          self.assertEqual(tensor_value[1], 1)
-        else:
-          self.assertEqual(tensor_value[1], 0)
-    elif tensor_debug_mode == "CONCISE_HEALTH":
-      for op_type, tensor_value in zip(executed_op_types, tensor_values):
-        self.assertLen(tensor_value, 5)
-        # 1st element: tensor_id, should be >= 0.
-        # TODO(cais): Assert on detailed value once Function-graph association
-        # is in place.
-        self.assertGreaterEqual(tensor_value[0], 0)
-        # 2nd element: element count.
-        self.assertEqual(tensor_value[1], 8)
-        # Remaining 3 elements: The counts of -inf, inf and nan.
-        if op_type == "RealDiv":
-          self.assertAllClose(tensor_value[2:], [1, 3, 2])
-        else:
-          self.assertAllClose(tensor_value[2:], [0, 0, 0])
-    else:  # SHAPE.
-      for op_type, tensor_value in zip(executed_op_types, tensor_values):
-        self.assertLen(tensor_value, 10)
-        # 1st element: tensor_id, should be >= 0.
-        # TODO(cais): Assert on detailed value once Function-graph association
-        # is in place.
-        self.assertGreaterEqual(tensor_value[0], 0)
-        # 2nd element: dtype enum value (float16).
-        self.assertEqual(tensor_value[1], 19)
-        # 3rd element: rank (1)
-        self.assertEqual(tensor_value[2], 1)
-        # 4th element: element count.
-        self.assertEqual(tensor_value[3], 8)
-        # Remaining elements: shape at fixed length.
-        self.assertAllClose(tensor_value[4:], [8, 0, 0, 0, 0, 0])
+    with debug_events_reader.DebugDataReader(self.dump_root) as reader:
+      reader.update()
+      graph_exec_traces = reader.graph_execution_traces()
+      executed_op_types = [trace.op_type for trace in graph_exec_traces]
+      self.assertCountEqual(executed_op_types, ["AddV2", "Sub", "RealDiv"])
+      if tensor_debug_mode == "CURT_HEALTH":
+        for trace in graph_exec_traces:
+          # 1st element: tensor_id, should be >= 0.
+          # 2nd element: indicates if there is any inf or nan.
+          tensor_id = reader.graph_execution_trace_to_tensor_id(trace)
+          self.assertGreaterEqual(tensor_id, 0)
+          if trace.op_type == "RealDiv":
+            self.assertAllClose(trace.debug_tensor_value, [tensor_id, 1])
+          else:
+            self.assertAllClose(trace.debug_tensor_value, [tensor_id, 0])
+      elif tensor_debug_mode == "CONCISE_HEALTH":
+        for trace in graph_exec_traces:
+          # 1st element: tensor_id, should be >= 0.
+          # 2nd element: element count (8).
+          # Remaining 3 elements: The counts of -inf, inf and nan.
+          tensor_id = reader.graph_execution_trace_to_tensor_id(trace)
+          self.assertGreaterEqual(tensor_id, 0)
+          if trace.op_type == "RealDiv":
+            self.assertAllClose(trace.debug_tensor_value,
+                                [tensor_id, 8, 1, 3, 2])
+          else:
+            self.assertAllClose(trace.debug_tensor_value,
+                                [tensor_id, 8, 0, 0, 0])
+      else:  # SHAPE.
+        for trace in graph_exec_traces:
+          # 1st element: tensor_id, should be >= 0.
+          # 2nd element: dtype enum value (float16 = 19).
+          # 3rd element: rank (1)
+          # 4th element: element count (8).
+          # Remaining elements: shape at fixed length (6).
+          tensor_id = reader.graph_execution_trace_to_tensor_id(trace)
+          self.assertGreaterEqual(tensor_id, 0)
+          self.assertAllClose(trace.debug_tensor_value,
+                              [tensor_id, 19, 1, 8, 8, 0, 0, 0, 0, 0])
 
   @parameterized.named_parameters(
       ("Shape", "SHAPE"),
@@ -317,28 +311,21 @@ class TracingCallbackTest(
     writer.FlushNonExecutionFiles()
     writer.FlushExecutionFiles()
 
-    stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
-    (context_ids,
-     _, op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id)
-
-    (op_names, _, _,
-     tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
-    executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
-    self.assertEqual(executed_op_types, ["LogicalAnd", "LogicalNot"])
-
-    for tensor_value in tensor_values:
-      # 1st element: tensor_id, should be >= 0.
-      # TODO(cais): Assert on detailed value once Function-graph association
-      # is in place.
-      self.assertGreaterEqual(tensor_value[0], 0)
-      # 2nd element: dtype enum value (bool).
-      self.assertEqual(tensor_value[1], 10)
-      # 3rd element: rank (2)
-      self.assertEqual(tensor_value[2], 2)
-      # 4th element: element count.
-      self.assertEqual(tensor_value[3], 4)
-      # Remaining elements: shape at fixed length.
-      self.assertAllClose(tensor_value[4:], [2, 2, 0, 0, 0, 0])
+    with debug_events_reader.DebugDataReader(self.dump_root) as reader:
+      reader.update()
+      graph_exec_traces = reader.graph_execution_traces()
+      executed_op_types = [trace.op_type for trace in graph_exec_traces]
+      self.assertEqual(executed_op_types, ["LogicalAnd", "LogicalNot"])
+      for trace in graph_exec_traces:
+        tensor_id = reader.graph_execution_trace_to_tensor_id(trace)
+        self.assertGreaterEqual(tensor_id, 0)
+        # 1st element: tensor_id, should be >= 0.
+        # 2nd element: dtype enum value (bool).
+        # 3rd element: rank (2).
+        # 4th element: element count (4).
+        # Remaining elements: shape at fixed length.
+        self.assertAllClose(
+            trace.debug_tensor_value, [tensor_id, 10, 2, 4, 2, 2, 0, 0, 0, 0])
 
   @parameterized.named_parameters(
       ("NoTensor", "NO_TENSOR"),
@@ -366,86 +353,151 @@ class TracingCallbackTest(
     writer.FlushNonExecutionFiles()
     writer.FlushExecutionFiles()
 
-    if context.executing_eagerly():
-      # NOTE(b/142486213): Execution of the TF function happens with
-      # Session.run() in v1 graph mode, so doesn't get logged to the
-      # .execution file.
-      (executed_op_types, executed_graph_ids,
-       _, _, _, _) = self._readAndCheckExecutionFile()
-      executed_op_types = [op_type for op_type in executed_op_types
-                           if "sin1p_log_sum" in op_type]
-      self.assertLen(executed_op_types, 1)
+    with debug_events_reader.DebugDataReader(self.dump_root) as reader:
+      reader.update()
+      outermost_graphs = reader.outermost_graphs()
+      self.assertLen(outermost_graphs, 1)
 
-    stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
-    (context_ids, op_types, op_name_to_op_type,
-     op_name_to_context_id) = self._readAndCheckGraphsFile(stack_frame_by_id)
+      if context.executing_eagerly():
+        # NOTE(b/142486213): Execution of the TF function happens with
+        # Session.run() in v1 graph mode, so doesn't get logged to the
+        # .execution file.
+        executions = reader.executions()
+        self.assertLen(executions, 1)
+        self.assertIn("sin1p_log_sum", executions[0].op_type)
+        # Get the executed graph and verify its identity and inner graph.
+        graph = reader.graph_by_id(executions[0].graph_id)
+        self.assertEqual(graph.name, "sin1p_log_sum")
+        self.assertLen(graph.inner_graph_ids, 1)
+        inner_graph = reader.graph_by_id(graph.inner_graph_ids[0])
+        self.assertEqual(inner_graph.name, "log_sum")
 
-    self.assertIn("AddV2", op_types)
-    self.assertIn("Log", op_types)
-    self.assertIn("Sin", op_types)
-    if context.executing_eagerly():
-      # Check the correctness of the ID of the executed graph ID.
-      sin_op_name = [op_name for op_name in op_name_to_op_type
-                     if op_name_to_op_type[op_name] == "Sin"]
-      self.assertLen(sin_op_name, 1)
-      sin_context_id = op_name_to_context_id[sin_op_name[0]]
-      # The executed "op" is a FuncGraph, and its graph ID should have been
-      # recorded properly and be the ID of the graph that the Sin op belongs to.
-      executed_graph_ids = [
-          executed_graph_ids[i] for i, op_type
-          in enumerate(executed_op_types) if "sin1p_log_sum" in op_type]
-      self.assertEqual(executed_graph_ids[0], sin_context_id)
+      # Verify the recorded graph-building history.
+      add_op_digests = reader.graph_op_digests(op_type="AddV2")
+      self.assertLen(add_op_digests, 2)
+      self.assertEqual(
+          reader.graph_by_id(add_op_digests[0].graph_id).name, "log_sum")
+      self.assertEqual(
+          reader.graph_by_id(add_op_digests[1].graph_id).name, "sin1p_log_sum")
+      log_op_digests = reader.graph_op_digests(op_type="Log")
+      self.assertLen(log_op_digests, 1)
+      self.assertEqual(
+          reader.graph_by_id(log_op_digests[0].graph_id).name, "log_sum")
+      sin_op_digests = reader.graph_op_digests(op_type="Sin")
+      self.assertLen(sin_op_digests, 1)
+      self.assertEqual(
+          reader.graph_by_id(sin_op_digests[0].graph_id).name, "sin1p_log_sum")
 
-    (op_names, _, _,
-     tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
-    executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
-    self.assertEqual(executed_op_types, ["AddV2", "Log", "AddV2", "Sin"])
+      # Verify the output tensor IDs and the stack traces.
+      for op_digest in add_op_digests + log_op_digests + sin_op_digests:
+        # These are all single-output ops.
+        self.assertLen(op_digest.output_tensor_ids, 1)
+        self.assertGreaterEqual(op_digest.output_tensor_ids[0], 0)
+        _, stack_frames = reader.read_graph_op_creation_stack_trace(op_digest)
+        self._verifyStackFrames(stack_frames)
 
-    if tensor_debug_mode == "NO_TENSOR":
-      # Under the default NO_TENSOR tensor-debug mode, the tensor_proto ought to
-      # be an empty float32 tensor.
-      for tensor_value in tensor_values:
-        self.assertEqual(tensor_value.dtype, np.float32)
-        self.assertEqual(tensor_value.shape, (0,))
-    elif tensor_debug_mode == "CURT_HEALTH":
-      for tensor_value in tensor_values:
-        self.assertLen(tensor_value, 2)
+      graph_exec_traces = reader.graph_execution_traces()
+      executed_op_types = [digest.op_type for digest in graph_exec_traces]
+      self.assertEqual(executed_op_types, ["AddV2", "Log", "AddV2", "Sin"])
+
+      # Verify the graph ID stack of each op.
+      # 1st AddV2 op.
+      self.assertEqual(
+          reader.graph_by_id(graph_exec_traces[0].graph_ids[-1]).name,
+          "log_sum")
+      self.assertEqual(
+          reader.graph_by_id(graph_exec_traces[0].graph_ids[-2]).name,
+          "sin1p_log_sum")
+      # Log op.
+      self.assertEqual(
+          reader.graph_by_id(graph_exec_traces[1].graph_ids[-1]).name,
+          "log_sum")
+      self.assertEqual(
+          reader.graph_by_id(graph_exec_traces[1].graph_ids[-2]).name,
+          "sin1p_log_sum")
+      # 2nd AddV2 op.
+      self.assertEqual(
+          reader.graph_by_id(graph_exec_traces[2].graph_ids[-1]).name,
+          "sin1p_log_sum")
+      # Sin op.
+      self.assertEqual(
+          reader.graph_by_id(graph_exec_traces[3].graph_ids[-1]).name,
+          "sin1p_log_sum")
+
+      if tensor_debug_mode == "NO_TENSOR":
+        # Under the default NO_TENSOR tensor-debug mode, the tensor_proto ought
+        # to be an empty float32 tensor.
+        for trace in graph_exec_traces:
+          self.assertEqual(trace.debug_tensor_value, [])
+      elif tensor_debug_mode == "CURT_HEALTH":
+        # Test the association between graph exec and prior graph building.
+        # In each case, the 1st element of debug_tensor_value is the ID of the
+        # symbolic tenosr and the 2nd element is a zero indicating there is no
+        # inf or nan.
+        self.assertAllClose(
+            graph_exec_traces[0].debug_tensor_value,
+            [add_op_digests[0].output_tensor_ids[0], 0.0])  # 1st AddV2 op.
+        self.assertAllClose(
+            graph_exec_traces[1].debug_tensor_value,
+            [log_op_digests[0].output_tensor_ids[0], 0.0])  # Log op.
+        self.assertAllClose(
+            graph_exec_traces[2].debug_tensor_value,
+            [add_op_digests[1].output_tensor_ids[0], 0.0])  # 2nd AddV2 op.
+        self.assertAllClose(
+            graph_exec_traces[3].debug_tensor_value,
+            [sin_op_digests[0].output_tensor_ids[0], 0.0])  # Sin op.
+      elif tensor_debug_mode == "CONCISE_HEALTH":
         # 1st element: tensor_id, should be >= 0.
-        # TODO(cais): Assert on detailed value once Function-graph association
-        # is in place.
-        self.assertGreaterEqual(tensor_value[0], 0)
-        # 2nd element: 0 means there is no inf or nan.
-        self.assertEqual(tensor_value[1], 0)
-    elif tensor_debug_mode == "CONCISE_HEALTH":
-      for tensor_value in tensor_values:
-        self.assertLen(tensor_value, 5)
-        # 1st element: tensor_id, should be >= 0.
-        # TODO(cais): Assert on detailed value once Function-graph association
-        # is in place.
-        self.assertGreaterEqual(tensor_value[0], 0)
         # 2nd element: element count. Remaining elements: all zero because there
         # is no -inf, inf or nan.
-        self.assertAllClose(tensor_value[1:], [1, 0, 0, 0])
-    elif tensor_debug_mode == "SHAPE":
-      for tensor_value in tensor_values:
-        # 1st element: tensor_id, should be >= 0.
-        # TODO(cais): Assert on detailed value once Function-graph association
-        # is in place.
-        self.assertGreaterEqual(tensor_value[0], 0)
+        # 1st AddV2 op.
+        self.assertAllClose(
+            graph_exec_traces[0].debug_tensor_value,
+            [add_op_digests[0].output_tensor_ids[0], 1.0, 0.0, 0.0, 0.0])
+        # Log op.
+        self.assertAllClose(
+            graph_exec_traces[1].debug_tensor_value,
+            [log_op_digests[0].output_tensor_ids[0], 1.0, 0.0, 0.0, 0.0])
+        # 2nd AddV2 op.
+        self.assertAllClose(
+            graph_exec_traces[2].debug_tensor_value,
+            [add_op_digests[1].output_tensor_ids[0], 1.0, 0.0, 0.0, 0.0])
+        # Sin op.
+        self.assertAllClose(
+            graph_exec_traces[3].debug_tensor_value,
+            [sin_op_digests[0].output_tensor_ids[0], 1.0, 0.0, 0.0, 0.0])
+      elif tensor_debug_mode == "SHAPE":
+        # 1st element: tensor_id.
         # 2nd element: dtype (float32).
-        self.assertGreaterEqual(tensor_value[1], 1)
         # 3rd element: rank (scalar).
-        self.assertGreaterEqual(tensor_value[2], 0)
-        # 4th element: element count.
-        self.assertGreaterEqual(tensor_value[3], 1)
-        # Remaining elements: shape padded to fixed length.
-        self.assertAllClose(tensor_value[4:], [0, 0, 0, 0, 0, 0])
-    elif tensor_debug_mode == "FULL_TENSOR":
-      self.assertAllClose(tensor_values[0], 5.0)  # 1st AddV2 op.
-      self.assertAllClose(tensor_values[1], np.log(5.0))  # Log op.
-      self.assertAllClose(tensor_values[2], np.log(5.0) + 1.0)  # 2nd AddV2 op.
-      self.assertAllClose(tensor_values[3],
-                          np.sin(np.log(5.0) + 1.0))  # Sin op.
+        # 4th element: element count (1).
+        # Remaining elements: shape padded to fixed length (6).
+        # 1st AddV2 op.
+        self.assertAllClose(
+            graph_exec_traces[0].debug_tensor_value,
+            [add_op_digests[0].output_tensor_ids[0], 1, 0, 1, 0, 0, 0, 0, 0, 0])
+        # Log op.
+        self.assertAllClose(
+            graph_exec_traces[1].debug_tensor_value,
+            [log_op_digests[0].output_tensor_ids[0], 1, 0, 1, 0, 0, 0, 0, 0, 0])
+        # 2nd AddV2 op.
+        self.assertAllClose(
+            graph_exec_traces[2].debug_tensor_value,
+            [add_op_digests[1].output_tensor_ids[0], 1, 0, 1, 0, 0, 0, 0, 0, 0])
+        # Sin op.
+        self.assertAllClose(
+            graph_exec_traces[3].debug_tensor_value,
+            [sin_op_digests[0].output_tensor_ids[0], 1, 0, 1, 0, 0, 0, 0, 0, 0])
+      else:  # FULL_TENSOR.
+        full_tensor_values = [
+            reader.graph_execution_trace_to_tensor_value(trace)
+            for trace in graph_exec_traces]
+        self.assertAllClose(full_tensor_values[0], 5.0)  # 1st AddV2 op.
+        self.assertAllClose(full_tensor_values[1], np.log(5.0))  # Log op.
+        self.assertAllClose(
+            full_tensor_values[2], np.log(5.0) + 1.0)  # 2nd AddV2 op.
+        self.assertAllClose(
+            full_tensor_values[3], np.sin(np.log(5.0) + 1.0))  # Sin op.
 
   def testCapturingExecutedGraphIdsOfTwoCompilationsOfSameFunction(self):
     """Test correct executed IDs of two FuncGraphs from the same Py function."""
@@ -467,15 +519,21 @@ class TracingCallbackTest(
     writer.FlushNonExecutionFiles()
     writer.FlushExecutionFiles()
 
-    (executed_op_types, executed_graph_ids,
-     _, _, _, _) = self._readAndCheckExecutionFile()
-    self.assertLen(executed_op_types, 4)
-    for executed_op_type in executed_op_types:
-      self.assertStartsWith(executed_op_type, "__inference_ceil_times_two_")
-    self.assertLen(executed_graph_ids, 4)
-    self.assertEqual(executed_graph_ids[0], executed_graph_ids[2])
-    self.assertEqual(executed_graph_ids[1], executed_graph_ids[3])
-    self.assertLen(set(executed_graph_ids), 2)
+    with debug_events_reader.DebugDataReader(self.dump_root) as reader:
+      reader.update()
+
+      executions = reader.executions()
+      self.assertLen(executions, 4)
+      for execution in executions:
+        self.assertStartsWith(execution.op_type, "__inference_ceil_times_two_")
+      executed_graph_ids = [execution.graph_id for execution in executions]
+      self.assertEqual(executed_graph_ids[0], executed_graph_ids[2])
+      self.assertEqual(executed_graph_ids[1], executed_graph_ids[3])
+      self.assertNotEqual(executed_graph_ids[0], executed_graph_ids[1])
+      self.assertNotEqual(executed_graph_ids[2], executed_graph_ids[3])
+      for executed_graph_id in executed_graph_ids:
+        self.assertEqual(
+            reader.graph_by_id(executed_graph_id).name, "ceil_times_two")
 
   def testCapturingExecutedGraphIdsOfDuplicateFunctionNames(self):
     """Two FuncGraphs compiled from Python functions with identical names."""
@@ -503,15 +561,20 @@ class TracingCallbackTest(
     writer.FlushNonExecutionFiles()
     writer.FlushExecutionFiles()
 
-    (executed_op_types, executed_graph_ids,
-     _, _, _, _) = self._readAndCheckExecutionFile()
-    self.assertLen(executed_op_types, 4)
-    for executed_op_type in executed_op_types:
-      self.assertStartsWith(executed_op_type, "__inference_ceil_times_two_")
-    self.assertLen(executed_graph_ids, 4)
-    self.assertEqual(executed_graph_ids[0], executed_graph_ids[2])
-    self.assertEqual(executed_graph_ids[1], executed_graph_ids[3])
-    self.assertLen(set(executed_graph_ids), 2)
+    with debug_events_reader.DebugDataReader(self.dump_root) as reader:
+      reader.update()
+      executions = reader.executions()
+      self.assertLen(executions, 4)
+      for execution in executions:
+        self.assertStartsWith(execution.op_type, "__inference_ceil_times_two_")
+      executed_graph_ids = [execution.graph_id for execution in executions]
+      self.assertEqual(executed_graph_ids[0], executed_graph_ids[2])
+      self.assertEqual(executed_graph_ids[1], executed_graph_ids[3])
+      self.assertNotEqual(executed_graph_ids[0], executed_graph_ids[1])
+      self.assertNotEqual(executed_graph_ids[2], executed_graph_ids[3])
+      for executed_graph_id in executed_graph_ids:
+        self.assertEqual(
+            reader.graph_by_id(executed_graph_id).name, "ceil_times_two")
 
   @parameterized.named_parameters(
       ("AddV2", "AddV2"),
@@ -539,32 +602,35 @@ class TracingCallbackTest(
     writer.FlushNonExecutionFiles()
     writer.FlushExecutionFiles()
 
-    stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
-    (context_ids, op_types,
-     op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id)
-    self.assertIn("AddV2", op_types)
-    self.assertIn("Log", op_types)
-    self.assertIn("Sin", op_types)
+    with debug_events_reader.DebugDataReader(self.dump_root) as reader:
+      reader.update()
+      graph_op_digests = reader.graph_op_digests()
+      op_types = [digest.op_type for digest in graph_op_digests]
+      self.assertIn("AddV2", op_types)
+      self.assertIn("Log", op_types)
+      self.assertIn("Sin", op_types)
 
-    (op_names, _, _,
-     tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
-    executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
-
-    if op_regex == "AddV2":
-      self.assertEqual(executed_op_types, ["AddV2", "AddV2"])
-      self.assertLen(tensor_values, 2)
-      self.assertAllClose(tensor_values[0], 5.0)  # 1st AddV2 op.
-      self.assertAllClose(tensor_values[1], np.log(5.0) + 1.0)  # 2nd AddV2 op.
-    elif op_regex == "Log":
-      self.assertEqual(executed_op_types, ["Log"])
-      self.assertLen(tensor_values, 1)
-      self.assertAllClose(tensor_values[0], np.log(5.0))  # Log op.
-    else:  # "(AddV2|Log)"
-      self.assertEqual(executed_op_types, ["AddV2", "Log", "AddV2"])
-      self.assertLen(tensor_values, 3)
-      self.assertAllClose(tensor_values[0], 5.0)  # 1st AddV2 op.
-      self.assertAllClose(tensor_values[1], np.log(5.0))  # Log op.
-      self.assertAllClose(tensor_values[2], np.log(5.0) + 1.0)  # 2nd AddV2 op.
+      graph_exec_digests = reader.graph_execution_traces(digest=True)
+      executed_op_types = [digest.op_type for digest in graph_exec_digests]
+      tensor_values = [reader.graph_execution_trace_to_tensor_value(digest)
+                       for digest in graph_exec_digests]
+      if op_regex == "AddV2":
+        self.assertEqual(executed_op_types, ["AddV2", "AddV2"])
+        self.assertLen(tensor_values, 2)
+        self.assertAllClose(tensor_values[0], 5.0)  # 1st AddV2 op.
+        self.assertAllClose(
+            tensor_values[1], np.log(5.0) + 1.0)  # 2nd AddV2 op.
+      elif op_regex == "Log":
+        self.assertEqual(executed_op_types, ["Log"])
+        self.assertLen(tensor_values, 1)
+        self.assertAllClose(tensor_values[0], np.log(5.0))  # Log op.
+      else:  # "(AddV2|Log)"
+        self.assertEqual(executed_op_types, ["AddV2", "Log", "AddV2"])
+        self.assertLen(tensor_values, 3)
+        self.assertAllClose(tensor_values[0], 5.0)  # 1st AddV2 op.
+        self.assertAllClose(tensor_values[1], np.log(5.0))  # Log op.
+        self.assertAllClose(
+            tensor_values[2], np.log(5.0) + 1.0)  # 2nd AddV2 op.
 
   def testIncorrectTensorDTypeArgFormatLeadsToError(self):
     with self.assertRaisesRegexp(
@@ -617,48 +683,54 @@ class TracingCallbackTest(
 
     writer.FlushNonExecutionFiles()
     writer.FlushExecutionFiles()
-    stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
-    (context_ids, _,
-     op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id)
-    (op_names, _, _,
-     tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
-    executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
 
-    if tensor_dtypes == [dtypes.float32] and not op_regex:
-      self.assertEqual(executed_op_types, ["Unique", "Sum"])
-      self.assertLen(tensor_values, 2)
-      self.assertAllClose(tensor_values[0], [2., 6., 8., 1.])  # Unique values.
-      self.assertAllClose(tensor_values[1], 17.)  # Sum.
-    elif tensor_dtypes == ["float32"] and op_regex == "Sum":
-      self.assertEqual(executed_op_types, ["Sum"])
-      self.assertLen(tensor_values, 1)
-      self.assertAllClose(tensor_values[0], 17.)  # Sum.
-    elif tensor_dtypes == (dtypes.float32,) and op_regex == "(?!Sum)":
-      self.assertEqual(executed_op_types, ["Unique"])
-      self.assertLen(tensor_values, 1)
-      self.assertAllClose(tensor_values[0], [2., 6., 8., 1.])  # Unique values.
-    elif tensor_dtypes == [dtypes.int32] and not op_regex:
-      self.assertEqual(executed_op_types, ["Unique"])
-      self.assertLen(tensor_values, 1)
-      self.assertAllEqual(tensor_values[0], [0, 1, 2, 3, 0])  # Unique indices.
-    elif callable(tensor_dtypes) and not op_regex:
-      self.assertEqual(executed_op_types, ["Unique"])
-      self.assertLen(tensor_values, 1)
-      self.assertAllEqual(tensor_values[0], [0, 1, 2, 3, 0])  # Unique indices.
-    elif not tensor_dtypes and op_regex == "(?!Sum)":
-      self.assertEqual(executed_op_types, ["Unique", "Unique"])
-      self.assertLen(tensor_values, 2)
-      self.assertAllClose(tensor_values[0], [2., 6., 8., 1.])  # Unique values.
-      self.assertAllEqual(tensor_values[1], [0, 1, 2, 3, 0])  # Unique indices.
-    else:  # "All".
-      self.assertEqual(executed_op_types, ["Unique", "Unique", "Sum"])
-      self.assertLen(tensor_values, 3)
-      self.assertAllClose(tensor_values[0], [2., 6., 8., 1.])  # Unique values.
-      self.assertAllEqual(tensor_values[1], [0, 1, 2, 3, 0])  # Unique indices.
-      self.assertAllClose(tensor_values[2], 17.)  # Sum.
+    with debug_events_reader.DebugDataReader(self.dump_root) as reader:
+      reader.update()
+      graph_exec_digests = reader.graph_execution_traces(digest=True)
+      executed_op_types = [digest.op_type for digest in graph_exec_digests]
+      tensor_values = [reader.graph_execution_trace_to_tensor_value(digest)
+                       for digest in graph_exec_digests]
+
+      if tensor_dtypes == [dtypes.float32] and not op_regex:
+        self.assertEqual(executed_op_types, ["Unique", "Sum"])
+        self.assertLen(tensor_values, 2)
+        self.assertAllClose(tensor_values[0], [2, 6, 8, 1])  # Unique values.
+        self.assertAllClose(tensor_values[1], 17.)  # Sum.
+      elif tensor_dtypes == ["float32"] and op_regex == "Sum":
+        self.assertEqual(executed_op_types, ["Sum"])
+        self.assertLen(tensor_values, 1)
+        self.assertAllClose(tensor_values[0], 17.)  # Sum.
+      elif tensor_dtypes == (dtypes.float32,) and op_regex == "(?!Sum)":
+        self.assertEqual(executed_op_types, ["Unique"])
+        self.assertLen(tensor_values, 1)
+        self.assertAllClose(tensor_values[0], [2, 6, 8, 1])  # Unique values.
+      elif tensor_dtypes == [dtypes.int32] and not op_regex:
+        self.assertEqual(executed_op_types, ["Unique"])
+        self.assertLen(tensor_values, 1)
+        self.assertAllEqual(
+            tensor_values[0], [0, 1, 2, 3, 0])  # Unique indices.
+      elif callable(tensor_dtypes) and not op_regex:
+        self.assertEqual(executed_op_types, ["Unique"])
+        self.assertLen(tensor_values, 1)
+        self.assertAllEqual(
+            tensor_values[0], [0, 1, 2, 3, 0])  # Unique indices.
+      elif not tensor_dtypes and op_regex == "(?!Sum)":
+        self.assertEqual(executed_op_types, ["Unique", "Unique"])
+        self.assertLen(tensor_values, 2)
+        self.assertAllClose(tensor_values[0], [2, 6, 8, 1])  # Unique values.
+        self.assertAllEqual(
+            tensor_values[1], [0, 1, 2, 3, 0])  # Unique indices.
+      else:  # "All".
+        self.assertEqual(executed_op_types, ["Unique", "Unique", "Sum"])
+        self.assertLen(tensor_values, 3)
+        self.assertAllClose(tensor_values[0], [2, 6, 8, 1])  # Unique values.
+        self.assertAllEqual(
+            tensor_values[1], [0, 1, 2, 3, 0])  # Unique indices.
+        self.assertAllClose(tensor_values[2], 17)  # Sum.
 
   @parameterized.named_parameters(
       ("NoTensor", "NO_TENSOR"),
+      ("CurtHealth", "CURT_HEALTH"),
       ("FullTensor", "FULL_TENSOR"),
   )
   @test_util.run_in_graph_and_eager_modes
@@ -679,86 +751,78 @@ class TracingCallbackTest(
     self.assertAllClose(self.evaluate(iterative_doubling(x, times)), 8.0)
 
     writer.FlushNonExecutionFiles()
-    stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
+    with debug_events_reader.DebugDataReader(self.dump_root) as reader:
+      reader.update()
+      graph_op_digests = reader.graph_op_digests()
+      op_types = [digest.op_type for digest in graph_op_digests]
+      self.assertIn("Less", op_types)
+      self.assertIn("Mul", op_types)
+      self.assertIn("AddV2", op_types)
 
-    # Verify the content of the .graphs file.
-    context_ids, op_types, op_name_to_op_type, _ = (
-        self._readAndCheckGraphsFile(stack_frame_by_id))
-    self.assertIn("Less", op_types)
-    self.assertIn("Mul", op_types)
-    self.assertIn("AddV2", op_types)
-
-    # Before FlushExecutionFiles() is called, the .execution and
-    # .graph_execution_traces files should be both empty.
-    with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
-      execution_iter = reader.execution_iterator()
-      graph_execution_traces_iter = reader.graph_execution_traces_iterator()
-      with self.assertRaises(StopIteration):
-        next(execution_iter)
-      with self.assertRaises(StopIteration):
-        next(graph_execution_traces_iter)
+      # Before FlushExecutionFiles() is called, the .execution and
+      # .graph_execution_traces files should be both empty.
+      self.assertEqual(reader.num_executions(), 0)
+      self.assertEqual(reader.num_graph_execution_traces(), 0)
 
       # TODO(cais): Backport execution instrumentation to tf.Session.
       writer.FlushExecutionFiles()
       # After the flushing, the .execution file should hold the appropriate
       # contents.
+      reader.update()
       if context.executing_eagerly():
-        (executed_op_types, _, input_tensor_ids, output_tensor_ids,
-         tensor_debug_modes, tensor_values) = self._readAndCheckExecutionFile()
         # NOTE(b/142486213): Execution of the TF function happens with
         # Session.run() in v1 graph mode, hence it doesn't get logged to the
-        # .execution file.
-        self.assertLen(executed_op_types, 1)
-        self.assertIn("iterative_doubling", executed_op_types[0])
-        self.assertLen(input_tensor_ids[0], 2)
-        self.assertLen(output_tensor_ids[0], 1)
+        executions = reader.executions()
+        self.assertLen(executions, 1)
+        executed_op_types = [execution.op_type for execution in executions]
+        self.assertIn("iterative_doubling", executions[0].op_type)
+        execution = executions[0]
+        self.assertLen(execution.input_tensor_ids, 2)
+        self.assertLen(execution.output_tensor_ids, 1)
         self.assertEqual(
-            tensor_debug_modes[0],
-            debug_event_pb2.TensorDebugMode.Value(tensor_debug_mode))
+            debug_event_pb2.TensorDebugMode.keys()[execution.tensor_debug_mode],
+            tensor_debug_mode)
         if tensor_debug_mode == "FULL_TENSOR":
-          self.assertAllClose(tensor_values, [[8.0]])
+          tensor_values = reader.execution_to_tensor_values(execution)
+          self.assertAllClose(tensor_values, [8.0])
 
-      (op_names, _, output_slots,
-       tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
-      executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
-      # The Less op should have been executed 5 times.
-      self.assertEqual(executed_op_types.count("Less"), 5)
-      # The last executed op should be Less.
-      self.assertEqual(executed_op_types[-1], "Less")
+      graph_exec_traces = reader.graph_execution_traces()
+      executed_op_types = [trace.op_type for trace in graph_exec_traces]
+      if tensor_debug_mode != "CURT_HEALTH":
+        # Less outputs a boolean tensor, which is not tracked under CURT_HEALTH.
+        # The Less op should have been executed 5 times.
+        self.assertEqual(executed_op_types.count("Less"), 5)
+        # The last executed op should be Less.
+        self.assertEqual(executed_op_types[-1], "Less")
+        # AddV2 produces an int tensor, which is not tracked under CURT_HEALTH.
+        # The AddV2 op should have been run, but we refrain from asserting on
+        # how many times it's executed.
+        self.assertIn("AddV2", executed_op_types)
+        for trace in graph_exec_traces:
+          self.assertEqual(trace.output_slot, 0)
       # The Mul op should have been executed 4 times.
       self.assertEqual(executed_op_types.count("Mul"), 4)
-      # The AddV2 op should have been run, but we refrain from asserting on how
-      # many times it's executed.
-      self.assertIn("AddV2", executed_op_types)
-      for output_slot in output_slots:
-        self.assertEqual(output_slot, 0)
+
+      tensor_values = [reader.graph_execution_trace_to_tensor_value(trace)
+                       for trace in graph_exec_traces]
       if tensor_debug_mode == "NO_TENSOR":
         # Under the default NO_TENSOR tensor-debug mode, the tensor_proto ought
         # to be an empty float32 tensor.
         for tensor_value in tensor_values:
-          self.assertEqual(tensor_value.dtype, np.float32)
-          self.assertEqual(tensor_value.shape, (0,))
-      elif tensor_debug_mode == "CURT_TENSOR":
-        for tensor_value in tensor_values:
-          self.assertLen(tensor_value, 2)
-          # 1st element: tensor_id, should be >= 0.
-          # TODO(cais): Assert on detailed value once Function-graph association
-          # is in place.
-          self.assertGreaterEqual(tensor_value[0], 0)
-          # 2nd element: 0 means there is no inf or nan.
-          self.assertEqual(tensor_value[1], 0)
+          self.assertAllEqual(tensor_value, [])
+      elif tensor_debug_mode == "CURT_HEALTH":
+        for trace in graph_exec_traces:
+          tensor_id = reader.graph_execution_trace_to_tensor_id(trace)
+          # 1st element: tensor_id; 2nd element: 0 indicating no inf or nan.
+          self.assertAllClose(trace.debug_tensor_value, [tensor_id, 0.0])
       elif tensor_debug_mode == "FULL_TENSOR":
         less_values = [
-            tensor_values[i]
-            for i, op_type in enumerate(executed_op_types)
-            if op_type == "Less"
-        ]
-        self.assertAllClose(less_values, [True, True, True, True, False])
+            reader.graph_execution_trace_to_tensor_value(trace)
+            for trace in graph_exec_traces if trace.op_type == "Less"]
+        self.assertAllEqual(less_values, [True, True, True, True, False])
         mul_values = [
-            tensor_values[i]
-            for i, op_type in enumerate(executed_op_types)
-            if op_type == "Mul"
-        ]
+            reader.graph_execution_trace_to_tensor_value(trace)
+            for trace in graph_exec_traces if trace.op_type == "Mul"]
         self.assertAllClose(mul_values, [1.0, 2.0, 4.0, 8.0])
 
   def testCallingEnableTracingTwiceWithTheSameDumpRootIsIdempotent(self):
@@ -772,17 +836,16 @@ class TracingCallbackTest(
     writer.FlushNonExecutionFiles()
     writer.FlushExecutionFiles()
 
-    with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
-      execution_iter = reader.execution_iterator()
-      for _ in range(2):
-        debug_event = next(execution_iter)
-        self.assertGreater(debug_event.wall_time, 0)
-        execution = debug_event.execution
+    with debug_events_reader.DebugDataReader(self.dump_root) as reader:
+      reader.update()
+      executions = reader.executions()
+      self.assertLen(executions, 2)
+      for execution in executions:
+        self.assertGreater(execution.wall_time, 0)
         self.assertEqual(execution.op_type, "Unique")
         self.assertEqual(execution.num_outputs, 2)
-        self.assertTrue(execution.code_location)
-      with self.assertRaises(StopIteration):
-        next(execution_iter)
+        _, stack_frames = reader.read_execution_stack_trace(execution)
+        self._verifyStackFrames(stack_frames)
 
   def testCallingEnableTracingTwiceWithDifferentDumpRootsOverwrites(self):
     dumping_callback.enable_dump_debug_info(self.dump_root)
@@ -796,27 +859,26 @@ class TracingCallbackTest(
     writer.FlushNonExecutionFiles()
     writer.FlushExecutionFiles()
 
-    with debug_events_reader.DebugEventsReader(new_dump_root) as reader:
-      execution_iter = reader.execution_iterator()
-      for _ in range(2):
-        debug_event = next(execution_iter)
-        self.assertGreater(debug_event.wall_time, 0)
-        execution = debug_event.execution
+    with debug_events_reader.DebugDataReader(new_dump_root) as reader:
+      reader.update()
+      executions = reader.executions()
+      self.assertLen(executions, 2)
+      for execution in executions:
+        self.assertGreater(execution.wall_time, 0)
         self.assertEqual(execution.op_type, "Unique")
         self.assertEqual(execution.num_outputs, 2)
-        self.assertTrue(execution.code_location)
-      with self.assertRaises(StopIteration):
-        next(execution_iter)
+        _, stack_frames = reader.read_execution_stack_trace(execution)
+        self._verifyStackFrames(stack_frames)
 
-      with debug_events_reader.DebugEventsReader(
-          self.dump_root) as old_dump_root_reader:
-        execution_iter = old_dump_root_reader.execution_iterator()
-        # The old dump root shouldn't have been written to.
-        with self.assertRaises(StopIteration):
-          next(execution_iter)
+    with debug_events_reader.DebugDataReader(
+        self.dump_root) as old_dump_root_reader:
+      old_dump_root_reader.update()
+      # The old dump root shouldn't have been written to.
+      self.assertEqual(old_dump_root_reader.num_executions(), 0)
+      self.assertFalse(old_dump_root_reader.outermost_graphs())
 
   def testCallingEnableRepeatedlyWithDifferentTensorDebugMode(self):
-    """Assert that calling enable_dump_debug_info() with different tensor-debug modes.
+    """Assert calling enable_dump_debug_info() with two tensor-debug modes.
 
     It should lead to overwriting of the previously-configured mode.
     """
@@ -830,16 +892,16 @@ class TracingCallbackTest(
     self.assertAllClose(add_1_divide_by_2(constant_op.constant(4.0)), 2.5)
     writer.FlushNonExecutionFiles()
     writer.FlushExecutionFiles()
-    stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
-    context_ids, _, _, _ = self._readAndCheckGraphsFile(stack_frame_by_id)
-    _, _, _, _, _, tensor_values = self._readAndCheckExecutionFile()
-    self.assertEqual(tensor_values, [[]])
-    (_, _, _,
-     tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
-    self.assertLen(tensor_values, 2)
-    for tensor_value in tensor_values:
-      self.assertEqual(tensor_value.dtype, np.float32)
-      self.assertEqual(tensor_value.shape, (0,))
+
+    with debug_events_reader.DebugDataReader(self.dump_root) as reader:
+      reader.update()
+      graph_exec_digests = reader.graph_execution_traces(digest=True)
+      tensor_values = [reader.graph_execution_trace_to_tensor_value(digest)
+                       for digest in graph_exec_digests]
+      for tensor_value in tensor_values:
+        # Under NO_TENSOR mode, each tensor is summarized as an empty float32
+        # array.
+        self.assertAllEqual(tensor_value, [])
 
     with self.assertRaisesRegexp(
         ValueError, r"already.*NO_TENSOR.*FULL_TENSOR.*not be honored"):
@@ -862,17 +924,11 @@ class TracingCallbackTest(
     writer.FlushNonExecutionFiles()
     writer.FlushExecutionFiles()
 
-    with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
-      source_files_iter = reader.source_files_iterator()
-      stack_frames_iter = reader.stack_frames_iterator()
-      execution_iter = reader.execution_iterator()
-      # No source-file, stack-frame or execution data should have been dumped.
-      with self.assertRaises(StopIteration):
-        next(source_files_iter)
-      with self.assertRaises(StopIteration):
-        next(stack_frames_iter)
-      with self.assertRaises(StopIteration):
-        next(execution_iter)
+    with debug_events_reader.DebugDataReader(self.dump_root) as reader:
+      reader.update()
+      self.assertEqual(reader.num_executions(), 0)
+      self.assertEqual(reader.num_graph_execution_traces(), 0)
+      self.assertFalse(reader.outermost_graphs())
 
   @parameterized.named_parameters(
       ("NoTensor", "NO_TENSOR"),
@@ -908,73 +964,54 @@ class TracingCallbackTest(
     writer.FlushNonExecutionFiles()
     writer.FlushExecutionFiles()
 
-    stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
-    with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
-      execution_iter = reader.execution_iterator()
+    with debug_events_reader.DebugDataReader(self.dump_root) as reader:
+      reader.update()
+      exec_digests = reader.executions(digest=True)
       prev_wall_time = 1
-      for debug_event in execution_iter:
-        self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
-        prev_wall_time = debug_event.wall_time
+      for exec_digest in exec_digests:
+        self.assertGreaterEqual(exec_digest.wall_time, prev_wall_time)
+        prev_wall_time = exec_digest.wall_time
 
-    (context_ids, _,
-     op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id)
+      graph_exec_traces = reader.graph_execution_traces()
+      executed_op_types = [trace.op_type for trace in graph_exec_traces]
+      self.assertEqual(executed_op_types.count("Mul"), 1 + num_threads)
+      self.assertEqual(
+          executed_op_types.count("ReadVariableOp"), 2 * (1 + num_threads))
+      for trace in graph_exec_traces:
+        # These are all single-output tensors.
+        self.assertEqual(trace.output_slot, 0)
 
-    (op_names, _, output_slots,
-     tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
-    executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
-    self.assertEqual(executed_op_types.count("Mul"), 1 + num_threads)
-    self.assertEqual(
-        executed_op_types.count("ReadVariableOp"), 2 * (1 + num_threads))
-    for output_slot in output_slots:
-      self.assertEqual(output_slot, 0)
+    tensor_values = [reader.graph_execution_trace_to_tensor_value(trace)
+                     for trace in graph_exec_traces]
     if tensor_debug_mode == "NO_TENSOR":
       for tensor_value in tensor_values:
-        self.assertEqual(tensor_value.dtype, np.float32)
-        self.assertEqual(tensor_value.shape, (0,))
+        self.assertAllEqual(tensor_value, [])
     elif tensor_debug_mode == "CURT_HEALTH":
-      for tensor_value in tensor_values:
-        self.assertLen(tensor_value, 2)
-        # 1st element: tensor_id, should be >= 0.
-        # TODO(cais): Assert on detailed value once Function-graph association
-        # is in place.
-        self.assertGreaterEqual(tensor_value[0], 0)
-        # 2nd element: 0 means there is no inf or nan.
-        self.assertEqual(tensor_value[1], 0)
+      for trace in graph_exec_traces:
+        tensor_id = reader.graph_execution_trace_to_tensor_id(trace)
+        # 1st element: tensor ID; 2nd element: 0 indicating no inf or nan.
+        self.assertAllClose(trace.debug_tensor_value, [tensor_id, 0])
     elif tensor_debug_mode == "CONCISE_HEALTH":
       for tensor_value in tensor_values:
-        self.assertLen(tensor_value, 5)
-        # 1st element: tensor_id, should be >= 0.
-        # TODO(cais): Assert on detailed value once Function-graph association
-        # is in place.
-        self.assertGreaterEqual(tensor_value[0], 0)
+        tensor_id = reader.graph_execution_trace_to_tensor_id(trace)
+        # 1st element: tensor ID.
         # 2nd element: element count. Remaining elements: all zero because there
         # is no -inf, inf or nan.
-        self.assertAllClose(tensor_value[1:], [1, 0, 0, 0])
+        self.assertAllClose(trace.debug_tensor_value, [tensor_id, 1, 0, 0, 0])
     elif tensor_debug_mode == "SHAPE":
-      mul_values = [
-          tensor_values[i]
-          for i, op_type in enumerate(executed_op_types)
-          if op_type == "Mul"
-      ]
-      for mul_value in mul_values:
-        # 1st element: tensor_id, should be >= 0.
-        # TODO(cais): Assert on detailed value once Function-graph association
-        # is in place.
-        self.assertGreaterEqual(mul_value[0], 0)
-        # 2nd element: dtype enum value (float32).
-        self.assertEqual(mul_value[1], 1)
-        # 3rd element: rank.
-        self.assertEqual(mul_value[2], 0)
-        # 3rd element: element count.
-        self.assertEqual(mul_value[3], 1)
-        # Remaining elements: shape padded to a fixed length.
-        self.assertAllClose(mul_value[4:], [0, 0, 0, 0, 0, 0])
+      for trace in graph_exec_traces:
+        if trace.op_type == "Mul":
+          tensor_id = reader.graph_execution_trace_to_tensor_id(trace)
+          mul_value = reader.graph_execution_trace_to_tensor_value(trace)
+          # 1st element: tensor_id, should be >= 0.
+          # 2nd element: dtype enum value (float32).
+          # 3rd element: rank.
+          # 4th element: element count.
+          self.assertAllClose(mul_value, [tensor_id, 1, 0, 1, 0, 0, 0, 0, 0, 0])
     elif tensor_debug_mode == "FULL_TENSOR":
       mul_values = [
-          tensor_values[i]
-          for i, op_type in enumerate(executed_op_types)
-          if op_type == "Mul"
-      ]
+          reader.graph_execution_trace_to_tensor_value(trace)
+          for trace in graph_exec_traces if trace.op_type == "Mul"]
       self.assertAllClose(mul_values, [6.0, 6.0, 6.0, 6.0])
 
   def testMultiThreadedDumpingWithDifferentSettings(self):
@@ -1017,23 +1054,28 @@ class TracingCallbackTest(
     self.assertAllClose(v1.read_value(), -67084290.0)
     self.assertAllClose(v2.read_value(), -6.0)
 
-    (executed_op_types, _, _, _, _,
-     tensor_values) = self._readAndCheckExecutionFile(dump_root=dump_root_1)
-    v1_squared_values = [
-        tensor_values[i] for i, op_type in enumerate(executed_op_types)
-        if op_type == "Pow"]
-    negative_v1_squared_values = [
-        tensor_values[i] for i, op_type in enumerate(executed_op_types)
-        if op_type == "Neg"]
-    self.assertAllClose(v1_squared_values, [[100.0], [8100.0], [67076100.0]])
-    self.assertAllClose(
-        negative_v1_squared_values, [[-100.0], [-8100.0], [-67076100.0]])
+    with debug_events_reader.DebugDataReader(dump_root_1) as reader:
+      reader.update()
+      exec_digests = reader.executions(digest=True)
+      v1_squared_values = [
+          reader.execution_to_tensor_values(digest)
+          for digest in exec_digests if digest.op_type == "Pow"]
+      negative_v1_squared_values = [
+          reader.execution_to_tensor_values(digest)
+          for digest in exec_digests if digest.op_type == "Neg"]
+      self.assertAllClose(v1_squared_values, [[100.0], [8100.0], [67076100.0]])
+      self.assertAllClose(
+          negative_v1_squared_values, [[-100.0], [-8100.0], [-67076100.0]])
 
-    (executed_op_types, _, _, _, _,
-     tensor_values) = self._readAndCheckExecutionFile(dump_root=dump_root_2)
-    self.assertNotIn("Neg", executed_op_types)
-    v2_squared_values = tensor_values[executed_op_types.index("Pow")]
-    self.assertAllClose(v2_squared_values, [9.0])
+    with debug_events_reader.DebugDataReader(dump_root_2) as reader:
+      reader.update()
+      exec_digests = reader.executions(digest=True)
+      executed_op_types = [digest.op_type for digest in exec_digests]
+      self.assertNotIn("Neg", executed_op_types)
+      v2_squared_values = [
+          reader.execution_to_tensor_values(digest)
+          for digest in exec_digests if digest.op_type == "Pow"]
+      self.assertAllClose(v2_squared_values, [[9.0]])
 
   @test_util.run_in_graph_and_eager_modes
   def testNestedContextIsCapturedByGraphOpCreationHistory(self):
@@ -1055,36 +1097,18 @@ class TracingCallbackTest(
 
     writer.FlushNonExecutionFiles()
     writer.FlushExecutionFiles()
-
-    stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
-    (_, _, op_name_to_op_type,
-     op_name_to_context_id) = self._readAndCheckGraphsFile(stack_frame_by_id)
-
-    less_op_names = [op_name for op_name in op_name_to_op_type
-                     if op_name_to_op_type[op_name] == "Less"]
-    less_context_ids = [op_name_to_context_id[op_name]
-                        for op_name in less_op_names]
-    mul_op_names = [op_name for op_name in op_name_to_op_type
-                    if op_name_to_op_type[op_name] == "Mul"]
-    mul_context_ids = [op_name_to_context_id[op_name]
-                       for op_name in mul_op_names]
-    sub_op_names = [op_name for op_name in op_name_to_op_type
-                    if op_name_to_op_type[op_name] == "Sub"]
-    sub_context_ids = [op_name_to_context_id[op_name]
-                       for op_name in sub_op_names]
-    self.assertLen(less_context_ids, 1)
-    self.assertLen(mul_context_ids, 1)
-    self.assertLen(sub_context_ids, 1)
-    self.assertTrue(less_context_ids[0])
-    self.assertTrue(mul_context_ids[0])
-    self.assertTrue(sub_context_ids[0])
-    # The Less op is from the while-loop cond context and hence should have
-    # a different innermost context ID from the mul and sub ops, which are both
-    # from the while-loop body context.
-    self.assertNotEqual(less_context_ids[0], mul_context_ids[0])
-    self.assertNotEqual(less_context_ids[0], sub_context_ids[0])
-    # The Mul and Sub ops are from the same innermost context.
-    self.assertEqual(mul_context_ids[0], sub_context_ids[0])
+    with debug_events_reader.DebugDataReader(self.dump_root) as reader:
+      reader.update()
+      less_op_digest = reader.graph_op_digests(op_type="Less")[-1]
+      mul_op_digest = reader.graph_op_digests(op_type="Mul")[-1]
+      sub_op_digest = reader.graph_op_digests(op_type="Sub")[-1]
+      # The Less op is from the while-loop cond context and hence should have
+      # a different innermost context ID from the mul and sub ops, which are
+      # both from the while-loop body context.
+      self.assertNotEqual(less_op_digest.graph_id, mul_op_digest.graph_id)
+      self.assertNotEqual(less_op_digest.graph_id, sub_op_digest.graph_id)
+      # The Mul and Sub ops are from the same innermost context.
+      self.assertEqual(mul_op_digest.graph_id, sub_op_digest.graph_id)
 
   @parameterized.named_parameters(
       ("NoTensor", "NO_TENSOR"),
@@ -1102,53 +1126,38 @@ class TracingCallbackTest(
     writer.FlushNonExecutionFiles()
     writer.FlushExecutionFiles()
 
-    stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
-    (context_ids, op_types,
-     op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id)
-    # Simply assert that graph are recorded and refrain from asserting on the
-    # internal details of the Keras model.
-    self.assertTrue(context_ids)
-    self.assertTrue(op_types)
-    self.assertTrue(op_name_to_op_type)
+    with debug_events_reader.DebugDataReader(self.dump_root) as reader:
+      reader.update()
+      if context.executing_eagerly():
+        # NOTE(b/142486213): Execution of the TF function happens with
+        # Session.run() in v1 graph mode, hence it doesn't get logged to the
+        # .execution file.
+        self.assertTrue(reader.executions(digest=True))
 
-    if context.executing_eagerly():
-      # NOTE(b/142486213): Execution of the TF function happens with
-      # Session.run() in v1 graph mode, hence it doesn't get logged to the
-      # .execution file.
-      (executed_op_types, _, _, _, _,
-       tensor_values) = self._readAndCheckExecutionFile()
-      self.assertTrue(executed_op_types)
+      graph_exec_digests = reader.graph_execution_traces(digest=True)
+      executed_op_types = [digest.op_type for digest in graph_exec_digests]
+      # These are the ops that we can safely assume to have been executed during
+      # the model prediction.
+      self.assertIn("MatMul", executed_op_types)
+      self.assertIn("BiasAdd", executed_op_types)
+      # On the GPU, CudnnRNN is used in lieu of the default op-by-op
+      # implementation.
+      self.assertTrue(
+          ("Sigmoid" in executed_op_types and "Tanh" in executed_op_types or
+           "CudnnRNN" in executed_op_types))
 
-      for value_list in tensor_values:
-        if tensor_debug_mode == "NO_TENSOR":
-          self.assertFalse(value_list)
-
-    (op_names, _, _,
-     tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
-    executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
-    # These are the ops that we can safely assume to have been executed during
-    # the model prediction.
-    self.assertIn("MatMul", executed_op_types)
-    self.assertIn("BiasAdd", executed_op_types)
-    # On the GPU, CudnnRNN is used in lieu of the default op-by-op
-    # implementation.
-    self.assertTrue(
-        ("Sigmoid" in executed_op_types and "Tanh" in executed_op_types or
-         "CudnnRNN" in executed_op_types))
-    # Under the default NO_TENSOR tensor-debug mode, the tensor_proto ought to
-    # be an empty float32 tensor.
-    if tensor_debug_mode == "NO_TENSOR":
-      for tensor_value in tensor_values:
-        self.assertEqual(tensor_value.dtype, np.float32)
-        self.assertEqual(tensor_value.shape, (0,))
-    else:
-      # Refrain from asserting the internal implementation details of the LSTM
-      # layer.
-      concrete_tensor_values = [
-          value for value in tensor_values
-          if value is not None and value.size > 0
-      ]
-      self.assertTrue(concrete_tensor_values)
+      # Under the default NO_TENSOR tensor-debug mode, the tensor_proto ought to
+      # be an empty float32 tensor.
+      tensor_values = [reader.graph_execution_trace_to_tensor_value(digest)
+                       for digest in graph_exec_digests]
+      if tensor_debug_mode == "NO_TENSOR":
+        for tensor_value in tensor_values:
+          self.assertAllEqual(tensor_value, [])
+      else:
+        # Refrain from asserting the internal implementation details of the LSTM
+        # layer.
+        self.assertTrue(any(
+            bool(tensor_value.size) for tensor_value in tensor_values))
 
   @parameterized.named_parameters(
       ("NoTensor", "NO_TENSOR"),
@@ -1169,48 +1178,38 @@ class TracingCallbackTest(
     writer.FlushNonExecutionFiles()
     writer.FlushExecutionFiles()
 
-    stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
-    (context_ids, op_types,
-     op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id)
-    # Simply assert that graph are recorded and refrain from asserting on the
-    # internal details of the Keras model.
-    self.assertTrue(context_ids)
-    self.assertTrue(op_types)
-    self.assertTrue(op_name_to_op_type)
+    with debug_events_reader.DebugDataReader(self.dump_root) as reader:
+      reader.update()
+      if context.executing_eagerly():
+        exec_digests = reader.executions(digest=True)
+        self.assertTrue(exec_digests)
+        if tensor_debug_mode == "NO_TENSOR":
+          for digest in exec_digests:
+            tensor_values = reader.execution_to_tensor_values(digest)
+            for tensor_value in tensor_values:
+              self.assertEqual(tensor_value, [])
 
-    if context.executing_eagerly():
-      # NOTE(b/142486213): Execution of the TF function happens with
-      # Session.run() in v1 graph mode, hence it doesn't get logged to the
-      # .execution file.
-      (executed_op_types, _, _, _, _,
-       tensor_values) = self._readAndCheckExecutionFile()
-      self.assertTrue(executed_op_types)
+      graph_exec_digests = reader.graph_execution_traces(digest=True)
+      executed_op_types = [digest.op_type for digest in graph_exec_digests]
+      # These are the ops that we can safely assume to have been executed during
+      # the recurrent model's fit() call.
+      self.assertIn("MatMul", executed_op_types)
+      self.assertIn("BiasAdd", executed_op_types)
+
+      # On the GPU, CudnnRNN is used in lieu of the default op-by-op
+      # implementation.
+      self.assertTrue(
+          ("Sigmoid" in executed_op_types and "Tanh" in executed_op_types or
+           "CudnnRNN" in executed_op_types))
+      self.assertTrue(
+          ("SigmoidGrad" in executed_op_types and
+           "TanhGrad" in executed_op_types or
+           "CudnnRNNBackprop" in executed_op_types))
       if tensor_debug_mode == "NO_TENSOR":
-        for value_list in tensor_values:
-          self.assertFalse(value_list)
-
-    (op_names, _, _,
-     tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
-    executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
-    # These are the ops that we can safely assume to have been executed during
-    # the recurrent model's fit() call.
-    self.assertIn("MatMul", executed_op_types)
-    self.assertIn("BiasAdd", executed_op_types)
-    # On the GPU, CudnnRNN is used in lieu of the default op-by-op
-    # implementation.
-    self.assertTrue(
-        ("Sigmoid" in executed_op_types and "Tanh" in executed_op_types or
-         "CudnnRNN" in executed_op_types))
-    self.assertTrue(
-        ("SigmoidGrad" in executed_op_types and
-         "TanhGrad" in executed_op_types or
-         "CudnnRNNBackprop" in executed_op_types))
-    if tensor_debug_mode == "NO_TENSOR":
-      # Under the default NO_TENSOR tensor-debug mode, the tensor_proto ought
-      # to be an empty float32 tensor.
-      for tensor_value in tensor_values:
-        self.assertEqual(tensor_value.dtype, np.float32)
-        self.assertEqual(tensor_value.shape, (0,))
+        for digest in graph_exec_digests:
+          tensor_values = reader.graph_execution_trace_to_tensor_value(digest)
+          for tensor_value in tensor_values:
+            self.assertEqual(tensor_value, [])
 
   @parameterized.named_parameters(
       ("NoTensor", "NO_TENSOR"),
@@ -1242,72 +1241,60 @@ class TracingCallbackTest(
     writer.FlushNonExecutionFiles()
     writer.FlushExecutionFiles()
 
-    stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
-    (context_ids, op_types,
-     op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id)
-    # Simply assert that graph are recorded and refrain from asserting on the
-    # internal details of the Keras model.
-    self.assertTrue(context_ids)
-    self.assertTrue(op_types)
-    self.assertTrue(op_name_to_op_type)
+    with debug_events_reader.DebugDataReader(self.dump_root) as reader:
+      reader.update()
+      if context.executing_eagerly():
+        # NOTE(b/142486213): Execution of the TF function happens with
+        # Session.run() in v1 graph mode, hence it doesn't get logged to the
+        # .execution file.
+        exec_digests = reader.executions(digest=True)
+        self.assertTrue(exec_digests)
 
-    if context.executing_eagerly():
-      # NOTE(b/142486213): Execution of the TF function happens with
-      # Session.run() in v1 graph mode, hence it doesn't get logged to the
-      # .execution file.
-      executed_op_types, _, _, _, _, _ = self._readAndCheckExecutionFile()
-      self.assertTrue(executed_op_types)
+      graph_exec_digests = reader.graph_execution_traces()
+      executed_op_types = [digest.op_type for digest in graph_exec_digests]
+      # These are the ops that we can safely assume to have been executed during
+      # the model's fit() call.
+      self.assertIn("Conv2D", executed_op_types)
+      self.assertIn("Relu6", executed_op_types)
+      self.assertIn("Conv2DBackpropFilter", executed_op_types)
+      self.assertIn("Relu6Grad", executed_op_types)
 
-    (op_names, _, _,
-     tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
-    executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
-    # These are the ops that we can safely assume to have been executed during
-    # the model's fit() call.
-    self.assertIn("Conv2D", executed_op_types)
-    self.assertIn("Relu6", executed_op_types)
-    self.assertIn("Conv2DBackpropFilter", executed_op_types)
-    self.assertIn("Relu6Grad", executed_op_types)
-    if tensor_debug_mode == "NO_TENSOR":
-      # Under the default NO_TENSOR tensor-debug mode, the tensor_proto ought to
-      # be an empty float32 tensor.
-      for tensor_value in tensor_values:
-        self.assertEqual(tensor_value.dtype, np.float32)
-        self.assertEqual(tensor_value.shape, (0,))
-    elif tensor_debug_mode == "FULL_TENSOR":
-      conv2d_values = [
-          tensor_values[i]
-          for i, op_type in enumerate(executed_op_types)
-          if op_type == "Conv2D"
-      ]
-      self.assertTrue(conv2d_values)
-      for conv2d_value in conv2d_values:
-        self.assertGreater(len(conv2d_value.shape), 1)
-        self.assertEqual(conv2d_value.shape[0], batch_size)
-      relu6_values = [
-          tensor_values[i]
-          for i, op_type in enumerate(executed_op_types)
-          if op_type == "Relu6"
-      ]
-      self.assertTrue(relu6_values)
-      for relu6_value in relu6_values:
-        self.assertGreater(len(relu6_value.shape), 1)
-        self.assertEqual(relu6_value.shape[0], batch_size)
-      conv2d_bp_filter_values = [
-          tensor_values[i]
-          for i, op_type in enumerate(executed_op_types)
-          if op_type == "Conv2DBackpropFilter"
-      ]
-      self.assertTrue(conv2d_bp_filter_values)
-      for conv2d_bp_filter_value in conv2d_bp_filter_values:
-        self.assertGreater(len(conv2d_bp_filter_value.shape), 1)
-      relu6_grad_values = [
-          tensor_values[i]
-          for i, op_type in enumerate(executed_op_types)
-          if op_type == "Relu6Grad"
-      ]
-      self.assertTrue(relu6_grad_values)
-      for relu6_grad_value in relu6_grad_values:
-        self.assertGreater(len(relu6_grad_value.shape), 1)
+      if tensor_debug_mode == "NO_TENSOR":
+        # Under the default NO_TENSOR tensor-debug mode, the tensor_proto ought
+        # to be an empty float32 tensor.
+        tensor_values = [
+            reader.graph_execution_trace_to_tensor_value(digest)
+            for digest in graph_exec_digests]
+        for tensor_value in tensor_values:
+          self.assertAllEqual(tensor_value, [])
+      elif tensor_debug_mode == "FULL_TENSOR":
+        conv2d_values = [
+            reader.graph_execution_trace_to_tensor_value(digest)
+            for digest in graph_exec_digests if digest.op_type == "Conv2D"]
+        self.assertTrue(conv2d_values)
+        for conv2d_value in conv2d_values:
+          self.assertGreater(len(conv2d_value.shape), 1)
+          self.assertEqual(conv2d_value.shape[0], batch_size)
+        relu6_values = [
+            reader.graph_execution_trace_to_tensor_value(digest)
+            for digest in graph_exec_digests if digest.op_type == "Relu6"]
+        self.assertTrue(relu6_values)
+        for relu6_value in relu6_values:
+          self.assertGreater(len(relu6_value.shape), 1)
+          self.assertEqual(relu6_value.shape[0], batch_size)
+        conv2d_bp_filter_values = [
+            reader.graph_execution_trace_to_tensor_value(digest)
+            for digest in graph_exec_digests
+            if digest.op_type == "Conv2DBackpropFilter"]
+        self.assertTrue(conv2d_bp_filter_values)
+        for conv2d_bp_filter_value in conv2d_bp_filter_values:
+          self.assertGreater(len(conv2d_bp_filter_value.shape), 1)
+        relu6_grad_values = [
+            reader.graph_execution_trace_to_tensor_value(digest)
+            for digest in graph_exec_digests if digest.op_type == "Relu6Grad"]
+        self.assertTrue(relu6_grad_values)
+        for relu6_grad_value in relu6_grad_values:
+          self.assertGreater(len(relu6_grad_value.shape), 1)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/debug/lib/dumping_callback_test_lib.py b/tensorflow/python/debug/lib/dumping_callback_test_lib.py
index 6144f2ba9cc..1d449f61e0b 100644
--- a/tensorflow/python/debug/lib/dumping_callback_test_lib.py
+++ b/tensorflow/python/debug/lib/dumping_callback_test_lib.py
@@ -52,7 +52,7 @@ class DumpingCallbackTestBase(test_util.TensorFlowTestCase):
     """Read and check the .metadata debug-events file."""
     with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
       metadata_iter = reader.metadata_iterator()
-      metadata = next(metadata_iter).debug_metadata
+      metadata = next(metadata_iter).debug_event.debug_metadata
       self.assertEqual(metadata.tensorflow_version, versions.__version__)
       self.assertTrue(metadata.file_version.startswith("debug.Event"))
 
@@ -67,7 +67,7 @@ class DumpingCallbackTestBase(test_util.TensorFlowTestCase):
       source_files_iter = reader.source_files_iterator()
       source_file_paths = []
       prev_wall_time = 1
-      for debug_event in source_files_iter:
+      for debug_event, _ in source_files_iter:
         self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
         prev_wall_time = debug_event.wall_time
         source_file = debug_event.source_file
@@ -84,7 +84,7 @@ class DumpingCallbackTestBase(test_util.TensorFlowTestCase):
       stack_frame_by_id = collections.OrderedDict()
       stack_frames_iter = reader.stack_frames_iterator()
       prev_wall_time = 0
-      for debug_event in stack_frames_iter:
+      for debug_event, _ in stack_frames_iter:
         self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
         prev_wall_time = debug_event.wall_time
         stack_frame_with_id = debug_event.stack_frame_with_id
@@ -133,7 +133,7 @@ class DumpingCallbackTestBase(test_util.TensorFlowTestCase):
       # outermost contexts).
       context_id_to_outer_id = dict()
 
-      for debug_event in graphs_iter:
+      for debug_event, _ in graphs_iter:
         self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
         prev_wall_time = debug_event.wall_time
         # A DebugEvent in the .graphs file contains either of the two fields:
@@ -219,7 +219,7 @@ class DumpingCallbackTestBase(test_util.TensorFlowTestCase):
       output_tensor_ids = []
       tensor_debug_modes = []
       tensor_values = []
-      for debug_event in execution_iter:
+      for debug_event, _ in execution_iter:
         self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
         prev_wall_time = debug_event.wall_time
         execution = debug_event.execution
@@ -260,7 +260,7 @@ class DumpingCallbackTestBase(test_util.TensorFlowTestCase):
       device_names = []
       output_slots = []
       tensor_values = []
-      for debug_event in graph_execution_traces_iter:
+      for debug_event, _ in graph_execution_traces_iter:
         self.assertGreaterEqual(debug_event.wall_time, 0)
         graph_execution_trace = debug_event.graph_execution_trace
         op_names.append(graph_execution_trace.op_name)