From e724d9e162b92c8e039dfbaf976762a725681584 Mon Sep 17 00:00:00 2001 From: Shanqing Cai <cais@google.com> Date: Mon, 23 Dec 2019 16:04:06 -0800 Subject: [PATCH] [tfdbg] Implement random file read & DebugDataReader; Simplify & improve tests. - Support offset-based random read access in DebugEventsReader - Support yielding offsets from the iterators of DebugEventsReader to enable subsequent random-access reading - Check the tensor ID in the debug tensor values under the CURT_HEALTH, CONCISE_HEALTH and SHAPE modes: Tackling multiple TODO items. - Use new DebugDataReader in tests to simplify code. Per design for scalable reading of large tfdbg v2 datasets: - Create light-weight digest classes: ExecutionDigest and GraphExecutionTraceDigest - Loaded by DebugDataReader.executions() and .graph_execution_traces() with kwarg digest=True. - Corresponding detailed data classes: Execution and GraphExecutionTrace. - Other data classes: - DebuggedGraph - GraphOpCreationDigest PiperOrigin-RevId: 286955104 Change-Id: I750fc085fd75a7df11637413389b68dd0a6733c6 --- tensorflow/core/protobuf/debug_event.proto | 1 + tensorflow/python/debug/BUILD | 1 + .../python/debug/lib/debug_events_reader.py | 883 +++++++++++- .../debug/lib/debug_events_writer_test.py | 52 +- .../python/debug/lib/debug_v2_ops_test.py | 10 +- .../python/debug/lib/dumping_callback_test.py | 1225 ++++++++--------- .../debug/lib/dumping_callback_test_lib.py | 12 +- 7 files changed, 1522 insertions(+), 662 deletions(-) diff --git a/tensorflow/core/protobuf/debug_event.proto b/tensorflow/core/protobuf/debug_event.proto index 8f9680f38d9..ebbb93ee049 100644 --- a/tensorflow/core/protobuf/debug_event.proto +++ b/tensorflow/core/protobuf/debug_event.proto @@ -162,6 +162,7 @@ message GraphOpCreation { string graph_name = 3; // Unique ID of the graph (generated by debugger). + // This is the ID of the immediately-enclosing graph. string graph_id = 4; // Name of the device that the op is assigned to (if available). diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD index 4ee84e512bc..58bebfa6bd9 100644 --- a/tensorflow/python/debug/BUILD +++ b/tensorflow/python/debug/BUILD @@ -120,6 +120,7 @@ py_library( deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:framework", + "@six_archive//:six", ], ) diff --git a/tensorflow/python/debug/lib/debug_events_reader.py b/tensorflow/python/debug/lib/debug_events_reader.py index c6142c6e309..1594b5f27f8 100644 --- a/tensorflow/python/debug/lib/debug_events_reader.py +++ b/tensorflow/python/debug/lib/debug_events_reader.py @@ -18,17 +18,24 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import glob import os import threading -from six.moves import map +import six from tensorflow.core.protobuf import debug_event_pb2 -from tensorflow.python.lib.io import tf_record +from tensorflow.python import pywrap_tensorflow +from tensorflow.python.framework import errors +from tensorflow.python.framework import tensor_util from tensorflow.python.util import compat +DebugEventWithOffset = collections.namedtuple( + "DebugEventWithOffset", "debug_event offset") + + class DebugEventsReader(object): """Reader class for a tfdbg v2 DebugEvents directory.""" @@ -56,6 +63,8 @@ class DebugEventsReader(object): self._readers = dict() # A map from file path to reader. self._readers_lock = threading.Lock() + self._offsets = dict() + def __enter__(self): return self @@ -64,15 +73,48 @@ class DebugEventsReader(object): self.close() def _generic_iterator(self, file_path): - """A helper method that makes an iterator given a debug-events file path.""" + """A helper method that makes an iterator given a debug-events file path. + + Repeated calls to this method create iterators that remember the last + successful reading position (offset) for each given `file_path`. So the + iterators are meant for incremental reading of the file. + + Args: + file_path: Path to the file to create the iterator for. + + Yields: + A tuple of (offset, debug_event_proto) on each `next()` call. + """ # The following code uses the double-checked locking pattern to optimize # the common case (where the reader is already initialized). if file_path not in self._readers: # 1st check, without lock. with self._readers_lock: if file_path not in self._readers: # 2nd check, with lock. - self._readers[file_path] = tf_record.tf_record_iterator(file_path) + with errors.raise_exception_on_not_ok_status() as status: + # TODO(b/136474806): Use tf_record.tf_record_iterator() once it + # supports offset. + self._readers[file_path] = pywrap_tensorflow.PyRecordReader_New( + compat.as_bytes(file_path), 0, b"", status) + reader = self._readers[file_path] + while True: + offset = reader.offset() + try: + reader.GetNext() + except (errors.DataLossError, errors.OutOfRangeError): + # We ignore partial read exceptions, because a record may be truncated. + # PyRecordReader holds the offset prior to the failed read, so retrying + # will succeed. + break + yield DebugEventWithOffset( + debug_event=debug_event_pb2.DebugEvent.FromString(reader.record()), + offset=offset) - return map(debug_event_pb2.DebugEvent.FromString, self._readers[file_path]) + def _create_offset_reader(self, file_path, offset): + with errors.raise_exception_on_not_ok_status() as status: + # TODO(b/136474806): Use tf_record.tf_record_iterator() once it + # supports ofset. + return pywrap_tensorflow.PyRecordReader_New( + file_path, offset, b"", status) def metadata_iterator(self): return self._generic_iterator(self._metadata_path) @@ -86,12 +128,839 @@ class DebugEventsReader(object): def graphs_iterator(self): return self._generic_iterator(self._graphs_path) + def read_graphs_event(self, offset): + """Read a DebugEvent proto at a given offset from the .graphs file. + + Args: + offset: Offset to read the DebugEvent proto from. + + Returns: + A DebugEventProto. + + Raises: + `errors.DataLossError` if offset is at a wrong location. + `errors.OutOfRangeError` if offset is out of range of the file. + """ + # TODO(cais): After switching to new Python wrapper of tfrecord reader, + # use seeking instead of repeated file opening. Same below. + reader = self._create_offset_reader(self._graphs_path, offset) + reader.GetNext() + debug_event = debug_event_pb2.DebugEvent.FromString(reader.record()) + reader.Close() + return debug_event + def execution_iterator(self): return self._generic_iterator(self._execution_path) + def read_execution_debug_event(self, offset): + """Read a DebugEvent proto at a given offset from the .execution file. + + Args: + offset: Offset to read the DebugEvent proto from. + + Returns: + A DebugEventProto. + + Raises: + `errors.DataLossError` if offset is at a wrong location. + `errors.OutOfRangeError` if offset is out of range of the file. + """ + reader = self._create_offset_reader(self._execution_path, offset) + reader.GetNext() + debug_event = debug_event_pb2.DebugEvent.FromString(reader.record()) + reader.Close() + return debug_event + def graph_execution_traces_iterator(self): return self._generic_iterator(self._graph_execution_traces_path) + def read_graph_execution_traces_event(self, offset): + """Read DebugEvent at given offset from .graph_execution_traces file. + + Args: + offset: Offset to read the DebugEvent proto from. + + Returns: + A DebugEventProto. + + Raises: + `errors.DataLossError` if offset is at a wrong location. + `errors.OutOfRangeError` if offset is out of range of the file. + """ + reader = self._create_offset_reader( + self._graph_execution_traces_path, offset) + reader.GetNext() + debug_event = debug_event_pb2.DebugEvent.FromString(reader.record()) + reader.Close() + return debug_event + def close(self): - with self._readers_lock: - self._readers.clear() + for reader in self._readers.values(): + reader.Close() + + +class BaseDigest(object): + """Base class for digest. + + Properties: + wall_time: A timestamp for the digest (unit: s). + offset: A offset number in the corresponding file that can be used for + fast random read access. + """ + + def __init__(self, wall_time, offset): + self._wall_time = wall_time + self._offset = offset + + @property + def wall_time(self): + return self._wall_time + + @property + def offset(self): + return self._offset + + +class ExecutionDigest(BaseDigest): + """Light-weight digest summarizing top-level execution event. + + Use `DebugDataReader.read_execution(execution_digest)` to load the more + detailed data object concerning the execution event (`Execution`). + + Properties: + op_type: Type name of the executed op. In the case of the eager execution of + an individual op, it is the name of the op (e.g., "MatMul"). + In the case of the execution of a tf.function (FuncGraph), this is the + internally-generated name of the function (e.g., + "__inference_my_func_123"). + """ + + def __init__(self, + wall_time, + offset, + op_type): + super(ExecutionDigest, self).__init__(wall_time, offset) + self._op_type = op_type + + @property + def op_type(self): + return self._op_type + + # TODO(cais): Implement to_json(). + + +class Execution(ExecutionDigest): + """Detailed data relating to a top-level execution event. + + The execution is of an individual op or a tf.function, which may have any + number of output tensors. + + Properties (beyond the base class `ExecutionDigest`): + stack_frame_ids: Reference IDs for stack frames, ordered from bottommost to + topmost. Use `DebugDataReader.read_execution_stack_trace()` to load the + detailed stack frames (filepath, lineno and function name). + tensor_debug_mode: TensorDebugMode enum value, as an `int`. + graph_id: ID of the executed FuncGraph (applicable only the execution of a + tf.function). `None` for the eager execution of an individual op. + input_tensor_ids: IDs of the input (eager) tensor(s) for this execution, if + any. + output_tensor_ids: IDs of the output (eager) tensor(s) from this execution, + if any. + debug_tensor_values: Values of the debug tensor(s), applicable only to + non-FULL_TENSOR tensor debug mode. A tuple of list of numbers. Each + element of the tuple corresponds to an output tensor of the execution. + See documentation of the various TensorDebugModes for the semantics of the + numbers. + """ + + def __init__(self, + execution_digest, + stack_frame_ids, + tensor_debug_mode, + graph_id=None, + input_tensor_ids=None, + output_tensor_ids=None, + debug_tensor_values=None): + super(Execution, self).__init__( + execution_digest.wall_time, + execution_digest.offset, + execution_digest.op_type) + self._stack_frame_ids = stack_frame_ids + self._tensor_debug_mode = tensor_debug_mode + self._graph_id = graph_id + self._input_tensor_ids = input_tensor_ids + self._output_tensor_ids = output_tensor_ids + self._debug_tensor_values = debug_tensor_values + + @property + def stack_frame_ids(self): + return self._stack_frame_ids + + @property + def tensor_debug_mode(self): + return self._tensor_debug_mode + + @property + def graph_id(self): + return self._graph_id + + @property + def input_tensor_ids(self): + return self._input_tensor_ids + + @property + def num_outputs(self): + return len(self._output_tensor_ids) + + @property + def output_tensor_ids(self): + return self._output_tensor_ids + + @property + def debug_tensor_values(self): + return self._debug_tensor_values + + # TODO(cais): Implement to_json(). + + +class DebuggedGraph(object): + """Data object representing debugging information about a tf.Graph. + + Includes `FuncGraph`s. + + Properties: + name: Name of the graph (if any). May be `None` for non-function graphs. + graph_id: Debugger-generated ID for the graph. + inner_graph_ids: A list of the debugger-generated IDs for the graphs + enclosed by this graph. + outer_graph_id: If this graph is nested within an outer graph, ID of the + outer graph. If this is an outermost graph, `None`. + """ + + def __init__(self, + name, + graph_id, + outer_graph_id=None): + self._name = name + self._graph_id = graph_id + self._outer_graph_id = outer_graph_id + self._inner_graph_ids = [] + # A dictionary from op name to GraphOpCreationDigest. + self._op_by_name = dict() + + def add_inner_graph_id(self, inner_graph_id): + """Add the debugger-generated ID of a graph nested within this graph. + + Args: + inner_graph_id: The debugger-generated ID of the nested inner graph. + """ + assert isinstance(inner_graph_id, six.string_types) + self._inner_graph_ids.append(inner_graph_id) + + def add_op(self, graph_op_creation_digest): + """Add an op creation data object. + + Args: + graph_op_creation_digest: A GraphOpCreationDigest data object describing + the creation of an op inside this graph. + """ + assert graph_op_creation_digest.op_name not in self._op_by_name + self._op_by_name[ + graph_op_creation_digest.op_name] = graph_op_creation_digest + + @property + def name(self): + return self._name + + @property + def graph_id(self): + return self._graph_id + + @property + def outer_graph_id(self): + return self._outer_graph_id + + @property + def inner_graph_ids(self): + return self._inner_graph_ids + + def get_op_type(self, op_name): + return self._op_by_name[op_name].op_type + + def get_tensor_id(self, op_name, output_slot): + """Get the ID of a symbolic tensor in this graph.""" + return self._op_by_name[op_name].output_tensor_ids[output_slot] + + # TODO(cais): Implement to_json(). + + +class GraphOpCreationDigest(BaseDigest): + """Data object describing the creation of an op inside a graph. + + For size efficiency, this digest object does not contain any stack frames or + any references to them. To obtain the stack frames, use + `DataReader.read_graph_op_creation_stack_trace()`. + + Properties (beyond the base class): + graph_id: Debugger-generated ID of the immediately-enclosing graph. + op_type: Type name of the op (e.g., "MatMul"). + op_name: Name of the op (e.g., "dense_1/MatMul"). + output_tensor_ids: Debugger-generated IDs for the output(s) of the op. + input_names: Names of the input tensors to the op. + device_name: The name of the device that the op is placed on (if available). + """ + + def __init__(self, + wall_time, + offset, + graph_id, + op_type, + op_name, + output_tensor_ids, + input_names=None, + device_name=None): + super(GraphOpCreationDigest, self).__init__(wall_time, offset) + self._graph_id = graph_id + self._op_type = op_type + self._op_name = op_name + self._output_tensor_ids = output_tensor_ids + self._input_names = input_names + self._device_name = device_name + + @property + def graph_id(self): + return self._graph_id + + @property + def op_type(self): + return self._op_type + + @property + def op_name(self): + return self._op_name + + @property + def output_tensor_ids(self): + return self._output_tensor_ids + + @property + def num_outputs(self): + return len(self._output_tensor_ids) + + @property + def input_names(self): + return self._input_names + + @property + def device_name(self): + return self._device_name + + # TODO(cais): Implement to_json(). + + +class GraphExecutionTraceDigest(BaseDigest): + """Light-weight summary of a intra-graph tensor execution event. + + Use `DebugDataReader.read_graph_execution_trace()` on this object to read more + detailed data (`GraphExecutionTrace`). + + Properties (beyond the base class): + op_type: Type name of the executed op (e.g., "Conv2D"). + op_name: Name of the op (e.g., "conv_2d_3/Conv2D"). + output_slot: Output slot index of the tensor. + """ + + def __init__(self, + wall_time, + offset, + op_type, + op_name, + output_slot): + super(GraphExecutionTraceDigest, self).__init__(wall_time, offset) + self._op_type = op_type + self._op_name = op_name + self._output_slot = output_slot + + @property + def op_type(self): + return self._op_type + + @property + def op_name(self): + return self._op_name + + @property + def output_slot(self): + return self._output_slot + + # TODO(cais): Implement to_json(). + + +class GraphExecutionTrace(GraphExecutionTraceDigest): + """Detailed data object describing an intra-graph tensor execution. + + Attributes (in addition to GraphExecutionTraceDigest): + graph_ids: The debugger-generated IDs of the graphs that enclose the + executed op (tensor), ordered from the outermost to the innermost. + graph_id: The debugger-generated ID of the innermost (immediately-enclosing) + graph. + tensor_debug_mode: TensorDebugMode enum value. + debug_tensor_value: Debug tensor values (only for non-FULL_TENSOR + tensor_debug_mode). A list of numbers. See the documentation of the + TensorDebugModes for the semantics of the numbers. + device_name: Device on which the tensor resides (if available) + """ + + def __init__(self, + graph_execution_trace_digest, + graph_ids, + tensor_debug_mode, + debug_tensor_value=None, + device_name=None): + super(GraphExecutionTrace, self).__init__( + graph_execution_trace_digest.wall_time, + graph_execution_trace_digest.offset, + graph_execution_trace_digest.op_type, + graph_execution_trace_digest.op_name, + graph_execution_trace_digest.output_slot) + self._graph_ids = graph_ids + self._tensor_debug_mode = tensor_debug_mode + self._debug_tensor_value = debug_tensor_value + self._device_name = device_name + + @property + def graph_ids(self): + return self._graph_ids + + @property + def graph_id(self): + return self._graph_ids[-1] + + @property + def tensor_debug_mode(self): + return self._tensor_debug_mode + + @property + def debug_tensor_value(self): + return self._debug_tensor_value + + @property + def device_name(self): + return self._device_name + + # TODO(cais): Implement to_json(). + + +def _parse_tensor_value(tensor_proto, return_list=False): + """Helper method for reading a tensor value from a tensor proto. + + The rationale for the distinction between `True` and `False value of + `return_list` is as follows: + - `return_list=True` is used for TensorDebugMode values other than + FULL_TENSOR, e.g., CONCISE_HEALTH, SHAPE and FULL_HEATLH. Under + those modes, the value is guaranteed (by contract) to be a 1D float64 + tensor. + - `return_list=False` is used for the FULL_HEALTH TensorDebugMode + specifically. Instead, we use `numpy.ndarray` to maximally preserve + the shape, dtype and value information regarding the underlying tensor + value. Under that mode, we don't use a python list to represent the + tensor value because that can lead to loss of information (e.g., both + float16 and float32 dtypes get mapped to Python floats). + + Args: + tensor_proto: The TensorProto instance from which the tensor value will be + loaded. + return_list: Whether the return value will be a nested Python list that + comes out from `numpy.ndarray.tolist()`. + + Returns: + If parsing is successful, the tensor value as a `numpy.ndarray` or the + nested Python list converted from it. + If parsing fails, `None`. + """ + try: + ndarray = tensor_util.MakeNdarray(tensor_proto) + return ndarray.tolist() if return_list else ndarray + except TypeError: + # Depending on tensor_debug_mode, certain dtype of tensors don't + # have logged debug tensor values. + return None + + +class DebugDataReader(object): + """A reader that reads structured debugging data in the tfdbg v2 format. + + The set of data read by an object of this class concerns the execution history + of a tfdbg2-instrumented TensorFlow program. + + Note: + - An object of this class incrementally reads data from files that belong to + the tfdbg v2 DebugEvent file set. Calling `update()` triggers the reading + from the last-successful reading positions in the files. + - This object can be used as a context manager. Its `__exit__()` call + closes the file readers cleanly. + """ + + def __init__(self, dump_root): + self._reader = DebugEventsReader(dump_root) + # TODO(cais): Implement pagination for memory constraints. + self._execution_digests = [] + + # A list of (host_name, file_path) tuples. + self._host_name_file_paths = [] + # A dict mapping id to (host_name, file_path, lineno, func) tuple. + self._stack_frame_by_id = dict() + # Stores unprocessed stack frame IDs. This is necessary to handle the + # case in which reading of the .stack_frames file gets ahead of the reading + # of the .source_files file. + self._unprocessed_stack_frames = dict() + # A dict mapping id to DebuggedGraph objects. + self._graph_by_id = dict() + self._graph_op_digests = [] + # TODO(cais): Implement pagination for memory constraints. + self._graph_execution_trace_digests = [] + + # The following timestamps keep track where we've reached in each + # file of the DebugEvent source file, so that we don't run into race + # conditions with the writer. + self._source_files_timestamp = 0 + # Temporary object used to hold DebugEvent protos with stack_frames + # field that has been read beyond max_wall_time. + # self._last_successful_stack_frames_offset = -1 # TODO(cais): Fix. + + # TODO(cais): Read metadata. + def _load_source_files(self): + """Incrementally read the .source_files DebugEvent file.""" + source_files_iter = self._reader.source_files_iterator() + for debug_event, _ in source_files_iter: + source_file = debug_event.source_file + self._host_name_file_paths.append( + (source_file.host_name, source_file.file_path)) + self._source_file_timestamp = debug_event.wall_time + + def _load_stack_frames(self): + """Incrementally read the .stack_frames file. + + This must be called after _load_source_files(). + It assumes that the following contract is honored by the writer of the tfdbg + v2 data file set: + - Before a stack frame is written to the .stack_frames file, the + corresponding source file information must have been written to the + .source_files file first. + """ + stack_frames_iter = self._reader.stack_frames_iterator() + for debug_event, _ in stack_frames_iter: + stack_frame_with_id = debug_event.stack_frame_with_id + file_line_col = stack_frame_with_id.file_line_col + self._unprocessed_stack_frames[stack_frame_with_id.id] = file_line_col + # We do the processing in a separate stage, because the reading in the + # .source_files file may sometimes get ahead of the .source_files file. + unprocessed_stack_frame_ids = tuple(self._unprocessed_stack_frames.keys()) + for stack_frame_id in unprocessed_stack_frame_ids: + file_line_col = self._unprocessed_stack_frames[stack_frame_id] + if len(self._host_name_file_paths) > file_line_col.file_index: + self._stack_frame_by_id[stack_frame_id] = ( + self._host_name_file_paths[file_line_col.file_index][0], + self._host_name_file_paths[file_line_col.file_index][1], + file_line_col.line, + file_line_col.func) + del self._unprocessed_stack_frames[stack_frame_id] + + def _load_graphs(self): + """Incrementally read the .graphs file. + + Compiles the DebuggedGraph and GraphOpCreation data. + """ + graphs_iter = self._reader.graphs_iterator() + for debug_event, offset in graphs_iter: + if debug_event.graph_op_creation.ByteSize(): + op_creation_proto = debug_event.graph_op_creation + op_digest = GraphOpCreationDigest( + debug_event.wall_time, + offset, + op_creation_proto.graph_id, + op_creation_proto.op_type, + op_creation_proto.op_name, + tuple(op_creation_proto.output_tensor_ids), + input_names=tuple(op_creation_proto.input_names)) + self._graph_op_digests.append(op_digest) + self._graph_by_id[op_creation_proto.graph_id].add_op(op_digest) + elif debug_event.debugged_graph.ByteSize(): + graph_proto = debug_event.debugged_graph + graph = DebuggedGraph( + graph_proto.graph_name or None, + graph_proto.graph_id, + outer_graph_id=graph_proto.outer_context_id or None) + self._graph_by_id[graph_proto.graph_id] = graph + if graph_proto.outer_context_id: + self._graph_by_id[ + graph_proto.outer_context_id].add_inner_graph_id(graph.graph_id) + + def _load_graph_execution_traces(self): + """Incrementally load the .graph_execution_traces file.""" + traces_iter = self._reader.graph_execution_traces_iterator() + for debug_event, offset in traces_iter: + trace_proto = debug_event.graph_execution_trace + op_name = trace_proto.op_name + op_type = self._lookup_op_type(trace_proto.tfdbg_context_id, op_name) + digest = GraphExecutionTraceDigest( + debug_event.wall_time, + offset, + op_type, + op_name, + trace_proto.output_slot) + self._graph_execution_trace_digests.append(digest) + + def _lookup_op_type(self, graph_id, op_name): + """Lookup the type of an op by name and the immediately enclosing graph. + + Args: + graph_id: Debugger-generated ID of the immediately-enclosing graph. + op_name: Name of the op. + + Returns: + Op type as a str. + """ + return self._graph_by_id[graph_id].get_op_type(op_name) + + def _load_execution(self): + """Incrementally read the .execution file.""" + execution_iter = self._reader.execution_iterator() + for debug_event, offset in execution_iter: + self._execution_digests.append(ExecutionDigest( + debug_event.wall_time, + offset, + debug_event.execution.op_type)) + + def update(self): + """Perform incremental read of the file set.""" + self._load_source_files() + self._load_stack_frames() + self._load_graphs() + self._load_graph_execution_traces() + self._load_execution() + + def outermost_graphs(self): + """Get the number of outer most graphs read so far.""" + return [graph for graph in self._graph_by_id.values() + if not graph.outer_graph_id] + + def graph_by_id(self, graph_id): + """Get a DebuggedGraph object by its ID.""" + return self._graph_by_id[graph_id] + + def graph_op_digests(self, op_type=None): + """Get the list of the digests for graph-op creation so far. + + Args: + op_type: Optional op type to filter the creation events with. + + Returns: + A list of `GraphOpCreationDigest` objects. + """ + if op_type is not None: + return [digest for digest in self._graph_op_digests + if digest.op_type == op_type] + else: + return self._graph_op_digests + + def graph_execution_traces(self, digest=False): + """Get all the intra-graph execution tensor traces read so far. + + TODO(cais): Support begin and end to enable partial loading. + + Args: + digest: Whether the results will be returned in the more light-weight + digest form. + + Returns: + If `digest`: a `list` of `GraphExecutionTraceDigest` objects. + Else: a `list` of `GraphExecutionTrace` objects. + """ + if digest: + return self._graph_execution_trace_digests + else: + return [self.read_graph_execution_trace(digest) + for digest in self._graph_execution_trace_digests] + + def num_graph_execution_traces(self): + """Get the number of graph execution traces read so far.""" + return len(self._graph_execution_trace_digests) + + def executions(self, digest=False): + """Get `Execution`s or `ExecutionDigest`s this reader has read so far. + + # TODO(cais): Support begin index and end index to support partial loading. + + Args: + digest: Whether the results are returned in a digest form, i.e., + `ExecutionDigest` format, instead of the more detailed `Execution` + format. + + Returns: + If `digest`: a `list` of `ExecutionDigest` objects. + Else: a `list` of `Execution` objects. + """ + if digest: + return self._execution_digests + else: + # TODO(cais): Optimizer performance removing repeated file open/close. + return [self.read_execution(digest) for digest in self._execution_digests] + + def num_executions(self): + """Get the number of execution events read so far.""" + return len(self._execution_digests) + + def read_execution(self, execution_digest): + """Read a detailed Execution object.""" + debug_event = self._reader.read_execution_debug_event( + execution_digest.offset) + execution_proto = debug_event.execution + + debug_tensor_values = None + if (execution_proto.tensor_debug_mode == + debug_event_pb2.TensorDebugMode.FULL_TENSOR): + pass # TODO(cais): Build tensor store. + elif (execution_proto.tensor_debug_mode != + debug_event_pb2.TensorDebugMode.NO_TENSOR): + debug_tensor_values = [] + for tensor_proto in execution_proto.tensor_protos: + # TODO(cais): Refactor into a helper method. + debug_tensor_values.append( + _parse_tensor_value(tensor_proto, return_list=True)) + return Execution( + execution_digest, + tuple(execution_proto.code_location.stack_frame_ids), + execution_proto.tensor_debug_mode, + graph_id=execution_proto.graph_id, + input_tensor_ids=tuple(execution_proto.input_tensor_ids), + output_tensor_ids=tuple(execution_proto.output_tensor_ids), + debug_tensor_values=tuple( + debug_tensor_values) if debug_tensor_values else None) + + def read_graph_execution_trace(self, graph_execution_trace_digest): + """Read the detailed graph execution trace. + + Args: + graph_execution_trace_digest: A `GraphExecutionTraceDigest` object. + + Returns: + The corresponding `GraphExecutionTrace` object. + """ + debug_event = self._reader.read_graph_execution_traces_event( + graph_execution_trace_digest.offset) + trace_proto = debug_event.graph_execution_trace + + graph_ids = [trace_proto.tfdbg_context_id] + # Exhaust the outer contexts (graphs). + while True: + graph = self.graph_by_id(graph_ids[0]) + if graph.outer_graph_id: + graph_ids.insert(0, graph.outer_graph_id) + else: + break + + debug_tensor_value = None + if (trace_proto.tensor_debug_mode == + debug_event_pb2.TensorDebugMode.FULL_TENSOR): + pass # TODO(cais): Build tensor store. + else: + debug_tensor_value = _parse_tensor_value( + trace_proto.tensor_proto, return_list=True) + return GraphExecutionTrace( + graph_execution_trace_digest, + graph_ids=graph_ids, + tensor_debug_mode=trace_proto.tensor_debug_mode, + debug_tensor_value=debug_tensor_value, + device_name=trace_proto.device_name or None) + + def read_execution_stack_trace(self, execution): + """Read the stack trace of a given Execution object. + + Args: + execution: The Execution object of interest. + + Returns: + A tuple consisting of: + 1. The host name. + 2. The stack trace, as a list of (file_path, lineno, func) tuples. + """ + host_name = self._stack_frame_by_id[execution.stack_frame_ids[0]][0] + return (host_name, [ + self._stack_frame_by_id[frame_id][1:] + for frame_id in execution.stack_frame_ids]) + + def read_graph_op_creation_stack_trace(self, graph_op_creation_digest): + """Read the stack trace of a given graph op creation object. + + Args: + graph_op_creation_digest: The GraphOpCreationDigest object of interest. + + Returns: + A tuple consisting of: + 1. The host name. + 2. The stack trace, as a list of (file_path, lineno, func) tuples. + """ + debug_event = self._reader.read_graphs_event( + graph_op_creation_digest.offset) + graph_op_creation = debug_event.graph_op_creation + host_name = graph_op_creation.code_location.host_name + return host_name, [ + self._stack_frame_by_id[frame_id][1:] + for frame_id in graph_op_creation.code_location.stack_frame_ids] + + # TODO(cais): Add graph_execution_digests() with an ExecutionDigest + # as a kwarg, to establish the association between top-level and intra-graph + # execution events. + + def execution_to_tensor_values(self, execution): + """Read the full tensor values from an Execution or ExecutionDigest. + + Args: + execution: An `ExecutionDigest` or `ExeuctionDigest` object. + + Returns: + A list of numpy arrays representing the output tensor values of the + execution event. + """ + debug_event = self._reader.read_execution_debug_event(execution.offset) + return [_parse_tensor_value(tensor_proto) + for tensor_proto in debug_event.execution.tensor_protos] + + def graph_execution_trace_to_tensor_value(self, trace): + """Read full tensor values from an Execution or ExecutionDigest. + + Args: + trace: An `GraphExecutionTraceDigest` or `GraphExecutionTrace` object. + + Returns: + A numpy array representing the output tensor value of the intra-graph + tensor execution event. + """ + debug_event = self._reader.read_graph_execution_traces_event(trace.offset) + return _parse_tensor_value(debug_event.graph_execution_trace.tensor_proto) + + def symbolic_tensor_id(self, graph_id, op_name, output_slot): + """Get the ID of a symbolic tensor. + + Args: + graph_id: The ID of the immediately-enclosing graph. + op_name: Name of the op. + output_slot: Output slot as an int. + + Returns: + The ID of the symbolic tensor as an int. + """ + return self._graph_by_id[graph_id].get_tensor_id(op_name, output_slot) + + def graph_execution_trace_to_tensor_id(self, trace): + """Get symbolic tensor ID from a GraphExecutoinTraceDigest object.""" + return self.symbolic_tensor_id( + trace.graph_id, trace.op_name, trace.output_slot) + + def __enter__(self): + return self + + def __exit__(self, exception_type, exception_value, traceback): + del exception_type, exception_value, traceback # Unused + self._reader.close() diff --git a/tensorflow/python/debug/lib/debug_events_writer_test.py b/tensorflow/python/debug/lib/debug_events_writer_test.py index f6e973befed..b62fc9b3f9f 100644 --- a/tensorflow/python/debug/lib/debug_events_writer_test.py +++ b/tensorflow/python/debug/lib/debug_events_writer_test.py @@ -76,20 +76,20 @@ class DebugEventsWriterTest(dumping_callback_test_lib.DumpingCallbackTestBase): writer.FlushNonExecutionFiles() with debug_events_reader.DebugEventsReader(self.dump_root) as reader: - actuals = list(reader.source_files_iterator()) + actuals = list(item.debug_event.source_file + for item in reader.source_files_iterator()) self.assertLen(actuals, num_protos) for i in range(num_protos): - self.assertEqual(actuals[i].source_file.file_path, - "/home/tf2user/main.py") - self.assertEqual(actuals[i].source_file.host_name, "machine.cluster") - self.assertEqual(actuals[i].source_file.lines, ["print(%d)" % i]) + self.assertEqual(actuals[i].file_path, "/home/tf2user/main.py") + self.assertEqual(actuals[i].host_name, "machine.cluster") + self.assertEqual(actuals[i].lines, ["print(%d)" % i]) - actuals = list(reader.stack_frames_iterator()) + actuals = list(item.debug_event.stack_frame_with_id + for item in reader.stack_frames_iterator()) self.assertLen(actuals, num_protos) for i in range(num_protos): - self.assertEqual(actuals[i].stack_frame_with_id.id, "stack_%d" % i) - self.assertEqual( - actuals[i].stack_frame_with_id.file_line_col.file_index, i * 10) + self.assertEqual(actuals[i].id, "stack_%d" % i) + self.assertEqual(actuals[i].file_line_col.file_index, i * 10) def testWriteGraphOpCreationAndDebuggedGraphs(self): writer = debug_events_writer.DebugEventsWriter(self.dump_root) @@ -106,7 +106,7 @@ class DebugEventsWriterTest(dumping_callback_test_lib.DumpingCallbackTestBase): writer.FlushNonExecutionFiles() reader = debug_events_reader.DebugEventsReader(self.dump_root) - actuals = list(reader.graphs_iterator()) + actuals = list(item.debug_event for item in reader.graphs_iterator()) self.assertLen(actuals, num_op_creations + 1) for i in range(num_op_creations): self.assertEqual(actuals[i].graph_op_creation.op_type, "Conv2D") @@ -172,24 +172,24 @@ class DebugEventsWriterTest(dumping_callback_test_lib.DumpingCallbackTestBase): # Verify the content of the .source_files file. with debug_events_reader.DebugEventsReader(self.dump_root) as reader: source_files_iter = reader.source_files_iterator() - actuals = list(source_files_iter) - file_paths = sorted([actual.source_file.file_path for actual in actuals]) + actuals = list(item.debug_event.source_file for item in source_files_iter) + file_paths = sorted([actual.file_path for actual in actuals]) self.assertEqual(file_paths, [ "/home/tf2user/file_0.py", "/home/tf2user/file_1.py", "/home/tf2user/file_2.py" ]) # Verify the content of the .stack_frames file. - actuals = list(reader.stack_frames_iterator()) - stack_frame_ids = sorted( - [actual.stack_frame_with_id.id for actual in actuals]) + actuals = list(item.debug_event.stack_frame_with_id + for item in reader.stack_frames_iterator()) + stack_frame_ids = sorted([actual.id for actual in actuals]) self.assertEqual(stack_frame_ids, ["stack_frame_0", "stack_frame_1", "stack_frame_2"]) # Verify the content of the .graphs file. - actuals = list(reader.graphs_iterator()) - graph_op_names = sorted( - [actual.graph_op_creation.op_name for actual in actuals]) + actuals = list(item.debug_event.graph_op_creation + for item in reader.graphs_iterator()) + graph_op_names = sorted([actual.op_name for actual in actuals]) self.assertEqual(graph_op_names, ["Op0", "Op1", "Op2"]) def testWriteExecutionEventsWithCircularBuffer(self): @@ -242,11 +242,12 @@ class DebugEventsWriterTest(dumping_callback_test_lib.DumpingCallbackTestBase): self.assertEqual(len(actuals), 0) writer.FlushExecutionFiles() - actuals = list(reader.graph_execution_traces_iterator()) + actuals = list(item.debug_event.graph_execution_trace + for item in reader.graph_execution_traces_iterator()) self.assertLen(actuals, debug_events_writer.DEFAULT_CIRCULAR_BUFFER_SIZE) for i in range(debug_events_writer.DEFAULT_CIRCULAR_BUFFER_SIZE): self.assertEqual( - actuals[i].graph_execution_trace.op_name, + actuals[i].op_name, "Op%d" % (i + debug_events_writer.DEFAULT_CIRCULAR_BUFFER_SIZE)) def testWriteGraphExecutionTraceEventsWithoutCircularBufferBehavior(self): @@ -260,10 +261,11 @@ class DebugEventsWriterTest(dumping_callback_test_lib.DumpingCallbackTestBase): writer.FlushExecutionFiles() with debug_events_reader.DebugEventsReader(self.dump_root) as reader: - actuals = list(reader.graph_execution_traces_iterator()) + actuals = list(item.debug_event.graph_execution_trace + for item in reader.graph_execution_traces_iterator()) self.assertLen(actuals, num_execution_events) for i in range(num_execution_events): - self.assertEqual(actuals[i].graph_execution_trace.op_name, "Op%d" % i) + self.assertEqual(actuals[i].op_name, "Op%d" % i) def testConcurrentWritesToExecutionFiles(self): circular_buffer_size = 5 @@ -308,9 +310,9 @@ class DebugEventsWriterTest(dumping_callback_test_lib.DumpingCallbackTestBase): # Verify the content of the .execution file. with debug_events_reader.DebugEventsReader(self.dump_root) as reader: - actuals = list(reader.graph_execution_traces_iterator()) - op_names = sorted( - [actual.graph_execution_trace.op_name for actual in actuals]) + actuals = list(item.debug_event.graph_execution_trace + for item in reader.graph_execution_traces_iterator()) + op_names = sorted([actual.op_name for actual in actuals]) self.assertLen(op_names, circular_buffer_size) self.assertLen(op_names, len(set(op_names))) diff --git a/tensorflow/python/debug/lib/debug_v2_ops_test.py b/tensorflow/python/debug/lib/debug_v2_ops_test.py index c665da7132d..d6f0d4310a2 100644 --- a/tensorflow/python/debug/lib/debug_v2_ops_test.py +++ b/tensorflow/python/debug/lib/debug_v2_ops_test.py @@ -88,7 +88,7 @@ class DebugIdentityV2OpTest(dumping_callback_test_lib.DumpingCallbackTestBase): metadata_iter = reader.metadata_iterator() # Check that the .metadata DebugEvents data file has been created, even # before FlushExecutionFiles() is called. - debug_event = next(metadata_iter) + debug_event = next(metadata_iter).debug_event self.assertGreater(debug_event.wall_time, 0) self.assertTrue(debug_event.debug_metadata.tensorflow_version) self.assertTrue( @@ -107,7 +107,7 @@ class DebugIdentityV2OpTest(dumping_callback_test_lib.DumpingCallbackTestBase): # The circular buffer has a size of 4. So only the data from the # last two iterations should have been written to self.dump_root. for _ in range(2): - debug_event = next(graph_trace_iter) + debug_event = next(graph_trace_iter).debug_event self.assertGreater(debug_event.wall_time, 0) trace = debug_event.graph_execution_trace self.assertEqual(trace.tfdbg_context_id, "deadbeaf") @@ -118,7 +118,7 @@ class DebugIdentityV2OpTest(dumping_callback_test_lib.DumpingCallbackTestBase): tensor_value = tensor_util.MakeNdarray(trace.tensor_proto) self.assertAllClose(tensor_value, [9.0, 16.0]) - debug_event = next(graph_trace_iter) + debug_event = next(graph_trace_iter).debug_event self.assertGreater(debug_event.wall_time, 0) trace = debug_event.graph_execution_trace self.assertEqual(trace.tfdbg_context_id, "beafdead") @@ -165,7 +165,7 @@ class DebugIdentityV2OpTest(dumping_callback_test_lib.DumpingCallbackTestBase): x_values = [] timestamp = 0 while True: - debug_event = next(graph_trace_iter) + debug_event = next(graph_trace_iter).debug_event self.assertGreater(debug_event.wall_time, timestamp) timestamp = debug_event.wall_time trace = debug_event.graph_execution_trace @@ -210,7 +210,7 @@ class DebugIdentityV2OpTest(dumping_callback_test_lib.DumpingCallbackTestBase): with debug_events_reader.DebugEventsReader(debug_root) as reader: graph_trace_iter = reader.graph_execution_traces_iterator() - debug_event = next(graph_trace_iter) + debug_event = next(graph_trace_iter).debug_event trace = debug_event.graph_execution_trace self.assertEqual(trace.tfdbg_context_id, "deadbeaf") self.assertEqual(trace.op_name, "") diff --git a/tensorflow/python/debug/lib/dumping_callback_test.py b/tensorflow/python/debug/lib/dumping_callback_test.py index b7e90f3179c..061cb001639 100644 --- a/tensorflow/python/debug/lib/dumping_callback_test.py +++ b/tensorflow/python/debug/lib/dumping_callback_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import collections import os import shutil +import socket import tempfile import threading @@ -36,7 +37,6 @@ from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_util from tensorflow.python.keras import models from tensorflow.python.keras.applications import mobilenet_v2 @@ -61,6 +61,10 @@ def _create_simple_recurrent_keras_model(input_shape): return model +_host_name = socket.gethostname() +_current_file_full_path = os.path.abspath(__file__) + + class TracingCallbackTest( dumping_callback_test_lib.DumpingCallbackTestBase, parameterized.TestCase): @@ -74,6 +78,19 @@ class TracingCallbackTest( dumping_callback.disable_dump_debug_info() super(TracingCallbackTest, self).tearDown() + def _verifyStackFrames(self, stack_frames): + """Verify the correctness of the stack frames. + + Currently, it simply asserts that the current file is found in the stack + frames. + TODO(cais): Perhaps implement a stricter check later. + + Args: + stack_frames: The stack frames to verify. + """ + self.assertTrue([ + frame for frame in stack_frames if frame[0] == _current_file_full_path]) + def testInvalidTensorDebugModeCausesError(self): with self.assertRaisesRegexp( ValueError, @@ -111,73 +128,66 @@ class TracingCallbackTest( writer.FlushNonExecutionFiles() self._readAndCheckMetadataFile() - stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames() - # Before FlushExecutionFiles() is called, the .execution file should be - # empty. - with debug_events_reader.DebugEventsReader(self.dump_root) as reader: - execution_iter = reader.execution_iterator() - with self.assertRaises(StopIteration): - next(execution_iter) + with debug_events_reader.DebugDataReader(self.dump_root) as reader: + reader.update() + # Before FlushExecutionFiles() is called, the .execution file should be + # empty. + self.assertFalse(reader.executions()) # After the flushing, the .execution file should hold the appropriate # contents. writer.FlushExecutionFiles() - execution_iter = reader.execution_iterator() + reader.update() + executions = reader.executions() prev_wall_time = 1 executed_op_types = [] tensor_values = collections.defaultdict(lambda: []) - for debug_event in execution_iter: - self.assertGreaterEqual(debug_event.wall_time, prev_wall_time) - prev_wall_time = debug_event.wall_time - execution = debug_event.execution + for execution in executions: + self.assertGreaterEqual(execution.wall_time, prev_wall_time) + prev_wall_time = execution.wall_time executed_op_types.append(execution.op_type) # No graph IDs should have been logged for eager op executions. self.assertFalse(execution.graph_id) self.assertTrue(execution.input_tensor_ids) self.assertTrue(execution.output_tensor_ids) + self.assertEqual( + debug_event_pb2.TensorDebugMode.keys()[execution.tensor_debug_mode], + tensor_debug_mode) if tensor_debug_mode == "NO_TENSOR": # Due to the NO_TENSOR tensor debug mode, tensor_protos ought to # be empty. - self.assertFalse(execution.tensor_protos) + self.assertFalse(execution.debug_tensor_values) elif tensor_debug_mode == "CURT_HEALTH": - self.assertLen(execution.tensor_protos, 1) + self.assertLen(execution.debug_tensor_values, 1) if execution.op_type in ("AddV2", "Mul", "RealDiv"): # 1st element: -1 is the unset tensor_id for eager op execution. # 2nd element: 0 means there is no inf or nan. - self.assertAllClose( - tensor_util.MakeNdarray(execution.tensor_protos[0]), - [-1.0, 0.0]) + self.assertAllClose(execution.debug_tensor_values, [[-1.0, 0.0]]) elif tensor_debug_mode == "CONCISE_HEALTH": - self.assertLen(execution.tensor_protos, 1) if execution.op_type in ("AddV2", "Mul", "RealDiv"): # 1st element: -1 is the unset tensor_id for eager op execution. # 2nd element: each scalar tensor has 1 element. # Remaining elements: no -inf, inf or nan in these self.assertAllClose( - tensor_util.MakeNdarray(execution.tensor_protos[0]), - [-1, 1, 0, 0, 0]) + execution.debug_tensor_values, [[-1, 1, 0, 0, 0]]) elif tensor_debug_mode == "SHAPE": - self.assertLen(execution.tensor_protos, 1) if execution.op_type in ("AddV2", "Mul", "RealDiv"): # 1st element: -1 is the unset tensor_id for eager op execution. # 2nd element: dtype enum value (float32). # 3rd element: rank (scalar). # 4th element: element count (4). # Remaining elements: shape at fixed length (6). - self.assertAllClose( - tensor_util.MakeNdarray(execution.tensor_protos[0]), - [-1, 1, 0, 1, 0, 0, 0, 0, 0, 0]) + self.assertAllClose(execution.debug_tensor_values, + [[-1, 1, 0, 1, 0, 0, 0, 0, 0, 0]]) elif tensor_debug_mode == "FULL_TENSOR": - # Under the FULL_TENSOR mode, the value of the tensor should be - # available through `tensor_protos`. - tensor_value = float( - tensor_util.MakeNdarray(execution.tensor_protos[0])) - tensor_values[execution.op_type].append(tensor_value) - # Verify the code_location field. - self.assertTrue(execution.code_location.stack_frame_ids) - for stack_frame_id in execution.code_location.stack_frame_ids: - self.assertIn(stack_frame_id, stack_frame_by_id) + tensor_values[execution.op_type].append( + reader.execution_to_tensor_values(execution)[0]) + + host_name, stack_frames = reader.read_execution_stack_trace(execution) + self.assertEqual(host_name, _host_name) + self._verifyStackFrames(stack_frames) + if tensor_debug_mode == "FULL_TENSOR": self.assertAllClose(tensor_values["Greater"], [1, 1, 1, 1, 1, 1, 0]) self.assertAllClose(tensor_values["RealDiv"], [5, 8, 4, 2, 1]) @@ -217,12 +227,8 @@ class TracingCallbackTest( # Due to the pure eager op execution, the .graph file and the # .graph_execution_traces file ought to be empty. - graphs_iterator = reader.graphs_iterator() - with self.assertRaises(StopIteration): - next(graphs_iterator) - graph_trace_iter = reader.graph_execution_traces_iterator() - with self.assertRaises(StopIteration): - next(graph_trace_iter) + self.assertFalse(reader.outermost_graphs()) + self.assertEqual(reader.num_graph_execution_traces(), 0) @parameterized.named_parameters( ("CurtHealth", "CURT_HEALTH"), @@ -242,60 +248,48 @@ class TracingCallbackTest( y = np.array([2, -1, 0, 0, 1, 1, 1, 3], dtype=np.float16) # (x + y) / (x - y) = [0.2, -inf, nan, nan, inf, inf, inf, -5]. self.evaluate(func(x, y)) - writer.FlushNonExecutionFiles() writer.FlushExecutionFiles() - stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames() - (context_ids, - _, op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id) - - (op_names, _, _, - tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids) - executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names] - self.assertCountEqual(executed_op_types, ["AddV2", "Sub", "RealDiv"]) - - if tensor_debug_mode == "CURT_HEALTH": - for op_type, tensor_value in zip(executed_op_types, tensor_values): - self.assertLen(tensor_value, 2) - # 1st element: tensor_id, should be >= 0. - # TODO(cais): Assert on detailed value once Function-graph association - # is in place. - self.assertGreaterEqual(tensor_value[0], 0) - # 2nd element: 0 means there is no inf or nan. - if op_type == "RealDiv": - self.assertEqual(tensor_value[1], 1) - else: - self.assertEqual(tensor_value[1], 0) - elif tensor_debug_mode == "CONCISE_HEALTH": - for op_type, tensor_value in zip(executed_op_types, tensor_values): - self.assertLen(tensor_value, 5) - # 1st element: tensor_id, should be >= 0. - # TODO(cais): Assert on detailed value once Function-graph association - # is in place. - self.assertGreaterEqual(tensor_value[0], 0) - # 2nd element: element count. - self.assertEqual(tensor_value[1], 8) - # Remaining 3 elements: The counts of -inf, inf and nan. - if op_type == "RealDiv": - self.assertAllClose(tensor_value[2:], [1, 3, 2]) - else: - self.assertAllClose(tensor_value[2:], [0, 0, 0]) - else: # SHAPE. - for op_type, tensor_value in zip(executed_op_types, tensor_values): - self.assertLen(tensor_value, 10) - # 1st element: tensor_id, should be >= 0. - # TODO(cais): Assert on detailed value once Function-graph association - # is in place. - self.assertGreaterEqual(tensor_value[0], 0) - # 2nd element: dtype enum value (float16). - self.assertEqual(tensor_value[1], 19) - # 3rd element: rank (1) - self.assertEqual(tensor_value[2], 1) - # 4th element: element count. - self.assertEqual(tensor_value[3], 8) - # Remaining elements: shape at fixed length. - self.assertAllClose(tensor_value[4:], [8, 0, 0, 0, 0, 0]) + with debug_events_reader.DebugDataReader(self.dump_root) as reader: + reader.update() + graph_exec_traces = reader.graph_execution_traces() + executed_op_types = [trace.op_type for trace in graph_exec_traces] + self.assertCountEqual(executed_op_types, ["AddV2", "Sub", "RealDiv"]) + if tensor_debug_mode == "CURT_HEALTH": + for trace in graph_exec_traces: + # 1st element: tensor_id, should be >= 0. + # 2nd element: indicates if there is any inf or nan. + tensor_id = reader.graph_execution_trace_to_tensor_id(trace) + self.assertGreaterEqual(tensor_id, 0) + if trace.op_type == "RealDiv": + self.assertAllClose(trace.debug_tensor_value, [tensor_id, 1]) + else: + self.assertAllClose(trace.debug_tensor_value, [tensor_id, 0]) + elif tensor_debug_mode == "CONCISE_HEALTH": + for trace in graph_exec_traces: + # 1st element: tensor_id, should be >= 0. + # 2nd element: element count (8). + # Remaining 3 elements: The counts of -inf, inf and nan. + tensor_id = reader.graph_execution_trace_to_tensor_id(trace) + self.assertGreaterEqual(tensor_id, 0) + if trace.op_type == "RealDiv": + self.assertAllClose(trace.debug_tensor_value, + [tensor_id, 8, 1, 3, 2]) + else: + self.assertAllClose(trace.debug_tensor_value, + [tensor_id, 8, 0, 0, 0]) + else: # SHAPE. + for trace in graph_exec_traces: + # 1st element: tensor_id, should be >= 0. + # 2nd element: dtype enum value (float16 = 19). + # 3rd element: rank (1) + # 4th element: element count (8). + # Remaining elements: shape at fixed length (6). + tensor_id = reader.graph_execution_trace_to_tensor_id(trace) + self.assertGreaterEqual(tensor_id, 0) + self.assertAllClose(trace.debug_tensor_value, + [tensor_id, 19, 1, 8, 8, 0, 0, 0, 0, 0]) @parameterized.named_parameters( ("Shape", "SHAPE"), @@ -317,28 +311,21 @@ class TracingCallbackTest( writer.FlushNonExecutionFiles() writer.FlushExecutionFiles() - stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames() - (context_ids, - _, op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id) - - (op_names, _, _, - tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids) - executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names] - self.assertEqual(executed_op_types, ["LogicalAnd", "LogicalNot"]) - - for tensor_value in tensor_values: - # 1st element: tensor_id, should be >= 0. - # TODO(cais): Assert on detailed value once Function-graph association - # is in place. - self.assertGreaterEqual(tensor_value[0], 0) - # 2nd element: dtype enum value (bool). - self.assertEqual(tensor_value[1], 10) - # 3rd element: rank (2) - self.assertEqual(tensor_value[2], 2) - # 4th element: element count. - self.assertEqual(tensor_value[3], 4) - # Remaining elements: shape at fixed length. - self.assertAllClose(tensor_value[4:], [2, 2, 0, 0, 0, 0]) + with debug_events_reader.DebugDataReader(self.dump_root) as reader: + reader.update() + graph_exec_traces = reader.graph_execution_traces() + executed_op_types = [trace.op_type for trace in graph_exec_traces] + self.assertEqual(executed_op_types, ["LogicalAnd", "LogicalNot"]) + for trace in graph_exec_traces: + tensor_id = reader.graph_execution_trace_to_tensor_id(trace) + self.assertGreaterEqual(tensor_id, 0) + # 1st element: tensor_id, should be >= 0. + # 2nd element: dtype enum value (bool). + # 3rd element: rank (2). + # 4th element: element count (4). + # Remaining elements: shape at fixed length. + self.assertAllClose( + trace.debug_tensor_value, [tensor_id, 10, 2, 4, 2, 2, 0, 0, 0, 0]) @parameterized.named_parameters( ("NoTensor", "NO_TENSOR"), @@ -366,86 +353,151 @@ class TracingCallbackTest( writer.FlushNonExecutionFiles() writer.FlushExecutionFiles() - if context.executing_eagerly(): - # NOTE(b/142486213): Execution of the TF function happens with - # Session.run() in v1 graph mode, so doesn't get logged to the - # .execution file. - (executed_op_types, executed_graph_ids, - _, _, _, _) = self._readAndCheckExecutionFile() - executed_op_types = [op_type for op_type in executed_op_types - if "sin1p_log_sum" in op_type] - self.assertLen(executed_op_types, 1) + with debug_events_reader.DebugDataReader(self.dump_root) as reader: + reader.update() + outermost_graphs = reader.outermost_graphs() + self.assertLen(outermost_graphs, 1) - stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames() - (context_ids, op_types, op_name_to_op_type, - op_name_to_context_id) = self._readAndCheckGraphsFile(stack_frame_by_id) + if context.executing_eagerly(): + # NOTE(b/142486213): Execution of the TF function happens with + # Session.run() in v1 graph mode, so doesn't get logged to the + # .execution file. + executions = reader.executions() + self.assertLen(executions, 1) + self.assertIn("sin1p_log_sum", executions[0].op_type) + # Get the executed graph and verify its identity and inner graph. + graph = reader.graph_by_id(executions[0].graph_id) + self.assertEqual(graph.name, "sin1p_log_sum") + self.assertLen(graph.inner_graph_ids, 1) + inner_graph = reader.graph_by_id(graph.inner_graph_ids[0]) + self.assertEqual(inner_graph.name, "log_sum") - self.assertIn("AddV2", op_types) - self.assertIn("Log", op_types) - self.assertIn("Sin", op_types) - if context.executing_eagerly(): - # Check the correctness of the ID of the executed graph ID. - sin_op_name = [op_name for op_name in op_name_to_op_type - if op_name_to_op_type[op_name] == "Sin"] - self.assertLen(sin_op_name, 1) - sin_context_id = op_name_to_context_id[sin_op_name[0]] - # The executed "op" is a FuncGraph, and its graph ID should have been - # recorded properly and be the ID of the graph that the Sin op belongs to. - executed_graph_ids = [ - executed_graph_ids[i] for i, op_type - in enumerate(executed_op_types) if "sin1p_log_sum" in op_type] - self.assertEqual(executed_graph_ids[0], sin_context_id) + # Verify the recorded graph-building history. + add_op_digests = reader.graph_op_digests(op_type="AddV2") + self.assertLen(add_op_digests, 2) + self.assertEqual( + reader.graph_by_id(add_op_digests[0].graph_id).name, "log_sum") + self.assertEqual( + reader.graph_by_id(add_op_digests[1].graph_id).name, "sin1p_log_sum") + log_op_digests = reader.graph_op_digests(op_type="Log") + self.assertLen(log_op_digests, 1) + self.assertEqual( + reader.graph_by_id(log_op_digests[0].graph_id).name, "log_sum") + sin_op_digests = reader.graph_op_digests(op_type="Sin") + self.assertLen(sin_op_digests, 1) + self.assertEqual( + reader.graph_by_id(sin_op_digests[0].graph_id).name, "sin1p_log_sum") - (op_names, _, _, - tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids) - executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names] - self.assertEqual(executed_op_types, ["AddV2", "Log", "AddV2", "Sin"]) + # Verify the output tensor IDs and the stack traces. + for op_digest in add_op_digests + log_op_digests + sin_op_digests: + # These are all single-output ops. + self.assertLen(op_digest.output_tensor_ids, 1) + self.assertGreaterEqual(op_digest.output_tensor_ids[0], 0) + _, stack_frames = reader.read_graph_op_creation_stack_trace(op_digest) + self._verifyStackFrames(stack_frames) - if tensor_debug_mode == "NO_TENSOR": - # Under the default NO_TENSOR tensor-debug mode, the tensor_proto ought to - # be an empty float32 tensor. - for tensor_value in tensor_values: - self.assertEqual(tensor_value.dtype, np.float32) - self.assertEqual(tensor_value.shape, (0,)) - elif tensor_debug_mode == "CURT_HEALTH": - for tensor_value in tensor_values: - self.assertLen(tensor_value, 2) + graph_exec_traces = reader.graph_execution_traces() + executed_op_types = [digest.op_type for digest in graph_exec_traces] + self.assertEqual(executed_op_types, ["AddV2", "Log", "AddV2", "Sin"]) + + # Verify the graph ID stack of each op. + # 1st AddV2 op. + self.assertEqual( + reader.graph_by_id(graph_exec_traces[0].graph_ids[-1]).name, + "log_sum") + self.assertEqual( + reader.graph_by_id(graph_exec_traces[0].graph_ids[-2]).name, + "sin1p_log_sum") + # Log op. + self.assertEqual( + reader.graph_by_id(graph_exec_traces[1].graph_ids[-1]).name, + "log_sum") + self.assertEqual( + reader.graph_by_id(graph_exec_traces[1].graph_ids[-2]).name, + "sin1p_log_sum") + # 2nd AddV2 op. + self.assertEqual( + reader.graph_by_id(graph_exec_traces[2].graph_ids[-1]).name, + "sin1p_log_sum") + # Sin op. + self.assertEqual( + reader.graph_by_id(graph_exec_traces[3].graph_ids[-1]).name, + "sin1p_log_sum") + + if tensor_debug_mode == "NO_TENSOR": + # Under the default NO_TENSOR tensor-debug mode, the tensor_proto ought + # to be an empty float32 tensor. + for trace in graph_exec_traces: + self.assertEqual(trace.debug_tensor_value, []) + elif tensor_debug_mode == "CURT_HEALTH": + # Test the association between graph exec and prior graph building. + # In each case, the 1st element of debug_tensor_value is the ID of the + # symbolic tenosr and the 2nd element is a zero indicating there is no + # inf or nan. + self.assertAllClose( + graph_exec_traces[0].debug_tensor_value, + [add_op_digests[0].output_tensor_ids[0], 0.0]) # 1st AddV2 op. + self.assertAllClose( + graph_exec_traces[1].debug_tensor_value, + [log_op_digests[0].output_tensor_ids[0], 0.0]) # Log op. + self.assertAllClose( + graph_exec_traces[2].debug_tensor_value, + [add_op_digests[1].output_tensor_ids[0], 0.0]) # 2nd AddV2 op. + self.assertAllClose( + graph_exec_traces[3].debug_tensor_value, + [sin_op_digests[0].output_tensor_ids[0], 0.0]) # Sin op. + elif tensor_debug_mode == "CONCISE_HEALTH": # 1st element: tensor_id, should be >= 0. - # TODO(cais): Assert on detailed value once Function-graph association - # is in place. - self.assertGreaterEqual(tensor_value[0], 0) - # 2nd element: 0 means there is no inf or nan. - self.assertEqual(tensor_value[1], 0) - elif tensor_debug_mode == "CONCISE_HEALTH": - for tensor_value in tensor_values: - self.assertLen(tensor_value, 5) - # 1st element: tensor_id, should be >= 0. - # TODO(cais): Assert on detailed value once Function-graph association - # is in place. - self.assertGreaterEqual(tensor_value[0], 0) # 2nd element: element count. Remaining elements: all zero because there # is no -inf, inf or nan. - self.assertAllClose(tensor_value[1:], [1, 0, 0, 0]) - elif tensor_debug_mode == "SHAPE": - for tensor_value in tensor_values: - # 1st element: tensor_id, should be >= 0. - # TODO(cais): Assert on detailed value once Function-graph association - # is in place. - self.assertGreaterEqual(tensor_value[0], 0) + # 1st AddV2 op. + self.assertAllClose( + graph_exec_traces[0].debug_tensor_value, + [add_op_digests[0].output_tensor_ids[0], 1.0, 0.0, 0.0, 0.0]) + # Log op. + self.assertAllClose( + graph_exec_traces[1].debug_tensor_value, + [log_op_digests[0].output_tensor_ids[0], 1.0, 0.0, 0.0, 0.0]) + # 2nd AddV2 op. + self.assertAllClose( + graph_exec_traces[2].debug_tensor_value, + [add_op_digests[1].output_tensor_ids[0], 1.0, 0.0, 0.0, 0.0]) + # Sin op. + self.assertAllClose( + graph_exec_traces[3].debug_tensor_value, + [sin_op_digests[0].output_tensor_ids[0], 1.0, 0.0, 0.0, 0.0]) + elif tensor_debug_mode == "SHAPE": + # 1st element: tensor_id. # 2nd element: dtype (float32). - self.assertGreaterEqual(tensor_value[1], 1) # 3rd element: rank (scalar). - self.assertGreaterEqual(tensor_value[2], 0) - # 4th element: element count. - self.assertGreaterEqual(tensor_value[3], 1) - # Remaining elements: shape padded to fixed length. - self.assertAllClose(tensor_value[4:], [0, 0, 0, 0, 0, 0]) - elif tensor_debug_mode == "FULL_TENSOR": - self.assertAllClose(tensor_values[0], 5.0) # 1st AddV2 op. - self.assertAllClose(tensor_values[1], np.log(5.0)) # Log op. - self.assertAllClose(tensor_values[2], np.log(5.0) + 1.0) # 2nd AddV2 op. - self.assertAllClose(tensor_values[3], - np.sin(np.log(5.0) + 1.0)) # Sin op. + # 4th element: element count (1). + # Remaining elements: shape padded to fixed length (6). + # 1st AddV2 op. + self.assertAllClose( + graph_exec_traces[0].debug_tensor_value, + [add_op_digests[0].output_tensor_ids[0], 1, 0, 1, 0, 0, 0, 0, 0, 0]) + # Log op. + self.assertAllClose( + graph_exec_traces[1].debug_tensor_value, + [log_op_digests[0].output_tensor_ids[0], 1, 0, 1, 0, 0, 0, 0, 0, 0]) + # 2nd AddV2 op. + self.assertAllClose( + graph_exec_traces[2].debug_tensor_value, + [add_op_digests[1].output_tensor_ids[0], 1, 0, 1, 0, 0, 0, 0, 0, 0]) + # Sin op. + self.assertAllClose( + graph_exec_traces[3].debug_tensor_value, + [sin_op_digests[0].output_tensor_ids[0], 1, 0, 1, 0, 0, 0, 0, 0, 0]) + else: # FULL_TENSOR. + full_tensor_values = [ + reader.graph_execution_trace_to_tensor_value(trace) + for trace in graph_exec_traces] + self.assertAllClose(full_tensor_values[0], 5.0) # 1st AddV2 op. + self.assertAllClose(full_tensor_values[1], np.log(5.0)) # Log op. + self.assertAllClose( + full_tensor_values[2], np.log(5.0) + 1.0) # 2nd AddV2 op. + self.assertAllClose( + full_tensor_values[3], np.sin(np.log(5.0) + 1.0)) # Sin op. def testCapturingExecutedGraphIdsOfTwoCompilationsOfSameFunction(self): """Test correct executed IDs of two FuncGraphs from the same Py function.""" @@ -467,15 +519,21 @@ class TracingCallbackTest( writer.FlushNonExecutionFiles() writer.FlushExecutionFiles() - (executed_op_types, executed_graph_ids, - _, _, _, _) = self._readAndCheckExecutionFile() - self.assertLen(executed_op_types, 4) - for executed_op_type in executed_op_types: - self.assertStartsWith(executed_op_type, "__inference_ceil_times_two_") - self.assertLen(executed_graph_ids, 4) - self.assertEqual(executed_graph_ids[0], executed_graph_ids[2]) - self.assertEqual(executed_graph_ids[1], executed_graph_ids[3]) - self.assertLen(set(executed_graph_ids), 2) + with debug_events_reader.DebugDataReader(self.dump_root) as reader: + reader.update() + + executions = reader.executions() + self.assertLen(executions, 4) + for execution in executions: + self.assertStartsWith(execution.op_type, "__inference_ceil_times_two_") + executed_graph_ids = [execution.graph_id for execution in executions] + self.assertEqual(executed_graph_ids[0], executed_graph_ids[2]) + self.assertEqual(executed_graph_ids[1], executed_graph_ids[3]) + self.assertNotEqual(executed_graph_ids[0], executed_graph_ids[1]) + self.assertNotEqual(executed_graph_ids[2], executed_graph_ids[3]) + for executed_graph_id in executed_graph_ids: + self.assertEqual( + reader.graph_by_id(executed_graph_id).name, "ceil_times_two") def testCapturingExecutedGraphIdsOfDuplicateFunctionNames(self): """Two FuncGraphs compiled from Python functions with identical names.""" @@ -503,15 +561,20 @@ class TracingCallbackTest( writer.FlushNonExecutionFiles() writer.FlushExecutionFiles() - (executed_op_types, executed_graph_ids, - _, _, _, _) = self._readAndCheckExecutionFile() - self.assertLen(executed_op_types, 4) - for executed_op_type in executed_op_types: - self.assertStartsWith(executed_op_type, "__inference_ceil_times_two_") - self.assertLen(executed_graph_ids, 4) - self.assertEqual(executed_graph_ids[0], executed_graph_ids[2]) - self.assertEqual(executed_graph_ids[1], executed_graph_ids[3]) - self.assertLen(set(executed_graph_ids), 2) + with debug_events_reader.DebugDataReader(self.dump_root) as reader: + reader.update() + executions = reader.executions() + self.assertLen(executions, 4) + for execution in executions: + self.assertStartsWith(execution.op_type, "__inference_ceil_times_two_") + executed_graph_ids = [execution.graph_id for execution in executions] + self.assertEqual(executed_graph_ids[0], executed_graph_ids[2]) + self.assertEqual(executed_graph_ids[1], executed_graph_ids[3]) + self.assertNotEqual(executed_graph_ids[0], executed_graph_ids[1]) + self.assertNotEqual(executed_graph_ids[2], executed_graph_ids[3]) + for executed_graph_id in executed_graph_ids: + self.assertEqual( + reader.graph_by_id(executed_graph_id).name, "ceil_times_two") @parameterized.named_parameters( ("AddV2", "AddV2"), @@ -539,32 +602,35 @@ class TracingCallbackTest( writer.FlushNonExecutionFiles() writer.FlushExecutionFiles() - stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames() - (context_ids, op_types, - op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id) - self.assertIn("AddV2", op_types) - self.assertIn("Log", op_types) - self.assertIn("Sin", op_types) + with debug_events_reader.DebugDataReader(self.dump_root) as reader: + reader.update() + graph_op_digests = reader.graph_op_digests() + op_types = [digest.op_type for digest in graph_op_digests] + self.assertIn("AddV2", op_types) + self.assertIn("Log", op_types) + self.assertIn("Sin", op_types) - (op_names, _, _, - tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids) - executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names] - - if op_regex == "AddV2": - self.assertEqual(executed_op_types, ["AddV2", "AddV2"]) - self.assertLen(tensor_values, 2) - self.assertAllClose(tensor_values[0], 5.0) # 1st AddV2 op. - self.assertAllClose(tensor_values[1], np.log(5.0) + 1.0) # 2nd AddV2 op. - elif op_regex == "Log": - self.assertEqual(executed_op_types, ["Log"]) - self.assertLen(tensor_values, 1) - self.assertAllClose(tensor_values[0], np.log(5.0)) # Log op. - else: # "(AddV2|Log)" - self.assertEqual(executed_op_types, ["AddV2", "Log", "AddV2"]) - self.assertLen(tensor_values, 3) - self.assertAllClose(tensor_values[0], 5.0) # 1st AddV2 op. - self.assertAllClose(tensor_values[1], np.log(5.0)) # Log op. - self.assertAllClose(tensor_values[2], np.log(5.0) + 1.0) # 2nd AddV2 op. + graph_exec_digests = reader.graph_execution_traces(digest=True) + executed_op_types = [digest.op_type for digest in graph_exec_digests] + tensor_values = [reader.graph_execution_trace_to_tensor_value(digest) + for digest in graph_exec_digests] + if op_regex == "AddV2": + self.assertEqual(executed_op_types, ["AddV2", "AddV2"]) + self.assertLen(tensor_values, 2) + self.assertAllClose(tensor_values[0], 5.0) # 1st AddV2 op. + self.assertAllClose( + tensor_values[1], np.log(5.0) + 1.0) # 2nd AddV2 op. + elif op_regex == "Log": + self.assertEqual(executed_op_types, ["Log"]) + self.assertLen(tensor_values, 1) + self.assertAllClose(tensor_values[0], np.log(5.0)) # Log op. + else: # "(AddV2|Log)" + self.assertEqual(executed_op_types, ["AddV2", "Log", "AddV2"]) + self.assertLen(tensor_values, 3) + self.assertAllClose(tensor_values[0], 5.0) # 1st AddV2 op. + self.assertAllClose(tensor_values[1], np.log(5.0)) # Log op. + self.assertAllClose( + tensor_values[2], np.log(5.0) + 1.0) # 2nd AddV2 op. def testIncorrectTensorDTypeArgFormatLeadsToError(self): with self.assertRaisesRegexp( @@ -617,48 +683,54 @@ class TracingCallbackTest( writer.FlushNonExecutionFiles() writer.FlushExecutionFiles() - stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames() - (context_ids, _, - op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id) - (op_names, _, _, - tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids) - executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names] - if tensor_dtypes == [dtypes.float32] and not op_regex: - self.assertEqual(executed_op_types, ["Unique", "Sum"]) - self.assertLen(tensor_values, 2) - self.assertAllClose(tensor_values[0], [2., 6., 8., 1.]) # Unique values. - self.assertAllClose(tensor_values[1], 17.) # Sum. - elif tensor_dtypes == ["float32"] and op_regex == "Sum": - self.assertEqual(executed_op_types, ["Sum"]) - self.assertLen(tensor_values, 1) - self.assertAllClose(tensor_values[0], 17.) # Sum. - elif tensor_dtypes == (dtypes.float32,) and op_regex == "(?!Sum)": - self.assertEqual(executed_op_types, ["Unique"]) - self.assertLen(tensor_values, 1) - self.assertAllClose(tensor_values[0], [2., 6., 8., 1.]) # Unique values. - elif tensor_dtypes == [dtypes.int32] and not op_regex: - self.assertEqual(executed_op_types, ["Unique"]) - self.assertLen(tensor_values, 1) - self.assertAllEqual(tensor_values[0], [0, 1, 2, 3, 0]) # Unique indices. - elif callable(tensor_dtypes) and not op_regex: - self.assertEqual(executed_op_types, ["Unique"]) - self.assertLen(tensor_values, 1) - self.assertAllEqual(tensor_values[0], [0, 1, 2, 3, 0]) # Unique indices. - elif not tensor_dtypes and op_regex == "(?!Sum)": - self.assertEqual(executed_op_types, ["Unique", "Unique"]) - self.assertLen(tensor_values, 2) - self.assertAllClose(tensor_values[0], [2., 6., 8., 1.]) # Unique values. - self.assertAllEqual(tensor_values[1], [0, 1, 2, 3, 0]) # Unique indices. - else: # "All". - self.assertEqual(executed_op_types, ["Unique", "Unique", "Sum"]) - self.assertLen(tensor_values, 3) - self.assertAllClose(tensor_values[0], [2., 6., 8., 1.]) # Unique values. - self.assertAllEqual(tensor_values[1], [0, 1, 2, 3, 0]) # Unique indices. - self.assertAllClose(tensor_values[2], 17.) # Sum. + with debug_events_reader.DebugDataReader(self.dump_root) as reader: + reader.update() + graph_exec_digests = reader.graph_execution_traces(digest=True) + executed_op_types = [digest.op_type for digest in graph_exec_digests] + tensor_values = [reader.graph_execution_trace_to_tensor_value(digest) + for digest in graph_exec_digests] + + if tensor_dtypes == [dtypes.float32] and not op_regex: + self.assertEqual(executed_op_types, ["Unique", "Sum"]) + self.assertLen(tensor_values, 2) + self.assertAllClose(tensor_values[0], [2, 6, 8, 1]) # Unique values. + self.assertAllClose(tensor_values[1], 17.) # Sum. + elif tensor_dtypes == ["float32"] and op_regex == "Sum": + self.assertEqual(executed_op_types, ["Sum"]) + self.assertLen(tensor_values, 1) + self.assertAllClose(tensor_values[0], 17.) # Sum. + elif tensor_dtypes == (dtypes.float32,) and op_regex == "(?!Sum)": + self.assertEqual(executed_op_types, ["Unique"]) + self.assertLen(tensor_values, 1) + self.assertAllClose(tensor_values[0], [2, 6, 8, 1]) # Unique values. + elif tensor_dtypes == [dtypes.int32] and not op_regex: + self.assertEqual(executed_op_types, ["Unique"]) + self.assertLen(tensor_values, 1) + self.assertAllEqual( + tensor_values[0], [0, 1, 2, 3, 0]) # Unique indices. + elif callable(tensor_dtypes) and not op_regex: + self.assertEqual(executed_op_types, ["Unique"]) + self.assertLen(tensor_values, 1) + self.assertAllEqual( + tensor_values[0], [0, 1, 2, 3, 0]) # Unique indices. + elif not tensor_dtypes and op_regex == "(?!Sum)": + self.assertEqual(executed_op_types, ["Unique", "Unique"]) + self.assertLen(tensor_values, 2) + self.assertAllClose(tensor_values[0], [2, 6, 8, 1]) # Unique values. + self.assertAllEqual( + tensor_values[1], [0, 1, 2, 3, 0]) # Unique indices. + else: # "All". + self.assertEqual(executed_op_types, ["Unique", "Unique", "Sum"]) + self.assertLen(tensor_values, 3) + self.assertAllClose(tensor_values[0], [2, 6, 8, 1]) # Unique values. + self.assertAllEqual( + tensor_values[1], [0, 1, 2, 3, 0]) # Unique indices. + self.assertAllClose(tensor_values[2], 17) # Sum. @parameterized.named_parameters( ("NoTensor", "NO_TENSOR"), + ("CurtHealth", "CURT_HEALTH"), ("FullTensor", "FULL_TENSOR"), ) @test_util.run_in_graph_and_eager_modes @@ -679,86 +751,78 @@ class TracingCallbackTest( self.assertAllClose(self.evaluate(iterative_doubling(x, times)), 8.0) writer.FlushNonExecutionFiles() - stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames() + with debug_events_reader.DebugDataReader(self.dump_root) as reader: + reader.update() + graph_op_digests = reader.graph_op_digests() + op_types = [digest.op_type for digest in graph_op_digests] + self.assertIn("Less", op_types) + self.assertIn("Mul", op_types) + self.assertIn("AddV2", op_types) - # Verify the content of the .graphs file. - context_ids, op_types, op_name_to_op_type, _ = ( - self._readAndCheckGraphsFile(stack_frame_by_id)) - self.assertIn("Less", op_types) - self.assertIn("Mul", op_types) - self.assertIn("AddV2", op_types) - - # Before FlushExecutionFiles() is called, the .execution and - # .graph_execution_traces files should be both empty. - with debug_events_reader.DebugEventsReader(self.dump_root) as reader: - execution_iter = reader.execution_iterator() - graph_execution_traces_iter = reader.graph_execution_traces_iterator() - with self.assertRaises(StopIteration): - next(execution_iter) - with self.assertRaises(StopIteration): - next(graph_execution_traces_iter) + # Before FlushExecutionFiles() is called, the .execution and + # .graph_execution_traces files should be both empty. + self.assertEqual(reader.num_executions(), 0) + self.assertEqual(reader.num_graph_execution_traces(), 0) # TODO(cais): Backport execution instrumentation to tf.Session. writer.FlushExecutionFiles() # After the flushing, the .execution file should hold the appropriate # contents. + reader.update() if context.executing_eagerly(): - (executed_op_types, _, input_tensor_ids, output_tensor_ids, - tensor_debug_modes, tensor_values) = self._readAndCheckExecutionFile() # NOTE(b/142486213): Execution of the TF function happens with # Session.run() in v1 graph mode, hence it doesn't get logged to the - # .execution file. - self.assertLen(executed_op_types, 1) - self.assertIn("iterative_doubling", executed_op_types[0]) - self.assertLen(input_tensor_ids[0], 2) - self.assertLen(output_tensor_ids[0], 1) + executions = reader.executions() + self.assertLen(executions, 1) + executed_op_types = [execution.op_type for execution in executions] + self.assertIn("iterative_doubling", executions[0].op_type) + execution = executions[0] + self.assertLen(execution.input_tensor_ids, 2) + self.assertLen(execution.output_tensor_ids, 1) self.assertEqual( - tensor_debug_modes[0], - debug_event_pb2.TensorDebugMode.Value(tensor_debug_mode)) + debug_event_pb2.TensorDebugMode.keys()[execution.tensor_debug_mode], + tensor_debug_mode) if tensor_debug_mode == "FULL_TENSOR": - self.assertAllClose(tensor_values, [[8.0]]) + tensor_values = reader.execution_to_tensor_values(execution) + self.assertAllClose(tensor_values, [8.0]) - (op_names, _, output_slots, - tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids) - executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names] - # The Less op should have been executed 5 times. - self.assertEqual(executed_op_types.count("Less"), 5) - # The last executed op should be Less. - self.assertEqual(executed_op_types[-1], "Less") + graph_exec_traces = reader.graph_execution_traces() + executed_op_types = [trace.op_type for trace in graph_exec_traces] + if tensor_debug_mode != "CURT_HEALTH": + # Less outputs a boolean tensor, which is not tracked under CURT_HEALTH. + # The Less op should have been executed 5 times. + self.assertEqual(executed_op_types.count("Less"), 5) + # The last executed op should be Less. + self.assertEqual(executed_op_types[-1], "Less") + # AddV2 produces an int tensor, which is not tracked under CURT_HEALTH. + # The AddV2 op should have been run, but we refrain from asserting on + # how many times it's executed. + self.assertIn("AddV2", executed_op_types) + for trace in graph_exec_traces: + self.assertEqual(trace.output_slot, 0) # The Mul op should have been executed 4 times. self.assertEqual(executed_op_types.count("Mul"), 4) - # The AddV2 op should have been run, but we refrain from asserting on how - # many times it's executed. - self.assertIn("AddV2", executed_op_types) - for output_slot in output_slots: - self.assertEqual(output_slot, 0) + + tensor_values = [reader.graph_execution_trace_to_tensor_value(trace) + for trace in graph_exec_traces] if tensor_debug_mode == "NO_TENSOR": # Under the default NO_TENSOR tensor-debug mode, the tensor_proto ought # to be an empty float32 tensor. for tensor_value in tensor_values: - self.assertEqual(tensor_value.dtype, np.float32) - self.assertEqual(tensor_value.shape, (0,)) - elif tensor_debug_mode == "CURT_TENSOR": - for tensor_value in tensor_values: - self.assertLen(tensor_value, 2) - # 1st element: tensor_id, should be >= 0. - # TODO(cais): Assert on detailed value once Function-graph association - # is in place. - self.assertGreaterEqual(tensor_value[0], 0) - # 2nd element: 0 means there is no inf or nan. - self.assertEqual(tensor_value[1], 0) + self.assertAllEqual(tensor_value, []) + elif tensor_debug_mode == "CURT_HEALTH": + for trace in graph_exec_traces: + tensor_id = reader.graph_execution_trace_to_tensor_id(trace) + # 1st element: tensor_id; 2nd element: 0 indicating no inf or nan. + self.assertAllClose(trace.debug_tensor_value, [tensor_id, 0.0]) elif tensor_debug_mode == "FULL_TENSOR": less_values = [ - tensor_values[i] - for i, op_type in enumerate(executed_op_types) - if op_type == "Less" - ] - self.assertAllClose(less_values, [True, True, True, True, False]) + reader.graph_execution_trace_to_tensor_value(trace) + for trace in graph_exec_traces if trace.op_type == "Less"] + self.assertAllEqual(less_values, [True, True, True, True, False]) mul_values = [ - tensor_values[i] - for i, op_type in enumerate(executed_op_types) - if op_type == "Mul" - ] + reader.graph_execution_trace_to_tensor_value(trace) + for trace in graph_exec_traces if trace.op_type == "Mul"] self.assertAllClose(mul_values, [1.0, 2.0, 4.0, 8.0]) def testCallingEnableTracingTwiceWithTheSameDumpRootIsIdempotent(self): @@ -772,17 +836,16 @@ class TracingCallbackTest( writer.FlushNonExecutionFiles() writer.FlushExecutionFiles() - with debug_events_reader.DebugEventsReader(self.dump_root) as reader: - execution_iter = reader.execution_iterator() - for _ in range(2): - debug_event = next(execution_iter) - self.assertGreater(debug_event.wall_time, 0) - execution = debug_event.execution + with debug_events_reader.DebugDataReader(self.dump_root) as reader: + reader.update() + executions = reader.executions() + self.assertLen(executions, 2) + for execution in executions: + self.assertGreater(execution.wall_time, 0) self.assertEqual(execution.op_type, "Unique") self.assertEqual(execution.num_outputs, 2) - self.assertTrue(execution.code_location) - with self.assertRaises(StopIteration): - next(execution_iter) + _, stack_frames = reader.read_execution_stack_trace(execution) + self._verifyStackFrames(stack_frames) def testCallingEnableTracingTwiceWithDifferentDumpRootsOverwrites(self): dumping_callback.enable_dump_debug_info(self.dump_root) @@ -796,27 +859,26 @@ class TracingCallbackTest( writer.FlushNonExecutionFiles() writer.FlushExecutionFiles() - with debug_events_reader.DebugEventsReader(new_dump_root) as reader: - execution_iter = reader.execution_iterator() - for _ in range(2): - debug_event = next(execution_iter) - self.assertGreater(debug_event.wall_time, 0) - execution = debug_event.execution + with debug_events_reader.DebugDataReader(new_dump_root) as reader: + reader.update() + executions = reader.executions() + self.assertLen(executions, 2) + for execution in executions: + self.assertGreater(execution.wall_time, 0) self.assertEqual(execution.op_type, "Unique") self.assertEqual(execution.num_outputs, 2) - self.assertTrue(execution.code_location) - with self.assertRaises(StopIteration): - next(execution_iter) + _, stack_frames = reader.read_execution_stack_trace(execution) + self._verifyStackFrames(stack_frames) - with debug_events_reader.DebugEventsReader( - self.dump_root) as old_dump_root_reader: - execution_iter = old_dump_root_reader.execution_iterator() - # The old dump root shouldn't have been written to. - with self.assertRaises(StopIteration): - next(execution_iter) + with debug_events_reader.DebugDataReader( + self.dump_root) as old_dump_root_reader: + old_dump_root_reader.update() + # The old dump root shouldn't have been written to. + self.assertEqual(old_dump_root_reader.num_executions(), 0) + self.assertFalse(old_dump_root_reader.outermost_graphs()) def testCallingEnableRepeatedlyWithDifferentTensorDebugMode(self): - """Assert that calling enable_dump_debug_info() with different tensor-debug modes. + """Assert calling enable_dump_debug_info() with two tensor-debug modes. It should lead to overwriting of the previously-configured mode. """ @@ -830,16 +892,16 @@ class TracingCallbackTest( self.assertAllClose(add_1_divide_by_2(constant_op.constant(4.0)), 2.5) writer.FlushNonExecutionFiles() writer.FlushExecutionFiles() - stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames() - context_ids, _, _, _ = self._readAndCheckGraphsFile(stack_frame_by_id) - _, _, _, _, _, tensor_values = self._readAndCheckExecutionFile() - self.assertEqual(tensor_values, [[]]) - (_, _, _, - tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids) - self.assertLen(tensor_values, 2) - for tensor_value in tensor_values: - self.assertEqual(tensor_value.dtype, np.float32) - self.assertEqual(tensor_value.shape, (0,)) + + with debug_events_reader.DebugDataReader(self.dump_root) as reader: + reader.update() + graph_exec_digests = reader.graph_execution_traces(digest=True) + tensor_values = [reader.graph_execution_trace_to_tensor_value(digest) + for digest in graph_exec_digests] + for tensor_value in tensor_values: + # Under NO_TENSOR mode, each tensor is summarized as an empty float32 + # array. + self.assertAllEqual(tensor_value, []) with self.assertRaisesRegexp( ValueError, r"already.*NO_TENSOR.*FULL_TENSOR.*not be honored"): @@ -862,17 +924,11 @@ class TracingCallbackTest( writer.FlushNonExecutionFiles() writer.FlushExecutionFiles() - with debug_events_reader.DebugEventsReader(self.dump_root) as reader: - source_files_iter = reader.source_files_iterator() - stack_frames_iter = reader.stack_frames_iterator() - execution_iter = reader.execution_iterator() - # No source-file, stack-frame or execution data should have been dumped. - with self.assertRaises(StopIteration): - next(source_files_iter) - with self.assertRaises(StopIteration): - next(stack_frames_iter) - with self.assertRaises(StopIteration): - next(execution_iter) + with debug_events_reader.DebugDataReader(self.dump_root) as reader: + reader.update() + self.assertEqual(reader.num_executions(), 0) + self.assertEqual(reader.num_graph_execution_traces(), 0) + self.assertFalse(reader.outermost_graphs()) @parameterized.named_parameters( ("NoTensor", "NO_TENSOR"), @@ -908,73 +964,54 @@ class TracingCallbackTest( writer.FlushNonExecutionFiles() writer.FlushExecutionFiles() - stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames() - with debug_events_reader.DebugEventsReader(self.dump_root) as reader: - execution_iter = reader.execution_iterator() + with debug_events_reader.DebugDataReader(self.dump_root) as reader: + reader.update() + exec_digests = reader.executions(digest=True) prev_wall_time = 1 - for debug_event in execution_iter: - self.assertGreaterEqual(debug_event.wall_time, prev_wall_time) - prev_wall_time = debug_event.wall_time + for exec_digest in exec_digests: + self.assertGreaterEqual(exec_digest.wall_time, prev_wall_time) + prev_wall_time = exec_digest.wall_time - (context_ids, _, - op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id) + graph_exec_traces = reader.graph_execution_traces() + executed_op_types = [trace.op_type for trace in graph_exec_traces] + self.assertEqual(executed_op_types.count("Mul"), 1 + num_threads) + self.assertEqual( + executed_op_types.count("ReadVariableOp"), 2 * (1 + num_threads)) + for trace in graph_exec_traces: + # These are all single-output tensors. + self.assertEqual(trace.output_slot, 0) - (op_names, _, output_slots, - tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids) - executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names] - self.assertEqual(executed_op_types.count("Mul"), 1 + num_threads) - self.assertEqual( - executed_op_types.count("ReadVariableOp"), 2 * (1 + num_threads)) - for output_slot in output_slots: - self.assertEqual(output_slot, 0) + tensor_values = [reader.graph_execution_trace_to_tensor_value(trace) + for trace in graph_exec_traces] if tensor_debug_mode == "NO_TENSOR": for tensor_value in tensor_values: - self.assertEqual(tensor_value.dtype, np.float32) - self.assertEqual(tensor_value.shape, (0,)) + self.assertAllEqual(tensor_value, []) elif tensor_debug_mode == "CURT_HEALTH": - for tensor_value in tensor_values: - self.assertLen(tensor_value, 2) - # 1st element: tensor_id, should be >= 0. - # TODO(cais): Assert on detailed value once Function-graph association - # is in place. - self.assertGreaterEqual(tensor_value[0], 0) - # 2nd element: 0 means there is no inf or nan. - self.assertEqual(tensor_value[1], 0) + for trace in graph_exec_traces: + tensor_id = reader.graph_execution_trace_to_tensor_id(trace) + # 1st element: tensor ID; 2nd element: 0 indicating no inf or nan. + self.assertAllClose(trace.debug_tensor_value, [tensor_id, 0]) elif tensor_debug_mode == "CONCISE_HEALTH": for tensor_value in tensor_values: - self.assertLen(tensor_value, 5) - # 1st element: tensor_id, should be >= 0. - # TODO(cais): Assert on detailed value once Function-graph association - # is in place. - self.assertGreaterEqual(tensor_value[0], 0) + tensor_id = reader.graph_execution_trace_to_tensor_id(trace) + # 1st element: tensor ID. # 2nd element: element count. Remaining elements: all zero because there # is no -inf, inf or nan. - self.assertAllClose(tensor_value[1:], [1, 0, 0, 0]) + self.assertAllClose(trace.debug_tensor_value, [tensor_id, 1, 0, 0, 0]) elif tensor_debug_mode == "SHAPE": - mul_values = [ - tensor_values[i] - for i, op_type in enumerate(executed_op_types) - if op_type == "Mul" - ] - for mul_value in mul_values: - # 1st element: tensor_id, should be >= 0. - # TODO(cais): Assert on detailed value once Function-graph association - # is in place. - self.assertGreaterEqual(mul_value[0], 0) - # 2nd element: dtype enum value (float32). - self.assertEqual(mul_value[1], 1) - # 3rd element: rank. - self.assertEqual(mul_value[2], 0) - # 3rd element: element count. - self.assertEqual(mul_value[3], 1) - # Remaining elements: shape padded to a fixed length. - self.assertAllClose(mul_value[4:], [0, 0, 0, 0, 0, 0]) + for trace in graph_exec_traces: + if trace.op_type == "Mul": + tensor_id = reader.graph_execution_trace_to_tensor_id(trace) + mul_value = reader.graph_execution_trace_to_tensor_value(trace) + # 1st element: tensor_id, should be >= 0. + # 2nd element: dtype enum value (float32). + # 3rd element: rank. + # 4th element: element count. + self.assertAllClose(mul_value, [tensor_id, 1, 0, 1, 0, 0, 0, 0, 0, 0]) elif tensor_debug_mode == "FULL_TENSOR": mul_values = [ - tensor_values[i] - for i, op_type in enumerate(executed_op_types) - if op_type == "Mul" - ] + reader.graph_execution_trace_to_tensor_value(trace) + for trace in graph_exec_traces if trace.op_type == "Mul"] self.assertAllClose(mul_values, [6.0, 6.0, 6.0, 6.0]) def testMultiThreadedDumpingWithDifferentSettings(self): @@ -1017,23 +1054,28 @@ class TracingCallbackTest( self.assertAllClose(v1.read_value(), -67084290.0) self.assertAllClose(v2.read_value(), -6.0) - (executed_op_types, _, _, _, _, - tensor_values) = self._readAndCheckExecutionFile(dump_root=dump_root_1) - v1_squared_values = [ - tensor_values[i] for i, op_type in enumerate(executed_op_types) - if op_type == "Pow"] - negative_v1_squared_values = [ - tensor_values[i] for i, op_type in enumerate(executed_op_types) - if op_type == "Neg"] - self.assertAllClose(v1_squared_values, [[100.0], [8100.0], [67076100.0]]) - self.assertAllClose( - negative_v1_squared_values, [[-100.0], [-8100.0], [-67076100.0]]) + with debug_events_reader.DebugDataReader(dump_root_1) as reader: + reader.update() + exec_digests = reader.executions(digest=True) + v1_squared_values = [ + reader.execution_to_tensor_values(digest) + for digest in exec_digests if digest.op_type == "Pow"] + negative_v1_squared_values = [ + reader.execution_to_tensor_values(digest) + for digest in exec_digests if digest.op_type == "Neg"] + self.assertAllClose(v1_squared_values, [[100.0], [8100.0], [67076100.0]]) + self.assertAllClose( + negative_v1_squared_values, [[-100.0], [-8100.0], [-67076100.0]]) - (executed_op_types, _, _, _, _, - tensor_values) = self._readAndCheckExecutionFile(dump_root=dump_root_2) - self.assertNotIn("Neg", executed_op_types) - v2_squared_values = tensor_values[executed_op_types.index("Pow")] - self.assertAllClose(v2_squared_values, [9.0]) + with debug_events_reader.DebugDataReader(dump_root_2) as reader: + reader.update() + exec_digests = reader.executions(digest=True) + executed_op_types = [digest.op_type for digest in exec_digests] + self.assertNotIn("Neg", executed_op_types) + v2_squared_values = [ + reader.execution_to_tensor_values(digest) + for digest in exec_digests if digest.op_type == "Pow"] + self.assertAllClose(v2_squared_values, [[9.0]]) @test_util.run_in_graph_and_eager_modes def testNestedContextIsCapturedByGraphOpCreationHistory(self): @@ -1055,36 +1097,18 @@ class TracingCallbackTest( writer.FlushNonExecutionFiles() writer.FlushExecutionFiles() - - stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames() - (_, _, op_name_to_op_type, - op_name_to_context_id) = self._readAndCheckGraphsFile(stack_frame_by_id) - - less_op_names = [op_name for op_name in op_name_to_op_type - if op_name_to_op_type[op_name] == "Less"] - less_context_ids = [op_name_to_context_id[op_name] - for op_name in less_op_names] - mul_op_names = [op_name for op_name in op_name_to_op_type - if op_name_to_op_type[op_name] == "Mul"] - mul_context_ids = [op_name_to_context_id[op_name] - for op_name in mul_op_names] - sub_op_names = [op_name for op_name in op_name_to_op_type - if op_name_to_op_type[op_name] == "Sub"] - sub_context_ids = [op_name_to_context_id[op_name] - for op_name in sub_op_names] - self.assertLen(less_context_ids, 1) - self.assertLen(mul_context_ids, 1) - self.assertLen(sub_context_ids, 1) - self.assertTrue(less_context_ids[0]) - self.assertTrue(mul_context_ids[0]) - self.assertTrue(sub_context_ids[0]) - # The Less op is from the while-loop cond context and hence should have - # a different innermost context ID from the mul and sub ops, which are both - # from the while-loop body context. - self.assertNotEqual(less_context_ids[0], mul_context_ids[0]) - self.assertNotEqual(less_context_ids[0], sub_context_ids[0]) - # The Mul and Sub ops are from the same innermost context. - self.assertEqual(mul_context_ids[0], sub_context_ids[0]) + with debug_events_reader.DebugDataReader(self.dump_root) as reader: + reader.update() + less_op_digest = reader.graph_op_digests(op_type="Less")[-1] + mul_op_digest = reader.graph_op_digests(op_type="Mul")[-1] + sub_op_digest = reader.graph_op_digests(op_type="Sub")[-1] + # The Less op is from the while-loop cond context and hence should have + # a different innermost context ID from the mul and sub ops, which are + # both from the while-loop body context. + self.assertNotEqual(less_op_digest.graph_id, mul_op_digest.graph_id) + self.assertNotEqual(less_op_digest.graph_id, sub_op_digest.graph_id) + # The Mul and Sub ops are from the same innermost context. + self.assertEqual(mul_op_digest.graph_id, sub_op_digest.graph_id) @parameterized.named_parameters( ("NoTensor", "NO_TENSOR"), @@ -1102,53 +1126,38 @@ class TracingCallbackTest( writer.FlushNonExecutionFiles() writer.FlushExecutionFiles() - stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames() - (context_ids, op_types, - op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id) - # Simply assert that graph are recorded and refrain from asserting on the - # internal details of the Keras model. - self.assertTrue(context_ids) - self.assertTrue(op_types) - self.assertTrue(op_name_to_op_type) + with debug_events_reader.DebugDataReader(self.dump_root) as reader: + reader.update() + if context.executing_eagerly(): + # NOTE(b/142486213): Execution of the TF function happens with + # Session.run() in v1 graph mode, hence it doesn't get logged to the + # .execution file. + self.assertTrue(reader.executions(digest=True)) - if context.executing_eagerly(): - # NOTE(b/142486213): Execution of the TF function happens with - # Session.run() in v1 graph mode, hence it doesn't get logged to the - # .execution file. - (executed_op_types, _, _, _, _, - tensor_values) = self._readAndCheckExecutionFile() - self.assertTrue(executed_op_types) + graph_exec_digests = reader.graph_execution_traces(digest=True) + executed_op_types = [digest.op_type for digest in graph_exec_digests] + # These are the ops that we can safely assume to have been executed during + # the model prediction. + self.assertIn("MatMul", executed_op_types) + self.assertIn("BiasAdd", executed_op_types) + # On the GPU, CudnnRNN is used in lieu of the default op-by-op + # implementation. + self.assertTrue( + ("Sigmoid" in executed_op_types and "Tanh" in executed_op_types or + "CudnnRNN" in executed_op_types)) - for value_list in tensor_values: - if tensor_debug_mode == "NO_TENSOR": - self.assertFalse(value_list) - - (op_names, _, _, - tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids) - executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names] - # These are the ops that we can safely assume to have been executed during - # the model prediction. - self.assertIn("MatMul", executed_op_types) - self.assertIn("BiasAdd", executed_op_types) - # On the GPU, CudnnRNN is used in lieu of the default op-by-op - # implementation. - self.assertTrue( - ("Sigmoid" in executed_op_types and "Tanh" in executed_op_types or - "CudnnRNN" in executed_op_types)) - # Under the default NO_TENSOR tensor-debug mode, the tensor_proto ought to - # be an empty float32 tensor. - if tensor_debug_mode == "NO_TENSOR": - for tensor_value in tensor_values: - self.assertEqual(tensor_value.dtype, np.float32) - self.assertEqual(tensor_value.shape, (0,)) - else: - # Refrain from asserting the internal implementation details of the LSTM - # layer. - concrete_tensor_values = [ - value for value in tensor_values - if value is not None and value.size > 0 - ] - self.assertTrue(concrete_tensor_values) + # Under the default NO_TENSOR tensor-debug mode, the tensor_proto ought to + # be an empty float32 tensor. + tensor_values = [reader.graph_execution_trace_to_tensor_value(digest) + for digest in graph_exec_digests] + if tensor_debug_mode == "NO_TENSOR": + for tensor_value in tensor_values: + self.assertAllEqual(tensor_value, []) + else: + # Refrain from asserting the internal implementation details of the LSTM + # layer. + self.assertTrue(any( + bool(tensor_value.size) for tensor_value in tensor_values)) @parameterized.named_parameters( ("NoTensor", "NO_TENSOR"), @@ -1169,48 +1178,38 @@ class TracingCallbackTest( writer.FlushNonExecutionFiles() writer.FlushExecutionFiles() - stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames() - (context_ids, op_types, - op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id) - # Simply assert that graph are recorded and refrain from asserting on the - # internal details of the Keras model. - self.assertTrue(context_ids) - self.assertTrue(op_types) - self.assertTrue(op_name_to_op_type) + with debug_events_reader.DebugDataReader(self.dump_root) as reader: + reader.update() + if context.executing_eagerly(): + exec_digests = reader.executions(digest=True) + self.assertTrue(exec_digests) + if tensor_debug_mode == "NO_TENSOR": + for digest in exec_digests: + tensor_values = reader.execution_to_tensor_values(digest) + for tensor_value in tensor_values: + self.assertEqual(tensor_value, []) - if context.executing_eagerly(): - # NOTE(b/142486213): Execution of the TF function happens with - # Session.run() in v1 graph mode, hence it doesn't get logged to the - # .execution file. - (executed_op_types, _, _, _, _, - tensor_values) = self._readAndCheckExecutionFile() - self.assertTrue(executed_op_types) + graph_exec_digests = reader.graph_execution_traces(digest=True) + executed_op_types = [digest.op_type for digest in graph_exec_digests] + # These are the ops that we can safely assume to have been executed during + # the recurrent model's fit() call. + self.assertIn("MatMul", executed_op_types) + self.assertIn("BiasAdd", executed_op_types) + + # On the GPU, CudnnRNN is used in lieu of the default op-by-op + # implementation. + self.assertTrue( + ("Sigmoid" in executed_op_types and "Tanh" in executed_op_types or + "CudnnRNN" in executed_op_types)) + self.assertTrue( + ("SigmoidGrad" in executed_op_types and + "TanhGrad" in executed_op_types or + "CudnnRNNBackprop" in executed_op_types)) if tensor_debug_mode == "NO_TENSOR": - for value_list in tensor_values: - self.assertFalse(value_list) - - (op_names, _, _, - tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids) - executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names] - # These are the ops that we can safely assume to have been executed during - # the recurrent model's fit() call. - self.assertIn("MatMul", executed_op_types) - self.assertIn("BiasAdd", executed_op_types) - # On the GPU, CudnnRNN is used in lieu of the default op-by-op - # implementation. - self.assertTrue( - ("Sigmoid" in executed_op_types and "Tanh" in executed_op_types or - "CudnnRNN" in executed_op_types)) - self.assertTrue( - ("SigmoidGrad" in executed_op_types and - "TanhGrad" in executed_op_types or - "CudnnRNNBackprop" in executed_op_types)) - if tensor_debug_mode == "NO_TENSOR": - # Under the default NO_TENSOR tensor-debug mode, the tensor_proto ought - # to be an empty float32 tensor. - for tensor_value in tensor_values: - self.assertEqual(tensor_value.dtype, np.float32) - self.assertEqual(tensor_value.shape, (0,)) + for digest in graph_exec_digests: + tensor_values = reader.graph_execution_trace_to_tensor_value(digest) + for tensor_value in tensor_values: + self.assertEqual(tensor_value, []) @parameterized.named_parameters( ("NoTensor", "NO_TENSOR"), @@ -1242,72 +1241,60 @@ class TracingCallbackTest( writer.FlushNonExecutionFiles() writer.FlushExecutionFiles() - stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames() - (context_ids, op_types, - op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id) - # Simply assert that graph are recorded and refrain from asserting on the - # internal details of the Keras model. - self.assertTrue(context_ids) - self.assertTrue(op_types) - self.assertTrue(op_name_to_op_type) + with debug_events_reader.DebugDataReader(self.dump_root) as reader: + reader.update() + if context.executing_eagerly(): + # NOTE(b/142486213): Execution of the TF function happens with + # Session.run() in v1 graph mode, hence it doesn't get logged to the + # .execution file. + exec_digests = reader.executions(digest=True) + self.assertTrue(exec_digests) - if context.executing_eagerly(): - # NOTE(b/142486213): Execution of the TF function happens with - # Session.run() in v1 graph mode, hence it doesn't get logged to the - # .execution file. - executed_op_types, _, _, _, _, _ = self._readAndCheckExecutionFile() - self.assertTrue(executed_op_types) + graph_exec_digests = reader.graph_execution_traces() + executed_op_types = [digest.op_type for digest in graph_exec_digests] + # These are the ops that we can safely assume to have been executed during + # the model's fit() call. + self.assertIn("Conv2D", executed_op_types) + self.assertIn("Relu6", executed_op_types) + self.assertIn("Conv2DBackpropFilter", executed_op_types) + self.assertIn("Relu6Grad", executed_op_types) - (op_names, _, _, - tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids) - executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names] - # These are the ops that we can safely assume to have been executed during - # the model's fit() call. - self.assertIn("Conv2D", executed_op_types) - self.assertIn("Relu6", executed_op_types) - self.assertIn("Conv2DBackpropFilter", executed_op_types) - self.assertIn("Relu6Grad", executed_op_types) - if tensor_debug_mode == "NO_TENSOR": - # Under the default NO_TENSOR tensor-debug mode, the tensor_proto ought to - # be an empty float32 tensor. - for tensor_value in tensor_values: - self.assertEqual(tensor_value.dtype, np.float32) - self.assertEqual(tensor_value.shape, (0,)) - elif tensor_debug_mode == "FULL_TENSOR": - conv2d_values = [ - tensor_values[i] - for i, op_type in enumerate(executed_op_types) - if op_type == "Conv2D" - ] - self.assertTrue(conv2d_values) - for conv2d_value in conv2d_values: - self.assertGreater(len(conv2d_value.shape), 1) - self.assertEqual(conv2d_value.shape[0], batch_size) - relu6_values = [ - tensor_values[i] - for i, op_type in enumerate(executed_op_types) - if op_type == "Relu6" - ] - self.assertTrue(relu6_values) - for relu6_value in relu6_values: - self.assertGreater(len(relu6_value.shape), 1) - self.assertEqual(relu6_value.shape[0], batch_size) - conv2d_bp_filter_values = [ - tensor_values[i] - for i, op_type in enumerate(executed_op_types) - if op_type == "Conv2DBackpropFilter" - ] - self.assertTrue(conv2d_bp_filter_values) - for conv2d_bp_filter_value in conv2d_bp_filter_values: - self.assertGreater(len(conv2d_bp_filter_value.shape), 1) - relu6_grad_values = [ - tensor_values[i] - for i, op_type in enumerate(executed_op_types) - if op_type == "Relu6Grad" - ] - self.assertTrue(relu6_grad_values) - for relu6_grad_value in relu6_grad_values: - self.assertGreater(len(relu6_grad_value.shape), 1) + if tensor_debug_mode == "NO_TENSOR": + # Under the default NO_TENSOR tensor-debug mode, the tensor_proto ought + # to be an empty float32 tensor. + tensor_values = [ + reader.graph_execution_trace_to_tensor_value(digest) + for digest in graph_exec_digests] + for tensor_value in tensor_values: + self.assertAllEqual(tensor_value, []) + elif tensor_debug_mode == "FULL_TENSOR": + conv2d_values = [ + reader.graph_execution_trace_to_tensor_value(digest) + for digest in graph_exec_digests if digest.op_type == "Conv2D"] + self.assertTrue(conv2d_values) + for conv2d_value in conv2d_values: + self.assertGreater(len(conv2d_value.shape), 1) + self.assertEqual(conv2d_value.shape[0], batch_size) + relu6_values = [ + reader.graph_execution_trace_to_tensor_value(digest) + for digest in graph_exec_digests if digest.op_type == "Relu6"] + self.assertTrue(relu6_values) + for relu6_value in relu6_values: + self.assertGreater(len(relu6_value.shape), 1) + self.assertEqual(relu6_value.shape[0], batch_size) + conv2d_bp_filter_values = [ + reader.graph_execution_trace_to_tensor_value(digest) + for digest in graph_exec_digests + if digest.op_type == "Conv2DBackpropFilter"] + self.assertTrue(conv2d_bp_filter_values) + for conv2d_bp_filter_value in conv2d_bp_filter_values: + self.assertGreater(len(conv2d_bp_filter_value.shape), 1) + relu6_grad_values = [ + reader.graph_execution_trace_to_tensor_value(digest) + for digest in graph_exec_digests if digest.op_type == "Relu6Grad"] + self.assertTrue(relu6_grad_values) + for relu6_grad_value in relu6_grad_values: + self.assertGreater(len(relu6_grad_value.shape), 1) if __name__ == "__main__": diff --git a/tensorflow/python/debug/lib/dumping_callback_test_lib.py b/tensorflow/python/debug/lib/dumping_callback_test_lib.py index 6144f2ba9cc..1d449f61e0b 100644 --- a/tensorflow/python/debug/lib/dumping_callback_test_lib.py +++ b/tensorflow/python/debug/lib/dumping_callback_test_lib.py @@ -52,7 +52,7 @@ class DumpingCallbackTestBase(test_util.TensorFlowTestCase): """Read and check the .metadata debug-events file.""" with debug_events_reader.DebugEventsReader(self.dump_root) as reader: metadata_iter = reader.metadata_iterator() - metadata = next(metadata_iter).debug_metadata + metadata = next(metadata_iter).debug_event.debug_metadata self.assertEqual(metadata.tensorflow_version, versions.__version__) self.assertTrue(metadata.file_version.startswith("debug.Event")) @@ -67,7 +67,7 @@ class DumpingCallbackTestBase(test_util.TensorFlowTestCase): source_files_iter = reader.source_files_iterator() source_file_paths = [] prev_wall_time = 1 - for debug_event in source_files_iter: + for debug_event, _ in source_files_iter: self.assertGreaterEqual(debug_event.wall_time, prev_wall_time) prev_wall_time = debug_event.wall_time source_file = debug_event.source_file @@ -84,7 +84,7 @@ class DumpingCallbackTestBase(test_util.TensorFlowTestCase): stack_frame_by_id = collections.OrderedDict() stack_frames_iter = reader.stack_frames_iterator() prev_wall_time = 0 - for debug_event in stack_frames_iter: + for debug_event, _ in stack_frames_iter: self.assertGreaterEqual(debug_event.wall_time, prev_wall_time) prev_wall_time = debug_event.wall_time stack_frame_with_id = debug_event.stack_frame_with_id @@ -133,7 +133,7 @@ class DumpingCallbackTestBase(test_util.TensorFlowTestCase): # outermost contexts). context_id_to_outer_id = dict() - for debug_event in graphs_iter: + for debug_event, _ in graphs_iter: self.assertGreaterEqual(debug_event.wall_time, prev_wall_time) prev_wall_time = debug_event.wall_time # A DebugEvent in the .graphs file contains either of the two fields: @@ -219,7 +219,7 @@ class DumpingCallbackTestBase(test_util.TensorFlowTestCase): output_tensor_ids = [] tensor_debug_modes = [] tensor_values = [] - for debug_event in execution_iter: + for debug_event, _ in execution_iter: self.assertGreaterEqual(debug_event.wall_time, prev_wall_time) prev_wall_time = debug_event.wall_time execution = debug_event.execution @@ -260,7 +260,7 @@ class DumpingCallbackTestBase(test_util.TensorFlowTestCase): device_names = [] output_slots = [] tensor_values = [] - for debug_event in graph_execution_traces_iter: + for debug_event, _ in graph_execution_traces_iter: self.assertGreaterEqual(debug_event.wall_time, 0) graph_execution_trace = debug_event.graph_execution_trace op_names.append(graph_execution_trace.op_name)