[tfdbg] Add device_name to GraphExecutionTrace proto; Refactor test.
- The device_name information will be important for debugging distributed execution (e.g., MirroredStrategy on multiple devices) - Refactor out a base class for the unit test for the dumping callback. - The base class will be used in the to-be-added tests for tfdbg v2 + DistributionStrategy. PiperOrigin-RevId: 276332541 Change-Id: Ifa47b3c86fb2706624fecdcc4eac027ce6c9172a
This commit is contained in:
parent
b07081a4df
commit
2e84f72d2d
@ -395,7 +395,10 @@ class DebugNumericSummaryOp : public BaseDebugOp {
|
||||
class DebugIdentityV2Op : public OpKernel {
|
||||
public:
|
||||
explicit DebugIdentityV2Op(OpKernelConstruction* context)
|
||||
: OpKernel(context), output_slot_(-1), tensor_debug_mode_(0) {
|
||||
: OpKernel(context),
|
||||
device_name_(context->device()->name()),
|
||||
output_slot_(-1),
|
||||
tensor_debug_mode_(0) {
|
||||
std::vector<string> debug_urls;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("debug_urls", &debug_urls));
|
||||
for (const string& debug_url : debug_urls) {
|
||||
@ -420,9 +423,9 @@ class DebugIdentityV2Op : public OpKernel {
|
||||
for (const string& dump_root : dump_roots_) {
|
||||
tfdbg::DebugEventsWriter* debug_events_writer =
|
||||
tfdbg::DebugEventsWriter::GetDebugEventsWriter(dump_root);
|
||||
debug_events_writer->WriteGraphExecutionTrace(tfdbg_context_id_, op_name_,
|
||||
output_slot_,
|
||||
tensor_debug_mode_, tensor);
|
||||
debug_events_writer->WriteGraphExecutionTrace(
|
||||
tfdbg_context_id_, device_name_, op_name_, output_slot_,
|
||||
tensor_debug_mode_, tensor);
|
||||
}
|
||||
context->set_output(0, tensor);
|
||||
}
|
||||
@ -430,6 +433,7 @@ class DebugIdentityV2Op : public OpKernel {
|
||||
private:
|
||||
std::vector<string> dump_roots_;
|
||||
string tfdbg_context_id_;
|
||||
string device_name_;
|
||||
string op_name_;
|
||||
int32 output_slot_;
|
||||
int32 tensor_debug_mode_;
|
||||
|
@ -252,4 +252,7 @@ message GraphExecutionTrace {
|
||||
// This tensor may summarize the value of a single intermediate op of the
|
||||
// graph, or those of multiple intermediate tensors.
|
||||
TensorProto tensor_proto = 5;
|
||||
|
||||
// Name of the device that the op belongs to.
|
||||
string device_name = 6;
|
||||
}
|
||||
|
@ -261,6 +261,7 @@ void DebugEventsWriter::WriteGraphExecutionTrace(
|
||||
}
|
||||
|
||||
void DebugEventsWriter::WriteGraphExecutionTrace(const string& tfdbg_context_id,
|
||||
const string& device_name,
|
||||
const string& op_name,
|
||||
int32 output_slot,
|
||||
int32 tensor_debug_mode,
|
||||
@ -276,6 +277,7 @@ void DebugEventsWriter::WriteGraphExecutionTrace(const string& tfdbg_context_id,
|
||||
if (tensor_debug_mode > 0) {
|
||||
trace->set_tensor_debug_mode(TensorDebugMode(tensor_debug_mode));
|
||||
}
|
||||
trace->set_device_name(device_name);
|
||||
tensor_value.AsProtoTensorContent(trace->mutable_tensor_proto());
|
||||
WriteGraphExecutionTrace(trace.release());
|
||||
}
|
||||
|
@ -155,6 +155,7 @@ class DebugEventsWriter {
|
||||
// that this trace is concerned with. The sematics of this tensor value
|
||||
// depends on the value of `tensor_debug_mode`.
|
||||
void WriteGraphExecutionTrace(const string& tfdbg_context_id,
|
||||
const string& device_name,
|
||||
const string& op_name, int32 output_slot,
|
||||
int32 tensor_debug_mode,
|
||||
const Tensor& tensor_value);
|
||||
|
@ -32,6 +32,7 @@ py_library(
|
||||
":debug_graphs",
|
||||
":debug_utils",
|
||||
":dumping_callback",
|
||||
":dumping_callback_test_lib",
|
||||
":grpc_debug_server",
|
||||
":grpc_debug_test_server",
|
||||
":hooks",
|
||||
@ -93,6 +94,18 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "dumping_callback_test_lib",
|
||||
srcs = ["lib/dumping_callback_test_lib.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":check_numerics_callback",
|
||||
":debug_events_reader",
|
||||
":dumping_callback",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "common",
|
||||
srcs = ["lib/common.py"],
|
||||
@ -697,6 +710,7 @@ cuda_py_test(
|
||||
":debug_events_reader",
|
||||
":debug_events_writer",
|
||||
":dumping_callback",
|
||||
":dumping_callback_test_lib",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:math_ops",
|
||||
|
@ -21,7 +21,6 @@ from __future__ import print_function
|
||||
import collections
|
||||
import os
|
||||
import shutil
|
||||
import socket
|
||||
import tempfile
|
||||
import threading
|
||||
|
||||
@ -31,6 +30,7 @@ import numpy as np
|
||||
from tensorflow.core.protobuf import debug_event_pb2
|
||||
from tensorflow.python.debug.lib import debug_events_reader
|
||||
from tensorflow.python.debug.lib import dumping_callback
|
||||
from tensorflow.python.debug.lib import dumping_callback_test_lib
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -38,7 +38,6 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework import versions
|
||||
from tensorflow.python.keras import models
|
||||
from tensorflow.python.keras.applications import mobilenet_v2
|
||||
from tensorflow.python.keras.layers import core
|
||||
@ -62,7 +61,8 @@ def _create_simple_recurrent_keras_model(input_shape):
|
||||
return model
|
||||
|
||||
|
||||
class TracingCallbackTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
class TracingCallbackTest(
|
||||
dumping_callback_test_lib.DumpingCallbackTestBase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(TracingCallbackTest, self).setUp()
|
||||
@ -74,175 +74,6 @@ class TracingCallbackTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
dumping_callback.disable_dumping()
|
||||
super(TracingCallbackTest, self).tearDown()
|
||||
|
||||
def _readAndCheckMetadataFile(self):
|
||||
"""Read and check the .metadata debug-events file."""
|
||||
reader = debug_events_reader.DebugEventsDir(self.dump_root)
|
||||
metadata_iter = reader.metadata_iterator()
|
||||
metadata = next(metadata_iter).debug_metadata
|
||||
self.assertEqual(metadata.tensorflow_version, versions.__version__)
|
||||
self.assertTrue(metadata.file_version.startswith("debug.Event"))
|
||||
|
||||
def _readAndCheckSourceFilesAndStackFrames(self):
|
||||
"""Read and verify the .source_files & .stack_frames debug-event files.
|
||||
|
||||
Returns:
|
||||
A dict mapping stack frame IDs to stack frames (FileLineCol).
|
||||
"""
|
||||
reader = debug_events_reader.DebugEventsDir(self.dump_root)
|
||||
# Check the content of the .source_files file.
|
||||
source_files_iter = reader.source_files_iterator()
|
||||
source_file_paths = []
|
||||
prev_wall_time = 1
|
||||
for debug_event in source_files_iter:
|
||||
self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
|
||||
prev_wall_time = debug_event.wall_time
|
||||
source_file = debug_event.source_file
|
||||
self.assertEqual(source_file.host_name, socket.gethostname())
|
||||
self.assertTrue(source_file.file_path)
|
||||
if source_file.lines:
|
||||
self.assertTrue(os.path.isfile(source_file.file_path))
|
||||
source_file_paths.append(source_file.file_path)
|
||||
# Assert the file paths are unique.
|
||||
self.assertEqual(len(source_file_paths), len(set(source_file_paths)))
|
||||
|
||||
# Check the content of the .stack_frames file.
|
||||
stack_frame_by_id = dict() # A map from ID to stack frame.
|
||||
stack_frames_iter = reader.stack_frames_iterator()
|
||||
prev_wall_time = 0
|
||||
for debug_event in stack_frames_iter:
|
||||
self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
|
||||
prev_wall_time = debug_event.wall_time
|
||||
stack_frame_with_id = debug_event.stack_frame_with_id
|
||||
stack_frame_id = stack_frame_with_id.id
|
||||
file_line_col = stack_frame_with_id.file_line_col
|
||||
self.assertTrue(stack_frame_id)
|
||||
self.assertNotIn(stack_frame_id, stack_frame_by_id,
|
||||
"Duplicate stack frame ID: %s" % id)
|
||||
stack_frame_by_id[stack_frame_id] = (file_line_col.file_index,
|
||||
file_line_col.line,
|
||||
file_line_col.func)
|
||||
self.assertGreaterEqual(file_line_col.file_index, 0)
|
||||
self.assertLess(file_line_col.file_index, len(source_file_paths))
|
||||
self.assertTrue(file_line_col.line) # Line numbers are 1-based.
|
||||
self.assertTrue(file_line_col.func)
|
||||
# Assert the stack frames are unique.
|
||||
self.assertEqual(
|
||||
len(stack_frame_by_id.values()), len(set(stack_frame_by_id.values())))
|
||||
return stack_frame_by_id
|
||||
|
||||
def _readAndCheckGraphsFile(self, stack_frame_by_id):
|
||||
"""Read and verify the content of the .graphs debug-event file.
|
||||
|
||||
Args:
|
||||
stack_frame_by_id: A dict mapping unique string IDs to stack frames.
|
||||
It is used by this method to look up stack frames.
|
||||
|
||||
Returns:
|
||||
context_ids: IDs of op creation contexts (e.g., TensorFlow graphs), as a
|
||||
`list` of `str`s.
|
||||
op_types: Types of the ops that are created, as a `list` of `str`s with
|
||||
the same length as `context_ids`.
|
||||
op_name_to_op_type: A `dict` mapping op name to op type.
|
||||
"""
|
||||
reader = debug_events_reader.DebugEventsDir(self.dump_root)
|
||||
graphs_iter = reader.graphs_iterator()
|
||||
prev_wall_time = 0
|
||||
op_types = []
|
||||
op_name_to_op_type = dict()
|
||||
context_ids = set()
|
||||
for debug_event in graphs_iter:
|
||||
self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
|
||||
prev_wall_time = debug_event.wall_time
|
||||
graph_op_creation = debug_event.graph_op_creation
|
||||
self.assertTrue(graph_op_creation.op_type)
|
||||
op_types.append(graph_op_creation.op_type)
|
||||
self.assertTrue(graph_op_creation.op_name)
|
||||
op_name_to_op_type[graph_op_creation.op_name] = graph_op_creation.op_type
|
||||
self.assertTrue(graph_op_creation.graph_id)
|
||||
context_ids.add(graph_op_creation.graph_id)
|
||||
self.assertTrue(graph_op_creation.code_location)
|
||||
for stack_frame_id in graph_op_creation.code_location.stack_frame_ids:
|
||||
self.assertIn(stack_frame_id, stack_frame_by_id)
|
||||
return context_ids, op_types, op_name_to_op_type
|
||||
|
||||
def _readAndCheckExecutionFile(self):
|
||||
"""Read and verify the content of the .execution debug-event file.
|
||||
|
||||
Returns:
|
||||
executed_op_types: Types of ops that are created, as a `list` of `str`.
|
||||
input_tensor_ids: Input tensor IDs for each of the ops executed, as a
|
||||
`list` of `list` of `int`s, with the same length as `executed_op_types`.
|
||||
output_tensor_ids: Output tensor IDs for each of the ops executed, as a
|
||||
`list` of `list` of `int`s, with the same length as `executed_op_types`.
|
||||
tensor_debug_modes: Tensor debug modes used to instrument each of ops
|
||||
executed.
|
||||
tensor_values: A `list` of `list` of `np.ndarray`s, representing the
|
||||
tensor values. Each item of the outer `list` corresponds to one
|
||||
execution event. Each item of the inner `list` corresponds to one
|
||||
output tensor slot of the executed op or Function.
|
||||
"""
|
||||
reader = debug_events_reader.DebugEventsDir(self.dump_root)
|
||||
execution_iter = reader.execution_iterator()
|
||||
prev_wall_time = 1
|
||||
executed_op_types = []
|
||||
input_tensor_ids = []
|
||||
output_tensor_ids = []
|
||||
tensor_debug_modes = []
|
||||
tensor_values = []
|
||||
for debug_event in execution_iter:
|
||||
self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
|
||||
prev_wall_time = debug_event.wall_time
|
||||
execution = debug_event.execution
|
||||
executed_op_types.append(execution.op_type)
|
||||
input_tensor_ids.append(execution.input_tensor_ids)
|
||||
output_tensor_ids.append(execution.output_tensor_ids)
|
||||
tensor_debug_modes.append(execution.tensor_debug_mode)
|
||||
tensor_values.append([
|
||||
tensor_util.MakeNdarray(tensor_proto)
|
||||
for tensor_proto in execution.tensor_protos
|
||||
])
|
||||
|
||||
# TODO(cais): When tensor debug modes other than NO_TENSOR is supported,
|
||||
# return tensor_values as well.
|
||||
return (executed_op_types, input_tensor_ids, output_tensor_ids,
|
||||
tensor_debug_modes, tensor_values)
|
||||
|
||||
def _readAndCheckGraphExecutionTracesFile(self, context_ids):
|
||||
"""Read & verify the content of the .graph_execution_trace debug-event file.
|
||||
|
||||
Args:
|
||||
context_ids: Op-creation context IDs from _readAndCheckGraphsFile().
|
||||
|
||||
Returns:
|
||||
op_names: Names of the ops that are executed, as a `list` of `str`s.
|
||||
output_slots: Output slots, as a `list` of `int`s, with the same length as
|
||||
`op_names`. In other words, for an executed op with N output tensors,
|
||||
there will be N entries in this `list` and in `op_names`, at
|
||||
corresponding indices.
|
||||
tensor_values: Tensor values or their concise summaries, depending on
|
||||
TensorDebugMode.
|
||||
"""
|
||||
reader = debug_events_reader.DebugEventsDir(self.dump_root)
|
||||
graph_execution_traces_iter = reader.graph_execution_traces_iterator()
|
||||
op_names = []
|
||||
output_slots = []
|
||||
tensor_values = []
|
||||
for debug_event in graph_execution_traces_iter:
|
||||
self.assertGreaterEqual(debug_event.wall_time, 0)
|
||||
graph_execution_trace = debug_event.graph_execution_trace
|
||||
op_names.append(graph_execution_trace.op_name)
|
||||
# All the ops in the graph have only one output.
|
||||
self.assertTrue(graph_execution_trace.tfdbg_context_id)
|
||||
self.assertIn(graph_execution_trace.tfdbg_context_id, context_ids)
|
||||
output_slots.append(graph_execution_trace.output_slot)
|
||||
try:
|
||||
tensor_values.append(
|
||||
tensor_util.MakeNdarray(graph_execution_trace.tensor_proto))
|
||||
except KeyError:
|
||||
# Certain dtypes are not convertible to numpy arrays.
|
||||
tensor_values.append(None)
|
||||
return op_names, output_slots, tensor_values
|
||||
|
||||
def testInvalidTensorDebugModeCausesError(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
@ -398,7 +229,7 @@ class TracingCallbackTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
self.assertIn("Log", op_types)
|
||||
self.assertIn("Sin", op_types)
|
||||
|
||||
(op_names, _,
|
||||
(op_names, _, _,
|
||||
tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
|
||||
executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
|
||||
self.assertEqual(executed_op_types, ["AddV2", "Log", "AddV2", "Sin"])
|
||||
@ -476,7 +307,7 @@ class TracingCallbackTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
if tensor_debug_mode == "FULL_TENSOR":
|
||||
self.assertAllClose(tensor_values, [[8.0]])
|
||||
|
||||
(op_names, output_slots,
|
||||
(op_names, _, output_slots,
|
||||
tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
|
||||
executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
|
||||
# The Less op should have been executed 5 times.
|
||||
@ -582,7 +413,7 @@ class TracingCallbackTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
context_ids, _, _ = self._readAndCheckGraphsFile(stack_frame_by_id)
|
||||
_, _, _, _, tensor_values = self._readAndCheckExecutionFile()
|
||||
self.assertEqual(tensor_values, [[]])
|
||||
(_, _,
|
||||
(_, _, _,
|
||||
tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
|
||||
self.assertLen(tensor_values, 2)
|
||||
for tensor_value in tensor_values:
|
||||
@ -663,7 +494,7 @@ class TracingCallbackTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
(context_ids, _,
|
||||
op_name_to_op_type) = self._readAndCheckGraphsFile(stack_frame_by_id)
|
||||
|
||||
(op_names, output_slots,
|
||||
(op_names, _, output_slots,
|
||||
tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
|
||||
executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
|
||||
self.assertEqual(executed_op_types.count("Mul"), 1 + num_threads)
|
||||
@ -720,7 +551,7 @@ class TracingCallbackTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
if tensor_debug_mode == "NO_TENSOR":
|
||||
self.assertFalse(value_list)
|
||||
|
||||
(op_names, _,
|
||||
(op_names, _, _,
|
||||
tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
|
||||
executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
|
||||
# These are the ops that we can safely assume to have been executed during
|
||||
@ -786,7 +617,7 @@ class TracingCallbackTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
for value_list in tensor_values:
|
||||
self.assertFalse(value_list)
|
||||
|
||||
(op_names, _,
|
||||
(op_names, _, _,
|
||||
tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
|
||||
executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
|
||||
# These are the ops that we can safely assume to have been executed during
|
||||
@ -855,7 +686,7 @@ class TracingCallbackTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
executed_op_types, _, _, _, _ = self._readAndCheckExecutionFile()
|
||||
self.assertTrue(executed_op_types)
|
||||
|
||||
(op_names, _,
|
||||
(op_names, _, _,
|
||||
tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
|
||||
executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
|
||||
# These are the ops that we can safely assume to have been executed during
|
||||
|
221
tensorflow/python/debug/lib/dumping_callback_test_lib.py
Normal file
221
tensorflow/python/debug/lib/dumping_callback_test_lib.py
Normal file
@ -0,0 +1,221 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Shared library for testing tfdbg v2 dumping callback."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import socket
|
||||
import tempfile
|
||||
|
||||
from tensorflow.python.debug.lib import check_numerics_callback
|
||||
from tensorflow.python.debug.lib import debug_events_reader
|
||||
from tensorflow.python.debug.lib import dumping_callback
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework import versions
|
||||
|
||||
|
||||
class DumpingCallbackTestBase(test_util.TensorFlowTestCase):
|
||||
"""Base test-case class for tfdbg v2 callbacks."""
|
||||
|
||||
def setUp(self):
|
||||
super(DumpingCallbackTestBase, self).setUp()
|
||||
self.dump_root = tempfile.mkdtemp()
|
||||
|
||||
def tearDown(self):
|
||||
if os.path.isdir(self.dump_root):
|
||||
shutil.rmtree(self.dump_root, ignore_errors=True)
|
||||
check_numerics_callback.disable_check_numerics()
|
||||
dumping_callback.disable_dumping()
|
||||
super(DumpingCallbackTestBase, self).tearDown()
|
||||
|
||||
def _readAndCheckMetadataFile(self):
|
||||
"""Read and check the .metadata debug-events file."""
|
||||
reader = debug_events_reader.DebugEventsDir(self.dump_root)
|
||||
metadata_iter = reader.metadata_iterator()
|
||||
metadata = next(metadata_iter).debug_metadata
|
||||
self.assertEqual(metadata.tensorflow_version, versions.__version__)
|
||||
self.assertTrue(metadata.file_version.startswith("debug.Event"))
|
||||
|
||||
def _readAndCheckSourceFilesAndStackFrames(self):
|
||||
"""Read and verify the .source_files & .stack_frames debug-event files.
|
||||
|
||||
Returns:
|
||||
A dict mapping stack frame IDs to stack frames (FileLineCol).
|
||||
"""
|
||||
reader = debug_events_reader.DebugEventsDir(self.dump_root)
|
||||
# Check the content of the .source_files file.
|
||||
source_files_iter = reader.source_files_iterator()
|
||||
source_file_paths = []
|
||||
prev_wall_time = 1
|
||||
for debug_event in source_files_iter:
|
||||
self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
|
||||
prev_wall_time = debug_event.wall_time
|
||||
source_file = debug_event.source_file
|
||||
self.assertEqual(source_file.host_name, socket.gethostname())
|
||||
self.assertTrue(source_file.file_path)
|
||||
if source_file.lines:
|
||||
self.assertTrue(os.path.isfile(source_file.file_path))
|
||||
source_file_paths.append(source_file.file_path)
|
||||
# Assert the file paths are unique.
|
||||
self.assertEqual(len(source_file_paths), len(set(source_file_paths)))
|
||||
|
||||
# Check the content of the .stack_frames file.
|
||||
stack_frame_by_id = dict() # A map from ID to stack frame.
|
||||
stack_frames_iter = reader.stack_frames_iterator()
|
||||
prev_wall_time = 0
|
||||
for debug_event in stack_frames_iter:
|
||||
self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
|
||||
prev_wall_time = debug_event.wall_time
|
||||
stack_frame_with_id = debug_event.stack_frame_with_id
|
||||
stack_frame_id = stack_frame_with_id.id
|
||||
file_line_col = stack_frame_with_id.file_line_col
|
||||
self.assertTrue(stack_frame_id)
|
||||
self.assertNotIn(stack_frame_id, stack_frame_by_id,
|
||||
"Duplicate stack frame ID: %s" % id)
|
||||
stack_frame_by_id[stack_frame_id] = (file_line_col.file_index,
|
||||
file_line_col.line,
|
||||
file_line_col.func)
|
||||
self.assertGreaterEqual(file_line_col.file_index, 0)
|
||||
self.assertLess(file_line_col.file_index, len(source_file_paths))
|
||||
self.assertTrue(file_line_col.line) # Line numbers are 1-based.
|
||||
self.assertTrue(file_line_col.func)
|
||||
# Assert the stack frames are unique.
|
||||
self.assertEqual(
|
||||
len(stack_frame_by_id.values()), len(set(stack_frame_by_id.values())))
|
||||
return stack_frame_by_id
|
||||
|
||||
def _readAndCheckGraphsFile(self, stack_frame_by_id):
|
||||
"""Read and verify the content of the .graphs debug-event file.
|
||||
|
||||
Args:
|
||||
stack_frame_by_id: A dict mapping unique string IDs to stack frames.
|
||||
It is used by this method to look up stack frames.
|
||||
|
||||
Returns:
|
||||
context_ids: IDs of op creation contexts (e.g., TensorFlow graphs), as a
|
||||
`list` of `str`s.
|
||||
op_types: Types of the ops that are created, as a `list` of `str`s with
|
||||
the same length as `context_ids`.
|
||||
op_name_to_op_type: A `dict` mapping op name to op type.
|
||||
"""
|
||||
reader = debug_events_reader.DebugEventsDir(self.dump_root)
|
||||
graphs_iter = reader.graphs_iterator()
|
||||
prev_wall_time = 0
|
||||
op_types = []
|
||||
op_name_to_op_type = dict()
|
||||
context_ids = set()
|
||||
for debug_event in graphs_iter:
|
||||
self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
|
||||
prev_wall_time = debug_event.wall_time
|
||||
graph_op_creation = debug_event.graph_op_creation
|
||||
self.assertTrue(graph_op_creation.op_type)
|
||||
op_types.append(graph_op_creation.op_type)
|
||||
self.assertTrue(graph_op_creation.op_name)
|
||||
op_name_to_op_type[graph_op_creation.op_name] = graph_op_creation.op_type
|
||||
self.assertTrue(graph_op_creation.graph_id)
|
||||
context_ids.add(graph_op_creation.graph_id)
|
||||
self.assertTrue(graph_op_creation.code_location)
|
||||
for stack_frame_id in graph_op_creation.code_location.stack_frame_ids:
|
||||
self.assertIn(stack_frame_id, stack_frame_by_id)
|
||||
return context_ids, op_types, op_name_to_op_type
|
||||
|
||||
def _readAndCheckExecutionFile(self):
|
||||
"""Read and verify the content of the .execution debug-event file.
|
||||
|
||||
Returns:
|
||||
executed_op_types: Types of ops that are created, as a `list` of `str`.
|
||||
input_tensor_ids: Input tensor IDs for each of the ops executed, as a
|
||||
`list` of `list` of `int`s, with the same length as `executed_op_types`.
|
||||
output_tensor_ids: Output tensor IDs for each of the ops executed, as a
|
||||
`list` of `list` of `int`s, with the same length as `executed_op_types`.
|
||||
tensor_debug_modes: Tensor debug modes used to instrument each of ops
|
||||
executed.
|
||||
tensor_values: A `list` of `list` of `np.ndarray`s, representing the
|
||||
tensor values. Each item of the outer `list` corresponds to one
|
||||
execution event. Each item of the inner `list` corresponds to one
|
||||
output tensor slot of the executed op or Function.
|
||||
"""
|
||||
reader = debug_events_reader.DebugEventsDir(self.dump_root)
|
||||
execution_iter = reader.execution_iterator()
|
||||
prev_wall_time = 1
|
||||
executed_op_types = []
|
||||
input_tensor_ids = []
|
||||
output_tensor_ids = []
|
||||
tensor_debug_modes = []
|
||||
tensor_values = []
|
||||
for debug_event in execution_iter:
|
||||
self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
|
||||
prev_wall_time = debug_event.wall_time
|
||||
execution = debug_event.execution
|
||||
executed_op_types.append(execution.op_type)
|
||||
input_tensor_ids.append(execution.input_tensor_ids)
|
||||
output_tensor_ids.append(execution.output_tensor_ids)
|
||||
tensor_debug_modes.append(execution.tensor_debug_mode)
|
||||
tensor_values.append([
|
||||
tensor_util.MakeNdarray(tensor_proto)
|
||||
for tensor_proto in execution.tensor_protos
|
||||
])
|
||||
|
||||
# TODO(cais): When tensor debug modes other than NO_TENSOR is supported,
|
||||
# return tensor_values as well.
|
||||
return (executed_op_types, input_tensor_ids, output_tensor_ids,
|
||||
tensor_debug_modes, tensor_values)
|
||||
|
||||
def _readAndCheckGraphExecutionTracesFile(self, context_ids):
|
||||
"""Read & verify the content of the .graph_execution_trace debug-event file.
|
||||
|
||||
Args:
|
||||
context_ids: Op-creation context IDs from _readAndCheckGraphsFile().
|
||||
|
||||
Returns:
|
||||
op_names: Names of the ops that are executed, as a `list` of `str`s.
|
||||
device_names: Names of the devices that the ops belong to, respectively.
|
||||
A `list` of `str`s of the same length as `op_name`s.
|
||||
output_slots: Output slots, as a `list` of `int`s, with the same length as
|
||||
`op_names`. In other words, for an executed op with N output tensors,
|
||||
there will be N entries in this `list` and in `op_names`, at
|
||||
corresponding indices.
|
||||
tensor_values: Tensor values or their concise summaries, depending on
|
||||
TensorDebugMode.
|
||||
"""
|
||||
reader = debug_events_reader.DebugEventsDir(self.dump_root)
|
||||
graph_execution_traces_iter = reader.graph_execution_traces_iterator()
|
||||
op_names = []
|
||||
device_names = []
|
||||
output_slots = []
|
||||
tensor_values = []
|
||||
for debug_event in graph_execution_traces_iter:
|
||||
self.assertGreaterEqual(debug_event.wall_time, 0)
|
||||
graph_execution_trace = debug_event.graph_execution_trace
|
||||
op_names.append(graph_execution_trace.op_name)
|
||||
self.assertTrue(graph_execution_trace.device_name)
|
||||
device_names.append(graph_execution_trace.device_name)
|
||||
# All the ops in the graph have only one output.
|
||||
self.assertTrue(graph_execution_trace.tfdbg_context_id)
|
||||
self.assertIn(graph_execution_trace.tfdbg_context_id, context_ids)
|
||||
output_slots.append(graph_execution_trace.output_slot)
|
||||
dtype = dtypes.DType(graph_execution_trace.tensor_proto.dtype)
|
||||
if dtype.is_numpy_compatible: # pylint:disable=protected-access
|
||||
tensor_values.append(
|
||||
tensor_util.MakeNdarray(graph_execution_trace.tensor_proto))
|
||||
else:
|
||||
tensor_values.append(None)
|
||||
return op_names, device_names, output_slots, tensor_values
|
Loading…
Reference in New Issue
Block a user