From 0a583aae32d62ac66be962694dcedc0193e8bffa Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Apr 2017 13:52:13 -0800 Subject: [PATCH] Automated rollback of change 154220704 Change: 154225030 --- tensorflow/python/debug/BUILD | 35 -- tensorflow/python/debug/cli/cli_shared.py | 10 - .../python/debug/cli/cli_shared_test.py | 15 - tensorflow/python/debug/cli/command_parser.py | 25 +- .../python/debug/cli/command_parser_test.py | 19 - .../python/debug/cli/profile_analyzer_cli.py | 459 ------------------ .../debug/cli/profile_analyzer_cli_test.py | 264 ---------- tensorflow/python/debug/lib/source_utils.py | 4 +- .../python/debug/lib/source_utils_test.py | 8 +- 9 files changed, 9 insertions(+), 830 deletions(-) delete mode 100644 tensorflow/python/debug/cli/profile_analyzer_cli.py delete mode 100644 tensorflow/python/debug/cli/profile_analyzer_cli_test.py diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD index e738c86a1f8..f7e17f1c53d 100644 --- a/tensorflow/python/debug/BUILD +++ b/tensorflow/python/debug/BUILD @@ -147,22 +147,6 @@ py_library( ], ) -py_library( - name = "profile_analyzer_cli", - srcs = ["cli/profile_analyzer_cli.py"], - srcs_version = "PY2AND3", - deps = [ - ":cli_shared", - ":command_parser", - ":debug_data", - ":debugger_cli_common", - ":source_utils", - ":ui_factory", - "//third_party/py/numpy", - "@six_archive//:six", - ], -) - py_library( name = "stepper_cli", srcs = ["cli/stepper_cli.py"], @@ -257,7 +241,6 @@ py_library( ":debug_data", ":debugger_cli_common", ":framework", - ":profile_analyzer_cli", ":stepper_cli", ":ui_factory", ], @@ -623,24 +606,6 @@ cuda_py_test( ], ) -py_test( - name = "profile_analyzer_cli_test", - size = "small", - srcs = [ - "cli/profile_analyzer_cli_test.py", - ], - deps = [ - ":command_parser", - ":profile_analyzer_cli", - "//tensorflow/python:client_testlib", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - ], -) - cuda_py_test( name = "stepper_cli_test", size = "small", diff --git a/tensorflow/python/debug/cli/cli_shared.py b/tensorflow/python/debug/cli/cli_shared.py index 3deb6dbad65..8ff09167614 100644 --- a/tensorflow/python/debug/cli/cli_shared.py +++ b/tensorflow/python/debug/cli/cli_shared.py @@ -17,8 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import math - import numpy as np import six @@ -75,14 +73,6 @@ def bytes_to_readable_str(num_bytes, include_b=False): return result -def time_to_readable_str(value): - if not value: - return "0" - suffixes = ["us", "ms", "s"] - order = min(len(suffixes) - 1, int(math.log(value, 10) / 3)) - return "{:.3g}{}".format(value / math.pow(10.0, 3*order), suffixes[order]) - - def parse_ranges_highlight(ranges_string): """Process ranges highlight string. diff --git a/tensorflow/python/debug/cli/cli_shared_test.py b/tensorflow/python/debug/cli/cli_shared_test.py index fde1d66998f..1ef3c342546 100644 --- a/tensorflow/python/debug/cli/cli_shared_test.py +++ b/tensorflow/python/debug/cli/cli_shared_test.py @@ -70,21 +70,6 @@ class BytesToReadableStrTest(test_util.TensorFlowTestCase): 1024**3, include_b=True)) -class TimeToReadableStrTest(test_util.TensorFlowTestCase): - - def testNoneTimeWorks(self): - self.assertEqual("0", cli_shared.time_to_readable_str(None)) - - def testMicrosecondsTime(self): - self.assertEqual("40us", cli_shared.time_to_readable_str(40)) - - def testMillisecondTime(self): - self.assertEqual("40ms", cli_shared.time_to_readable_str(40e3)) - - def testSecondTime(self): - self.assertEqual("40s", cli_shared.time_to_readable_str(40e6)) - - class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase): def setUp(self): diff --git a/tensorflow/python/debug/cli/command_parser.py b/tensorflow/python/debug/cli/command_parser.py index 143c1045199..a71982f86a6 100644 --- a/tensorflow/python/debug/cli/command_parser.py +++ b/tensorflow/python/debug/cli/command_parser.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function import ast +from collections import namedtuple import re import sys @@ -28,28 +29,8 @@ _WHITESPACE_PATTERN = re.compile(r"\s+") _NUMBER_PATTERN = re.compile(r"[-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?") - -class Interval(object): - """Represents an interval between a start and end value.""" - - def __init__(self, start, start_included, end, end_included): - self.start = start - self.start_included = start_included - self.end = end - self.end_included = end_included - - def contains(self, value): - if value < self.start or value == self.start and not self.start_included: - return False - if value > self.end or value == self.end and not self.end_included: - return False - return True - - def __eq__(self, other): - return (self.start == other.start and - self.start_included == other.start_included and - self.end == other.end and - self.end_included == other.end_included) +Interval = namedtuple("Interval", + ["start", "start_included", "end", "end_included"]) def parse_command(command): diff --git a/tensorflow/python/debug/cli/command_parser_test.py b/tensorflow/python/debug/cli/command_parser_test.py index 1ea890be8c9..ab9b3245dc6 100644 --- a/tensorflow/python/debug/cli/command_parser_test.py +++ b/tensorflow/python/debug/cli/command_parser_test.py @@ -490,25 +490,6 @@ class ParseInterval(test_util.TensorFlowTestCase): "equal to end of interval."): command_parser.parse_memory_interval("[5k, 3k]") - def testIntervalContains(self): - interval = command_parser.Interval( - start=1, start_included=True, end=10, end_included=True) - self.assertTrue(interval.contains(1)) - self.assertTrue(interval.contains(10)) - self.assertTrue(interval.contains(5)) - - interval.start_included = False - self.assertFalse(interval.contains(1)) - self.assertTrue(interval.contains(10)) - - interval.end_included = False - self.assertFalse(interval.contains(1)) - self.assertFalse(interval.contains(10)) - - interval.start_included = True - self.assertTrue(interval.contains(1)) - self.assertFalse(interval.contains(10)) - if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/debug/cli/profile_analyzer_cli.py b/tensorflow/python/debug/cli/profile_analyzer_cli.py deleted file mode 100644 index 42440521eba..00000000000 --- a/tensorflow/python/debug/cli/profile_analyzer_cli.py +++ /dev/null @@ -1,459 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Formats and displays profiling information.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import argparse -import os -import re - -from tensorflow.python.debug.cli import cli_shared -from tensorflow.python.debug.cli import command_parser -from tensorflow.python.debug.cli import debugger_cli_common -from tensorflow.python.debug.cli import ui_factory -from tensorflow.python.debug.lib import source_utils - - -SORT_OPS_BY_OP_NAME = "node" -SORT_OPS_BY_OP_TIME = "op_time" -SORT_OPS_BY_EXEC_TIME = "exec_time" -SORT_OPS_BY_START_TIME = "start_time" -SORT_OPS_BY_LINE = "line" - - -class ProfileDatum(object): - """Profile data point.""" - - def __init__(self, node_exec_stats, file_line, op_type): - """Constructor. - - Args: - node_exec_stats: `NodeExecStats` proto. - file_line: A `string` formatted as :. - op_type: (string) Operation type. - """ - self.node_exec_stats = node_exec_stats - self.file_line = file_line - self.op_type = op_type - self.op_time = (self.node_exec_stats.op_end_rel_micros - - self.node_exec_stats.op_start_rel_micros) - - @property - def exec_time(self): - """Measures compute function exection time plus pre- and post-processing.""" - return self.node_exec_stats.all_end_rel_micros - - -class ProfileDataTableView(object): - """Table View of profiling data.""" - - def __init__(self, profile_datum_list): - """Constructor. - - Args: - profile_datum_list: List of `ProfileDatum` objects. - """ - self._profile_datum_list = profile_datum_list - self.formatted_op_time = [ - cli_shared.time_to_readable_str(datum.op_time) - for datum in profile_datum_list] - self.formatted_exec_time = [ - cli_shared.time_to_readable_str( - datum.node_exec_stats.all_end_rel_micros) - for datum in profile_datum_list] - self._column_sort_ids = [SORT_OPS_BY_OP_NAME, SORT_OPS_BY_OP_TIME, - SORT_OPS_BY_EXEC_TIME, SORT_OPS_BY_LINE] - - def value(self, row, col): - if col == 0: - return self._profile_datum_list[row].node_exec_stats.node_name - elif col == 1: - return self.formatted_op_time[row] - elif col == 2: - return self.formatted_exec_time[row] - elif col == 3: - return self._profile_datum_list[row].file_line - else: - raise IndexError("Invalid column index %d." % col) - - def row_count(self): - return len(self._profile_datum_list) - - def column_count(self): - return 4 - - def column_names(self): - return ["Node", "Op Time", "Exec Time", "Filename:Lineno(function)"] - - def column_sort_id(self, col): - return self._column_sort_ids[col] - - -def _list_profile_filter( - profile_datum, node_name_regex, file_name_regex, op_type_regex, - op_time_interval, exec_time_interval): - """Filter function for list_profile command. - - Args: - profile_datum: A `ProfileDatum` object. - node_name_regex: Regular expression pattern object to filter by name. - file_name_regex: Regular expression pattern object to filter by file. - op_type_regex: Regular expression pattern object to filter by op type. - op_time_interval: `Interval` for filtering op time. - exec_time_interval: `Interval` for filtering exec time. - - Returns: - True if profile_datum should be included. - """ - if not node_name_regex.match( - profile_datum.node_exec_stats.node_name): - return False - if profile_datum.file_line is not None and not file_name_regex.match( - profile_datum.file_line): - return False - if profile_datum.op_type is not None and not op_type_regex.match( - profile_datum.op_type): - return False - if op_time_interval is not None and not op_time_interval.contains( - profile_datum.op_time): - return False - if exec_time_interval and not exec_time_interval.contains( - profile_datum.node_exec_stats.all_end_rel_micros): - return False - return True - - -def _list_profile_sort_key(profile_datum, sort_by): - """Get a profile_datum property to sort by in list_profile command. - - Args: - profile_datum: A `ProfileDatum` object. - sort_by: (string) indicates a value to sort by. - Must be one of SORT_BY* constants. - - Returns: - profile_datum property to sort by. - """ - if sort_by == SORT_OPS_BY_OP_NAME: - return profile_datum.node_exec_stats.node_name - elif sort_by == SORT_OPS_BY_LINE: - return profile_datum.file_line - elif sort_by == SORT_OPS_BY_OP_TIME: - return profile_datum.op_time - elif sort_by == SORT_OPS_BY_EXEC_TIME: - return profile_datum.node_exec_stats.all_end_rel_micros - else: # sort by start time - return profile_datum.node_exec_stats.all_start_micros - - -class ProfileAnalyzer(object): - """Analyzer for profiling data.""" - - def __init__(self, graph, run_metadata): - """ProfileAnalyzer constructor. - - Args: - graph: (tf.Graph) Python graph object. - run_metadata: A `RunMetadata` protobuf object. - - Raises: - ValueError: If run_metadata is None. - """ - self._graph = graph - if not run_metadata: - raise ValueError("No RunMetadata passed for profile analysis.") - self._run_metadata = run_metadata - self._arg_parsers = {} - ap = argparse.ArgumentParser( - description="List nodes profile information.", - usage=argparse.SUPPRESS) - ap.add_argument( - "-d", - "--device_name_filter", - dest="device_name_filter", - type=str, - default="", - help="filter device name by regex.") - ap.add_argument( - "-n", - "--node_name_filter", - dest="node_name_filter", - type=str, - default="", - help="filter node name by regex.") - ap.add_argument( - "-t", - "--op_type_filter", - dest="op_type_filter", - type=str, - default="", - help="filter op type by regex.") - # TODO(annarev): allow file filtering at non-stack top position. - ap.add_argument( - "-f", - "--file_name_filter", - dest="file_name_filter", - type=str, - default="", - help="filter by file name at the top position of node's creation " - "stack that does not belong to TensorFlow library.") - ap.add_argument( - "-e", - "--execution_time", - dest="execution_time", - type=str, - default="", - help="Filter by execution time interval " - "(includes compute plus pre- and post -processing time). " - "Supported units are s, ms and us (default). " - "E.g. -e >100s, -e <100, -e [100us,1000ms]") - ap.add_argument( - "-o", - "--op_time", - dest="op_time", - type=str, - default="", - help="Filter by op time interval (only includes compute time). " - "Supported units are s, ms and us (default). " - "E.g. -e >100s, -e <100, -e [100us,1000ms]") - ap.add_argument( - "-s", - "--sort_by", - dest="sort_by", - type=str, - default=SORT_OPS_BY_START_TIME, - help=("the field to sort the data by: (%s | %s | %s | %s | %s)" % - (SORT_OPS_BY_OP_NAME, SORT_OPS_BY_START_TIME, - SORT_OPS_BY_OP_TIME, SORT_OPS_BY_EXEC_TIME, SORT_OPS_BY_LINE))) - ap.add_argument( - "-r", - "--reverse", - dest="reverse", - action="store_true", - help="sort the data in reverse (descending) order") - - self._arg_parsers["list_profile"] = ap - - def list_profile(self, args, screen_info=None): - """Command handler for list_profile. - - List per-operation profile information. - - Args: - args: Command-line arguments, excluding the command prefix, as a list of - str. - screen_info: Optional dict input containing screen information such as - cols. - - Returns: - Output text lines as a RichTextLines object. - """ - del screen_info - - parsed = self._arg_parsers["list_profile"].parse_args(args) - op_time_interval = (command_parser.parse_time_interval(parsed.op_time) - if parsed.op_time else None) - exec_time_interval = ( - command_parser.parse_time_interval(parsed.execution_time) - if parsed.execution_time else None) - node_name_regex = re.compile(parsed.node_name_filter) - file_name_regex = re.compile(parsed.file_name_filter) - op_type_regex = re.compile(parsed.op_type_filter) - - output = debugger_cli_common.RichTextLines([""]) - device_name_regex = re.compile(parsed.device_name_filter) - data_generator = self._get_profile_data_generator() - device_count = len(self._run_metadata.step_stats.dev_stats) - for index in range(device_count): - device_stats = self._run_metadata.step_stats.dev_stats[index] - if device_name_regex.match(device_stats.device): - profile_data = [ - datum for datum in data_generator(device_stats) - if _list_profile_filter( - datum, node_name_regex, file_name_regex, op_type_regex, - op_time_interval, exec_time_interval)] - profile_data = sorted( - profile_data, - key=lambda datum: _list_profile_sort_key(datum, parsed.sort_by), - reverse=parsed.reverse) - output.extend( - self._get_list_profile_lines( - device_stats.device, index, device_count, - profile_data, parsed.sort_by, parsed.reverse)) - return output - - def _get_profile_data_generator(self): - """Get function that generates `ProfileDatum` objects. - - Returns: - A function that generates `ProfileDatum` objects. - """ - node_to_file_line = {} - node_to_op_type = {} - for op in self._graph.get_operations(): - file_line = "" - for trace_entry in reversed(op.traceback): - filepath = trace_entry[0] - file_line = "%s:%d(%s)" % ( - os.path.basename(filepath), trace_entry[1], trace_entry[2]) - if not source_utils.guess_is_tensorflow_py_library(filepath): - break - node_to_file_line[op.name] = file_line - node_to_op_type[op.name] = op.type - - def profile_data_generator(device_step_stats): - for node_stats in device_step_stats.node_stats: - if node_stats.node_name == "_SOURCE" or node_stats.node_name == "_SINK": - continue - yield ProfileDatum( - node_stats, - node_to_file_line.get(node_stats.node_name, ""), - node_to_op_type.get(node_stats.node_name, "")) - return profile_data_generator - - def _get_list_profile_lines( - self, device_name, device_index, device_count, - profile_datum_list, sort_by, sort_reverse): - """Get `RichTextLines` object for list_profile command for a given device. - - Args: - device_name: (string) Device name. - device_index: (int) Device index. - device_count: (int) Number of devices. - profile_datum_list: List of `ProfileDatum` objects. - sort_by: (string) Identifier of column to sort. Sort identifier - must match value of SORT_OPS_BY_OP_NAME, SORT_OPS_BY_EXEC_TIME, - SORT_OPS_BY_MEMORY or SORT_OPS_BY_LINE. - sort_reverse: (bool) Whether to sort in descending instead of default - (ascending) order. - - Returns: - `RichTextLines` object containing a table that displays profiling - information for each op. - """ - profile_data = ProfileDataTableView(profile_datum_list) - - # Calculate total time early to calculate column widths. - total_op_time = sum(datum.op_time for datum in profile_datum_list) - total_exec_time = sum(datum.node_exec_stats.all_end_rel_micros - for datum in profile_datum_list) - device_total_row = [ - "Device Total", cli_shared.time_to_readable_str(total_op_time), - cli_shared.time_to_readable_str(total_exec_time)] - - # Calculate column widths. - column_widths = [ - len(column_name) for column_name in profile_data.column_names()] - for col in range(len(device_total_row)): - column_widths[col] = max(column_widths[col], len(device_total_row[col])) - for col in range(len(column_widths)): - for row in range(profile_data.row_count()): - column_widths[col] = max( - column_widths[col], len(str(profile_data.value(row, col)))) - column_widths[col] += 2 # add margin between columns - - # Add device name. - output = debugger_cli_common.RichTextLines(["-"*80]) - device_row = "Device %d of %d: %s" % ( - device_index + 1, device_count, device_name) - output.extend(debugger_cli_common.RichTextLines([device_row, ""])) - - # Add headers. - base_command = "list_profile" - attr_segs = {0: []} - row = "" - for col in range(profile_data.column_count()): - column_name = profile_data.column_names()[col] - sort_id = profile_data.column_sort_id(col) - command = "%s -s %s" % (base_command, sort_id) - if sort_by == sort_id and not sort_reverse: - command += " -r" - curr_row = ("{:<%d}" % column_widths[col]).format(column_name) - prev_len = len(row) - row += curr_row - attr_segs[0].append( - (prev_len, prev_len + len(column_name), - [debugger_cli_common.MenuItem(None, command), "bold"])) - - output.extend( - debugger_cli_common.RichTextLines([row], font_attr_segs=attr_segs)) - - # Add data rows. - for row in range(profile_data.row_count()): - row_str = "" - for col in range(profile_data.column_count()): - row_str += ("{:<%d}" % column_widths[col]).format( - profile_data.value(row, col)) - output.extend(debugger_cli_common.RichTextLines([row_str])) - - # Add stat totals. - row_str = "" - for col in range(len(device_total_row)): - row_str += ("{:<%d}" % column_widths[col]).format(device_total_row[col]) - output.extend(debugger_cli_common.RichTextLines("")) - output.extend(debugger_cli_common.RichTextLines(row_str)) - return output - - def _measure_list_profile_column_widths(self, profile_data): - """Determine the maximum column widths for each data list. - - Args: - profile_data: list of ProfileDatum objects. - - Returns: - List of column widths in the same order as columns in data. - """ - num_columns = len(profile_data.column_names()) - widths = [len(column_name) for column_name in profile_data.column_names()] - for row in range(profile_data.row_count()): - for col in range(num_columns): - widths[col] = max( - widths[col], len(str(profile_data.row_values(row)[col])) + 2) - return widths - - def get_help(self, handler_name): - return self._arg_parsers[handler_name].format_help() - - -def create_profiler_ui(graph, - run_metadata, - ui_type="curses", - on_ui_exit=None): - """Create an instance of CursesUI based on a `tf.Graph` and `RunMetadata`. - - Args: - graph: Python `Graph` object. - run_metadata: A `RunMetadata` protobuf object. - ui_type: (str) requested UI type, e.g., "curses", "readline". - on_ui_exit: (`Callable`) the callback to be called when the UI exits. - - Returns: - (base_ui.BaseUI) A BaseUI subtype object with a set of standard analyzer - commands and tab-completions registered. - """ - - analyzer = ProfileAnalyzer(graph, run_metadata) - - cli = ui_factory.get_ui(ui_type, on_ui_exit=on_ui_exit) - cli.register_command_handler( - "list_profile", - analyzer.list_profile, - analyzer.get_help("list_profile"), - prefix_aliases=["lp"]) - - return cli diff --git a/tensorflow/python/debug/cli/profile_analyzer_cli_test.py b/tensorflow/python/debug/cli/profile_analyzer_cli_test.py deleted file mode 100644 index 7b34d87c991..00000000000 --- a/tensorflow/python/debug/cli/profile_analyzer_cli_test.py +++ /dev/null @@ -1,264 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for profile_analyzer_cli.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import re - -from tensorflow.core.framework import step_stats_pb2 -from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.client import session -from tensorflow.python.debug.cli import profile_analyzer_cli -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util -from tensorflow.python.ops import math_ops -from tensorflow.python.platform import googletest -from tensorflow.python.platform import test - - -class ProfileAnalyzerTest(test_util.TensorFlowTestCase): - - def testNodeInfoEmpty(self): - graph = ops.Graph() - run_metadata = config_pb2.RunMetadata() - - prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata) - prof_output = prof_analyzer.list_profile([]).lines - self.assertEquals([""], prof_output) - - def testSingleDevice(self): - node1 = step_stats_pb2.NodeExecStats( - node_name="Add/123", - op_start_rel_micros=3, - op_end_rel_micros=5, - all_end_rel_micros=4) - - node2 = step_stats_pb2.NodeExecStats( - node_name="Mul/456", - op_start_rel_micros=1, - op_end_rel_micros=2, - all_end_rel_micros=3) - - run_metadata = config_pb2.RunMetadata() - device1 = run_metadata.step_stats.dev_stats.add() - device1.device = "deviceA" - device1.node_stats.extend([node1, node2]) - - graph = test.mock.MagicMock() - op1 = test.mock.MagicMock() - op1.name = "Add/123" - op1.traceback = [("a/b/file1", 10, "some_var")] - op1.type = "add" - op2 = test.mock.MagicMock() - op2.name = "Mul/456" - op2.traceback = [("a/b/file1", 11, "some_var")] - op2.type = "mul" - graph.get_operations.return_value = [op1, op2] - - prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata) - prof_output = prof_analyzer.list_profile([]).lines - - self._assertAtLeastOneLineMatches(r"Device 1 of 1: deviceA", prof_output) - self._assertAtLeastOneLineMatches(r"^Add/123.*2us.*4us", prof_output) - self._assertAtLeastOneLineMatches(r"^Mul/456.*1us.*3us", prof_output) - - def testMultipleDevices(self): - node1 = step_stats_pb2.NodeExecStats( - node_name="Add/123", - op_start_rel_micros=3, - op_end_rel_micros=5, - all_end_rel_micros=3) - - run_metadata = config_pb2.RunMetadata() - device1 = run_metadata.step_stats.dev_stats.add() - device1.device = "deviceA" - device1.node_stats.extend([node1]) - - device2 = run_metadata.step_stats.dev_stats.add() - device2.device = "deviceB" - device2.node_stats.extend([node1]) - - graph = test.mock.MagicMock() - op = test.mock.MagicMock() - op.name = "Add/123" - op.traceback = [("a/b/file1", 10, "some_var")] - op.type = "abc" - graph.get_operations.return_value = [op] - - prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata) - prof_output = prof_analyzer.list_profile([]).lines - - self._assertAtLeastOneLineMatches(r"Device 1 of 2: deviceA", prof_output) - self._assertAtLeastOneLineMatches(r"Device 2 of 2: deviceB", prof_output) - - # Try filtering by device. - prof_output = prof_analyzer.list_profile(["-d", "deviceB"]).lines - self._assertAtLeastOneLineMatches(r"Device 2 of 2: deviceB", prof_output) - self._assertNoLinesMatch(r"Device 1 of 2: deviceA", prof_output) - - def testWithSession(self): - options = config_pb2.RunOptions() - options.trace_level = config_pb2.RunOptions.FULL_TRACE - run_metadata = config_pb2.RunMetadata() - - with session.Session() as sess: - a = constant_op.constant([1, 2, 3]) - b = constant_op.constant([2, 2, 1]) - result = math_ops.add(a, b) - - sess.run(result, options=options, run_metadata=run_metadata) - - prof_analyzer = profile_analyzer_cli.ProfileAnalyzer( - sess.graph, run_metadata) - prof_output = prof_analyzer.list_profile([]).lines - - self._assertAtLeastOneLineMatches("Device 1 of 1:", prof_output) - expected_headers = [ - "Node", "Op Time", "Exec Time", r"Filename:Lineno\(function\)"] - self._assertAtLeastOneLineMatches( - ".*".join(expected_headers), prof_output) - self._assertAtLeastOneLineMatches(r"^Add/", prof_output) - self._assertAtLeastOneLineMatches(r"Device Total", prof_output) - - def testSorting(self): - node1 = step_stats_pb2.NodeExecStats( - node_name="Add/123", - all_start_micros=123, - op_start_rel_micros=3, - op_end_rel_micros=5, - all_end_rel_micros=4) - - node2 = step_stats_pb2.NodeExecStats( - node_name="Mul/456", - all_start_micros=122, - op_start_rel_micros=1, - op_end_rel_micros=2, - all_end_rel_micros=5) - - run_metadata = config_pb2.RunMetadata() - device1 = run_metadata.step_stats.dev_stats.add() - device1.device = "deviceA" - device1.node_stats.extend([node1, node2]) - - graph = test.mock.MagicMock() - op1 = test.mock.MagicMock() - op1.name = "Add/123" - op1.traceback = [("a/b/file2", 10, "some_var")] - op1.type = "add" - op2 = test.mock.MagicMock() - op2.name = "Mul/456" - op2.traceback = [("a/b/file1", 11, "some_var")] - op2.type = "mul" - graph.get_operations.return_value = [op1, op2] - - prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata) - - # Default sort by start time (i.e. all_start_micros). - prof_output = prof_analyzer.list_profile([]).lines - self.assertRegexpMatches("".join(prof_output), r"Mul/456.*Add/123") - # Default sort in reverse. - prof_output = prof_analyzer.list_profile(["-r"]).lines - self.assertRegexpMatches("".join(prof_output), r"Add/123.*Mul/456") - # Sort by name. - prof_output = prof_analyzer.list_profile(["-s", "node"]).lines - self.assertRegexpMatches("".join(prof_output), r"Add/123.*Mul/456") - # Sort by op time (i.e. op_end_rel_micros - op_start_rel_micros). - prof_output = prof_analyzer.list_profile(["-s", "op_time"]).lines - self.assertRegexpMatches("".join(prof_output), r"Mul/456.*Add/123") - # Sort by exec time (i.e. all_end_rel_micros). - prof_output = prof_analyzer.list_profile(["-s", "exec_time"]).lines - self.assertRegexpMatches("".join(prof_output), r"Add/123.*Mul/456") - # Sort by line number. - prof_output = prof_analyzer.list_profile(["-s", "line"]).lines - self.assertRegexpMatches("".join(prof_output), r"Mul/456.*Add/123") - - def testFiltering(self): - node1 = step_stats_pb2.NodeExecStats( - node_name="Add/123", - all_start_micros=123, - op_start_rel_micros=3, - op_end_rel_micros=5, - all_end_rel_micros=4) - - node2 = step_stats_pb2.NodeExecStats( - node_name="Mul/456", - all_start_micros=122, - op_start_rel_micros=1, - op_end_rel_micros=2, - all_end_rel_micros=5) - - run_metadata = config_pb2.RunMetadata() - device1 = run_metadata.step_stats.dev_stats.add() - device1.device = "deviceA" - device1.node_stats.extend([node1, node2]) - - graph = test.mock.MagicMock() - op1 = test.mock.MagicMock() - op1.name = "Add/123" - op1.traceback = [("a/b/file2", 10, "some_var")] - op1.type = "add" - op2 = test.mock.MagicMock() - op2.name = "Mul/456" - op2.traceback = [("a/b/file1", 11, "some_var")] - op2.type = "mul" - graph.get_operations.return_value = [op1, op2] - - prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata) - - # Filter by name - prof_output = prof_analyzer.list_profile(["-n", "Add"]).lines - self._assertAtLeastOneLineMatches(r"Add/123", prof_output) - self._assertNoLinesMatch(r"Mul/456", prof_output) - # Filter by op_type - prof_output = prof_analyzer.list_profile(["-t", "mul"]).lines - self._assertAtLeastOneLineMatches(r"Mul/456", prof_output) - self._assertNoLinesMatch(r"Add/123", prof_output) - # Filter by file name. - prof_output = prof_analyzer.list_profile(["-f", "file2"]).lines - self._assertAtLeastOneLineMatches(r"Add/123", prof_output) - self._assertNoLinesMatch(r"Mul/456", prof_output) - # Fitler by execution time. - prof_output = prof_analyzer.list_profile(["-e", "[5, 10]"]).lines - self._assertAtLeastOneLineMatches(r"Mul/456", prof_output) - self._assertNoLinesMatch(r"Add/123", prof_output) - # Fitler by op time. - prof_output = prof_analyzer.list_profile(["-o", ">=2"]).lines - self._assertAtLeastOneLineMatches(r"Add/123", prof_output) - self._assertNoLinesMatch(r"Mul/456", prof_output) - - def _atLeastOneLineMatches(self, pattern, lines): - pattern_re = re.compile(pattern) - for line in lines: - if pattern_re.match(line): - return True - return False - - def _assertAtLeastOneLineMatches(self, pattern, lines): - if not self._atLeastOneLineMatches(pattern, lines): - raise AssertionError( - "%s does not match any line in %s." % (pattern, str(lines))) - - def _assertNoLinesMatch(self, pattern, lines): - if self._atLeastOneLineMatches(pattern, lines): - raise AssertionError( - "%s matched at least one line in %s." % (pattern, str(lines))) - - -if __name__ == "__main__": - googletest.main() diff --git a/tensorflow/python/debug/lib/source_utils.py b/tensorflow/python/debug/lib/source_utils.py index f610d05b83c..580bdc054bd 100644 --- a/tensorflow/python/debug/lib/source_utils.py +++ b/tensorflow/python/debug/lib/source_utils.py @@ -44,7 +44,7 @@ def _convert_watch_key_to_tensor_name(watch_key): return watch_key[:watch_key.rfind(":")] -def guess_is_tensorflow_py_library(py_file_path): +def _guess_is_tensorflow_py_library(py_file_path): """Guess whether a Python source file is a part of the tensorflow library. Special cases: @@ -231,7 +231,7 @@ def list_source_files_against_dump(dump, for file_path in path_to_node_names: output.append(( file_path, - guess_is_tensorflow_py_library(file_path), + _guess_is_tensorflow_py_library(file_path), len(path_to_node_names.get(file_path, {})), len(path_to_tensor_names.get(file_path, {})), path_to_num_dumps.get(file_path, 0), diff --git a/tensorflow/python/debug/lib/source_utils_test.py b/tensorflow/python/debug/lib/source_utils_test.py index f6195b6a5d1..a4fb0d99109 100644 --- a/tensorflow/python/debug/lib/source_utils_test.py +++ b/tensorflow/python/debug/lib/source_utils_test.py @@ -57,20 +57,20 @@ class GuessIsTensorFlowLibraryTest(test_util.TensorFlowTestCase): def testUnitTestFileReturnsFalse(self): self.assertFalse( - source_utils.guess_is_tensorflow_py_library(self.curr_file_path)) + source_utils._guess_is_tensorflow_py_library(self.curr_file_path)) def testSourceUtilModuleReturnsTrue(self): self.assertTrue( - source_utils.guess_is_tensorflow_py_library(source_utils.__file__)) + source_utils._guess_is_tensorflow_py_library(source_utils.__file__)) def testFileInPythonKernelsPathReturnsTrue(self): x = constant_op.constant(42.0, name="x") self.assertTrue( - source_utils.guess_is_tensorflow_py_library(x.op.traceback[-1][0])) + source_utils._guess_is_tensorflow_py_library(x.op.traceback[-1][0])) def testNonPythonFileRaisesException(self): with self.assertRaisesRegexp(ValueError, r"is not a Python source file"): - source_utils.guess_is_tensorflow_py_library( + source_utils._guess_is_tensorflow_py_library( os.path.join(os.path.dirname(self.curr_file_path), "foo.cc"))