tfdbg CLI: add initial support for command "run -p": profiler mode
Accompanying changes: * In list_profile / lp output, let the time unit be consistent across all ops, so that the results are easier to read. * Add the --time_unit option flag to list_profile / lp. * Add start time to the lp output table; allow sorting by start time. PiperOrigin-RevId: 155690128
This commit is contained in:
parent
e09b0b6ebf
commit
c4c22e7568
@ -44,6 +44,11 @@ COLOR_RED = "red"
|
||||
COLOR_WHITE = "white"
|
||||
COLOR_YELLOW = "yellow"
|
||||
|
||||
TIME_UNIT_US = "us"
|
||||
TIME_UNIT_MS = "ms"
|
||||
TIME_UNIT_S = "s"
|
||||
TIME_UNITS = [TIME_UNIT_US, TIME_UNIT_MS, TIME_UNIT_S]
|
||||
|
||||
|
||||
def bytes_to_readable_str(num_bytes, include_b=False):
|
||||
"""Generate a human-readable string representing number of bytes.
|
||||
@ -75,12 +80,32 @@ def bytes_to_readable_str(num_bytes, include_b=False):
|
||||
return result
|
||||
|
||||
|
||||
def time_to_readable_str(value):
|
||||
if not value:
|
||||
def time_to_readable_str(value_us, force_time_unit=None):
|
||||
"""Convert time value to human-readable string.
|
||||
|
||||
Args:
|
||||
value_us: time value in microseconds.
|
||||
force_time_unit: force the output to use the specified time unit. Must be
|
||||
in TIME_UNITS.
|
||||
|
||||
Returns:
|
||||
Human-readable string representation of the time value.
|
||||
|
||||
Raises:
|
||||
ValueError: if force_time_unit value is not in TIME_UNITS.
|
||||
"""
|
||||
if not value_us:
|
||||
return "0"
|
||||
suffixes = ["us", "ms", "s"]
|
||||
order = min(len(suffixes) - 1, int(math.log(value, 10) / 3))
|
||||
return "{:.3g}{}".format(value / math.pow(10.0, 3*order), suffixes[order])
|
||||
if force_time_unit:
|
||||
if force_time_unit not in TIME_UNITS:
|
||||
raise ValueError("Invalid time unit: %s" % force_time_unit)
|
||||
order = TIME_UNITS.index(force_time_unit)
|
||||
time_unit = force_time_unit
|
||||
return "{:.10g}{}".format(value_us / math.pow(10.0, 3*order), time_unit)
|
||||
else:
|
||||
order = min(len(TIME_UNITS) - 1, int(math.log(value_us, 10) / 3))
|
||||
time_unit = TIME_UNITS[order]
|
||||
return "{:.3g}{}".format(value_us / math.pow(10.0, 3*order), time_unit)
|
||||
|
||||
|
||||
def parse_ranges_highlight(ranges_string):
|
||||
|
@ -84,6 +84,26 @@ class TimeToReadableStrTest(test_util.TensorFlowTestCase):
|
||||
def testSecondTime(self):
|
||||
self.assertEqual("40s", cli_shared.time_to_readable_str(40e6))
|
||||
|
||||
def testForceTimeUnit(self):
|
||||
self.assertEqual("40s",
|
||||
cli_shared.time_to_readable_str(
|
||||
40e6, force_time_unit=cli_shared.TIME_UNIT_S))
|
||||
self.assertEqual("40000ms",
|
||||
cli_shared.time_to_readable_str(
|
||||
40e6, force_time_unit=cli_shared.TIME_UNIT_MS))
|
||||
self.assertEqual("40000000us",
|
||||
cli_shared.time_to_readable_str(
|
||||
40e6, force_time_unit=cli_shared.TIME_UNIT_US))
|
||||
self.assertEqual("4e-05s",
|
||||
cli_shared.time_to_readable_str(
|
||||
40, force_time_unit=cli_shared.TIME_UNIT_S))
|
||||
self.assertEqual("0",
|
||||
cli_shared.time_to_readable_str(
|
||||
0, force_time_unit=cli_shared.TIME_UNIT_S))
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, r"Invalid time unit: ks"):
|
||||
cli_shared.time_to_readable_str(100, force_time_unit="ks")
|
||||
|
||||
|
||||
class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
||||
|
||||
|
@ -50,6 +50,7 @@ class ProfileDatum(object):
|
||||
self.node_exec_stats = node_exec_stats
|
||||
self.file_line = file_line
|
||||
self.op_type = op_type
|
||||
self.start_time = self.node_exec_stats.all_start_micros
|
||||
self.op_time = (self.node_exec_stats.op_end_rel_micros -
|
||||
self.node_exec_stats.op_start_rel_micros)
|
||||
|
||||
@ -62,31 +63,45 @@ class ProfileDatum(object):
|
||||
class ProfileDataTableView(object):
|
||||
"""Table View of profiling data."""
|
||||
|
||||
def __init__(self, profile_datum_list):
|
||||
def __init__(self, profile_datum_list, time_unit=cli_shared.TIME_UNIT_US):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
profile_datum_list: List of `ProfileDatum` objects.
|
||||
time_unit: must be in cli_shared.TIME_UNITS.
|
||||
"""
|
||||
self._profile_datum_list = profile_datum_list
|
||||
self.formatted_start_time = [
|
||||
datum.start_time for datum in profile_datum_list]
|
||||
self.formatted_op_time = [
|
||||
cli_shared.time_to_readable_str(datum.op_time)
|
||||
cli_shared.time_to_readable_str(datum.op_time,
|
||||
force_time_unit=time_unit)
|
||||
for datum in profile_datum_list]
|
||||
self.formatted_exec_time = [
|
||||
cli_shared.time_to_readable_str(
|
||||
datum.node_exec_stats.all_end_rel_micros)
|
||||
datum.node_exec_stats.all_end_rel_micros,
|
||||
force_time_unit=time_unit)
|
||||
for datum in profile_datum_list]
|
||||
self._column_sort_ids = [SORT_OPS_BY_OP_NAME, SORT_OPS_BY_OP_TIME,
|
||||
SORT_OPS_BY_EXEC_TIME, SORT_OPS_BY_LINE]
|
||||
|
||||
self._column_names = ["Node",
|
||||
"Start Time (us)",
|
||||
"Op Time (%s)" % time_unit,
|
||||
"Exec Time (%s)" % time_unit,
|
||||
"Filename:Lineno(function)"]
|
||||
self._column_sort_ids = [SORT_OPS_BY_OP_NAME, SORT_OPS_BY_START_TIME,
|
||||
SORT_OPS_BY_OP_TIME, SORT_OPS_BY_EXEC_TIME,
|
||||
SORT_OPS_BY_LINE]
|
||||
|
||||
def value(self, row, col):
|
||||
if col == 0:
|
||||
return self._profile_datum_list[row].node_exec_stats.node_name
|
||||
elif col == 1:
|
||||
return self.formatted_op_time[row]
|
||||
return self.formatted_start_time[row]
|
||||
elif col == 2:
|
||||
return self.formatted_exec_time[row]
|
||||
return self.formatted_op_time[row]
|
||||
elif col == 3:
|
||||
return self.formatted_exec_time[row]
|
||||
elif col == 4:
|
||||
return self._profile_datum_list[row].file_line
|
||||
else:
|
||||
raise IndexError("Invalid column index %d." % col)
|
||||
@ -95,10 +110,10 @@ class ProfileDataTableView(object):
|
||||
return len(self._profile_datum_list)
|
||||
|
||||
def column_count(self):
|
||||
return 4
|
||||
return len(self._column_names)
|
||||
|
||||
def column_names(self):
|
||||
return ["Node", "Op Time", "Exec Time", "Filename:Lineno(function)"]
|
||||
return self._column_names
|
||||
|
||||
def column_sort_id(self, col):
|
||||
return self._column_sort_ids[col]
|
||||
@ -246,6 +261,12 @@ class ProfileAnalyzer(object):
|
||||
dest="reverse",
|
||||
action="store_true",
|
||||
help="sort the data in reverse (descending) order")
|
||||
ap.add_argument(
|
||||
"--time_unit",
|
||||
dest="time_unit",
|
||||
type=str,
|
||||
default=cli_shared.TIME_UNIT_US,
|
||||
help="Time unit (" + " | ".join(cli_shared.TIME_UNITS) + ")")
|
||||
|
||||
self._arg_parsers["list_profile"] = ap
|
||||
|
||||
@ -294,7 +315,7 @@ class ProfileAnalyzer(object):
|
||||
output.extend(
|
||||
self._get_list_profile_lines(
|
||||
device_stats.device, index, device_count,
|
||||
profile_data, parsed.sort_by, parsed.reverse))
|
||||
profile_data, parsed.sort_by, parsed.reverse, parsed.time_unit))
|
||||
return output
|
||||
|
||||
def _get_profile_data_generator(self):
|
||||
@ -328,7 +349,7 @@ class ProfileAnalyzer(object):
|
||||
|
||||
def _get_list_profile_lines(
|
||||
self, device_name, device_index, device_count,
|
||||
profile_datum_list, sort_by, sort_reverse):
|
||||
profile_datum_list, sort_by, sort_reverse, time_unit):
|
||||
"""Get `RichTextLines` object for list_profile command for a given device.
|
||||
|
||||
Args:
|
||||
@ -341,20 +362,24 @@ class ProfileAnalyzer(object):
|
||||
SORT_OPS_BY_MEMORY or SORT_OPS_BY_LINE.
|
||||
sort_reverse: (bool) Whether to sort in descending instead of default
|
||||
(ascending) order.
|
||||
time_unit: time unit, must be in cli_shared.TIME_UNITS.
|
||||
|
||||
Returns:
|
||||
`RichTextLines` object containing a table that displays profiling
|
||||
information for each op.
|
||||
"""
|
||||
profile_data = ProfileDataTableView(profile_datum_list)
|
||||
profile_data = ProfileDataTableView(profile_datum_list, time_unit=time_unit)
|
||||
|
||||
# Calculate total time early to calculate column widths.
|
||||
total_op_time = sum(datum.op_time for datum in profile_datum_list)
|
||||
total_exec_time = sum(datum.node_exec_stats.all_end_rel_micros
|
||||
for datum in profile_datum_list)
|
||||
device_total_row = [
|
||||
"Device Total", cli_shared.time_to_readable_str(total_op_time),
|
||||
cli_shared.time_to_readable_str(total_exec_time)]
|
||||
"Device Total", "",
|
||||
cli_shared.time_to_readable_str(total_op_time,
|
||||
force_time_unit=time_unit),
|
||||
cli_shared.time_to_readable_str(total_exec_time,
|
||||
force_time_unit=time_unit)]
|
||||
|
||||
# Calculate column widths.
|
||||
column_widths = [
|
||||
|
@ -130,7 +130,8 @@ class ProfileAnalyzerTest(test_util.TensorFlowTestCase):
|
||||
|
||||
self._assertAtLeastOneLineMatches("Device 1 of", prof_output)
|
||||
expected_headers = [
|
||||
"Node", "Op Time", "Exec Time", r"Filename:Lineno\(function\)"]
|
||||
"Node", r"Start Time \(us\)", r"Op Time \(.*\)", r"Exec Time \(.*\)",
|
||||
r"Filename:Lineno\(function\)"]
|
||||
self._assertAtLeastOneLineMatches(
|
||||
".*".join(expected_headers), prof_output)
|
||||
self._assertAtLeastOneLineMatches(r"^Add/", prof_output)
|
||||
@ -242,10 +243,49 @@ class ProfileAnalyzerTest(test_util.TensorFlowTestCase):
|
||||
self._assertAtLeastOneLineMatches(r"Add/123", prof_output)
|
||||
self._assertNoLinesMatch(r"Mul/456", prof_output)
|
||||
|
||||
def testSpecifyingTimeUnit(self):
|
||||
node1 = step_stats_pb2.NodeExecStats(
|
||||
node_name="Add/123",
|
||||
all_start_micros=123,
|
||||
op_start_rel_micros=3,
|
||||
op_end_rel_micros=5,
|
||||
all_end_rel_micros=4)
|
||||
|
||||
node2 = step_stats_pb2.NodeExecStats(
|
||||
node_name="Mul/456",
|
||||
all_start_micros=122,
|
||||
op_start_rel_micros=1,
|
||||
op_end_rel_micros=2,
|
||||
all_end_rel_micros=5)
|
||||
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
device1 = run_metadata.step_stats.dev_stats.add()
|
||||
device1.device = "deviceA"
|
||||
device1.node_stats.extend([node1, node2])
|
||||
|
||||
graph = test.mock.MagicMock()
|
||||
op1 = test.mock.MagicMock()
|
||||
op1.name = "Add/123"
|
||||
op1.traceback = [("a/b/file2", 10, "some_var")]
|
||||
op1.type = "add"
|
||||
op2 = test.mock.MagicMock()
|
||||
op2.name = "Mul/456"
|
||||
op2.traceback = [("a/b/file1", 11, "some_var")]
|
||||
op2.type = "mul"
|
||||
graph.get_operations.return_value = [op1, op2]
|
||||
|
||||
prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata)
|
||||
|
||||
# Force time unit.
|
||||
prof_output = prof_analyzer.list_profile(["--time_unit", "ms"]).lines
|
||||
self._assertAtLeastOneLineMatches(r"Add/123.*0\.002ms", prof_output)
|
||||
self._assertAtLeastOneLineMatches(r"Mul/456.*0\.005ms", prof_output)
|
||||
self._assertAtLeastOneLineMatches(r"Device Total.*0\.009ms", prof_output)
|
||||
|
||||
def _atLeastOneLineMatches(self, pattern, lines):
|
||||
pattern_re = re.compile(pattern)
|
||||
for line in lines:
|
||||
if pattern_re.match(line):
|
||||
if pattern_re.search(line):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -76,7 +76,8 @@ EOF
|
||||
CUSTOM_DUMP_ROOT=$(mktemp -d)
|
||||
mkdir -p ${CUSTOM_DUMP_ROOT}
|
||||
|
||||
cat << EOF | ${DEBUG_TFLEARN_IRIS_BIN} --debug --fake_data --train_steps=1 --dump_root="${CUSTOM_DUMP_ROOT}" --ui_type=readline
|
||||
cat << EOF | ${DEBUG_TFLEARN_IRIS_BIN} --debug --fake_data --train_steps=2 --dump_root="${CUSTOM_DUMP_ROOT}" --ui_type=readline
|
||||
run -p
|
||||
run -f has_inf_or_nan
|
||||
EOF
|
||||
|
||||
|
@ -219,6 +219,9 @@ class OnRunStartAction(object):
|
||||
# Run once with debug tensor-watching.
|
||||
DEBUG_RUN = "debug_run"
|
||||
|
||||
# Run once with profiler.
|
||||
PROFILE_RUN = "profile_run"
|
||||
|
||||
# Run without debug tensor-watching.
|
||||
NON_DEBUG_RUN = "non_debug_run"
|
||||
|
||||
@ -425,7 +428,7 @@ class BaseDebugWrapperSession(session.SessionInterface):
|
||||
decorated_run_options = options or config_pb2.RunOptions()
|
||||
run_metadata = run_metadata or config_pb2.RunMetadata()
|
||||
|
||||
self._decorate_run_options(
|
||||
self._decorate_run_options_for_debug(
|
||||
decorated_run_options,
|
||||
run_start_resp.debug_urls,
|
||||
debug_ops=run_start_resp.debug_ops,
|
||||
@ -454,6 +457,19 @@ class BaseDebugWrapperSession(session.SessionInterface):
|
||||
client_graph_def=self._sess.graph.as_graph_def(),
|
||||
tf_error=tf_error)
|
||||
|
||||
elif run_start_resp.action == OnRunStartAction.PROFILE_RUN:
|
||||
decorated_run_options = options or config_pb2.RunOptions()
|
||||
run_metadata = run_metadata or config_pb2.RunMetadata()
|
||||
self._decorate_run_options_for_profile(decorated_run_options)
|
||||
retvals = self._sess.run(fetches,
|
||||
feed_dict=feed_dict,
|
||||
options=decorated_run_options,
|
||||
run_metadata=run_metadata)
|
||||
run_end_req = OnRunEndRequest(
|
||||
run_start_resp.action,
|
||||
run_metadata=run_metadata,
|
||||
client_graph_def=self._sess.graph.as_graph_def())
|
||||
|
||||
elif (run_start_resp.action == OnRunStartAction.NON_DEBUG_RUN or
|
||||
run_start_resp.action == OnRunStartAction.INVOKE_STEPPER):
|
||||
if run_start_resp.action == OnRunStartAction.INVOKE_STEPPER:
|
||||
@ -496,14 +512,15 @@ class BaseDebugWrapperSession(session.SessionInterface):
|
||||
raise NotImplementedError(
|
||||
"partial_run is not implemented for debug-wrapper sessions.")
|
||||
|
||||
def _decorate_run_options(self,
|
||||
run_options,
|
||||
debug_urls,
|
||||
debug_ops="DebugIdentity",
|
||||
node_name_regex_whitelist=None,
|
||||
op_type_regex_whitelist=None,
|
||||
tensor_dtype_regex_whitelist=None,
|
||||
tolerate_debug_op_creation_failures=False):
|
||||
def _decorate_run_options_for_debug(
|
||||
self,
|
||||
run_options,
|
||||
debug_urls,
|
||||
debug_ops="DebugIdentity",
|
||||
node_name_regex_whitelist=None,
|
||||
op_type_regex_whitelist=None,
|
||||
tensor_dtype_regex_whitelist=None,
|
||||
tolerate_debug_op_creation_failures=False):
|
||||
"""Modify a RunOptions object for debug tensor watching.
|
||||
|
||||
Specifies request for outputting partition graphs. Adds
|
||||
@ -534,6 +551,15 @@ class BaseDebugWrapperSession(session.SessionInterface):
|
||||
tensor_dtype_regex_whitelist=tensor_dtype_regex_whitelist,
|
||||
tolerate_debug_op_creation_failures=tolerate_debug_op_creation_failures)
|
||||
|
||||
def _decorate_run_options_for_profile(self, run_options):
|
||||
"""Modify a RunOptions object for profiling TensorFlow graph execution.
|
||||
|
||||
Args:
|
||||
run_options: (RunOptions) the modified RunOptions object.
|
||||
"""
|
||||
|
||||
run_options.trace_level = config_pb2.RunOptions.FULL_TRACE
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_session_init(self, request):
|
||||
"""Callback invoked during construction of the debug-wrapper session.
|
||||
|
@ -130,6 +130,8 @@ class LocalCLIDebugHook(session_run_hook.SessionRunHook,
|
||||
on_run_start_response.tensor_dtype_regex_whitelist),
|
||||
tolerate_debug_op_creation_failures=(
|
||||
on_run_start_response.tolerate_debug_op_creation_failures)))
|
||||
elif self._performed_action == framework.OnRunStartAction.PROFILE_RUN:
|
||||
self._decorate_run_options_for_profile(run_args.options)
|
||||
elif self._performed_action == framework.OnRunStartAction.INVOKE_STEPPER:
|
||||
# The _finalized property must be set to False so that the NodeStepper
|
||||
# can insert ops for retrieving TensorHandles.
|
||||
|
@ -27,6 +27,7 @@ import tempfile
|
||||
from tensorflow.python.debug.cli import analyzer_cli
|
||||
from tensorflow.python.debug.cli import cli_shared
|
||||
from tensorflow.python.debug.cli import debugger_cli_common
|
||||
from tensorflow.python.debug.cli import profile_analyzer_cli
|
||||
from tensorflow.python.debug.cli import stepper_cli
|
||||
from tensorflow.python.debug.cli import ui_factory
|
||||
from tensorflow.python.debug.lib import debug_data
|
||||
@ -162,6 +163,12 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
|
||||
default="",
|
||||
help="Regular-expression filter for tensor dtype to be watched in the "
|
||||
"run, e.g., (float32|float64), int.*")
|
||||
ap.add_argument(
|
||||
"-p",
|
||||
"--profile",
|
||||
dest="profile",
|
||||
action="store_true",
|
||||
help="Run and profile TensorFlow graph execution.")
|
||||
self._argparsers["run"] = ap
|
||||
|
||||
ap = argparse.ArgumentParser(
|
||||
@ -318,12 +325,16 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
|
||||
passed_filter = self._active_tensor_filter
|
||||
self._active_tensor_filter = None
|
||||
|
||||
self._prep_cli_for_run_end(debug_dump, request.tf_error, passed_filter)
|
||||
self._prep_debug_cli_for_run_end(
|
||||
debug_dump, request.tf_error, passed_filter)
|
||||
|
||||
self._run_start_response = self._launch_cli()
|
||||
|
||||
# Clean up the dump generated by this run.
|
||||
self._remove_dump_root()
|
||||
elif request.performed_action == framework.OnRunStartAction.PROFILE_RUN:
|
||||
self._prep_profile_cli_for_run_end(self._sess.graph, request.run_metadata)
|
||||
self._run_start_response = self._launch_cli()
|
||||
else:
|
||||
# No debug information to show following a non-debug run() call.
|
||||
self._run_start_response = None
|
||||
@ -336,7 +347,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
|
||||
if os.path.isdir(self._dump_root):
|
||||
shutil.rmtree(self._dump_root)
|
||||
|
||||
def _prep_cli_for_run_end(self, debug_dump, tf_error, passed_filter):
|
||||
def _prep_debug_cli_for_run_end(self, debug_dump, tf_error, passed_filter):
|
||||
"""Prepare (but not launch) CLI for run-end, with debug dump from the run.
|
||||
|
||||
Args:
|
||||
@ -391,6 +402,12 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
|
||||
if help_intro:
|
||||
self._run_cli.set_help_intro(help_intro)
|
||||
|
||||
def _prep_profile_cli_for_run_end(self, py_graph, run_metadata):
|
||||
self._init_command = "lp"
|
||||
self._run_cli = profile_analyzer_cli.create_profiler_ui(
|
||||
py_graph, run_metadata, ui_type=self._ui_type)
|
||||
self._title = "run-end (profiler mode): " + self._run_description
|
||||
|
||||
def _launch_cli(self):
|
||||
"""Launch the interactive command-line interface.
|
||||
|
||||
@ -425,13 +442,18 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
|
||||
def _run_handler(self, args, screen_info=None):
|
||||
"""Command handler for "run" command during on-run-start."""
|
||||
|
||||
_ = screen_info # Currently unused.
|
||||
del screen_info # Currently unused.
|
||||
|
||||
parsed = self._argparsers["run"].parse_args(args)
|
||||
parsed.node_name_filter = parsed.node_name_filter or None
|
||||
parsed.op_type_filter = parsed.op_type_filter or None
|
||||
parsed.tensor_dtype_filter = parsed.tensor_dtype_filter or None
|
||||
|
||||
if parsed.profile:
|
||||
raise debugger_cli_common.CommandLineExit(
|
||||
exit_token=framework.OnRunStartResponse(
|
||||
framework.OnRunStartAction.PROFILE_RUN, []))
|
||||
|
||||
if parsed.till_filter_pass:
|
||||
# For the run-till-bad-numerical-value-appears mode, use the DEBUG_RUN
|
||||
# option to access the intermediate tensors, and set the corresponding
|
||||
|
@ -71,15 +71,21 @@ class LocalCLIDebuggerWrapperSessionForTest(
|
||||
"tf_errors": [],
|
||||
"run_start_cli_run_numbers": [],
|
||||
"run_end_cli_run_numbers": [],
|
||||
"profiler_py_graphs": [],
|
||||
"profiler_run_metadata": [],
|
||||
}
|
||||
|
||||
def _prep_cli_for_run_start(self):
|
||||
pass
|
||||
|
||||
def _prep_cli_for_run_end(self, debug_dump, tf_error, passed_filter):
|
||||
def _prep_debug_cli_for_run_end(self, debug_dump, tf_error, passed_filter):
|
||||
self.observers["debug_dumps"].append(debug_dump)
|
||||
self.observers["tf_errors"].append(tf_error)
|
||||
|
||||
def _prep_profile_cli_for_run_end(self, py_graph, run_metadata):
|
||||
self.observers["profiler_py_graphs"].append(py_graph)
|
||||
self.observers["profiler_run_metadata"].append(run_metadata)
|
||||
|
||||
def _launch_cli(self):
|
||||
if self._is_run_start:
|
||||
self.observers["run_start_cli_run_numbers"].append(self._run_call_count)
|
||||
@ -468,6 +474,19 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(1, dumps.size)
|
||||
self.assertEqual("w_int_inner", dumps.dumped_tensor_data[0].node_name)
|
||||
|
||||
def testRunUnderProfilerModeWorks(self):
|
||||
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
|
||||
[["-p"], []], self.sess)
|
||||
|
||||
wrapped_sess.run(self.w_int)
|
||||
|
||||
self.assertEqual(1, len(wrapped_sess.observers["profiler_run_metadata"]))
|
||||
self.assertTrue(
|
||||
wrapped_sess.observers["profiler_run_metadata"][0].step_stats)
|
||||
self.assertEqual(1, len(wrapped_sess.observers["profiler_py_graphs"]))
|
||||
self.assertIsInstance(
|
||||
wrapped_sess.observers["profiler_py_graphs"][0], ops.Graph)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user