Automated rollback of change 154220704
Change: 154225030
This commit is contained in:
parent
c19dd491e5
commit
0a583aae32
@ -147,22 +147,6 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "profile_analyzer_cli",
|
||||
srcs = ["cli/profile_analyzer_cli.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":cli_shared",
|
||||
":command_parser",
|
||||
":debug_data",
|
||||
":debugger_cli_common",
|
||||
":source_utils",
|
||||
":ui_factory",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "stepper_cli",
|
||||
srcs = ["cli/stepper_cli.py"],
|
||||
@ -257,7 +241,6 @@ py_library(
|
||||
":debug_data",
|
||||
":debugger_cli_common",
|
||||
":framework",
|
||||
":profile_analyzer_cli",
|
||||
":stepper_cli",
|
||||
":ui_factory",
|
||||
],
|
||||
@ -623,24 +606,6 @@ cuda_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "profile_analyzer_cli_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"cli/profile_analyzer_cli_test.py",
|
||||
],
|
||||
deps = [
|
||||
":command_parser",
|
||||
":profile_analyzer_cli",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "stepper_cli_test",
|
||||
size = "small",
|
||||
|
@ -17,8 +17,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
@ -75,14 +73,6 @@ def bytes_to_readable_str(num_bytes, include_b=False):
|
||||
return result
|
||||
|
||||
|
||||
def time_to_readable_str(value):
|
||||
if not value:
|
||||
return "0"
|
||||
suffixes = ["us", "ms", "s"]
|
||||
order = min(len(suffixes) - 1, int(math.log(value, 10) / 3))
|
||||
return "{:.3g}{}".format(value / math.pow(10.0, 3*order), suffixes[order])
|
||||
|
||||
|
||||
def parse_ranges_highlight(ranges_string):
|
||||
"""Process ranges highlight string.
|
||||
|
||||
|
@ -70,21 +70,6 @@ class BytesToReadableStrTest(test_util.TensorFlowTestCase):
|
||||
1024**3, include_b=True))
|
||||
|
||||
|
||||
class TimeToReadableStrTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testNoneTimeWorks(self):
|
||||
self.assertEqual("0", cli_shared.time_to_readable_str(None))
|
||||
|
||||
def testMicrosecondsTime(self):
|
||||
self.assertEqual("40us", cli_shared.time_to_readable_str(40))
|
||||
|
||||
def testMillisecondTime(self):
|
||||
self.assertEqual("40ms", cli_shared.time_to_readable_str(40e3))
|
||||
|
||||
def testSecondTime(self):
|
||||
self.assertEqual("40s", cli_shared.time_to_readable_str(40e6))
|
||||
|
||||
|
||||
class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -18,6 +18,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ast
|
||||
from collections import namedtuple
|
||||
import re
|
||||
import sys
|
||||
|
||||
@ -28,28 +29,8 @@ _WHITESPACE_PATTERN = re.compile(r"\s+")
|
||||
|
||||
_NUMBER_PATTERN = re.compile(r"[-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?")
|
||||
|
||||
|
||||
class Interval(object):
|
||||
"""Represents an interval between a start and end value."""
|
||||
|
||||
def __init__(self, start, start_included, end, end_included):
|
||||
self.start = start
|
||||
self.start_included = start_included
|
||||
self.end = end
|
||||
self.end_included = end_included
|
||||
|
||||
def contains(self, value):
|
||||
if value < self.start or value == self.start and not self.start_included:
|
||||
return False
|
||||
if value > self.end or value == self.end and not self.end_included:
|
||||
return False
|
||||
return True
|
||||
|
||||
def __eq__(self, other):
|
||||
return (self.start == other.start and
|
||||
self.start_included == other.start_included and
|
||||
self.end == other.end and
|
||||
self.end_included == other.end_included)
|
||||
Interval = namedtuple("Interval",
|
||||
["start", "start_included", "end", "end_included"])
|
||||
|
||||
|
||||
def parse_command(command):
|
||||
|
@ -490,25 +490,6 @@ class ParseInterval(test_util.TensorFlowTestCase):
|
||||
"equal to end of interval."):
|
||||
command_parser.parse_memory_interval("[5k, 3k]")
|
||||
|
||||
def testIntervalContains(self):
|
||||
interval = command_parser.Interval(
|
||||
start=1, start_included=True, end=10, end_included=True)
|
||||
self.assertTrue(interval.contains(1))
|
||||
self.assertTrue(interval.contains(10))
|
||||
self.assertTrue(interval.contains(5))
|
||||
|
||||
interval.start_included = False
|
||||
self.assertFalse(interval.contains(1))
|
||||
self.assertTrue(interval.contains(10))
|
||||
|
||||
interval.end_included = False
|
||||
self.assertFalse(interval.contains(1))
|
||||
self.assertFalse(interval.contains(10))
|
||||
|
||||
interval.start_included = True
|
||||
self.assertTrue(interval.contains(1))
|
||||
self.assertFalse(interval.contains(10))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
@ -1,459 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Formats and displays profiling information."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
|
||||
from tensorflow.python.debug.cli import cli_shared
|
||||
from tensorflow.python.debug.cli import command_parser
|
||||
from tensorflow.python.debug.cli import debugger_cli_common
|
||||
from tensorflow.python.debug.cli import ui_factory
|
||||
from tensorflow.python.debug.lib import source_utils
|
||||
|
||||
|
||||
SORT_OPS_BY_OP_NAME = "node"
|
||||
SORT_OPS_BY_OP_TIME = "op_time"
|
||||
SORT_OPS_BY_EXEC_TIME = "exec_time"
|
||||
SORT_OPS_BY_START_TIME = "start_time"
|
||||
SORT_OPS_BY_LINE = "line"
|
||||
|
||||
|
||||
class ProfileDatum(object):
|
||||
"""Profile data point."""
|
||||
|
||||
def __init__(self, node_exec_stats, file_line, op_type):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
node_exec_stats: `NodeExecStats` proto.
|
||||
file_line: A `string` formatted as <file_name>:<line_number>.
|
||||
op_type: (string) Operation type.
|
||||
"""
|
||||
self.node_exec_stats = node_exec_stats
|
||||
self.file_line = file_line
|
||||
self.op_type = op_type
|
||||
self.op_time = (self.node_exec_stats.op_end_rel_micros -
|
||||
self.node_exec_stats.op_start_rel_micros)
|
||||
|
||||
@property
|
||||
def exec_time(self):
|
||||
"""Measures compute function exection time plus pre- and post-processing."""
|
||||
return self.node_exec_stats.all_end_rel_micros
|
||||
|
||||
|
||||
class ProfileDataTableView(object):
|
||||
"""Table View of profiling data."""
|
||||
|
||||
def __init__(self, profile_datum_list):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
profile_datum_list: List of `ProfileDatum` objects.
|
||||
"""
|
||||
self._profile_datum_list = profile_datum_list
|
||||
self.formatted_op_time = [
|
||||
cli_shared.time_to_readable_str(datum.op_time)
|
||||
for datum in profile_datum_list]
|
||||
self.formatted_exec_time = [
|
||||
cli_shared.time_to_readable_str(
|
||||
datum.node_exec_stats.all_end_rel_micros)
|
||||
for datum in profile_datum_list]
|
||||
self._column_sort_ids = [SORT_OPS_BY_OP_NAME, SORT_OPS_BY_OP_TIME,
|
||||
SORT_OPS_BY_EXEC_TIME, SORT_OPS_BY_LINE]
|
||||
|
||||
def value(self, row, col):
|
||||
if col == 0:
|
||||
return self._profile_datum_list[row].node_exec_stats.node_name
|
||||
elif col == 1:
|
||||
return self.formatted_op_time[row]
|
||||
elif col == 2:
|
||||
return self.formatted_exec_time[row]
|
||||
elif col == 3:
|
||||
return self._profile_datum_list[row].file_line
|
||||
else:
|
||||
raise IndexError("Invalid column index %d." % col)
|
||||
|
||||
def row_count(self):
|
||||
return len(self._profile_datum_list)
|
||||
|
||||
def column_count(self):
|
||||
return 4
|
||||
|
||||
def column_names(self):
|
||||
return ["Node", "Op Time", "Exec Time", "Filename:Lineno(function)"]
|
||||
|
||||
def column_sort_id(self, col):
|
||||
return self._column_sort_ids[col]
|
||||
|
||||
|
||||
def _list_profile_filter(
|
||||
profile_datum, node_name_regex, file_name_regex, op_type_regex,
|
||||
op_time_interval, exec_time_interval):
|
||||
"""Filter function for list_profile command.
|
||||
|
||||
Args:
|
||||
profile_datum: A `ProfileDatum` object.
|
||||
node_name_regex: Regular expression pattern object to filter by name.
|
||||
file_name_regex: Regular expression pattern object to filter by file.
|
||||
op_type_regex: Regular expression pattern object to filter by op type.
|
||||
op_time_interval: `Interval` for filtering op time.
|
||||
exec_time_interval: `Interval` for filtering exec time.
|
||||
|
||||
Returns:
|
||||
True if profile_datum should be included.
|
||||
"""
|
||||
if not node_name_regex.match(
|
||||
profile_datum.node_exec_stats.node_name):
|
||||
return False
|
||||
if profile_datum.file_line is not None and not file_name_regex.match(
|
||||
profile_datum.file_line):
|
||||
return False
|
||||
if profile_datum.op_type is not None and not op_type_regex.match(
|
||||
profile_datum.op_type):
|
||||
return False
|
||||
if op_time_interval is not None and not op_time_interval.contains(
|
||||
profile_datum.op_time):
|
||||
return False
|
||||
if exec_time_interval and not exec_time_interval.contains(
|
||||
profile_datum.node_exec_stats.all_end_rel_micros):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _list_profile_sort_key(profile_datum, sort_by):
|
||||
"""Get a profile_datum property to sort by in list_profile command.
|
||||
|
||||
Args:
|
||||
profile_datum: A `ProfileDatum` object.
|
||||
sort_by: (string) indicates a value to sort by.
|
||||
Must be one of SORT_BY* constants.
|
||||
|
||||
Returns:
|
||||
profile_datum property to sort by.
|
||||
"""
|
||||
if sort_by == SORT_OPS_BY_OP_NAME:
|
||||
return profile_datum.node_exec_stats.node_name
|
||||
elif sort_by == SORT_OPS_BY_LINE:
|
||||
return profile_datum.file_line
|
||||
elif sort_by == SORT_OPS_BY_OP_TIME:
|
||||
return profile_datum.op_time
|
||||
elif sort_by == SORT_OPS_BY_EXEC_TIME:
|
||||
return profile_datum.node_exec_stats.all_end_rel_micros
|
||||
else: # sort by start time
|
||||
return profile_datum.node_exec_stats.all_start_micros
|
||||
|
||||
|
||||
class ProfileAnalyzer(object):
|
||||
"""Analyzer for profiling data."""
|
||||
|
||||
def __init__(self, graph, run_metadata):
|
||||
"""ProfileAnalyzer constructor.
|
||||
|
||||
Args:
|
||||
graph: (tf.Graph) Python graph object.
|
||||
run_metadata: A `RunMetadata` protobuf object.
|
||||
|
||||
Raises:
|
||||
ValueError: If run_metadata is None.
|
||||
"""
|
||||
self._graph = graph
|
||||
if not run_metadata:
|
||||
raise ValueError("No RunMetadata passed for profile analysis.")
|
||||
self._run_metadata = run_metadata
|
||||
self._arg_parsers = {}
|
||||
ap = argparse.ArgumentParser(
|
||||
description="List nodes profile information.",
|
||||
usage=argparse.SUPPRESS)
|
||||
ap.add_argument(
|
||||
"-d",
|
||||
"--device_name_filter",
|
||||
dest="device_name_filter",
|
||||
type=str,
|
||||
default="",
|
||||
help="filter device name by regex.")
|
||||
ap.add_argument(
|
||||
"-n",
|
||||
"--node_name_filter",
|
||||
dest="node_name_filter",
|
||||
type=str,
|
||||
default="",
|
||||
help="filter node name by regex.")
|
||||
ap.add_argument(
|
||||
"-t",
|
||||
"--op_type_filter",
|
||||
dest="op_type_filter",
|
||||
type=str,
|
||||
default="",
|
||||
help="filter op type by regex.")
|
||||
# TODO(annarev): allow file filtering at non-stack top position.
|
||||
ap.add_argument(
|
||||
"-f",
|
||||
"--file_name_filter",
|
||||
dest="file_name_filter",
|
||||
type=str,
|
||||
default="",
|
||||
help="filter by file name at the top position of node's creation "
|
||||
"stack that does not belong to TensorFlow library.")
|
||||
ap.add_argument(
|
||||
"-e",
|
||||
"--execution_time",
|
||||
dest="execution_time",
|
||||
type=str,
|
||||
default="",
|
||||
help="Filter by execution time interval "
|
||||
"(includes compute plus pre- and post -processing time). "
|
||||
"Supported units are s, ms and us (default). "
|
||||
"E.g. -e >100s, -e <100, -e [100us,1000ms]")
|
||||
ap.add_argument(
|
||||
"-o",
|
||||
"--op_time",
|
||||
dest="op_time",
|
||||
type=str,
|
||||
default="",
|
||||
help="Filter by op time interval (only includes compute time). "
|
||||
"Supported units are s, ms and us (default). "
|
||||
"E.g. -e >100s, -e <100, -e [100us,1000ms]")
|
||||
ap.add_argument(
|
||||
"-s",
|
||||
"--sort_by",
|
||||
dest="sort_by",
|
||||
type=str,
|
||||
default=SORT_OPS_BY_START_TIME,
|
||||
help=("the field to sort the data by: (%s | %s | %s | %s | %s)" %
|
||||
(SORT_OPS_BY_OP_NAME, SORT_OPS_BY_START_TIME,
|
||||
SORT_OPS_BY_OP_TIME, SORT_OPS_BY_EXEC_TIME, SORT_OPS_BY_LINE)))
|
||||
ap.add_argument(
|
||||
"-r",
|
||||
"--reverse",
|
||||
dest="reverse",
|
||||
action="store_true",
|
||||
help="sort the data in reverse (descending) order")
|
||||
|
||||
self._arg_parsers["list_profile"] = ap
|
||||
|
||||
def list_profile(self, args, screen_info=None):
|
||||
"""Command handler for list_profile.
|
||||
|
||||
List per-operation profile information.
|
||||
|
||||
Args:
|
||||
args: Command-line arguments, excluding the command prefix, as a list of
|
||||
str.
|
||||
screen_info: Optional dict input containing screen information such as
|
||||
cols.
|
||||
|
||||
Returns:
|
||||
Output text lines as a RichTextLines object.
|
||||
"""
|
||||
del screen_info
|
||||
|
||||
parsed = self._arg_parsers["list_profile"].parse_args(args)
|
||||
op_time_interval = (command_parser.parse_time_interval(parsed.op_time)
|
||||
if parsed.op_time else None)
|
||||
exec_time_interval = (
|
||||
command_parser.parse_time_interval(parsed.execution_time)
|
||||
if parsed.execution_time else None)
|
||||
node_name_regex = re.compile(parsed.node_name_filter)
|
||||
file_name_regex = re.compile(parsed.file_name_filter)
|
||||
op_type_regex = re.compile(parsed.op_type_filter)
|
||||
|
||||
output = debugger_cli_common.RichTextLines([""])
|
||||
device_name_regex = re.compile(parsed.device_name_filter)
|
||||
data_generator = self._get_profile_data_generator()
|
||||
device_count = len(self._run_metadata.step_stats.dev_stats)
|
||||
for index in range(device_count):
|
||||
device_stats = self._run_metadata.step_stats.dev_stats[index]
|
||||
if device_name_regex.match(device_stats.device):
|
||||
profile_data = [
|
||||
datum for datum in data_generator(device_stats)
|
||||
if _list_profile_filter(
|
||||
datum, node_name_regex, file_name_regex, op_type_regex,
|
||||
op_time_interval, exec_time_interval)]
|
||||
profile_data = sorted(
|
||||
profile_data,
|
||||
key=lambda datum: _list_profile_sort_key(datum, parsed.sort_by),
|
||||
reverse=parsed.reverse)
|
||||
output.extend(
|
||||
self._get_list_profile_lines(
|
||||
device_stats.device, index, device_count,
|
||||
profile_data, parsed.sort_by, parsed.reverse))
|
||||
return output
|
||||
|
||||
def _get_profile_data_generator(self):
|
||||
"""Get function that generates `ProfileDatum` objects.
|
||||
|
||||
Returns:
|
||||
A function that generates `ProfileDatum` objects.
|
||||
"""
|
||||
node_to_file_line = {}
|
||||
node_to_op_type = {}
|
||||
for op in self._graph.get_operations():
|
||||
file_line = ""
|
||||
for trace_entry in reversed(op.traceback):
|
||||
filepath = trace_entry[0]
|
||||
file_line = "%s:%d(%s)" % (
|
||||
os.path.basename(filepath), trace_entry[1], trace_entry[2])
|
||||
if not source_utils.guess_is_tensorflow_py_library(filepath):
|
||||
break
|
||||
node_to_file_line[op.name] = file_line
|
||||
node_to_op_type[op.name] = op.type
|
||||
|
||||
def profile_data_generator(device_step_stats):
|
||||
for node_stats in device_step_stats.node_stats:
|
||||
if node_stats.node_name == "_SOURCE" or node_stats.node_name == "_SINK":
|
||||
continue
|
||||
yield ProfileDatum(
|
||||
node_stats,
|
||||
node_to_file_line.get(node_stats.node_name, ""),
|
||||
node_to_op_type.get(node_stats.node_name, ""))
|
||||
return profile_data_generator
|
||||
|
||||
def _get_list_profile_lines(
|
||||
self, device_name, device_index, device_count,
|
||||
profile_datum_list, sort_by, sort_reverse):
|
||||
"""Get `RichTextLines` object for list_profile command for a given device.
|
||||
|
||||
Args:
|
||||
device_name: (string) Device name.
|
||||
device_index: (int) Device index.
|
||||
device_count: (int) Number of devices.
|
||||
profile_datum_list: List of `ProfileDatum` objects.
|
||||
sort_by: (string) Identifier of column to sort. Sort identifier
|
||||
must match value of SORT_OPS_BY_OP_NAME, SORT_OPS_BY_EXEC_TIME,
|
||||
SORT_OPS_BY_MEMORY or SORT_OPS_BY_LINE.
|
||||
sort_reverse: (bool) Whether to sort in descending instead of default
|
||||
(ascending) order.
|
||||
|
||||
Returns:
|
||||
`RichTextLines` object containing a table that displays profiling
|
||||
information for each op.
|
||||
"""
|
||||
profile_data = ProfileDataTableView(profile_datum_list)
|
||||
|
||||
# Calculate total time early to calculate column widths.
|
||||
total_op_time = sum(datum.op_time for datum in profile_datum_list)
|
||||
total_exec_time = sum(datum.node_exec_stats.all_end_rel_micros
|
||||
for datum in profile_datum_list)
|
||||
device_total_row = [
|
||||
"Device Total", cli_shared.time_to_readable_str(total_op_time),
|
||||
cli_shared.time_to_readable_str(total_exec_time)]
|
||||
|
||||
# Calculate column widths.
|
||||
column_widths = [
|
||||
len(column_name) for column_name in profile_data.column_names()]
|
||||
for col in range(len(device_total_row)):
|
||||
column_widths[col] = max(column_widths[col], len(device_total_row[col]))
|
||||
for col in range(len(column_widths)):
|
||||
for row in range(profile_data.row_count()):
|
||||
column_widths[col] = max(
|
||||
column_widths[col], len(str(profile_data.value(row, col))))
|
||||
column_widths[col] += 2 # add margin between columns
|
||||
|
||||
# Add device name.
|
||||
output = debugger_cli_common.RichTextLines(["-"*80])
|
||||
device_row = "Device %d of %d: %s" % (
|
||||
device_index + 1, device_count, device_name)
|
||||
output.extend(debugger_cli_common.RichTextLines([device_row, ""]))
|
||||
|
||||
# Add headers.
|
||||
base_command = "list_profile"
|
||||
attr_segs = {0: []}
|
||||
row = ""
|
||||
for col in range(profile_data.column_count()):
|
||||
column_name = profile_data.column_names()[col]
|
||||
sort_id = profile_data.column_sort_id(col)
|
||||
command = "%s -s %s" % (base_command, sort_id)
|
||||
if sort_by == sort_id and not sort_reverse:
|
||||
command += " -r"
|
||||
curr_row = ("{:<%d}" % column_widths[col]).format(column_name)
|
||||
prev_len = len(row)
|
||||
row += curr_row
|
||||
attr_segs[0].append(
|
||||
(prev_len, prev_len + len(column_name),
|
||||
[debugger_cli_common.MenuItem(None, command), "bold"]))
|
||||
|
||||
output.extend(
|
||||
debugger_cli_common.RichTextLines([row], font_attr_segs=attr_segs))
|
||||
|
||||
# Add data rows.
|
||||
for row in range(profile_data.row_count()):
|
||||
row_str = ""
|
||||
for col in range(profile_data.column_count()):
|
||||
row_str += ("{:<%d}" % column_widths[col]).format(
|
||||
profile_data.value(row, col))
|
||||
output.extend(debugger_cli_common.RichTextLines([row_str]))
|
||||
|
||||
# Add stat totals.
|
||||
row_str = ""
|
||||
for col in range(len(device_total_row)):
|
||||
row_str += ("{:<%d}" % column_widths[col]).format(device_total_row[col])
|
||||
output.extend(debugger_cli_common.RichTextLines(""))
|
||||
output.extend(debugger_cli_common.RichTextLines(row_str))
|
||||
return output
|
||||
|
||||
def _measure_list_profile_column_widths(self, profile_data):
|
||||
"""Determine the maximum column widths for each data list.
|
||||
|
||||
Args:
|
||||
profile_data: list of ProfileDatum objects.
|
||||
|
||||
Returns:
|
||||
List of column widths in the same order as columns in data.
|
||||
"""
|
||||
num_columns = len(profile_data.column_names())
|
||||
widths = [len(column_name) for column_name in profile_data.column_names()]
|
||||
for row in range(profile_data.row_count()):
|
||||
for col in range(num_columns):
|
||||
widths[col] = max(
|
||||
widths[col], len(str(profile_data.row_values(row)[col])) + 2)
|
||||
return widths
|
||||
|
||||
def get_help(self, handler_name):
|
||||
return self._arg_parsers[handler_name].format_help()
|
||||
|
||||
|
||||
def create_profiler_ui(graph,
|
||||
run_metadata,
|
||||
ui_type="curses",
|
||||
on_ui_exit=None):
|
||||
"""Create an instance of CursesUI based on a `tf.Graph` and `RunMetadata`.
|
||||
|
||||
Args:
|
||||
graph: Python `Graph` object.
|
||||
run_metadata: A `RunMetadata` protobuf object.
|
||||
ui_type: (str) requested UI type, e.g., "curses", "readline".
|
||||
on_ui_exit: (`Callable`) the callback to be called when the UI exits.
|
||||
|
||||
Returns:
|
||||
(base_ui.BaseUI) A BaseUI subtype object with a set of standard analyzer
|
||||
commands and tab-completions registered.
|
||||
"""
|
||||
|
||||
analyzer = ProfileAnalyzer(graph, run_metadata)
|
||||
|
||||
cli = ui_factory.get_ui(ui_type, on_ui_exit=on_ui_exit)
|
||||
cli.register_command_handler(
|
||||
"list_profile",
|
||||
analyzer.list_profile,
|
||||
analyzer.get_help("list_profile"),
|
||||
prefix_aliases=["lp"])
|
||||
|
||||
return cli
|
@ -1,264 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for profile_analyzer_cli."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import re
|
||||
|
||||
from tensorflow.core.framework import step_stats_pb2
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.debug.cli import profile_analyzer_cli
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ProfileAnalyzerTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testNodeInfoEmpty(self):
|
||||
graph = ops.Graph()
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
|
||||
prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata)
|
||||
prof_output = prof_analyzer.list_profile([]).lines
|
||||
self.assertEquals([""], prof_output)
|
||||
|
||||
def testSingleDevice(self):
|
||||
node1 = step_stats_pb2.NodeExecStats(
|
||||
node_name="Add/123",
|
||||
op_start_rel_micros=3,
|
||||
op_end_rel_micros=5,
|
||||
all_end_rel_micros=4)
|
||||
|
||||
node2 = step_stats_pb2.NodeExecStats(
|
||||
node_name="Mul/456",
|
||||
op_start_rel_micros=1,
|
||||
op_end_rel_micros=2,
|
||||
all_end_rel_micros=3)
|
||||
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
device1 = run_metadata.step_stats.dev_stats.add()
|
||||
device1.device = "deviceA"
|
||||
device1.node_stats.extend([node1, node2])
|
||||
|
||||
graph = test.mock.MagicMock()
|
||||
op1 = test.mock.MagicMock()
|
||||
op1.name = "Add/123"
|
||||
op1.traceback = [("a/b/file1", 10, "some_var")]
|
||||
op1.type = "add"
|
||||
op2 = test.mock.MagicMock()
|
||||
op2.name = "Mul/456"
|
||||
op2.traceback = [("a/b/file1", 11, "some_var")]
|
||||
op2.type = "mul"
|
||||
graph.get_operations.return_value = [op1, op2]
|
||||
|
||||
prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata)
|
||||
prof_output = prof_analyzer.list_profile([]).lines
|
||||
|
||||
self._assertAtLeastOneLineMatches(r"Device 1 of 1: deviceA", prof_output)
|
||||
self._assertAtLeastOneLineMatches(r"^Add/123.*2us.*4us", prof_output)
|
||||
self._assertAtLeastOneLineMatches(r"^Mul/456.*1us.*3us", prof_output)
|
||||
|
||||
def testMultipleDevices(self):
|
||||
node1 = step_stats_pb2.NodeExecStats(
|
||||
node_name="Add/123",
|
||||
op_start_rel_micros=3,
|
||||
op_end_rel_micros=5,
|
||||
all_end_rel_micros=3)
|
||||
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
device1 = run_metadata.step_stats.dev_stats.add()
|
||||
device1.device = "deviceA"
|
||||
device1.node_stats.extend([node1])
|
||||
|
||||
device2 = run_metadata.step_stats.dev_stats.add()
|
||||
device2.device = "deviceB"
|
||||
device2.node_stats.extend([node1])
|
||||
|
||||
graph = test.mock.MagicMock()
|
||||
op = test.mock.MagicMock()
|
||||
op.name = "Add/123"
|
||||
op.traceback = [("a/b/file1", 10, "some_var")]
|
||||
op.type = "abc"
|
||||
graph.get_operations.return_value = [op]
|
||||
|
||||
prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata)
|
||||
prof_output = prof_analyzer.list_profile([]).lines
|
||||
|
||||
self._assertAtLeastOneLineMatches(r"Device 1 of 2: deviceA", prof_output)
|
||||
self._assertAtLeastOneLineMatches(r"Device 2 of 2: deviceB", prof_output)
|
||||
|
||||
# Try filtering by device.
|
||||
prof_output = prof_analyzer.list_profile(["-d", "deviceB"]).lines
|
||||
self._assertAtLeastOneLineMatches(r"Device 2 of 2: deviceB", prof_output)
|
||||
self._assertNoLinesMatch(r"Device 1 of 2: deviceA", prof_output)
|
||||
|
||||
def testWithSession(self):
|
||||
options = config_pb2.RunOptions()
|
||||
options.trace_level = config_pb2.RunOptions.FULL_TRACE
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
|
||||
with session.Session() as sess:
|
||||
a = constant_op.constant([1, 2, 3])
|
||||
b = constant_op.constant([2, 2, 1])
|
||||
result = math_ops.add(a, b)
|
||||
|
||||
sess.run(result, options=options, run_metadata=run_metadata)
|
||||
|
||||
prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(
|
||||
sess.graph, run_metadata)
|
||||
prof_output = prof_analyzer.list_profile([]).lines
|
||||
|
||||
self._assertAtLeastOneLineMatches("Device 1 of 1:", prof_output)
|
||||
expected_headers = [
|
||||
"Node", "Op Time", "Exec Time", r"Filename:Lineno\(function\)"]
|
||||
self._assertAtLeastOneLineMatches(
|
||||
".*".join(expected_headers), prof_output)
|
||||
self._assertAtLeastOneLineMatches(r"^Add/", prof_output)
|
||||
self._assertAtLeastOneLineMatches(r"Device Total", prof_output)
|
||||
|
||||
def testSorting(self):
|
||||
node1 = step_stats_pb2.NodeExecStats(
|
||||
node_name="Add/123",
|
||||
all_start_micros=123,
|
||||
op_start_rel_micros=3,
|
||||
op_end_rel_micros=5,
|
||||
all_end_rel_micros=4)
|
||||
|
||||
node2 = step_stats_pb2.NodeExecStats(
|
||||
node_name="Mul/456",
|
||||
all_start_micros=122,
|
||||
op_start_rel_micros=1,
|
||||
op_end_rel_micros=2,
|
||||
all_end_rel_micros=5)
|
||||
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
device1 = run_metadata.step_stats.dev_stats.add()
|
||||
device1.device = "deviceA"
|
||||
device1.node_stats.extend([node1, node2])
|
||||
|
||||
graph = test.mock.MagicMock()
|
||||
op1 = test.mock.MagicMock()
|
||||
op1.name = "Add/123"
|
||||
op1.traceback = [("a/b/file2", 10, "some_var")]
|
||||
op1.type = "add"
|
||||
op2 = test.mock.MagicMock()
|
||||
op2.name = "Mul/456"
|
||||
op2.traceback = [("a/b/file1", 11, "some_var")]
|
||||
op2.type = "mul"
|
||||
graph.get_operations.return_value = [op1, op2]
|
||||
|
||||
prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata)
|
||||
|
||||
# Default sort by start time (i.e. all_start_micros).
|
||||
prof_output = prof_analyzer.list_profile([]).lines
|
||||
self.assertRegexpMatches("".join(prof_output), r"Mul/456.*Add/123")
|
||||
# Default sort in reverse.
|
||||
prof_output = prof_analyzer.list_profile(["-r"]).lines
|
||||
self.assertRegexpMatches("".join(prof_output), r"Add/123.*Mul/456")
|
||||
# Sort by name.
|
||||
prof_output = prof_analyzer.list_profile(["-s", "node"]).lines
|
||||
self.assertRegexpMatches("".join(prof_output), r"Add/123.*Mul/456")
|
||||
# Sort by op time (i.e. op_end_rel_micros - op_start_rel_micros).
|
||||
prof_output = prof_analyzer.list_profile(["-s", "op_time"]).lines
|
||||
self.assertRegexpMatches("".join(prof_output), r"Mul/456.*Add/123")
|
||||
# Sort by exec time (i.e. all_end_rel_micros).
|
||||
prof_output = prof_analyzer.list_profile(["-s", "exec_time"]).lines
|
||||
self.assertRegexpMatches("".join(prof_output), r"Add/123.*Mul/456")
|
||||
# Sort by line number.
|
||||
prof_output = prof_analyzer.list_profile(["-s", "line"]).lines
|
||||
self.assertRegexpMatches("".join(prof_output), r"Mul/456.*Add/123")
|
||||
|
||||
def testFiltering(self):
|
||||
node1 = step_stats_pb2.NodeExecStats(
|
||||
node_name="Add/123",
|
||||
all_start_micros=123,
|
||||
op_start_rel_micros=3,
|
||||
op_end_rel_micros=5,
|
||||
all_end_rel_micros=4)
|
||||
|
||||
node2 = step_stats_pb2.NodeExecStats(
|
||||
node_name="Mul/456",
|
||||
all_start_micros=122,
|
||||
op_start_rel_micros=1,
|
||||
op_end_rel_micros=2,
|
||||
all_end_rel_micros=5)
|
||||
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
device1 = run_metadata.step_stats.dev_stats.add()
|
||||
device1.device = "deviceA"
|
||||
device1.node_stats.extend([node1, node2])
|
||||
|
||||
graph = test.mock.MagicMock()
|
||||
op1 = test.mock.MagicMock()
|
||||
op1.name = "Add/123"
|
||||
op1.traceback = [("a/b/file2", 10, "some_var")]
|
||||
op1.type = "add"
|
||||
op2 = test.mock.MagicMock()
|
||||
op2.name = "Mul/456"
|
||||
op2.traceback = [("a/b/file1", 11, "some_var")]
|
||||
op2.type = "mul"
|
||||
graph.get_operations.return_value = [op1, op2]
|
||||
|
||||
prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata)
|
||||
|
||||
# Filter by name
|
||||
prof_output = prof_analyzer.list_profile(["-n", "Add"]).lines
|
||||
self._assertAtLeastOneLineMatches(r"Add/123", prof_output)
|
||||
self._assertNoLinesMatch(r"Mul/456", prof_output)
|
||||
# Filter by op_type
|
||||
prof_output = prof_analyzer.list_profile(["-t", "mul"]).lines
|
||||
self._assertAtLeastOneLineMatches(r"Mul/456", prof_output)
|
||||
self._assertNoLinesMatch(r"Add/123", prof_output)
|
||||
# Filter by file name.
|
||||
prof_output = prof_analyzer.list_profile(["-f", "file2"]).lines
|
||||
self._assertAtLeastOneLineMatches(r"Add/123", prof_output)
|
||||
self._assertNoLinesMatch(r"Mul/456", prof_output)
|
||||
# Fitler by execution time.
|
||||
prof_output = prof_analyzer.list_profile(["-e", "[5, 10]"]).lines
|
||||
self._assertAtLeastOneLineMatches(r"Mul/456", prof_output)
|
||||
self._assertNoLinesMatch(r"Add/123", prof_output)
|
||||
# Fitler by op time.
|
||||
prof_output = prof_analyzer.list_profile(["-o", ">=2"]).lines
|
||||
self._assertAtLeastOneLineMatches(r"Add/123", prof_output)
|
||||
self._assertNoLinesMatch(r"Mul/456", prof_output)
|
||||
|
||||
def _atLeastOneLineMatches(self, pattern, lines):
|
||||
pattern_re = re.compile(pattern)
|
||||
for line in lines:
|
||||
if pattern_re.match(line):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _assertAtLeastOneLineMatches(self, pattern, lines):
|
||||
if not self._atLeastOneLineMatches(pattern, lines):
|
||||
raise AssertionError(
|
||||
"%s does not match any line in %s." % (pattern, str(lines)))
|
||||
|
||||
def _assertNoLinesMatch(self, pattern, lines):
|
||||
if self._atLeastOneLineMatches(pattern, lines):
|
||||
raise AssertionError(
|
||||
"%s matched at least one line in %s." % (pattern, str(lines)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
@ -44,7 +44,7 @@ def _convert_watch_key_to_tensor_name(watch_key):
|
||||
return watch_key[:watch_key.rfind(":")]
|
||||
|
||||
|
||||
def guess_is_tensorflow_py_library(py_file_path):
|
||||
def _guess_is_tensorflow_py_library(py_file_path):
|
||||
"""Guess whether a Python source file is a part of the tensorflow library.
|
||||
|
||||
Special cases:
|
||||
@ -231,7 +231,7 @@ def list_source_files_against_dump(dump,
|
||||
for file_path in path_to_node_names:
|
||||
output.append((
|
||||
file_path,
|
||||
guess_is_tensorflow_py_library(file_path),
|
||||
_guess_is_tensorflow_py_library(file_path),
|
||||
len(path_to_node_names.get(file_path, {})),
|
||||
len(path_to_tensor_names.get(file_path, {})),
|
||||
path_to_num_dumps.get(file_path, 0),
|
||||
|
@ -57,20 +57,20 @@ class GuessIsTensorFlowLibraryTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testUnitTestFileReturnsFalse(self):
|
||||
self.assertFalse(
|
||||
source_utils.guess_is_tensorflow_py_library(self.curr_file_path))
|
||||
source_utils._guess_is_tensorflow_py_library(self.curr_file_path))
|
||||
|
||||
def testSourceUtilModuleReturnsTrue(self):
|
||||
self.assertTrue(
|
||||
source_utils.guess_is_tensorflow_py_library(source_utils.__file__))
|
||||
source_utils._guess_is_tensorflow_py_library(source_utils.__file__))
|
||||
|
||||
def testFileInPythonKernelsPathReturnsTrue(self):
|
||||
x = constant_op.constant(42.0, name="x")
|
||||
self.assertTrue(
|
||||
source_utils.guess_is_tensorflow_py_library(x.op.traceback[-1][0]))
|
||||
source_utils._guess_is_tensorflow_py_library(x.op.traceback[-1][0]))
|
||||
|
||||
def testNonPythonFileRaisesException(self):
|
||||
with self.assertRaisesRegexp(ValueError, r"is not a Python source file"):
|
||||
source_utils.guess_is_tensorflow_py_library(
|
||||
source_utils._guess_is_tensorflow_py_library(
|
||||
os.path.join(os.path.dirname(self.curr_file_path), "foo.cc"))
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user