Automated rollback of change 154220704

Change: 154225030
This commit is contained in:
A. Unique TensorFlower 2017-04-25 13:52:13 -08:00 committed by TensorFlower Gardener
parent c19dd491e5
commit 0a583aae32
9 changed files with 9 additions and 830 deletions

View File

@ -147,22 +147,6 @@ py_library(
],
)
py_library(
name = "profile_analyzer_cli",
srcs = ["cli/profile_analyzer_cli.py"],
srcs_version = "PY2AND3",
deps = [
":cli_shared",
":command_parser",
":debug_data",
":debugger_cli_common",
":source_utils",
":ui_factory",
"//third_party/py/numpy",
"@six_archive//:six",
],
)
py_library(
name = "stepper_cli",
srcs = ["cli/stepper_cli.py"],
@ -257,7 +241,6 @@ py_library(
":debug_data",
":debugger_cli_common",
":framework",
":profile_analyzer_cli",
":stepper_cli",
":ui_factory",
],
@ -623,24 +606,6 @@ cuda_py_test(
],
)
py_test(
name = "profile_analyzer_cli_test",
size = "small",
srcs = [
"cli/profile_analyzer_cli_test.py",
],
deps = [
":command_parser",
":profile_analyzer_cli",
"//tensorflow/python:client_testlib",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
)
cuda_py_test(
name = "stepper_cli_test",
size = "small",

View File

@ -17,8 +17,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import numpy as np
import six
@ -75,14 +73,6 @@ def bytes_to_readable_str(num_bytes, include_b=False):
return result
def time_to_readable_str(value):
if not value:
return "0"
suffixes = ["us", "ms", "s"]
order = min(len(suffixes) - 1, int(math.log(value, 10) / 3))
return "{:.3g}{}".format(value / math.pow(10.0, 3*order), suffixes[order])
def parse_ranges_highlight(ranges_string):
"""Process ranges highlight string.

View File

@ -70,21 +70,6 @@ class BytesToReadableStrTest(test_util.TensorFlowTestCase):
1024**3, include_b=True))
class TimeToReadableStrTest(test_util.TensorFlowTestCase):
def testNoneTimeWorks(self):
self.assertEqual("0", cli_shared.time_to_readable_str(None))
def testMicrosecondsTime(self):
self.assertEqual("40us", cli_shared.time_to_readable_str(40))
def testMillisecondTime(self):
self.assertEqual("40ms", cli_shared.time_to_readable_str(40e3))
def testSecondTime(self):
self.assertEqual("40s", cli_shared.time_to_readable_str(40e6))
class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
def setUp(self):

View File

@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
import ast
from collections import namedtuple
import re
import sys
@ -28,28 +29,8 @@ _WHITESPACE_PATTERN = re.compile(r"\s+")
_NUMBER_PATTERN = re.compile(r"[-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?")
class Interval(object):
"""Represents an interval between a start and end value."""
def __init__(self, start, start_included, end, end_included):
self.start = start
self.start_included = start_included
self.end = end
self.end_included = end_included
def contains(self, value):
if value < self.start or value == self.start and not self.start_included:
return False
if value > self.end or value == self.end and not self.end_included:
return False
return True
def __eq__(self, other):
return (self.start == other.start and
self.start_included == other.start_included and
self.end == other.end and
self.end_included == other.end_included)
Interval = namedtuple("Interval",
["start", "start_included", "end", "end_included"])
def parse_command(command):

View File

@ -490,25 +490,6 @@ class ParseInterval(test_util.TensorFlowTestCase):
"equal to end of interval."):
command_parser.parse_memory_interval("[5k, 3k]")
def testIntervalContains(self):
interval = command_parser.Interval(
start=1, start_included=True, end=10, end_included=True)
self.assertTrue(interval.contains(1))
self.assertTrue(interval.contains(10))
self.assertTrue(interval.contains(5))
interval.start_included = False
self.assertFalse(interval.contains(1))
self.assertTrue(interval.contains(10))
interval.end_included = False
self.assertFalse(interval.contains(1))
self.assertFalse(interval.contains(10))
interval.start_included = True
self.assertTrue(interval.contains(1))
self.assertFalse(interval.contains(10))
if __name__ == "__main__":
googletest.main()

View File

@ -1,459 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Formats and displays profiling information."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import re
from tensorflow.python.debug.cli import cli_shared
from tensorflow.python.debug.cli import command_parser
from tensorflow.python.debug.cli import debugger_cli_common
from tensorflow.python.debug.cli import ui_factory
from tensorflow.python.debug.lib import source_utils
SORT_OPS_BY_OP_NAME = "node"
SORT_OPS_BY_OP_TIME = "op_time"
SORT_OPS_BY_EXEC_TIME = "exec_time"
SORT_OPS_BY_START_TIME = "start_time"
SORT_OPS_BY_LINE = "line"
class ProfileDatum(object):
"""Profile data point."""
def __init__(self, node_exec_stats, file_line, op_type):
"""Constructor.
Args:
node_exec_stats: `NodeExecStats` proto.
file_line: A `string` formatted as <file_name>:<line_number>.
op_type: (string) Operation type.
"""
self.node_exec_stats = node_exec_stats
self.file_line = file_line
self.op_type = op_type
self.op_time = (self.node_exec_stats.op_end_rel_micros -
self.node_exec_stats.op_start_rel_micros)
@property
def exec_time(self):
"""Measures compute function exection time plus pre- and post-processing."""
return self.node_exec_stats.all_end_rel_micros
class ProfileDataTableView(object):
"""Table View of profiling data."""
def __init__(self, profile_datum_list):
"""Constructor.
Args:
profile_datum_list: List of `ProfileDatum` objects.
"""
self._profile_datum_list = profile_datum_list
self.formatted_op_time = [
cli_shared.time_to_readable_str(datum.op_time)
for datum in profile_datum_list]
self.formatted_exec_time = [
cli_shared.time_to_readable_str(
datum.node_exec_stats.all_end_rel_micros)
for datum in profile_datum_list]
self._column_sort_ids = [SORT_OPS_BY_OP_NAME, SORT_OPS_BY_OP_TIME,
SORT_OPS_BY_EXEC_TIME, SORT_OPS_BY_LINE]
def value(self, row, col):
if col == 0:
return self._profile_datum_list[row].node_exec_stats.node_name
elif col == 1:
return self.formatted_op_time[row]
elif col == 2:
return self.formatted_exec_time[row]
elif col == 3:
return self._profile_datum_list[row].file_line
else:
raise IndexError("Invalid column index %d." % col)
def row_count(self):
return len(self._profile_datum_list)
def column_count(self):
return 4
def column_names(self):
return ["Node", "Op Time", "Exec Time", "Filename:Lineno(function)"]
def column_sort_id(self, col):
return self._column_sort_ids[col]
def _list_profile_filter(
profile_datum, node_name_regex, file_name_regex, op_type_regex,
op_time_interval, exec_time_interval):
"""Filter function for list_profile command.
Args:
profile_datum: A `ProfileDatum` object.
node_name_regex: Regular expression pattern object to filter by name.
file_name_regex: Regular expression pattern object to filter by file.
op_type_regex: Regular expression pattern object to filter by op type.
op_time_interval: `Interval` for filtering op time.
exec_time_interval: `Interval` for filtering exec time.
Returns:
True if profile_datum should be included.
"""
if not node_name_regex.match(
profile_datum.node_exec_stats.node_name):
return False
if profile_datum.file_line is not None and not file_name_regex.match(
profile_datum.file_line):
return False
if profile_datum.op_type is not None and not op_type_regex.match(
profile_datum.op_type):
return False
if op_time_interval is not None and not op_time_interval.contains(
profile_datum.op_time):
return False
if exec_time_interval and not exec_time_interval.contains(
profile_datum.node_exec_stats.all_end_rel_micros):
return False
return True
def _list_profile_sort_key(profile_datum, sort_by):
"""Get a profile_datum property to sort by in list_profile command.
Args:
profile_datum: A `ProfileDatum` object.
sort_by: (string) indicates a value to sort by.
Must be one of SORT_BY* constants.
Returns:
profile_datum property to sort by.
"""
if sort_by == SORT_OPS_BY_OP_NAME:
return profile_datum.node_exec_stats.node_name
elif sort_by == SORT_OPS_BY_LINE:
return profile_datum.file_line
elif sort_by == SORT_OPS_BY_OP_TIME:
return profile_datum.op_time
elif sort_by == SORT_OPS_BY_EXEC_TIME:
return profile_datum.node_exec_stats.all_end_rel_micros
else: # sort by start time
return profile_datum.node_exec_stats.all_start_micros
class ProfileAnalyzer(object):
"""Analyzer for profiling data."""
def __init__(self, graph, run_metadata):
"""ProfileAnalyzer constructor.
Args:
graph: (tf.Graph) Python graph object.
run_metadata: A `RunMetadata` protobuf object.
Raises:
ValueError: If run_metadata is None.
"""
self._graph = graph
if not run_metadata:
raise ValueError("No RunMetadata passed for profile analysis.")
self._run_metadata = run_metadata
self._arg_parsers = {}
ap = argparse.ArgumentParser(
description="List nodes profile information.",
usage=argparse.SUPPRESS)
ap.add_argument(
"-d",
"--device_name_filter",
dest="device_name_filter",
type=str,
default="",
help="filter device name by regex.")
ap.add_argument(
"-n",
"--node_name_filter",
dest="node_name_filter",
type=str,
default="",
help="filter node name by regex.")
ap.add_argument(
"-t",
"--op_type_filter",
dest="op_type_filter",
type=str,
default="",
help="filter op type by regex.")
# TODO(annarev): allow file filtering at non-stack top position.
ap.add_argument(
"-f",
"--file_name_filter",
dest="file_name_filter",
type=str,
default="",
help="filter by file name at the top position of node's creation "
"stack that does not belong to TensorFlow library.")
ap.add_argument(
"-e",
"--execution_time",
dest="execution_time",
type=str,
default="",
help="Filter by execution time interval "
"(includes compute plus pre- and post -processing time). "
"Supported units are s, ms and us (default). "
"E.g. -e >100s, -e <100, -e [100us,1000ms]")
ap.add_argument(
"-o",
"--op_time",
dest="op_time",
type=str,
default="",
help="Filter by op time interval (only includes compute time). "
"Supported units are s, ms and us (default). "
"E.g. -e >100s, -e <100, -e [100us,1000ms]")
ap.add_argument(
"-s",
"--sort_by",
dest="sort_by",
type=str,
default=SORT_OPS_BY_START_TIME,
help=("the field to sort the data by: (%s | %s | %s | %s | %s)" %
(SORT_OPS_BY_OP_NAME, SORT_OPS_BY_START_TIME,
SORT_OPS_BY_OP_TIME, SORT_OPS_BY_EXEC_TIME, SORT_OPS_BY_LINE)))
ap.add_argument(
"-r",
"--reverse",
dest="reverse",
action="store_true",
help="sort the data in reverse (descending) order")
self._arg_parsers["list_profile"] = ap
def list_profile(self, args, screen_info=None):
"""Command handler for list_profile.
List per-operation profile information.
Args:
args: Command-line arguments, excluding the command prefix, as a list of
str.
screen_info: Optional dict input containing screen information such as
cols.
Returns:
Output text lines as a RichTextLines object.
"""
del screen_info
parsed = self._arg_parsers["list_profile"].parse_args(args)
op_time_interval = (command_parser.parse_time_interval(parsed.op_time)
if parsed.op_time else None)
exec_time_interval = (
command_parser.parse_time_interval(parsed.execution_time)
if parsed.execution_time else None)
node_name_regex = re.compile(parsed.node_name_filter)
file_name_regex = re.compile(parsed.file_name_filter)
op_type_regex = re.compile(parsed.op_type_filter)
output = debugger_cli_common.RichTextLines([""])
device_name_regex = re.compile(parsed.device_name_filter)
data_generator = self._get_profile_data_generator()
device_count = len(self._run_metadata.step_stats.dev_stats)
for index in range(device_count):
device_stats = self._run_metadata.step_stats.dev_stats[index]
if device_name_regex.match(device_stats.device):
profile_data = [
datum for datum in data_generator(device_stats)
if _list_profile_filter(
datum, node_name_regex, file_name_regex, op_type_regex,
op_time_interval, exec_time_interval)]
profile_data = sorted(
profile_data,
key=lambda datum: _list_profile_sort_key(datum, parsed.sort_by),
reverse=parsed.reverse)
output.extend(
self._get_list_profile_lines(
device_stats.device, index, device_count,
profile_data, parsed.sort_by, parsed.reverse))
return output
def _get_profile_data_generator(self):
"""Get function that generates `ProfileDatum` objects.
Returns:
A function that generates `ProfileDatum` objects.
"""
node_to_file_line = {}
node_to_op_type = {}
for op in self._graph.get_operations():
file_line = ""
for trace_entry in reversed(op.traceback):
filepath = trace_entry[0]
file_line = "%s:%d(%s)" % (
os.path.basename(filepath), trace_entry[1], trace_entry[2])
if not source_utils.guess_is_tensorflow_py_library(filepath):
break
node_to_file_line[op.name] = file_line
node_to_op_type[op.name] = op.type
def profile_data_generator(device_step_stats):
for node_stats in device_step_stats.node_stats:
if node_stats.node_name == "_SOURCE" or node_stats.node_name == "_SINK":
continue
yield ProfileDatum(
node_stats,
node_to_file_line.get(node_stats.node_name, ""),
node_to_op_type.get(node_stats.node_name, ""))
return profile_data_generator
def _get_list_profile_lines(
self, device_name, device_index, device_count,
profile_datum_list, sort_by, sort_reverse):
"""Get `RichTextLines` object for list_profile command for a given device.
Args:
device_name: (string) Device name.
device_index: (int) Device index.
device_count: (int) Number of devices.
profile_datum_list: List of `ProfileDatum` objects.
sort_by: (string) Identifier of column to sort. Sort identifier
must match value of SORT_OPS_BY_OP_NAME, SORT_OPS_BY_EXEC_TIME,
SORT_OPS_BY_MEMORY or SORT_OPS_BY_LINE.
sort_reverse: (bool) Whether to sort in descending instead of default
(ascending) order.
Returns:
`RichTextLines` object containing a table that displays profiling
information for each op.
"""
profile_data = ProfileDataTableView(profile_datum_list)
# Calculate total time early to calculate column widths.
total_op_time = sum(datum.op_time for datum in profile_datum_list)
total_exec_time = sum(datum.node_exec_stats.all_end_rel_micros
for datum in profile_datum_list)
device_total_row = [
"Device Total", cli_shared.time_to_readable_str(total_op_time),
cli_shared.time_to_readable_str(total_exec_time)]
# Calculate column widths.
column_widths = [
len(column_name) for column_name in profile_data.column_names()]
for col in range(len(device_total_row)):
column_widths[col] = max(column_widths[col], len(device_total_row[col]))
for col in range(len(column_widths)):
for row in range(profile_data.row_count()):
column_widths[col] = max(
column_widths[col], len(str(profile_data.value(row, col))))
column_widths[col] += 2 # add margin between columns
# Add device name.
output = debugger_cli_common.RichTextLines(["-"*80])
device_row = "Device %d of %d: %s" % (
device_index + 1, device_count, device_name)
output.extend(debugger_cli_common.RichTextLines([device_row, ""]))
# Add headers.
base_command = "list_profile"
attr_segs = {0: []}
row = ""
for col in range(profile_data.column_count()):
column_name = profile_data.column_names()[col]
sort_id = profile_data.column_sort_id(col)
command = "%s -s %s" % (base_command, sort_id)
if sort_by == sort_id and not sort_reverse:
command += " -r"
curr_row = ("{:<%d}" % column_widths[col]).format(column_name)
prev_len = len(row)
row += curr_row
attr_segs[0].append(
(prev_len, prev_len + len(column_name),
[debugger_cli_common.MenuItem(None, command), "bold"]))
output.extend(
debugger_cli_common.RichTextLines([row], font_attr_segs=attr_segs))
# Add data rows.
for row in range(profile_data.row_count()):
row_str = ""
for col in range(profile_data.column_count()):
row_str += ("{:<%d}" % column_widths[col]).format(
profile_data.value(row, col))
output.extend(debugger_cli_common.RichTextLines([row_str]))
# Add stat totals.
row_str = ""
for col in range(len(device_total_row)):
row_str += ("{:<%d}" % column_widths[col]).format(device_total_row[col])
output.extend(debugger_cli_common.RichTextLines(""))
output.extend(debugger_cli_common.RichTextLines(row_str))
return output
def _measure_list_profile_column_widths(self, profile_data):
"""Determine the maximum column widths for each data list.
Args:
profile_data: list of ProfileDatum objects.
Returns:
List of column widths in the same order as columns in data.
"""
num_columns = len(profile_data.column_names())
widths = [len(column_name) for column_name in profile_data.column_names()]
for row in range(profile_data.row_count()):
for col in range(num_columns):
widths[col] = max(
widths[col], len(str(profile_data.row_values(row)[col])) + 2)
return widths
def get_help(self, handler_name):
return self._arg_parsers[handler_name].format_help()
def create_profiler_ui(graph,
run_metadata,
ui_type="curses",
on_ui_exit=None):
"""Create an instance of CursesUI based on a `tf.Graph` and `RunMetadata`.
Args:
graph: Python `Graph` object.
run_metadata: A `RunMetadata` protobuf object.
ui_type: (str) requested UI type, e.g., "curses", "readline".
on_ui_exit: (`Callable`) the callback to be called when the UI exits.
Returns:
(base_ui.BaseUI) A BaseUI subtype object with a set of standard analyzer
commands and tab-completions registered.
"""
analyzer = ProfileAnalyzer(graph, run_metadata)
cli = ui_factory.get_ui(ui_type, on_ui_exit=on_ui_exit)
cli.register_command_handler(
"list_profile",
analyzer.list_profile,
analyzer.get_help("list_profile"),
prefix_aliases=["lp"])
return cli

View File

@ -1,264 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for profile_analyzer_cli."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
from tensorflow.core.framework import step_stats_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.debug.cli import profile_analyzer_cli
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
from tensorflow.python.platform import test
class ProfileAnalyzerTest(test_util.TensorFlowTestCase):
def testNodeInfoEmpty(self):
graph = ops.Graph()
run_metadata = config_pb2.RunMetadata()
prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata)
prof_output = prof_analyzer.list_profile([]).lines
self.assertEquals([""], prof_output)
def testSingleDevice(self):
node1 = step_stats_pb2.NodeExecStats(
node_name="Add/123",
op_start_rel_micros=3,
op_end_rel_micros=5,
all_end_rel_micros=4)
node2 = step_stats_pb2.NodeExecStats(
node_name="Mul/456",
op_start_rel_micros=1,
op_end_rel_micros=2,
all_end_rel_micros=3)
run_metadata = config_pb2.RunMetadata()
device1 = run_metadata.step_stats.dev_stats.add()
device1.device = "deviceA"
device1.node_stats.extend([node1, node2])
graph = test.mock.MagicMock()
op1 = test.mock.MagicMock()
op1.name = "Add/123"
op1.traceback = [("a/b/file1", 10, "some_var")]
op1.type = "add"
op2 = test.mock.MagicMock()
op2.name = "Mul/456"
op2.traceback = [("a/b/file1", 11, "some_var")]
op2.type = "mul"
graph.get_operations.return_value = [op1, op2]
prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata)
prof_output = prof_analyzer.list_profile([]).lines
self._assertAtLeastOneLineMatches(r"Device 1 of 1: deviceA", prof_output)
self._assertAtLeastOneLineMatches(r"^Add/123.*2us.*4us", prof_output)
self._assertAtLeastOneLineMatches(r"^Mul/456.*1us.*3us", prof_output)
def testMultipleDevices(self):
node1 = step_stats_pb2.NodeExecStats(
node_name="Add/123",
op_start_rel_micros=3,
op_end_rel_micros=5,
all_end_rel_micros=3)
run_metadata = config_pb2.RunMetadata()
device1 = run_metadata.step_stats.dev_stats.add()
device1.device = "deviceA"
device1.node_stats.extend([node1])
device2 = run_metadata.step_stats.dev_stats.add()
device2.device = "deviceB"
device2.node_stats.extend([node1])
graph = test.mock.MagicMock()
op = test.mock.MagicMock()
op.name = "Add/123"
op.traceback = [("a/b/file1", 10, "some_var")]
op.type = "abc"
graph.get_operations.return_value = [op]
prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata)
prof_output = prof_analyzer.list_profile([]).lines
self._assertAtLeastOneLineMatches(r"Device 1 of 2: deviceA", prof_output)
self._assertAtLeastOneLineMatches(r"Device 2 of 2: deviceB", prof_output)
# Try filtering by device.
prof_output = prof_analyzer.list_profile(["-d", "deviceB"]).lines
self._assertAtLeastOneLineMatches(r"Device 2 of 2: deviceB", prof_output)
self._assertNoLinesMatch(r"Device 1 of 2: deviceA", prof_output)
def testWithSession(self):
options = config_pb2.RunOptions()
options.trace_level = config_pb2.RunOptions.FULL_TRACE
run_metadata = config_pb2.RunMetadata()
with session.Session() as sess:
a = constant_op.constant([1, 2, 3])
b = constant_op.constant([2, 2, 1])
result = math_ops.add(a, b)
sess.run(result, options=options, run_metadata=run_metadata)
prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(
sess.graph, run_metadata)
prof_output = prof_analyzer.list_profile([]).lines
self._assertAtLeastOneLineMatches("Device 1 of 1:", prof_output)
expected_headers = [
"Node", "Op Time", "Exec Time", r"Filename:Lineno\(function\)"]
self._assertAtLeastOneLineMatches(
".*".join(expected_headers), prof_output)
self._assertAtLeastOneLineMatches(r"^Add/", prof_output)
self._assertAtLeastOneLineMatches(r"Device Total", prof_output)
def testSorting(self):
node1 = step_stats_pb2.NodeExecStats(
node_name="Add/123",
all_start_micros=123,
op_start_rel_micros=3,
op_end_rel_micros=5,
all_end_rel_micros=4)
node2 = step_stats_pb2.NodeExecStats(
node_name="Mul/456",
all_start_micros=122,
op_start_rel_micros=1,
op_end_rel_micros=2,
all_end_rel_micros=5)
run_metadata = config_pb2.RunMetadata()
device1 = run_metadata.step_stats.dev_stats.add()
device1.device = "deviceA"
device1.node_stats.extend([node1, node2])
graph = test.mock.MagicMock()
op1 = test.mock.MagicMock()
op1.name = "Add/123"
op1.traceback = [("a/b/file2", 10, "some_var")]
op1.type = "add"
op2 = test.mock.MagicMock()
op2.name = "Mul/456"
op2.traceback = [("a/b/file1", 11, "some_var")]
op2.type = "mul"
graph.get_operations.return_value = [op1, op2]
prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata)
# Default sort by start time (i.e. all_start_micros).
prof_output = prof_analyzer.list_profile([]).lines
self.assertRegexpMatches("".join(prof_output), r"Mul/456.*Add/123")
# Default sort in reverse.
prof_output = prof_analyzer.list_profile(["-r"]).lines
self.assertRegexpMatches("".join(prof_output), r"Add/123.*Mul/456")
# Sort by name.
prof_output = prof_analyzer.list_profile(["-s", "node"]).lines
self.assertRegexpMatches("".join(prof_output), r"Add/123.*Mul/456")
# Sort by op time (i.e. op_end_rel_micros - op_start_rel_micros).
prof_output = prof_analyzer.list_profile(["-s", "op_time"]).lines
self.assertRegexpMatches("".join(prof_output), r"Mul/456.*Add/123")
# Sort by exec time (i.e. all_end_rel_micros).
prof_output = prof_analyzer.list_profile(["-s", "exec_time"]).lines
self.assertRegexpMatches("".join(prof_output), r"Add/123.*Mul/456")
# Sort by line number.
prof_output = prof_analyzer.list_profile(["-s", "line"]).lines
self.assertRegexpMatches("".join(prof_output), r"Mul/456.*Add/123")
def testFiltering(self):
node1 = step_stats_pb2.NodeExecStats(
node_name="Add/123",
all_start_micros=123,
op_start_rel_micros=3,
op_end_rel_micros=5,
all_end_rel_micros=4)
node2 = step_stats_pb2.NodeExecStats(
node_name="Mul/456",
all_start_micros=122,
op_start_rel_micros=1,
op_end_rel_micros=2,
all_end_rel_micros=5)
run_metadata = config_pb2.RunMetadata()
device1 = run_metadata.step_stats.dev_stats.add()
device1.device = "deviceA"
device1.node_stats.extend([node1, node2])
graph = test.mock.MagicMock()
op1 = test.mock.MagicMock()
op1.name = "Add/123"
op1.traceback = [("a/b/file2", 10, "some_var")]
op1.type = "add"
op2 = test.mock.MagicMock()
op2.name = "Mul/456"
op2.traceback = [("a/b/file1", 11, "some_var")]
op2.type = "mul"
graph.get_operations.return_value = [op1, op2]
prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata)
# Filter by name
prof_output = prof_analyzer.list_profile(["-n", "Add"]).lines
self._assertAtLeastOneLineMatches(r"Add/123", prof_output)
self._assertNoLinesMatch(r"Mul/456", prof_output)
# Filter by op_type
prof_output = prof_analyzer.list_profile(["-t", "mul"]).lines
self._assertAtLeastOneLineMatches(r"Mul/456", prof_output)
self._assertNoLinesMatch(r"Add/123", prof_output)
# Filter by file name.
prof_output = prof_analyzer.list_profile(["-f", "file2"]).lines
self._assertAtLeastOneLineMatches(r"Add/123", prof_output)
self._assertNoLinesMatch(r"Mul/456", prof_output)
# Fitler by execution time.
prof_output = prof_analyzer.list_profile(["-e", "[5, 10]"]).lines
self._assertAtLeastOneLineMatches(r"Mul/456", prof_output)
self._assertNoLinesMatch(r"Add/123", prof_output)
# Fitler by op time.
prof_output = prof_analyzer.list_profile(["-o", ">=2"]).lines
self._assertAtLeastOneLineMatches(r"Add/123", prof_output)
self._assertNoLinesMatch(r"Mul/456", prof_output)
def _atLeastOneLineMatches(self, pattern, lines):
pattern_re = re.compile(pattern)
for line in lines:
if pattern_re.match(line):
return True
return False
def _assertAtLeastOneLineMatches(self, pattern, lines):
if not self._atLeastOneLineMatches(pattern, lines):
raise AssertionError(
"%s does not match any line in %s." % (pattern, str(lines)))
def _assertNoLinesMatch(self, pattern, lines):
if self._atLeastOneLineMatches(pattern, lines):
raise AssertionError(
"%s matched at least one line in %s." % (pattern, str(lines)))
if __name__ == "__main__":
googletest.main()

View File

@ -44,7 +44,7 @@ def _convert_watch_key_to_tensor_name(watch_key):
return watch_key[:watch_key.rfind(":")]
def guess_is_tensorflow_py_library(py_file_path):
def _guess_is_tensorflow_py_library(py_file_path):
"""Guess whether a Python source file is a part of the tensorflow library.
Special cases:
@ -231,7 +231,7 @@ def list_source_files_against_dump(dump,
for file_path in path_to_node_names:
output.append((
file_path,
guess_is_tensorflow_py_library(file_path),
_guess_is_tensorflow_py_library(file_path),
len(path_to_node_names.get(file_path, {})),
len(path_to_tensor_names.get(file_path, {})),
path_to_num_dumps.get(file_path, 0),

View File

@ -57,20 +57,20 @@ class GuessIsTensorFlowLibraryTest(test_util.TensorFlowTestCase):
def testUnitTestFileReturnsFalse(self):
self.assertFalse(
source_utils.guess_is_tensorflow_py_library(self.curr_file_path))
source_utils._guess_is_tensorflow_py_library(self.curr_file_path))
def testSourceUtilModuleReturnsTrue(self):
self.assertTrue(
source_utils.guess_is_tensorflow_py_library(source_utils.__file__))
source_utils._guess_is_tensorflow_py_library(source_utils.__file__))
def testFileInPythonKernelsPathReturnsTrue(self):
x = constant_op.constant(42.0, name="x")
self.assertTrue(
source_utils.guess_is_tensorflow_py_library(x.op.traceback[-1][0]))
source_utils._guess_is_tensorflow_py_library(x.op.traceback[-1][0]))
def testNonPythonFileRaisesException(self):
with self.assertRaisesRegexp(ValueError, r"is not a Python source file"):
source_utils.guess_is_tensorflow_py_library(
source_utils._guess_is_tensorflow_py_library(
os.path.join(os.path.dirname(self.curr_file_path), "foo.cc"))