tfprof multi-step profiling.

This allows users to fill in RunMetadata across different steps.
1. It is useful for RL model which runs a subset of graph each step.
2. It also gets averages of multi-step stats.

PiperOrigin-RevId: 157552388
This commit is contained in:
A. Unique TensorFlower 2017-05-30 22:21:48 -07:00 committed by TensorFlower Gardener
parent 7aac2395ce
commit a7fff05e05
35 changed files with 1046 additions and 400 deletions

View File

@ -33,6 +33,22 @@ py_test(
], ],
) )
py_test(
name = "profiler_test",
srcs = ["profiler_test.py"],
srcs_version = "PY2AND3",
tags = ["no_pip"],
deps = [
":model_analyzer",
":model_analyzer_testlib",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform",
"//tensorflow/python:variables",
],
)
py_library( py_library(
name = "model_analyzer_testlib", name = "model_analyzer_testlib",
srcs = ["model_analyzer_testlib.py"], srcs = ["model_analyzer_testlib.py"],

View File

@ -112,6 +112,181 @@ PRINT_ALL_TIMING_MEMORY = {
# pylint: enable=bad-continuation # pylint: enable=bad-continuation
def _build_options(tfprof_options):
"""Build tfprof.OptionsProto.
Args:
tfprof_options: A dictionary of options.
Returns:
tfprof.OptionsProto.
"""
opts = tfprof_options_pb2.OptionsProto()
opts.max_depth = tfprof_options.get('max_depth', 10)
opts.min_bytes = tfprof_options.get('min_bytes', 0)
opts.min_micros = tfprof_options.get('min_micros', 0)
opts.min_params = tfprof_options.get('min_params', 0)
opts.min_float_ops = tfprof_options.get('min_float_ops', 0)
opts.min_occurrence = tfprof_options.get('min_occurrence', 0)
opts.step = tfprof_options.get('step', -1)
opts.order_by = tfprof_options.get('order_by', 'name')
for p in tfprof_options.get('account_type_regexes', []):
opts.account_type_regexes.append(p)
for p in tfprof_options.get('start_name_regexes', []):
opts.start_name_regexes.append(p)
for p in tfprof_options.get('trim_name_regexes', []):
opts.trim_name_regexes.append(p)
for p in tfprof_options.get('show_name_regexes', []):
opts.show_name_regexes.append(p)
for p in tfprof_options.get('hide_name_regexes', []):
opts.hide_name_regexes.append(p)
opts.account_displayed_op_only = tfprof_options.get(
'account_displayed_op_only', False)
for p in tfprof_options.get('select', []):
opts.select.append(p)
opts.output = tfprof_options.get('output', 'stdout')
opts.dump_to_file = tfprof_options.get('dump_to_file', '')
return opts
class Profiler(object):
"""TensorFlow multi-step profiler.
See go/tfprof or README for details.
Typical use case:
# Currently we are only allowed to create 1 profiler per process.
profiler = Profile(sess.graph)
for i in xrange(total_steps):
if i % 10000 == 0:
run_meta = tf.RunMetadata()
_ = sess.run(...,
options=tf.RunOptions(
trace_level=tf.RunOptions.FULL_TRACE),
run_metadata=run_meta)
profiler.add_step(i, run_meta)
# Profile the parameters of your model.
profiler.profile_name_scope(options=TRAINABLE_VARS_PARAMS_STAT_OPTIONS)
# Or profile the timing of your model operations.
opts = PRINT_ALL_TIMING_MEMORY.copy()
opts['order_by'] = 'micros'
opts['select'] = ['micros', 'occurrence']
opts['max_depth'] = 20
profiler.profile_operations(options=opts)
# Or you can generate a timeline:
opts = PRINT_ALL_TIMING_MEMORY.copy()
opts['output'] = 'timeline:outfile=' + filename
opts['step'] = i
profiler.profile_graph(options=opts)
else:
_ = sess.run(...)
"""
def __init__(self, graph, op_log=None):
"""Constructor.
Args:
graph: tf.Graph.
op_log: optional. tensorflow::tfprof::OpLog proto. Used to define
extra op types.
"""
self._graph = graph
# pylint: disable=protected-access
op_log = tfprof_logger._merge_default_with_oplog(
self._graph, op_log=op_log)
# pylint: enable=protected-access
print_mdl.NewProfiler(
self._graph.as_graph_def().SerializeToString(),
op_log.SerializeToString())
def __del__(self):
print_mdl.DeleteProfiler()
def add_step(self, step, run_meta):
"""Add statistics of a step.
Args:
step: A step uint64 used to identify the RunMetadata. Must be different
across different AddStep() calls.
run_meta: RunMetadata proto that contains statistics of a session run.
"""
# pylint: disable=protected-access
op_log = tfprof_logger._merge_default_with_oplog(
self._graph, run_meta=run_meta, add_trace=False,
add_trainable_var=False)
# pylint: enable=protected-access
print_mdl.AddStep(
step, run_meta.SerializeToString(), op_log.SerializeToString())
def profile_python_codes(self, options):
"""Profile the statistics of the Python codes.
Hint: set options['show_name_regexes'] = ['.*my_code.py.*']
Args:
options: A dict of profiler options.
Returns:
a TFMultiGraphNodeProto that records the results.
"""
opts = _build_options(options)
tfprof_node = tfprof_output_pb2.TFMultiGraphNodeProto()
tfprof_node.ParseFromString(
print_mdl.Profile('code'.encode('utf-8'), opts.SerializeToString()))
return tfprof_node
def profile_operations(self, options):
"""Profile the statistics of the Operation types (e.g. MatMul, Conv2D).
Args:
options: A dict of profiler options.
Returns:
a TFMultiGraphNodeProto that records the results.
"""
opts = _build_options(options)
tfprof_node = tfprof_output_pb2.TFMultiGraphNodeProto()
tfprof_node.ParseFromString(
print_mdl.Profile('op'.encode('utf-8'), opts.SerializeToString()))
return tfprof_node
def profile_name_scope(self, options):
"""Profile the statistics of graph nodes, organized by name scope.
Args:
options: A dict of profiler options.
Returns:
a TFGraphNodeProto that records the results.
"""
opts = _build_options(options)
tfprof_node = tfprof_output_pb2.TFGraphNodeProto()
tfprof_node.ParseFromString(
print_mdl.Profile('scope'.encode('utf-8'), opts.SerializeToString()))
return tfprof_node
def profile_graph(self, options):
"""Profile the statistics of graph nodes, organized by dataflow graph.
Args:
options: A dict of profiler options.
Returns:
a TFGraphNodeProto that records the results.
"""
opts = _build_options(options)
tfprof_node = tfprof_output_pb2.TFGraphNodeProto()
tfprof_node.ParseFromString(
print_mdl.Profile('graph'.encode('utf-8'), opts.SerializeToString()))
return tfprof_node
def print_model_analysis(graph, def print_model_analysis(graph,
run_meta=None, run_meta=None,
op_log=None, op_log=None,
@ -145,33 +320,8 @@ def print_model_analysis(graph,
op_log = tfprof_logger._merge_default_with_oplog( op_log = tfprof_logger._merge_default_with_oplog(
graph, op_log, run_meta, add_trace=tfprof_cmd == 'code') graph, op_log, run_meta, add_trace=tfprof_cmd == 'code')
# pylint: enable=protected-access # pylint: enable=protected-access
opts = tfprof_options_pb2.OptionsProto()
opts.max_depth = tfprof_options['max_depth']
opts.min_bytes = tfprof_options['min_bytes']
opts.min_micros = tfprof_options['min_micros']
opts.min_params = tfprof_options['min_params']
opts.min_float_ops = tfprof_options['min_float_ops']
if 'min_occurrence' in tfprof_options:
opts.min_occurrence = tfprof_options['min_occurrence']
else:
opts.min_occurrence = 0
opts.order_by = tfprof_options['order_by'] opts = _build_options(tfprof_options)
for p in tfprof_options['account_type_regexes']:
opts.account_type_regexes.append(p)
for p in tfprof_options['start_name_regexes']:
opts.start_name_regexes.append(p)
for p in tfprof_options['trim_name_regexes']:
opts.trim_name_regexes.append(p)
for p in tfprof_options['show_name_regexes']:
opts.show_name_regexes.append(p)
for p in tfprof_options['hide_name_regexes']:
opts.hide_name_regexes.append(p)
opts.account_displayed_op_only = tfprof_options['account_displayed_op_only']
for p in tfprof_options['select']:
opts.select.append(p)
opts.output = tfprof_options['output']
opts.dump_to_file = tfprof_options['dump_to_file']
run_meta_str = run_meta.SerializeToString() if run_meta else b'' run_meta_str = run_meta.SerializeToString() if run_meta else b''

View File

@ -199,6 +199,7 @@ class PrintModelAnalysisTest(test.TestCase):
opts['output'] = 'timeline:outfile=' + outfile opts['output'] = 'timeline:outfile=' + outfile
opts['account_type_regexes'] = ['.*'] opts['account_type_regexes'] = ['.*']
opts['max_depth'] = 100000 opts['max_depth'] = 100000
opts['step'] = 0
with session.Session() as sess, ops.device('/cpu:0'): with session.Session() as sess, ops.device('/cpu:0'):
x = lib.BuildFullModel() x = lib.BuildFullModel()

View File

@ -65,3 +65,33 @@ def BuildFullModel():
loss = nn_ops.l2_loss(math_ops.reduce_mean(target - out)) loss = nn_ops.l2_loss(math_ops.reduce_mean(target - out))
sgd_op = gradient_descent.GradientDescentOptimizer(1e-2) sgd_op = gradient_descent.GradientDescentOptimizer(1e-2)
return sgd_op.minimize(loss) return sgd_op.minimize(loss)
def BuildSplitableModel():
"""Build a small model that can be run partially in each step."""
image = array_ops.zeros([2, 6, 6, 3])
kernel1 = variable_scope.get_variable(
'DW', [3, 3, 3, 6],
dtypes.float32,
initializer=init_ops.random_normal_initializer(stddev=0.001))
r1 = nn_ops.conv2d(image, kernel1, [1, 2, 2, 1], padding='SAME')
kernel2 = variable_scope.get_variable(
'DW2', [2, 3, 3, 6],
dtypes.float32,
initializer=init_ops.random_normal_initializer(stddev=0.001))
r2 = nn_ops.conv2d(image, kernel2, [1, 2, 2, 1], padding='SAME')
r3 = r1 + r2
return r1, r2, r3
def SearchTFProfNode(node, name):
"""Search a node in the tree."""
if node.name == name:
return node
for c in node.children:
r = SearchTFProfNode(c, name)
if r: return r
return None

View File

@ -0,0 +1,184 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
# pylint: disable=g-bad-import-order
from tensorflow.contrib.tfprof.python.tools.tfprof import model_analyzer
from tensorflow.contrib.tfprof.python.tools.tfprof import model_analyzer_testlib as lib
class ProfilerTest(test.TestCase):
def testProfileBasic(self):
ops.reset_default_graph()
opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS.copy()
opts['account_type_regexes'] = ['.*']
opts['select'] = ['params', 'float_ops', 'micros', 'bytes',
'device', 'op_types', 'occurrence']
outfile = os.path.join(test.get_temp_dir(), 'dump')
opts['output'] = 'file:outfile=' + outfile
# Test the output without run_meta.
sess = session.Session()
r = lib.BuildFullModel()
sess.run(variables.global_variables_initializer())
profiler = model_analyzer.Profiler(sess.graph)
profiler.profile_name_scope(opts)
with gfile.Open(outfile, 'r') as f:
profiler_str = f.read()
model_analyzer.print_model_analysis(
sess.graph, tfprof_cmd='scope', tfprof_options=opts)
with gfile.Open(outfile, 'r') as f:
pma_str = f.read()
self.assertEqual(pma_str, profiler_str)
# Test the output with run_meta.
run_meta = config_pb2.RunMetadata()
_ = sess.run(r,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE),
run_metadata=run_meta)
profiler.add_step(1, run_meta)
profiler.profile_graph(opts)
with gfile.Open(outfile, 'r') as f:
profiler_str = f.read()
model_analyzer.print_model_analysis(
sess.graph, tfprof_cmd='graph', run_meta=run_meta, tfprof_options=opts)
with gfile.Open(outfile, 'r') as f:
pma_str = f.read()
self.assertEqual(pma_str, profiler_str)
profiler.profile_python_codes(opts)
with gfile.Open(outfile, 'r') as f:
profiler_str = f.read()
model_analyzer.print_model_analysis(
sess.graph, tfprof_cmd='code', run_meta=run_meta, tfprof_options=opts)
with gfile.Open(outfile, 'r') as f:
pma_str = f.read()
self.assertEqual(pma_str, profiler_str)
profiler.profile_operations(opts)
with gfile.Open(outfile, 'r') as f:
profiler_str = f.read()
model_analyzer.print_model_analysis(
sess.graph, tfprof_cmd='op', run_meta=run_meta, tfprof_options=opts)
with gfile.Open(outfile, 'r') as f:
pma_str = f.read()
self.assertEqual(pma_str, profiler_str)
# Test the output difference between multi-step profile and 1-step profile.
_ = sess.run(r,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE),
run_metadata=run_meta)
profiler.add_step(2, run_meta)
profiler.profile_name_scope(opts)
with gfile.Open(outfile, 'r') as f:
profiler_str = f.read()
model_analyzer.print_model_analysis(
sess.graph, tfprof_cmd='scope', run_meta=run_meta, tfprof_options=opts)
with gfile.Open(outfile, 'r') as f:
pma_str = f.read()
self.assertNotEqual(pma_str, profiler_str)
opts2 = opts.copy()
opts2['select'] = ['params', 'float_ops']
profiler.profile_name_scope(opts2)
with gfile.Open(outfile, 'r') as f:
profiler_str = f.read()
model_analyzer.print_model_analysis(
sess.graph, tfprof_cmd='scope', run_meta=run_meta, tfprof_options=opts2)
with gfile.Open(outfile, 'r') as f:
pma_str = f.read()
self.assertEqual(pma_str, profiler_str)
def testMultiStepProfile(self):
ops.reset_default_graph()
opts = model_analyzer.PRINT_ALL_TIMING_MEMORY.copy()
opts['account_type_regexes'] = ['.*']
with session.Session() as sess, ops.device('/cpu:0'):
r1, r2, r3 = lib.BuildSplitableModel()
sess.run(variables.global_variables_initializer())
profiler = model_analyzer.Profiler(sess.graph)
pb0 = profiler.profile_name_scope(opts)
run_meta = config_pb2.RunMetadata()
_ = sess.run(r1,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE),
run_metadata=run_meta)
profiler.add_step(1, run_meta)
pb1 = profiler.profile_name_scope(opts)
self.assertNotEqual(lib.SearchTFProfNode(pb1, 'DW'), None)
self.assertEqual(lib.SearchTFProfNode(pb1, 'DW2'), None)
self.assertEqual(lib.SearchTFProfNode(pb1, 'add'), None)
run_meta2 = config_pb2.RunMetadata()
_ = sess.run(r2,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE),
run_metadata=run_meta2)
profiler.add_step(2, run_meta2)
pb2 = profiler.profile_name_scope(opts)
self.assertNotEqual(lib.SearchTFProfNode(pb2, 'DW'), None)
self.assertNotEqual(lib.SearchTFProfNode(pb2, 'DW2'), None)
self.assertEqual(lib.SearchTFProfNode(pb2, 'add'), None)
run_meta3 = config_pb2.RunMetadata()
_ = sess.run(r3,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE),
run_metadata=run_meta3)
profiler.add_step(3, run_meta3)
pb3 = profiler.profile_name_scope(opts)
self.assertNotEqual(lib.SearchTFProfNode(pb3, 'DW'), None)
self.assertNotEqual(lib.SearchTFProfNode(pb3, 'DW2'), None)
self.assertNotEqual(lib.SearchTFProfNode(pb3, 'add'), None)
self.assertEqual(lib.SearchTFProfNode(pb0, 'Conv2D'), None)
self.assertGreater(lib.SearchTFProfNode(pb1, 'Conv2D').exec_micros, 0)
self.assertEqual(lib.SearchTFProfNode(pb1, 'Conv2D_1'), None)
self.assertGreater(lib.SearchTFProfNode(pb2, 'Conv2D_1').exec_micros, 0)
self.assertEqual(lib.SearchTFProfNode(pb2, 'add'), None)
self.assertGreater(lib.SearchTFProfNode(pb3, 'add').exec_micros, 0)
if __name__ == '__main__':
test.main()

View File

@ -19,6 +19,8 @@ limitations under the License.
%{ %{
#include "tensorflow/tools/tfprof/internal/print_model_analysis.h" #include "tensorflow/tools/tfprof/internal/print_model_analysis.h"
#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.h"
using tensorflow::int64;
%} %}
%typemap(typecheck) const string & = char *; %typemap(typecheck) const string & = char *;
@ -37,6 +39,10 @@ limitations under the License.
%unignore tensorflow; %unignore tensorflow;
%unignore tensorflow::tfprof; %unignore tensorflow::tfprof;
%unignore tensorflow::tfprof::PrintModelAnalysis; %unignore tensorflow::tfprof::PrintModelAnalysis;
%unignore tensorflow::tfprof::NewProfiler;
%unignore tensorflow::tfprof::DeleteProfiler;
%unignore tensorflow::tfprof::AddStep;
%unignore tensorflow::tfprof::Profile;
%include "tensorflow/tools/tfprof/internal/print_model_analysis.h" %include "tensorflow/tools/tfprof/internal/print_model_analysis.h"

View File

@ -62,13 +62,16 @@ def _fill_missing_graph_shape(graph, run_meta):
return graph return graph
def _get_logged_ops(graph, run_meta=None, add_trace=True): def _get_logged_ops(graph, run_meta=None, add_trace=True,
add_trainable_var=True):
"""Extract trainable model parameters and FLOPs for ops from a Graph. """Extract trainable model parameters and FLOPs for ops from a Graph.
Args: Args:
graph: tf.Graph. graph: tf.Graph.
run_meta: RunMetadata proto used to complete shape information. run_meta: RunMetadata proto used to complete shape information.
add_trace: Whether to add op trace information. add_trace: Whether to add op trace information.
add_trainable_var: Whether to assign tf.trainable_variables() op type
'_trainable_variables'.
Returns: Returns:
logged_ops: dict mapping from op_name to OpLogEntry. logged_ops: dict mapping from op_name to OpLogEntry.
""" """
@ -77,6 +80,7 @@ def _get_logged_ops(graph, run_meta=None, add_trace=True):
op_missing_shape = 0 op_missing_shape = 0
logged_ops = {} logged_ops = {}
# TODO(xpan): Work with Profiler more efficiently.
for op in graph.get_operations(): for op in graph.get_operations():
try: try:
stats = ops.get_stats_for_node_def( stats = ops.get_stats_for_node_def(
@ -105,23 +109,24 @@ def _get_logged_ops(graph, run_meta=None, add_trace=True):
if add_entry: if add_entry:
logged_ops[entry.name] = entry logged_ops[entry.name] = entry
for v in graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES): if add_trainable_var:
if v.op.name not in logged_ops: for v in graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES):
entry = tfprof_log_pb2.OpLogEntry() if v.op.name not in logged_ops:
entry.name = v.op.name entry = tfprof_log_pb2.OpLogEntry()
entry.types.append(TRAINABLE_VARIABLES) entry.name = v.op.name
logged_ops[entry.name] = entry entry.types.append(TRAINABLE_VARIABLES)
else: logged_ops[entry.name] = entry
logged_ops[v.op.name].types.append(TRAINABLE_VARIABLES) else:
logged_ops[v.op.name].types.append(TRAINABLE_VARIABLES)
if op_missing_shape > 0 and not run_meta: if op_missing_shape > 0 and not run_meta:
sys.stderr.write('%d ops no flops stats due to incomplete shapes. ' sys.stderr.write('%d ops no flops stats due to incomplete shapes.\n' %
'Consider passing run_meta to use run_time shapes.\n' %
op_missing_shape) op_missing_shape)
return logged_ops return logged_ops
def _merge_default_with_oplog(graph, op_log=None, run_meta=None, def _merge_default_with_oplog(graph, op_log=None, run_meta=None,
add_trace=True): add_trace=True, add_trainable_var=True):
"""Merge the tfprof default extra info with caller's op_log. """Merge the tfprof default extra info with caller's op_log.
Args: Args:
@ -129,11 +134,15 @@ def _merge_default_with_oplog(graph, op_log=None, run_meta=None,
op_log: OpLog proto. op_log: OpLog proto.
run_meta: RunMetadata proto used to complete shape information. run_meta: RunMetadata proto used to complete shape information.
add_trace: Whether to add op trace information. add_trace: Whether to add op trace information.
add_trainable_var: Whether to assign tf.trainable_variables() op type
'_trainable_variables'.
Returns: Returns:
tmp_op_log: Merged OpLog proto. tmp_op_log: Merged OpLog proto.
""" """
tmp_op_log = tfprof_log_pb2.OpLog() tmp_op_log = tfprof_log_pb2.OpLog()
logged_ops = _get_logged_ops(graph, run_meta, add_trace=add_trace) logged_ops = _get_logged_ops(
graph, run_meta, add_trace=add_trace, add_trainable_var=add_trainable_var)
if not op_log: if not op_log:
tmp_op_log.log_entries.extend(logged_ops.values()) tmp_op_log.log_entries.extend(logged_ops.values())
else: else:

View File

@ -259,7 +259,8 @@ tfprof>
-min_micros 0 -min_micros 0
-min_params 0 -min_params 0
-min_float_ops 0 -min_float_ops 0
-min_occurrence 0 -min_occurrence 0
-step -1
-order_by name -order_by name
-account_type_regexes Variable,VariableV2 -account_type_regexes Variable,VariableV2
-start_name_regexes .* -start_name_regexes .*
@ -598,6 +599,8 @@ provides checkpointed tensors' values.
`-min_occurrence`: Show ops that appear at least this number of times. Only available in "op" view. `-min_occurrence`: Show ops that appear at least this number of times. Only available in "op" view.
`-step`: Show the stats of the this step when multiple steps of RunMetadata were added. By default, show the average of all steps."
`-order_by`: Order the results by [name|depth|bytes|micros|params|float_ops|occurrence] `-order_by`: Order the results by [name|depth|bytes|micros|params|float_ops|occurrence]
`-account_type_regexes`: Account and display the ops whose types match one of the type regexes specified. tfprof allow user to define extra op types for ops through tensorflow.tfprof.OpLog proto. regexes are comma-sperated. `-account_type_regexes`: Account and display the ops whose types match one of the type regexes specified. tfprof allow user to define extra op types for ops through tensorflow.tfprof.OpLog proto. regexes are comma-sperated.

View File

@ -30,6 +30,89 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace tfprof { namespace tfprof {
namespace {
TFStats* tf_stat = nullptr;
string RunProfile(const string& command, const string& options,
TFStats* tf_stats) {
Options opts;
tensorflow::Status s = Options::FromProtoStr(options, &opts);
if (!s.ok()) {
fprintf(stderr, "%s\n", s.ToString().c_str());
return "";
}
if (opts.output_type == kOutput[1]) {
printf("\n=========================Options=============================\n");
printf("%s", opts.ToString().c_str());
printf("\n==================Model Analysis Report======================\n");
string ret = "";
if (command == kCmds[2] || command == kCmds[3]) {
ret = tf_stats->ShowMultiGraphNode(command, opts).SerializeAsString();
} else if (command == kCmds[0] || command == kCmds[1]) {
ret = tf_stats->ShowGraphNode(command, opts).SerializeAsString();
} else {
fprintf(stderr, "Unknown command: %s\n", command.c_str());
}
printf("\n======================End of Report==========================\n");
fflush(stdout);
return ret;
}
if (command == kCmds[2] || command == kCmds[3]) {
return tf_stats->ShowMultiGraphNode(command, opts).SerializeAsString();
} else if (command == kCmds[0] || command == kCmds[1]) {
return tf_stats->ShowGraphNode(command, opts).SerializeAsString();
} else {
fprintf(stderr, "Unknown command: %s\n", command.c_str());
return "";
}
}
} // namespace
bool NewProfiler(const string* graph, const string* op_log) {
CHECK(!tf_stat) << "Currently only 1 living tfprof profiler is allowed";
CHECK(graph) << "graph mustn't be null";
std::unique_ptr<GraphDef> graph_ptr(new GraphDef());
graph_ptr->ParseFromString(*graph);
std::unique_ptr<OpLog> op_log_ptr;
if (op_log && !op_log->empty()) {
op_log_ptr.reset(new OpLog());
op_log_ptr->ParseFromString(*op_log);
}
tf_stat = new TFStats(std::move(graph_ptr), nullptr, std::move(op_log_ptr),
nullptr);
return true;
}
void DeleteProfiler() {
delete tf_stat;
tf_stat = nullptr;
}
void AddStep(int64 step, const string* run_meta, const string* op_log) {
CHECK(tf_stat);
CHECK(run_meta && !run_meta->empty());
// TODO(xpan): Better error handling.
std::unique_ptr<RunMetadata> run_meta_ptr(new RunMetadata());
run_meta_ptr->ParseFromString(*run_meta);
tf_stat->ParseRunMeta(step, std::move(run_meta_ptr));
std::unique_ptr<OpLog> op_log_ptr;
if (op_log && !op_log->empty()) {
op_log_ptr.reset(new OpLog());
op_log_ptr->ParseFromString(*op_log);
}
tf_stat->ParseOpLog(std::move(op_log_ptr));
}
string Profile(const string* command, const string* options) {
CHECK(tf_stat);
CHECK(command) << "command mustn't be null";
CHECK(options) << "options mustn't be null";
return RunProfile(*command, *options, tf_stat);
}
string PrintModelAnalysis(const string* graph, const string* run_meta, string PrintModelAnalysis(const string* graph, const string* run_meta,
const string* op_log, const string* command, const string* op_log, const string* command,
const string* options) { const string* options) {
@ -51,42 +134,13 @@ string PrintModelAnalysis(const string* graph, const string* run_meta,
op_log_ptr->ParseFromString(*op_log); op_log_ptr->ParseFromString(*op_log);
} }
// TODO(xpan): Maybe need to init the checkpoint reader?
std::unique_ptr<checkpoint::CheckpointReader> ckpt_reader; std::unique_ptr<checkpoint::CheckpointReader> ckpt_reader;
TFStats tf_stats(std::move(graph_ptr), std::move(run_meta_ptr), TFStats tf_stats(std::move(graph_ptr), std::move(run_meta_ptr),
std::move(op_log_ptr), std::move(ckpt_reader)); std::move(op_log_ptr), std::move(ckpt_reader));
Options opts; return RunProfile(*command, *options, &tf_stats);
tensorflow::Status s = Options::FromProtoStr(*options, &opts);
if (!s.ok()) {
fprintf(stderr, "%s\n", s.ToString().c_str());
return "";
}
if (opts.output_type == kOutput[1]) {
printf("\n=========================Options=============================\n");
printf("%s", opts.ToString().c_str());
printf("\n==================Model Analysis Report======================\n");
string ret = "";
if (*command == kCmds[2] || *command == kCmds[3]) {
ret = tf_stats.ShowMultiGraphNode(*command, opts).SerializeAsString();
} else if (*command == kCmds[0] || *command == kCmds[1]) {
ret = tf_stats.ShowGraphNode(*command, opts).SerializeAsString();
} else {
fprintf(stderr, "Unknown command: %s\n", command->c_str());
}
printf("\n======================End of Report==========================\n");
fflush(stdout);
return ret;
}
if (*command == kCmds[2] || *command == kCmds[3]) {
return tf_stats.ShowMultiGraphNode(*command, opts).SerializeAsString();
} else if (*command == kCmds[0] || *command == kCmds[1]) {
return tf_stats.ShowGraphNode(*command, opts).SerializeAsString();
} else {
fprintf(stderr, "Unknown command: %s\n", command->c_str());
return "";
}
} }
} // namespace tfprof } // namespace tfprof
} // namespace tensorflow } // namespace tensorflow

View File

@ -23,8 +23,19 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace tfprof { namespace tfprof {
struct Options; struct Options;
// ***This API is only for swig. Don't user it directory!***
// // **********************
// APIs in this file are only for swig.
// Talk to xpan@ if you want to call it directly!
// *********************
bool NewProfiler(const string* graph, const string* op_log);
void DeleteProfiler();
void AddStep(int64 step, const string* run_meta, const string* op_log);
string Profile(const string* command, const string* options);
// Interface defined for Python API swig. Calls the tfprof core API. // Interface defined for Python API swig. Calls the tfprof core API.
// 'graph', 'run_meta', 'op_log' are serialized GraphDef, RunMetadata, // 'graph', 'run_meta', 'op_log' are serialized GraphDef, RunMetadata,
// OpLog strings, respectively. // OpLog strings, respectively.

View File

@ -53,14 +53,16 @@ string GetTraceString(const CodeDef::Trace& trace) {
} // namespace } // namespace
void TFCode::AddNode(TFGraphNode* node) { void TFCode::AddNode(TFGraphNode* node) {
if (!node->code()) { if (node->code().traces_size() == 0) {
return; return;
} }
TFMultiGraphNode* pre_trace_node = nullptr; TFMultiGraphNode* pre_trace_node = nullptr;
for (int i = 0; i < node->code()->traces_size(); ++i) { // TODO(xpan): Consider to release CodeDef after TFCode is built. It
// takes a lot of memory.
for (int i = 0; i < node->code().traces_size(); ++i) {
// Unlike op name, which is globally unique, trace name is only unique // Unlike op name, which is globally unique, trace name is only unique
// w.r.t. it's parent. // w.r.t. it's parent.
const string& trace = GetTraceString(node->code()->traces(i)); const string& trace = GetTraceString(node->code().traces(i));
if (i == 0) { if (i == 0) {
if (!trace_root_) { if (!trace_root_) {
trace_root_.reset(new TFMultiGraphNode(trace)); trace_root_.reset(new TFMultiGraphNode(trace));
@ -72,7 +74,7 @@ void TFCode::AddNode(TFGraphNode* node) {
pre_trace_node->AddChildren(trace); pre_trace_node->AddChildren(trace);
TFMultiGraphNode* trace_node = pre_trace_node->children().at(trace).get(); TFMultiGraphNode* trace_node = pre_trace_node->children().at(trace).get();
if (i == node->code()->traces_size() - 1) { if (i == node->code().traces_size() - 1) {
trace_node->AddGraphNode(node); trace_node->AddGraphNode(node);
} }
pre_trace_node = trace_node; pre_trace_node = trace_node;

View File

@ -70,12 +70,19 @@ void TFGraph::Build() {
} }
const ShowNode* TFGraph::ShowInternal(const Options& opts, Timeline* timeline) { const ShowNode* TFGraph::ShowInternal(const Options& opts, Timeline* timeline) {
root_->ResetTotalStats();
root_->show_children.clear();
if (timeline && timeline->step() < 0) {
// TODO(xpan): Maybe pick a default step for users.
fprintf(stderr,
"Must specify -step option to generate timeline in graph view.\n");
return root_;
}
// 1. Account and aggregate the stats based on the graph structure. // 1. Account and aggregate the stats based on the graph structure.
// Returns a graph consists of accounted nodes. // Returns a graph consists of accounted nodes.
std::set<string> visits; std::set<string> visits;
std::vector<GraphNode*> roots = Account(root_->children, opts, &visits); std::vector<GraphNode*> roots =
root_->ResetTotalStats(); Account(root_->children, opts, timeline, &visits);
root_->show_children.clear();
for (GraphNode* n : roots) { for (GraphNode* n : roots) {
root_->AggregateTotalStats(n); root_->AggregateTotalStats(n);
} }
@ -98,7 +105,7 @@ const ShowNode* TFGraph::ShowInternal(const Options& opts, Timeline* timeline) {
Format(root->show_children, &root->formatted_str, root->mutable_proto()); Format(root->show_children, &root->formatted_str, root->mutable_proto());
if (timeline) { if (timeline) {
timeline->GenerateGraphTimeline(root, memory_tracker_); timeline->GenerateGraphTimeline(root);
} }
return root; return root;
} }
@ -201,26 +208,28 @@ std::vector<GraphNode*> TFGraph::PrintGraph(const std::vector<GraphNode*> roots,
std::vector<GraphNode*> TFGraph::Account(const std::vector<GraphNode*>& roots, std::vector<GraphNode*> TFGraph::Account(const std::vector<GraphNode*>& roots,
const Options& opts, const Options& opts,
Timeline* timeline,
std::set<string>* visits) { std::set<string>* visits) {
std::vector<GraphNode*> act_nodes; std::vector<GraphNode*> act_nodes;
for (GraphNode* node : roots) { for (GraphNode* node : roots) {
if (visits->find(node->name()) != visits->end()) continue; if (visits->find(node->name()) != visits->end()) continue;
visits->insert(node->name()); visits->insert(node->name());
// Depth-first. // Depth-first.
std::vector<GraphNode*> act_cnodes = Account(node->children, opts, visits); std::vector<GraphNode*> act_cnodes =
Account(node->children, opts, timeline, visits);
node->account = ShouldAccount(node, opts); node->account = ReAccount(node, opts);
if (node->account) { if (node->account) {
node->show_children.clear(); node->show_children.clear();
node->ResetTotalStats(); node->ResetTotalStats();
node->AddSelfToTotalStats(); node->AddSelfToTotalStats();
if (node->trackable) { if (timeline) {
memory_tracker_.TrackNode(node); timeline->TrackNode(node);
} }
// Aggregate its accounted children stats. // Aggregate its accounted children stats.
for (GraphNode* c : act_cnodes) { for (GraphNode* c : act_cnodes) {
if (node->trackable && c->trackable) { if (timeline) {
memory_tracker_.TrackNodeConnection(node, c); timeline->TrackNodeConnection(node, c);
} }
node->AggregateTotalStats(c); node->AggregateTotalStats(c);
node->show_children.push_back(c); node->show_children.push_back(c);

View File

@ -70,7 +70,7 @@ class TFGraph : public TFShow {
int last_ident, std::set<string>* visits); int last_ident, std::set<string>* visits);
std::vector<GraphNode*> Account(const std::vector<GraphNode*>& roots, std::vector<GraphNode*> Account(const std::vector<GraphNode*>& roots,
const Options& opts, const Options& opts, Timeline* timeline,
std::set<string>* visits); std::set<string>* visits);
void Format(const std::vector<GraphNode*> roots, string* display_str, void Format(const std::vector<GraphNode*> roots, string* display_str,

View File

@ -27,86 +27,31 @@ namespace tfprof {
// For CPU, op_end_rel is the kernel time, while all_end_rel_micros includes // For CPU, op_end_rel is the kernel time, while all_end_rel_micros includes
// some post-processing. // some post-processing.
// Here, we only consider kernel time for simplicity. // Here, we only consider kernel time for simplicity.
void TFGraphNode::AddStepStat(const string& device, void TFGraphNode::AddStepStat(int64 step, const string& device,
const NodeExecStats* step_stat) { const NodeExecStats& step_stat) {
step_stat_ = step_stat;
CHECK(step_stat_);
string dev = str_util::Lowercase(device); string dev = str_util::Lowercase(device);
// TODO(xpan): Test it. // TODO(xpan): Test it.
if (RE2::FullMatch(dev, "/job:.*/replica:\\d+/task:\\d+/[a-z]+:\\d+")) { if (RE2::FullMatch(dev, "/job:.*/replica:\\d+/task:\\d+/[a-z]+:\\d+")) {
canonical_device_ = dev; if (!canonical_device_.empty()) {
// TODO(xpan): Support things other than gpu? if (canonical_device_ != dev) {
host_device_ = StringReplace(dev, "gpu:\\d+", "cpu:0"); fprintf(stderr, "Unexpected: graph node changed device: %s->%s.\n",
AddOpType(canonical_device_); canonical_device_.c_str(), dev.c_str());
} return;
}
devices_.insert(dev);
if (step_stat_->all_start_micros() > 0) {
if (all_start_micros_ > 0) {
all_start_micros_ =
std::min(all_start_micros_,
static_cast<int64>(step_stat_->all_start_micros()));
} else { } else {
all_start_micros_ = step_stat_->all_start_micros(); canonical_device_ = dev;
} // TODO(xpan): Support things other than gpu?
int64 op_end_rel_micros = step_stat_->op_end_rel_micros(); host_device_ = StringReplace(dev, "gpu:\\d+", "cpu:0");
// Round quick execution to 1 micro to be semantically robust. AddOpType(canonical_device_);
if (op_end_rel_micros == 0) {
++op_end_rel_micros;
}
latest_end_rel_micros_ =
std::max(latest_end_rel_micros_, op_end_rel_micros);
op_execs_[dev].push_back(
std::make_pair(step_stat_->all_start_micros(), op_end_rel_micros));
if (dev.find("stream") != dev.npos && dev.find("stream:all") == dev.npos) {
gpu_kernel_execs_[dev].push_back(
std::make_pair(step_stat_->all_start_micros(), op_end_rel_micros));
} }
} }
ExecStep& exec = execs_[step];
exec.AddTimeStats(dev, step_stat);
if (dev == canonical_device_) { if (dev == canonical_device_) {
for (const auto& mem : step_stat_->memory()) { exec.AddMemoryStats(dev, step_stat);
// TODO(xpan): Fix this hack. Currently the allocator name seems quite
// ad-hoc.
if (mem.allocator_name().find("GPU") == mem.allocator_name().npos) {
continue;
}
if (dev == canonical_device_) {
allocator_bytes_in_use_ =
std::max(allocator_bytes_in_use_,
static_cast<int64>(mem.allocator_bytes_in_use()));
}
}
int64 total_output_bytes = 0;
for (const auto& output : step_stat_->output()) {
if (output.has_tensor_description() &&
output.tensor_description().has_allocation_description()) {
// TODO(xpan): Maybe allocated_bytes.
int64 output_bytes = std::max(output.tensor_description()
.allocation_description()
.allocated_bytes(),
output.tensor_description()
.allocation_description()
.requested_bytes());
uint64 output_ptr =
output.tensor_description().allocation_description().ptr();
total_output_bytes += output_bytes;
output_bytes_[output.slot()] = std::make_pair(output_bytes, output_ptr);
}
}
if (step_stat_->has_memory_stats()) {
host_temp_bytes_ += step_stat_->memory_stats().host_temp_memory_size();
host_persistent_bytes_ +=
step_stat_->memory_stats().host_persistent_memory_size();
accelerator_temp_bytes_ +=
step_stat_->memory_stats().device_temp_memory_size();
accelerator_persistent_bytes_ +=
step_stat_->memory_stats().device_persistent_memory_size();
}
requested_bytes_ = total_output_bytes;
} }
} }
} // namespace tfprof } // namespace tfprof

View File

@ -37,22 +37,162 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace tfprof { namespace tfprof {
class TFGraphNode { class ExecStep {
public: public:
TFGraphNode(const NodeDef* node) ExecStep()
: node_(node), : all_start_micros_(0),
code_(nullptr),
step_stat_(nullptr),
all_start_micros_(0),
latest_end_rel_micros_(0), latest_end_rel_micros_(0),
mem_initiated_(false),
requested_bytes_(0), requested_bytes_(0),
host_temp_bytes_(0), host_temp_bytes_(0),
host_persistent_bytes_(0), host_persistent_bytes_(0),
accelerator_temp_bytes_(0), accelerator_temp_bytes_(0),
accelerator_persistent_bytes_(0), accelerator_persistent_bytes_(0),
allocator_bytes_in_use_(0), allocator_bytes_in_use_(0) {}
float_ops_(0),
op_(node->op()) { void AddTimeStats(const string& dev, const NodeExecStats& step_stat) {
devices_.insert(dev);
if (step_stat.all_start_micros() > 0) {
if (all_start_micros_ > 0) {
all_start_micros_ =
std::min(all_start_micros_,
static_cast<int64>(step_stat.all_start_micros()));
} else {
all_start_micros_ = step_stat.all_start_micros();
}
int64 op_end_rel_micros = step_stat.op_end_rel_micros();
// Round quick execution to 1 micro to be semantically robust.
if (op_end_rel_micros == 0) {
++op_end_rel_micros;
}
latest_end_rel_micros_ =
std::max(latest_end_rel_micros_, op_end_rel_micros);
op_execs_[dev].push_back(
std::make_pair(step_stat.all_start_micros(), op_end_rel_micros));
if (dev.find("stream") != dev.npos &&
dev.find("stream:all") == dev.npos) {
gpu_kernel_execs_[dev].push_back(
std::make_pair(step_stat.all_start_micros(), op_end_rel_micros));
}
}
}
void AddMemoryStats(const string& dev, const NodeExecStats& step_stat) {
if (mem_initiated_) {
// fprintf(stderr, "Memory initiated twice on %s", dev.c_str());
return;
}
mem_initiated_ = true;
for (const auto& mem : step_stat.memory()) {
// TODO(xpan): Fix this hack. Currently the allocator name seems quite
// ad-hoc.
if (mem.allocator_name().find("GPU") == mem.allocator_name().npos) {
continue;
}
allocator_bytes_in_use_ =
std::max(allocator_bytes_in_use_,
static_cast<int64>(mem.allocator_bytes_in_use()));
}
int64 total_output_bytes = 0;
for (const auto& output : step_stat.output()) {
if (output.has_tensor_description() &&
output.tensor_description().has_allocation_description()) {
// TODO(xpan): Maybe allocated_bytes.
int64 output_bytes = std::max(output.tensor_description()
.allocation_description()
.allocated_bytes(),
output.tensor_description()
.allocation_description()
.requested_bytes());
uint64 output_ptr =
output.tensor_description().allocation_description().ptr();
total_output_bytes += output_bytes;
output_bytes_[output.slot()] = std::make_pair(output_bytes, output_ptr);
}
}
if (step_stat.has_memory_stats()) {
host_temp_bytes_ += step_stat.memory_stats().host_temp_memory_size();
host_persistent_bytes_ +=
step_stat.memory_stats().host_persistent_memory_size();
accelerator_temp_bytes_ +=
step_stat.memory_stats().device_temp_memory_size();
accelerator_persistent_bytes_ +=
step_stat.memory_stats().device_persistent_memory_size();
}
requested_bytes_ = total_output_bytes;
}
int64 exec_micros() const {
int64 total = 0;
for (const auto& execs : gpu_kernel_execs_) {
for (const auto& exec : execs.second) {
total += exec.second;
}
}
if (total > 0) return total;
// If there is no gpu kernel time, fall back to assume it runs on cpu.
// TODO(xpan): No way to track CPU async op timing accurately?
for (const auto& execs : op_execs_) {
for (const auto& exec : execs.second) {
total += exec.second;
}
}
return total;
}
const std::map<string, std::vector<std::pair<int64, int64>>>& op_execs()
const {
return op_execs_;
}
int64 all_start_micros() const { return all_start_micros_; }
int64 latest_end_rel_micros() const { return latest_end_rel_micros_; }
int64 requested_bytes() const { return requested_bytes_; }
int64 accelerator_temp_bytes() const { return accelerator_temp_bytes_; }
int64 host_temp_bytes() const { return host_temp_bytes_; }
int64 accelerator_persistent_bytes() const {
return accelerator_persistent_bytes_;
}
int64 host_persistent_bytes() const { return host_persistent_bytes_; }
const std::map<int64, std::pair<int64, uint64>>& output_bytes() const {
return output_bytes_;
}
int64 allocator_bytes_in_use() const { return allocator_bytes_in_use_; }
private:
// The earliest/latest time including scheduling and kernel execution.
int64 all_start_micros_;
int64 latest_end_rel_micros_;
// device -> vector of {op_start_micros, op_kernel_exec_micros} pairs.
std::map<string, std::vector<std::pair<int64, int64>>> gpu_kernel_execs_;
std::map<string, std::vector<std::pair<int64, int64>>> op_execs_;
// All devices the op is associated with (e.g. gpu:0 (scheduling),
// gpu:0:stream:xx (kernel exec), cpu:0 host)
std::set<string> devices_;
bool mem_initiated_;
// Total output bytes requested by the op.
int64 requested_bytes_;
// Total temporary bytes allocated and released by the op.
int64 host_temp_bytes_;
// Total persistent bytes (e.g. variable) allocated by the op.
int64 host_persistent_bytes_;
int64 accelerator_temp_bytes_;
int64 accelerator_persistent_bytes_;
// The total number of bytes currently allocated by the allocator if >0.
int64 allocator_bytes_in_use_;
// output_idx -> {output_bytes, memory_ptr}
std::map<int64, std::pair<int64, uint64>> output_bytes_;
};
class TFGraphNode {
public:
TFGraphNode(const NodeDef* node)
: node_(node), float_ops_(0), op_(node->op()) {
for (const auto& attr : node->attr()) { for (const auto& attr : node->attr()) {
// TODO(xpan): Also consider _output_shapes. // TODO(xpan): Also consider _output_shapes.
if (attr.first != "shape" || !attr.second.has_shape()) continue; if (attr.first != "shape" || !attr.second.has_shape()) continue;
@ -82,67 +222,123 @@ class TFGraphNode {
void AddOpType(const string& op_type) { op_types_.insert(op_type); } void AddOpType(const string& op_type) { op_types_.insert(op_type); }
void AddStepStat(const string& device, const NodeExecStats* step_stat); void AddStepStat(int64 step, const string& device,
const NodeExecStats& step_stat);
void AddFloatOps(int64 float_ops) { float_ops_ = float_ops; } void AddFloatOps(int64 float_ops) { float_ops_ = float_ops; }
void AddCode(const CodeDef* code) { code_ = code; } // TODO(xpan): This could take a lot of memory.
void AddCode(const CodeDef& code) { code_.MergeFrom(code); }
const string& name() const { return node_->name(); } const string& name() const { return node_->name(); }
const string& op() const { return op_; } const string& op() const { return op_; }
const NodeDef* node_def() { return node_; } const NodeDef* node_def() { return node_; }
const NodeExecStats* step_stats() const { return step_stat_; } bool trackable(int64 step) const {
auto exec = execs_.find(step);
if (exec == execs_.end()) return false;
if (exec->second.all_start_micros() == 0) return false;
if (canonical_device_.empty() || host_device_.empty()) {
return false;
}
return true;
}
const std::map<string, TFGraphNode*>& inputs() const { return inputs_; } const std::map<string, TFGraphNode*>& inputs() const { return inputs_; }
const std::map<string, int64>& output_idx() const { return output_idx_; } const std::map<string, int64>& output_idx() const { return output_idx_; }
// This is time spent in kernel execution. // This is time spent in kernel execution.
int64 kernel_exec_micros() const { int64 kernel_exec_micros(int64 step) const {
if (!step_stat_) return 0; if (execs_.empty()) {
int64 total = 0; return 0;
for (const auto& execs : gpu_kernel_execs_) {
for (const auto& exec : execs.second) {
total += exec.second;
}
} }
if (total > 0) return total; if (step >= 0) {
auto exec = execs_.find(step);
// If there is no gpu kernel time, fall back to assume it runs on cpu. CHECK(exec != execs_.end());
for (const auto& execs : op_execs_) { return exec->second.exec_micros();
for (const auto& exec : execs.second) {
total += exec.second;
}
} }
return total;
int64 total_micros = 0;
for (const auto& exec : execs_) {
total_micros += exec.second.exec_micros();
}
return total_micros / execs_.size();
} }
int64 all_start_micros() const { return all_start_micros_; } int64 requested_bytes(int64 step) const {
int64 latest_end_rel_micros() const { return latest_end_rel_micros_; } if (execs_.empty()) {
const std::map<string, std::vector<std::pair<int64, int64>>>& op_execs() return 0;
const { }
return op_execs_; if (step >= 0) {
auto exec = execs_.find(step);
CHECK(exec != execs_.end()) << "unknown step " << step;
return exec->second.requested_bytes();
}
int64 requested_bytes = 0;
for (const auto& exec : execs_) {
requested_bytes += exec.second.requested_bytes();
}
return requested_bytes / execs_.size();
} }
int64 requested_bytes() const { return requested_bytes_; } int64 all_start_micros(int64 step) const {
int64 accelerator_temp_bytes() const { return accelerator_temp_bytes_; } auto exec = execs_.find(step);
int64 host_temp_bytes() const { return host_temp_bytes_; } CHECK(exec != execs_.end()) << "unknown step " << step;
int64 accelerator_persistent_bytes() const { return exec->second.all_start_micros();
return accelerator_persistent_bytes_;
} }
int64 host_persistent_bytes() const { return host_persistent_bytes_; }
const std::map<int64, std::pair<int64, uint64>>& output_bytes() const { int64 latest_end_rel_micros(int64 step) const {
return output_bytes_; auto exec = execs_.find(step);
CHECK(exec != execs_.end()) << "unknown step " << step;
return exec->second.latest_end_rel_micros();
}
const std::map<string, std::vector<std::pair<int64, int64>>>& op_execs(
int64 step) const {
auto exec = execs_.find(step);
CHECK(exec != execs_.end()) << "unknown step " << step;
return exec->second.op_execs();
}
int64 accelerator_temp_bytes(int64 step) const {
auto exec = execs_.find(step);
CHECK(exec != execs_.end()) << "unknown step " << step;
return exec->second.accelerator_temp_bytes();
}
int64 host_temp_bytes(int64 step) const {
auto exec = execs_.find(step);
CHECK(exec != execs_.end()) << "unknown step " << step;
return exec->second.host_temp_bytes();
}
int64 accelerator_persistent_bytes(int64 step) const {
auto exec = execs_.find(step);
CHECK(exec != execs_.end()) << "unknown step " << step;
return exec->second.accelerator_persistent_bytes();
}
int64 host_persistent_bytes(int64 step) const {
auto exec = execs_.find(step);
CHECK(exec != execs_.end()) << "unknown step " << step;
return exec->second.host_persistent_bytes();
}
const std::map<int64, std::pair<int64, uint64>>& output_bytes(
int64 step) const {
auto exec = execs_.find(step);
CHECK(exec != execs_.end()) << "unknown step " << step;
return exec->second.output_bytes();
}
int64 allocator_bytes_in_use(int64 step) const {
auto exec = execs_.find(step);
CHECK(exec != execs_.end()) << "unknown step " << step;
return exec->second.allocator_bytes_in_use();
} }
int64 allocator_bytes_in_use() const { return allocator_bytes_in_use_; }
int64 float_ops() const { return float_ops_; } int64 float_ops() const { return float_ops_; }
const CodeDef* code() { return code_; } const CodeDef& code() { return code_; }
string canonical_device() const { return canonical_device_; } string canonical_device() const { return canonical_device_; }
string host_device() const { return host_device_; } string host_device() const { return host_device_; }
std::set<string> devices() const { return devices_; }
const std::set<string>& op_types() const { return op_types_; } const std::set<string>& op_types() const { return op_types_; }
const std::vector<int64>& shape() const { return shape_; } const std::vector<int64>& shape() const { return shape_; }
private: private:
@ -152,39 +348,18 @@ class TFGraphNode {
std::map<string, int64> output_idx_; std::map<string, int64> output_idx_;
const NodeDef* node_; const NodeDef* node_;
const CodeDef* code_;
const NodeExecStats* step_stat_; CodeDef code_;
std::vector<int64> shape_; std::vector<int64> shape_;
std::set<string> op_types_; std::set<string> op_types_;
// The earliest/latest time including scheduling and kernel execution. std::map<int64, ExecStep> execs_;
int64 all_start_micros_;
int64 latest_end_rel_micros_;
// device -> vector of {op_start_micros, op_kernel_exec_micros} pairs.
std::map<string, std::vector<std::pair<int64, int64>>> gpu_kernel_execs_;
std::map<string, std::vector<std::pair<int64, int64>>> op_execs_;
// /j:#/t:#/r:#/device:#. A canonical device name without extra suffix. // /j:#/t:#/r:#/device:#. A canonical device name without extra suffix.
string canonical_device_; string canonical_device_;
// The host device name. // The host device name.
string host_device_; string host_device_;
// All devices the op is associated with (e.g. gpu:0 (scheduling),
// gpu:0:stream:xx (kernel exec), cpu:0 host)
std::set<string> devices_;
// Total output bytes requested by the op.
int64 requested_bytes_;
// Total temporary bytes allocated and released by the op.
int64 host_temp_bytes_;
// Total persistent bytes (e.g. variable) allocated by the op.
int64 host_persistent_bytes_;
int64 accelerator_temp_bytes_;
int64 accelerator_persistent_bytes_;
// The total number of bytes currently allocated by the allocator if >0.
int64 allocator_bytes_in_use_;
// output_idx -> {output_bytes, memory_ptr}
std::map<int64, std::pair<int64, uint64>> output_bytes_;
int64 float_ops_; int64 float_ops_;
@ -199,7 +374,7 @@ class TFMultiGraphNode {
requested_bytes_(0), requested_bytes_(0),
float_ops_(0) {} float_ops_(0) {}
bool SnapshotNodes(const std::vector<string>& type_regexes) { bool SnapshotNodes(int64 step, const std::vector<string>& type_regexes) {
kernel_exec_micros_ = 0; kernel_exec_micros_ = 0;
requested_bytes_ = 0; requested_bytes_ = 0;
float_ops_ = 0; float_ops_ = 0;
@ -208,30 +383,23 @@ class TFMultiGraphNode {
devices_.clear(); devices_.clear();
snapshot_nodes_.clear(); snapshot_nodes_.clear();
std::map<string, std::vector<const TFGraphNode*>> nodes = std::vector<const TFGraphNode*> nodes = pick_nodes(type_regexes);
pick_nodes(type_regexes);
if (nodes.empty()) { if (nodes.empty()) {
return (type_regexes.size() == 1 && type_regexes[0] == ".*"); return (type_regexes.size() == 1 && type_regexes[0] == ".*");
} }
std::set<string> visits; for (const TFGraphNode* node : nodes) {
for (const auto& entry : nodes) { op_types_.insert(node->op_types().begin(), node->op_types().end());
op_types_.insert(entry.first);
for (const TFGraphNode* node : entry.second) { kernel_exec_micros_ += node->kernel_exec_micros(step);
if (visits.find(node->name()) != visits.end()) continue; requested_bytes_ += node->requested_bytes(step);
visits.insert(node->name()); float_ops_ += node->float_ops();
if (node->shape().size() > 0) {
kernel_exec_micros_ += node->kernel_exec_micros(); shapes_.push_back(node->shape());
requested_bytes_ += node->requested_bytes();
float_ops_ += node->float_ops();
if (node->shape().size() > 0) {
shapes_.push_back(node->shape());
}
devices_.insert(node->canonical_device());
snapshot_nodes_[node->name()] = node;
} }
devices_.insert(node->canonical_device());
snapshot_nodes_[node->name()] = node;
} }
return true; return true;
} }
@ -241,9 +409,6 @@ class TFMultiGraphNode {
return; return;
} }
nodes_[node->name()] = node; nodes_[node->name()] = node;
for (const string& type : node->op_types()) {
nodes_by_type_[type].push_back(node);
}
} }
const std::map<string, const TFGraphNode*>& graph_nodes() const { const std::map<string, const TFGraphNode*>& graph_nodes() const {
@ -275,19 +440,26 @@ class TFMultiGraphNode {
const std::vector<std::vector<int64>>& shapes() const { return shapes_; } const std::vector<std::vector<int64>>& shapes() const { return shapes_; }
private: private:
std::map<string, std::vector<const TFGraphNode*>> pick_nodes( std::vector<const TFGraphNode*> pick_nodes(
const std::vector<string>& type_regexes) { const std::vector<string>& type_regexes) {
if (type_regexes.empty()) { if (type_regexes.empty()) {
return {}; return {};
} }
std::vector<const TFGraphNode*> ret;
if (type_regexes.size() == 1 && type_regexes[0] == ".*") { if (type_regexes.size() == 1 && type_regexes[0] == ".*") {
return nodes_by_type_; for (const auto& n : nodes_) {
ret.push_back(n.second);
}
return ret;
} }
std::map<string, std::vector<const TFGraphNode*>> ret;
for (const string& regex : type_regexes) { for (const string& regex : type_regexes) {
for (const auto& n : nodes_by_type_) { for (const auto& n : nodes_) {
if (RE2::FullMatch(n.first, regex)) { for (const string& type : n.second->op_types()) {
ret[n.first] = n.second; if (RE2::FullMatch(type, regex)) {
ret.push_back(n.second);
break;
}
} }
} }
} }
@ -295,7 +467,7 @@ class TFMultiGraphNode {
} }
const string name_; const string name_;
// Snapshot micros based on type_regexes // Snapshot based on type_regexes
std::set<string> op_types_; std::set<string> op_types_;
int64 kernel_exec_micros_; int64 kernel_exec_micros_;
int64 requested_bytes_; int64 requested_bytes_;
@ -306,7 +478,6 @@ class TFMultiGraphNode {
// Overall data held by the TFMultiGraphNode. // Overall data held by the TFMultiGraphNode.
std::map<string, const TFGraphNode*> nodes_; std::map<string, const TFGraphNode*> nodes_;
std::map<string, std::vector<const TFGraphNode*>> nodes_by_type_;
std::map<string, std::unique_ptr<TFMultiGraphNode>> children_; std::map<string, std::unique_ptr<TFMultiGraphNode>> children_;
}; };
} // namespace tfprof } // namespace tfprof

View File

@ -22,18 +22,20 @@ namespace tfprof {
namespace {} namespace {}
ShowNode::ShowNode(const TFGraphNode* node) : node(node), account(false) { ShowNode::ShowNode(const TFGraphNode* node) : node(node), account(false) {
ReInit(); ReInit(-1);
} }
void ShowNode::ReInit() { void ShowNode::ReInit(int64 step) {
mutable_proto()->set_name(name()); mutable_proto()->set_name(name());
mutable_proto()->clear_devices();
if (!node->canonical_device().empty()) { if (!node->canonical_device().empty()) {
mutable_proto()->add_devices(node->canonical_device()); mutable_proto()->add_devices(node->canonical_device());
} }
mutable_proto()->set_exec_micros(node->kernel_exec_micros()); mutable_proto()->set_exec_micros(node->kernel_exec_micros(step));
mutable_proto()->set_requested_bytes(node->requested_bytes()); mutable_proto()->set_requested_bytes(node->requested_bytes(step));
mutable_proto()->set_float_ops(node->float_ops()); mutable_proto()->set_float_ops(node->float_ops());
proto_.clear_parameters();
if (!node->shape().empty()) { if (!node->shape().empty()) {
int64 params = 1; int64 params = 1;
bool complete_shape = true; bool complete_shape = true;
@ -90,17 +92,19 @@ void ShowNode::ResetTotalStats() {
ShowMultiNode::ShowMultiNode(TFMultiGraphNode* node) ShowMultiNode::ShowMultiNode(TFMultiGraphNode* node)
: node(node), account(false), show(false) { : node(node), account(false), show(false) {
ReInit({".*"}); ReInit(-1, {".*"});
} }
bool ShowMultiNode::ReInit(const std::vector<string>& type_regexes) { bool ShowMultiNode::ReInit(int64 step,
bool has_matched_type = node->SnapshotNodes(type_regexes); const std::vector<string>& type_regexes) {
bool has_matched_type = node->SnapshotNodes(step, type_regexes);
std::vector<ShowNode> snodes; std::vector<ShowNode> snodes;
mutable_proto()->mutable_graph_nodes()->Clear(); mutable_proto()->mutable_graph_nodes()->Clear();
for (auto it : node->graph_nodes()) { for (auto it : node->graph_nodes()) {
ShowNode snode(it.second); ShowNode snode(it.second);
snodes.push_back(snode); snodes.push_back(snode);
snodes.back().ReInit(step);
snodes.back().AddSelfToTotalStats(); snodes.back().AddSelfToTotalStats();
mutable_proto()->add_graph_nodes()->MergeFrom(snodes.back().proto()); mutable_proto()->add_graph_nodes()->MergeFrom(snodes.back().proto());
} }
@ -110,7 +114,7 @@ bool ShowMultiNode::ReInit(const std::vector<string>& type_regexes) {
mutable_proto()->set_requested_bytes(node->requested_bytes()); mutable_proto()->set_requested_bytes(node->requested_bytes());
mutable_proto()->set_float_ops(node->float_ops()); mutable_proto()->set_float_ops(node->float_ops());
mutable_proto()->set_parameters(0); mutable_proto()->clear_parameters();
if (!node->shapes().empty()) { if (!node->shapes().empty()) {
for (const std::vector<int64>& shape : node->shapes()) { for (const std::vector<int64>& shape : node->shapes()) {
int64 params = 1; int64 params = 1;

View File

@ -48,7 +48,7 @@ class ShowNode {
TFGraphNodeProto* mutable_proto(); TFGraphNodeProto* mutable_proto();
const TFGraphNodeProto& proto() const; const TFGraphNodeProto& proto() const;
void ReInit(); void ReInit(int64 step);
void AggregateTotalStats(ShowNode* node); void AggregateTotalStats(ShowNode* node);
@ -66,24 +66,10 @@ class ShowNode {
class GraphNode : public ShowNode { class GraphNode : public ShowNode {
public: public:
explicit GraphNode(TFGraphNode* node) : ShowNode(node) { explicit GraphNode(TFGraphNode* node) : ShowNode(node) {}
trackable = Trackable();
}
void ReInit() { bool Trackable(int64 step) { return node->trackable(step); }
ShowNode::ReInit();
}
bool Trackable() {
if (!node->step_stats()) return false;
if (node->all_start_micros() == 0) return false;
if (node->canonical_device().empty() || node->host_device().empty()) {
return false;
}
return true;
}
bool trackable;
std::vector<GraphNode*> children; std::vector<GraphNode*> children;
std::vector<GraphNode*> show_children; std::vector<GraphNode*> show_children;
}; };
@ -102,7 +88,7 @@ class ShowMultiNode {
explicit ShowMultiNode(TFMultiGraphNode* node); explicit ShowMultiNode(TFMultiGraphNode* node);
virtual ~ShowMultiNode() {} virtual ~ShowMultiNode() {}
bool ReInit(const std::vector<string>& type_regexes); bool ReInit(int64 step, const std::vector<string>& type_regexes);
const string& name() const { return node->name(); } const string& name() const { return node->name(); }
TFMultiGraphNodeProto* mutable_proto(); TFMultiGraphNodeProto* mutable_proto();

View File

@ -146,7 +146,7 @@ tensorflow::Status Options::FromProtoStr(const string& opts_proto_str,
*opts = Options( *opts = Options(
opts_pb.max_depth(), opts_pb.min_bytes(), opts_pb.min_micros(), opts_pb.max_depth(), opts_pb.min_bytes(), opts_pb.min_micros(),
opts_pb.min_params(), opts_pb.min_float_ops(), opts_pb.min_occurrence(), opts_pb.min_params(), opts_pb.min_float_ops(), opts_pb.min_occurrence(),
opts_pb.order_by(), opts_pb.step(), opts_pb.order_by(),
std::vector<string>(opts_pb.account_type_regexes().begin(), std::vector<string>(opts_pb.account_type_regexes().begin(),
opts_pb.account_type_regexes().end()), opts_pb.account_type_regexes().end()),
std::vector<string>(opts_pb.start_name_regexes().begin(), std::vector<string>(opts_pb.start_name_regexes().begin(),
@ -171,6 +171,7 @@ string Options::ToString() const {
"%-28s%lld\n" "%-28s%lld\n"
"%-28s%lld\n" "%-28s%lld\n"
"%-28s%lld\n" "%-28s%lld\n"
"%-28s%lld\n"
"%-28s%s\n" "%-28s%s\n"
"%-28s%s\n" "%-28s%s\n"
"%-28s%s\n" "%-28s%s\n"
@ -182,15 +183,15 @@ string Options::ToString() const {
"%-28s%s:%s\n", "%-28s%s:%s\n",
kOptions[0], max_depth, kOptions[1], min_bytes, kOptions[2], min_micros, kOptions[0], max_depth, kOptions[1], min_bytes, kOptions[2], min_micros,
kOptions[3], min_params, kOptions[4], min_float_ops, kOptions[5], kOptions[3], min_params, kOptions[4], min_float_ops, kOptions[5],
min_occurrence, kOptions[6], order_by.c_str(), kOptions[7], min_occurrence, kOptions[6], step, kOptions[7], order_by.c_str(),
str_util::Join(account_type_regexes, ",").c_str(), kOptions[8], kOptions[8], str_util::Join(account_type_regexes, ",").c_str(),
str_util::Join(start_name_regexes, ",").c_str(), kOptions[9], kOptions[9], str_util::Join(start_name_regexes, ",").c_str(),
str_util::Join(trim_name_regexes, ",").c_str(), kOptions[10], kOptions[10], str_util::Join(trim_name_regexes, ",").c_str(),
str_util::Join(show_name_regexes, ",").c_str(), kOptions[11], kOptions[11], str_util::Join(show_name_regexes, ",").c_str(),
str_util::Join(hide_name_regexes, ",").c_str(), kOptions[12], kOptions[12], str_util::Join(hide_name_regexes, ",").c_str(),
(account_displayed_op_only ? "true" : "false"), kOptions[13], kOptions[13], (account_displayed_op_only ? "true" : "false"),
str_util::Join(select, ",").c_str(), kOptions[14], output_type.c_str(), kOptions[14], str_util::Join(select, ",").c_str(), kOptions[15],
KeyValueToStr(output_options).c_str()); output_type.c_str(), KeyValueToStr(output_options).c_str());
return s; return s;
} }

View File

@ -33,6 +33,7 @@ static const char* const kOptions[] = {
"-min_params", "-min_params",
"-min_float_ops", "-min_float_ops",
"-min_occurrence", "-min_occurrence",
"-step",
"-order_by", "-order_by",
"-account_type_regexes", "-account_type_regexes",
"-start_name_regexes", "-start_name_regexes",
@ -81,12 +82,13 @@ struct Options {
virtual ~Options() {} virtual ~Options() {}
Options() Options()
: Options(0, 0, 0, 0, 0, 0, "", {}, {}, {}, {}, {}, false, {}, "", {}) {} : Options(0, 0, 0, 0, 0, 0, 0, "", {}, {}, {}, {}, {}, false, {}, "",
{}) {}
Options(int max_depth, tensorflow::int64 min_bytes, Options(int max_depth, tensorflow::int64 min_bytes,
tensorflow::int64 min_micros, tensorflow::int64 min_params, tensorflow::int64 min_micros, tensorflow::int64 min_params,
tensorflow::int64 min_float_ops, tensorflow::int64 min_occurrence, tensorflow::int64 min_float_ops, tensorflow::int64 min_occurrence,
const string& order_by, tensorflow::int64 step, const string& order_by,
const std::vector<string>& account_type_regexes, const std::vector<string>& account_type_regexes,
const std::vector<string>& start_name_regexes, const std::vector<string>& start_name_regexes,
const std::vector<string>& trim_name_regexes, const std::vector<string>& trim_name_regexes,
@ -101,6 +103,7 @@ struct Options {
min_params(min_params), min_params(min_params),
min_float_ops(min_float_ops), min_float_ops(min_float_ops),
min_occurrence(min_occurrence), min_occurrence(min_occurrence),
step(step),
order_by(order_by), order_by(order_by),
account_type_regexes(account_type_regexes), account_type_regexes(account_type_regexes),
start_name_regexes(start_name_regexes), start_name_regexes(start_name_regexes),
@ -120,6 +123,7 @@ struct Options {
tensorflow::int64 min_params; tensorflow::int64 min_params;
tensorflow::int64 min_float_ops; tensorflow::int64 min_float_ops;
tensorflow::int64 min_occurrence; tensorflow::int64 min_occurrence;
tensorflow::int64 step;
string order_by; string order_by;
std::vector<string> account_type_regexes; std::vector<string> account_type_regexes;

View File

@ -196,7 +196,7 @@ std::vector<ScopeNode*> TFScope::Account(const std::vector<ScopeNode*>& roots,
node->ResetTotalStats(); node->ResetTotalStats();
std::vector<ScopeNode*> act_cnodes = Account(node->children, opts); std::vector<ScopeNode*> act_cnodes = Account(node->children, opts);
node->account = ShouldAccount(node, opts); node->account = ReAccount(node, opts);
if (node->account || !act_cnodes.empty()) { if (node->account || !act_cnodes.empty()) {
node->show_children.clear(); node->show_children.clear();
node->ResetTotalStats(); node->ResetTotalStats();

View File

@ -27,7 +27,7 @@ namespace tfprof {
const TFGraphNodeProto& TFShow::Show(const Options& opts) { const TFGraphNodeProto& TFShow::Show(const Options& opts) {
if (opts.output_type == kOutput[0]) { if (opts.output_type == kOutput[0]) {
Timeline timeline(opts.output_options.at(kTimelineOpts[0])); Timeline timeline(opts.step, opts.output_options.at(kTimelineOpts[0]));
return ShowInternal(opts, &timeline)->proto(); return ShowInternal(opts, &timeline)->proto();
} else if (opts.output_type == kOutput[2]) { } else if (opts.output_type == kOutput[2]) {
const ShowNode* root = ShowInternal(opts, nullptr); const ShowNode* root = ShowInternal(opts, nullptr);
@ -105,7 +105,8 @@ bool TFShow::ShouldTrim(ShowNode* node, const std::vector<string>& regexes) {
return false; return false;
} }
bool TFShow::ShouldAccount(ShowNode* node, const Options& opts) { bool TFShow::ReAccount(ShowNode* node, const Options& opts) {
node->ReInit(opts.step);
if (opts.account_type_regexes.size() == 1 && if (opts.account_type_regexes.size() == 1 &&
opts.account_type_regexes[0] == ".*") { opts.account_type_regexes[0] == ".*") {
return true; return true;

View File

@ -63,7 +63,7 @@ class TFShow {
bool ShouldTrim(ShowNode* node, const std::vector<string>& regexes); bool ShouldTrim(ShowNode* node, const std::vector<string>& regexes);
bool ShouldAccount(ShowNode* node, const Options& opts); bool ReAccount(ShowNode* node, const Options& opts);
string FormatNode(ShowNode* node, const Options& opts); string FormatNode(ShowNode* node, const Options& opts);

View File

@ -29,7 +29,7 @@ namespace tfprof {
const TFMultiGraphNodeProto& TFMultiShow::Show(const Options& opts) { const TFMultiGraphNodeProto& TFMultiShow::Show(const Options& opts) {
if (opts.output_type == kOutput[0]) { if (opts.output_type == kOutput[0]) {
Timeline timeline(opts.output_options.at(kTimelineOpts[0])); Timeline timeline(opts.step, opts.output_options.at(kTimelineOpts[0]));
return ShowInternal(opts, &timeline)->proto(); return ShowInternal(opts, &timeline)->proto();
} else if (opts.output_type == kOutput[2]) { } else if (opts.output_type == kOutput[2]) {
const ShowMultiNode* root = ShowInternal(opts, nullptr); const ShowMultiNode* root = ShowInternal(opts, nullptr);
@ -99,7 +99,7 @@ bool TFMultiShow::ShouldTrim(ShowMultiNode* node,
} }
bool TFMultiShow::ReAccount(ShowMultiNode* node, const Options& opts) { bool TFMultiShow::ReAccount(ShowMultiNode* node, const Options& opts) {
return node->ReInit(opts.account_type_regexes); return node->ReInit(opts.step, opts.account_type_regexes);
} }
string TFMultiShow::FormatLegend(const Options& opts) { string TFMultiShow::FormatLegend(const Options& opts) {

View File

@ -71,7 +71,8 @@ class TFProfShowTest : public ::testing::Test {
TEST_F(TFProfShowTest, DumpScopeMode) { TEST_F(TFProfShowTest, DumpScopeMode) {
string dump_file = io::JoinPath(testing::TmpDir(), "dump"); string dump_file = io::JoinPath(testing::TmpDir(), "dump");
Options opts(5, 0, 0, 0, 0, 0, "name", {"VariableV2"}, // accout_type_regexes Options opts(5, 0, 0, 0, 0, 0, -1, "name",
{"VariableV2"}, // accout_type_regexes
{".*"}, {""}, {".*"}, {""}, false, {".*"}, {""}, {".*"}, {""}, false,
{"params", "bytes", "micros", "float_ops"}, "file", {"params", "bytes", "micros", "float_ops"}, "file",
{{"outfile", dump_file}}); {{"outfile", dump_file}});
@ -93,7 +94,7 @@ TEST_F(TFProfShowTest, DumpScopeMode) {
TEST_F(TFProfShowTest, DumpOpMode) { TEST_F(TFProfShowTest, DumpOpMode) {
string dump_file = io::JoinPath(testing::TmpDir(), "dump"); string dump_file = io::JoinPath(testing::TmpDir(), "dump");
Options opts(5, 0, 0, 0, 0, 4, "params", {".*"}, // accout_type_regexes Options opts(5, 0, 0, 0, 0, 4, -1, "params", {".*"}, // accout_type_regexes
{".*"}, {""}, {".*"}, {""}, false, {".*"}, {""}, {".*"}, {""}, false,
{"params", "bytes", "micros", "float_ops", "occurrence"}, "file", {"params", "bytes", "micros", "float_ops", "occurrence"}, "file",
{{"outfile", dump_file}}); {{"outfile", dump_file}});

View File

@ -30,24 +30,17 @@ TFStats::TFStats(std::unique_ptr<GraphDef> graph,
std::unique_ptr<OpLog> op_log, std::unique_ptr<OpLog> op_log,
std::unique_ptr<checkpoint::CheckpointReader> ckpt_reader) std::unique_ptr<checkpoint::CheckpointReader> ckpt_reader)
: graph_(std::move(graph)), : graph_(std::move(graph)),
run_meta_(std::move(run_meta)),
op_log_(std::move(op_log)),
ckpt_reader_(std::move(ckpt_reader)) { ckpt_reader_(std::move(ckpt_reader)) {
CHECK(graph_) << "Must at least have GraphDef"; CHECK(graph_) << "Must at least have GraphDef";
printf("Parsing GraphDef...\n"); printf("Parsing Inputs...\n");
ParseGraph(); ParseGraph();
if (run_meta_) { if (run_meta && run_meta->has_step_stats()) {
printf("Parsing RunMetadata...\n"); ParseRunMeta(0, std::move(run_meta));
ParseRunMeta();
}
if (op_log_) {
printf("Parsing OpLog...\n");
ParseOpLog();
} }
ParseOpLog(std::move(op_log));
if (ckpt_reader_) { if (ckpt_reader_) {
printf("Parsing Checkpoint...\n");
for (const auto& v : ckpt_reader_->GetVariableToShapeMap()) { for (const auto& v : ckpt_reader_->GetVariableToShapeMap()) {
auto node = nodes_map_.find(v.first); auto node = nodes_map_.find(v.first);
if (node != nodes_map_.end()) { if (node != nodes_map_.end()) {
@ -76,6 +69,9 @@ TFStats::TFStats(std::unique_ptr<GraphDef> graph,
const TFGraphNodeProto& TFStats::ShowGraphNode(const string& cmd, const TFGraphNodeProto& TFStats::ShowGraphNode(const string& cmd,
const Options& opts) { const Options& opts) {
if (!Validate(opts)) {
return empty_graph_node_;
}
if (cmd == kCmds[0]) { if (cmd == kCmds[0]) {
return scope_view_->Show(opts); return scope_view_->Show(opts);
} else if (cmd == kCmds[1]) { } else if (cmd == kCmds[1]) {
@ -88,6 +84,9 @@ const TFGraphNodeProto& TFStats::ShowGraphNode(const string& cmd,
const TFMultiGraphNodeProto& TFStats::ShowMultiGraphNode(const string& cmd, const TFMultiGraphNodeProto& TFStats::ShowMultiGraphNode(const string& cmd,
const Options& opts) { const Options& opts) {
if (!Validate(opts)) {
return empty_multi_graph_node_;
}
if (cmd == kCmds[2]) { if (cmd == kCmds[2]) {
return code_view_->Show(opts); return code_view_->Show(opts);
} else if (cmd == kCmds[3]) { } else if (cmd == kCmds[3]) {
@ -130,8 +129,11 @@ void TFStats::ParseGraph() {
} }
} }
void TFStats::ParseOpLog() { void TFStats::ParseOpLog(std::unique_ptr<OpLog> op_log) {
for (const OpLogEntry& entry : op_log_->log_entries()) { if (!op_log) {
return;
}
for (const OpLogEntry& entry : op_log->log_entries()) {
auto node = nodes_map_.find(entry.name()); auto node = nodes_map_.find(entry.name());
if (node == nodes_map_.end()) continue; if (node == nodes_map_.end()) continue;
for (const string& type : entry.types()) { for (const string& type : entry.types()) {
@ -141,16 +143,24 @@ void TFStats::ParseOpLog() {
node->second->AddFloatOps(entry.float_ops()); node->second->AddFloatOps(entry.float_ops());
} }
if (entry.has_code_def()) { if (entry.has_code_def()) {
node->second->AddCode(&entry.code_def()); node->second->AddCode(entry.code_def());
} }
} }
} }
void TFStats::ParseRunMeta() { void TFStats::ParseRunMeta(int64 step, std::unique_ptr<RunMetadata> run_meta) {
if (!run_meta_->has_step_stats()) return; if (!run_meta || !run_meta->has_step_stats()) {
fprintf(stderr, "Invalid RunMetadata for step %lld\n", step);
return;
}
if (steps_.find(step) != steps_.end()) {
fprintf(stderr, "The same step %lld has been added before.\n", step);
return;
}
steps_.insert(step);
for (const auto& dev_stat : run_meta_->step_stats().dev_stats()) { for (const auto& dev_stat : run_meta->step_stats().dev_stats()) {
for (const auto& node_stat : dev_stat.node_stats()) { for (const NodeExecStats& node_stat : dev_stat.node_stats()) {
string name = node_stat.node_name(); string name = node_stat.node_name();
// Sometimes the node_name is suffixed with unnecessary information. // Sometimes the node_name is suffixed with unnecessary information.
auto split_pos = node_stat.node_name().find(":"); auto split_pos = node_stat.node_name().find(":");
@ -159,10 +169,18 @@ void TFStats::ParseRunMeta() {
} }
auto node = nodes_map_.find(name); auto node = nodes_map_.find(name);
if (node != nodes_map_.end()) { if (node != nodes_map_.end()) {
node->second->AddStepStat(dev_stat.device(), &node_stat); node->second->AddStepStat(step, dev_stat.device(), node_stat);
} }
} }
} }
} }
bool TFStats::Validate(const Options& opts) {
if (opts.step >= 0 && steps_.find(opts.step) == steps_.end()) {
fprintf(stderr, "Options -step=%lld not found\n", opts.step);
return false;
}
return true;
}
} // namespace tfprof } // namespace tfprof
} // namespace tensorflow } // namespace tensorflow

View File

@ -62,20 +62,20 @@ class TFStats {
const TFMultiGraphNodeProto& ShowMultiGraphNode(const string& cmd, const TFMultiGraphNodeProto& ShowMultiGraphNode(const string& cmd,
const Options& opts); const Options& opts);
void ParseRunMeta(int64 step, std::unique_ptr<RunMetadata> run_meta);
void ParseOpLog(std::unique_ptr<OpLog> op_log);
private: private:
bool Validate(const Options& opts);
void ParseGraph(); void ParseGraph();
void ParseOpLog(); std::set<int64> steps_;
std::unique_ptr<GraphDef> graph_;
void ParseRunMeta();
std::unique_ptr<TFScope> scope_view_; std::unique_ptr<TFScope> scope_view_;
std::unique_ptr<TFGraph> graph_view_; std::unique_ptr<TFGraph> graph_view_;
std::unique_ptr<TFCode> code_view_; std::unique_ptr<TFCode> code_view_;
std::unique_ptr<TFOp> op_view_; std::unique_ptr<TFOp> op_view_;
std::unique_ptr<GraphDef> graph_;
std::unique_ptr<RunMetadata> run_meta_;
std::unique_ptr<OpLog> op_log_;
std::unique_ptr<checkpoint::CheckpointReader> ckpt_reader_; std::unique_ptr<checkpoint::CheckpointReader> ckpt_reader_;
// Store TFGraphNode instead of TFGraphNode* to avoid large number of // Store TFGraphNode instead of TFGraphNode* to avoid large number of
// dynamic alloc. // dynamic alloc.

View File

@ -71,7 +71,7 @@ class TFProfStatsTest : public ::testing::Test {
}; };
TEST_F(TFProfStatsTest, CustomOpType) { TEST_F(TFProfStatsTest, CustomOpType) {
Options opts(3, 0, 0, 0, 0, 0, "name", Options opts(3, 0, 0, 0, 0, 0, -1, "name",
{kTrainableVarType}, // accout_type_regexes {kTrainableVarType}, // accout_type_regexes
{".*"}, {""}, {".*"}, {""}, false, {".*"}, {""}, {".*"}, {""}, false,
{"params", "bytes", "micros", "float_ops"}, "", {}); {"params", "bytes", "micros", "float_ops"}, "", {});
@ -113,7 +113,8 @@ TEST_F(TFProfStatsTest, CustomOpType) {
} }
TEST_F(TFProfStatsTest, CheckPointOpType) { TEST_F(TFProfStatsTest, CheckPointOpType) {
Options opts(3, 0, 0, 0, 0, 0, "name", {kCkptVarType}, // accout_type_regexes Options opts(3, 0, 0, 0, 0, 0, -1, "name",
{kCkptVarType}, // accout_type_regexes
{".*"}, {""}, {".*"}, {""}, false, {".*"}, {""}, {".*"}, {""}, false,
{"params", "bytes", "micros", "float_ops"}, "", {}); {"params", "bytes", "micros", "float_ops"}, "", {});
const TFGraphNodeProto& root = tf_stats_->ShowGraphNode("scope", opts); const TFGraphNodeProto& root = tf_stats_->ShowGraphNode("scope", opts);
@ -154,7 +155,7 @@ TEST_F(TFProfStatsTest, CheckPointOpType) {
} }
TEST_F(TFProfStatsTest, TestGraph) { TEST_F(TFProfStatsTest, TestGraph) {
Options opts(100, 0, 10000, 0, 0, 0, "name", {".*"}, Options opts(100, 0, 10000, 0, 0, 0, -1, "name", {".*"},
{"cost.*"}, // start_name_regexes {"cost.*"}, // start_name_regexes
{""}, {".*"}, {""}, false, {""}, {".*"}, {""}, false,
{"params", "bytes", "micros", "float_ops"}, "", {}); {"params", "bytes", "micros", "float_ops"}, "", {});
@ -171,8 +172,8 @@ TEST_F(TFProfStatsTest, TestGraph) {
} }
TEST_F(TFProfStatsTest, TestFloatOps) { TEST_F(TFProfStatsTest, TestFloatOps) {
Options opts(10, 0, 0, 0, 1, 0, "name", {".*"}, {".*"}, {""}, {".*"}, {""}, Options opts(10, 0, 0, 0, 1, 0, -1, "name", {".*"}, {".*"}, {""}, {".*"},
false, {"float_ops"}, "", {}); {""}, false, {"float_ops"}, "", {});
const TFGraphNodeProto& root = tf_stats_->ShowGraphNode("scope", opts); const TFGraphNodeProto& root = tf_stats_->ShowGraphNode("scope", opts);
TFGraphNodeProto expected; TFGraphNodeProto expected;
@ -201,7 +202,7 @@ TEST_F(TFProfStatsTest, TestFloatOps) {
} }
TEST_F(TFProfStatsTest, TestAccountShownNameOnly) { TEST_F(TFProfStatsTest, TestAccountShownNameOnly) {
Options opts(100, 0, 0, 0, 0, 0, "name", {".*"}, {".*"}, {""}, Options opts(100, 0, 0, 0, 0, 0, -1, "name", {".*"}, {".*"}, {""},
{"unit_2_1.*DW"}, // show_name_regexes. {"unit_2_1.*DW"}, // show_name_regexes.
{""}, true, // account_displayed_op_only. {""}, true, // account_displayed_op_only.
{"params"}, "", {}); {"params"}, "", {});
@ -217,7 +218,7 @@ TEST_F(TFProfStatsTest, TestAccountShownNameOnly) {
} }
TEST_F(TFProfStatsTest, TestShowTensorValue) { TEST_F(TFProfStatsTest, TestShowTensorValue) {
Options opts(10, 0, 0, 0, 0, 0, "name", {".*"}, {".*"}, {""}, Options opts(10, 0, 0, 0, 0, 0, -1, "name", {".*"}, {".*"}, {""},
{"unit_1_0.*gamma"}, {""}, false, {"unit_1_0.*gamma"}, {""}, false,
{"tensor_value"}, // Show tensor value from checkpoint. {"tensor_value"}, // Show tensor value from checkpoint.
"", {}); "", {});

View File

@ -76,7 +76,8 @@ class TFProfTensor {
CHECK(strings::safe_strto64(sstream.str().c_str(), &int64_val)); CHECK(strings::safe_strto64(sstream.str().c_str(), &int64_val));
dim->add_value_int64(int64_val); dim->add_value_int64(int64_val);
formatted_str_ += strings::Printf( formatted_str_ += strings::Printf(
"%lld ", dim->value_int64(dim->value_int64_size() - 1)); "%lld ", static_cast<int64>(
dim->value_int64(dim->value_int64_size() - 1)));
} else if (typeid(values[nstart]) == typeid(string)) { } else if (typeid(values[nstart]) == typeid(string)) {
dim->add_value_str(sstream.str()); dim->add_value_str(sstream.str());
formatted_str_ = formatted_str_ =

View File

@ -55,8 +55,8 @@ class TFProfTensorTest : public ::testing::Test {
}; };
TEST_F(TFProfTensorTest, Basics) { TEST_F(TFProfTensorTest, Basics) {
Options opts(3, 0, 0, 0, 0, 0, "name", {"VariableV2"}, {".*"}, {""}, {".*"}, Options opts(3, 0, 0, 0, 0, 0, -1, "name", {"VariableV2"}, {".*"}, {""},
{""}, false, {"tensor_value"}, // show the tensor value. {".*"}, {""}, false, {"tensor_value"}, // show the tensor value.
"", {}); "", {});
const TFGraphNodeProto& root = tf_stats_->ShowGraphNode("scope", opts); const TFGraphNodeProto& root = tf_stats_->ShowGraphNode("scope", opts);

View File

@ -114,34 +114,42 @@ string ChromeTraceFormatter::Format() {
return trace_str; return trace_str;
} }
void MemoryTracker::TrackNode(GraphNode* node) { void MemoryTracker::TrackNode(int64 step, GraphNode* node) {
if (!node->Trackable(step)) {
return;
}
Device& dev = devices_[node->node->canonical_device()]; Device& dev = devices_[node->node->canonical_device()];
int64 end_micros = int64 end_micros = node->node->all_start_micros(step) +
node->node->all_start_micros() + node->node->latest_end_rel_micros(); node->node->latest_end_rel_micros(step);
if (node->node->accelerator_persistent_bytes() != 0) { if (node->node->accelerator_persistent_bytes(step) != 0) {
string tensor_name = strings::StrCat(node->name(), ":", -1); string tensor_name = strings::StrCat(node->name(), ":", -1);
dev.earliest_ref[tensor_name] = node->node->all_start_micros(); dev.earliest_ref[tensor_name] = node->node->all_start_micros(step);
dev.tensor_size[tensor_name] = node->node->accelerator_persistent_bytes(); dev.tensor_size[tensor_name] =
node->node->accelerator_persistent_bytes(step);
// TODO(xpan): Need latest_ref? // TODO(xpan): Need latest_ref?
} }
if (node->node->accelerator_temp_bytes()) { if (node->node->accelerator_temp_bytes(step)) {
string tensor_name = strings::StrCat(node->name(), ":", -2); string tensor_name = strings::StrCat(node->name(), ":", -2);
dev.earliest_ref[tensor_name] = node->node->all_start_micros(); dev.earliest_ref[tensor_name] = node->node->all_start_micros(step);
dev.latest_ref[tensor_name] = end_micros; dev.latest_ref[tensor_name] = end_micros;
dev.tensor_size[tensor_name] = node->node->accelerator_temp_bytes(); dev.tensor_size[tensor_name] = node->node->accelerator_temp_bytes(step);
} }
if (node->node->allocator_bytes_in_use() > 0) { if (node->node->allocator_bytes_in_use(step) > 0) {
dev.allocator_stats[end_micros] = node->node->allocator_bytes_in_use(); dev.allocator_stats[end_micros] = node->node->allocator_bytes_in_use(step);
} }
} }
void MemoryTracker::TrackNodeConnection(GraphNode* node, GraphNode* src) { void MemoryTracker::TrackNodeConnection(int64 step, GraphNode* node,
GraphNode* src) {
if (!node->Trackable(step) || !src->Trackable(step)) {
return;
}
const auto& output_idx = node->node->output_idx().find(src->name()); const auto& output_idx = node->node->output_idx().find(src->name());
if (output_idx == node->node->output_idx().end()) { if (output_idx == node->node->output_idx().end()) {
return; return;
} }
const auto& output = src->node->output_bytes().find(output_idx->second); const auto& output = src->node->output_bytes(step).find(output_idx->second);
if (output == src->node->output_bytes().end()) { if (output == src->node->output_bytes(step).end()) {
return; return;
} }
int64 output_bytes = output->second.first; int64 output_bytes = output->second.first;
@ -155,14 +163,14 @@ void MemoryTracker::TrackNodeConnection(GraphNode* node, GraphNode* src) {
} }
src_dev.tensor_size[tensor_name] = output_bytes; src_dev.tensor_size[tensor_name] = output_bytes;
src_dev.earliest_ref[tensor_name] = src->node->all_start_micros(); src_dev.earliest_ref[tensor_name] = src->node->all_start_micros(step);
int64 src_end_micros = int64 src_end_micros = src->node->all_start_micros(step) +
src->node->all_start_micros() + src->node->latest_end_rel_micros(); src->node->latest_end_rel_micros(step);
if (src->node->canonical_device() != node->node->canonical_device()) { if (src->node->canonical_device() != node->node->canonical_device()) {
int64 transfer_micros = int64 transfer_micros =
(src_end_micros + node->node->all_start_micros()) / 2; (src_end_micros + node->node->all_start_micros(step)) / 2;
src_dev.latest_ref[tensor_name] = src_dev.latest_ref[tensor_name] =
std::max(src_dev.latest_ref[tensor_name], transfer_micros); std::max(src_dev.latest_ref[tensor_name], transfer_micros);
@ -171,18 +179,19 @@ void MemoryTracker::TrackNodeConnection(GraphNode* node, GraphNode* src) {
strings::StrCat(tensor_name, node->node->canonical_device()); strings::StrCat(tensor_name, node->node->canonical_device());
dest_dev.tensor_size[dest_tensor_name] = output_bytes; dest_dev.tensor_size[dest_tensor_name] = output_bytes;
dest_dev.earliest_ref[dest_tensor_name] = transfer_micros; dest_dev.earliest_ref[dest_tensor_name] = transfer_micros;
dest_dev.latest_ref[dest_tensor_name] = std::max( dest_dev.latest_ref[dest_tensor_name] =
dest_dev.latest_ref[dest_tensor_name], std::max(dest_dev.latest_ref[dest_tensor_name],
node->node->all_start_micros() + node->node->latest_end_rel_micros()); node->node->all_start_micros(step) +
node->node->latest_end_rel_micros(step));
} else { } else {
src_dev.latest_ref[tensor_name] = std::max( src_dev.latest_ref[tensor_name] =
src_dev.latest_ref[tensor_name], std::max(src_dev.latest_ref[tensor_name],
node->node->all_start_micros() + node->node->latest_end_rel_micros()); node->node->all_start_micros(step) +
node->node->latest_end_rel_micros(step));
} }
} }
void Timeline::GenerateGraphTimeline(const GraphNode* gnode, void Timeline::GenerateGraphTimeline(const GraphNode* gnode) {
const MemoryTracker& memory_tracker) {
AddGraphNode(gnode); AddGraphNode(gnode);
AllocateLanes(); AllocateLanes();
fprintf(stdout, "generating trace file.\n"); fprintf(stdout, "generating trace file.\n");
@ -215,7 +224,7 @@ void Timeline::GenerateGraphTimeline(const GraphNode* gnode,
} }
} }
} }
for (const auto& dev : memory_tracker.devices()) { for (const auto& dev : mem_tracker_.devices()) {
int64 pid = AllocatePID(); int64 pid = AllocatePID();
chrome_formatter_.EmitPID(GetMemoryLaneName(dev.first), pid); chrome_formatter_.EmitPID(GetMemoryLaneName(dev.first), pid);
const MemoryTracker::Device& device = dev.second; const MemoryTracker::Device& device = dev.second;
@ -268,12 +277,12 @@ std::vector<TimeNode*> Timeline::AddGraphNode(const GraphNode* gnode) {
std::vector<TimeNode*> inputs = AddGraphNode(schild); std::vector<TimeNode*> inputs = AddGraphNode(schild);
shown_cinputs.insert(shown_cinputs.end(), inputs.begin(), inputs.end()); shown_cinputs.insert(shown_cinputs.end(), inputs.begin(), inputs.end());
} }
if (!gnode->node->step_stats()) { if (!gnode->node->trackable(step_)) {
return shown_cinputs; return shown_cinputs;
} }
const TFGraphNode* node = gnode->node; const TFGraphNode* node = gnode->node;
for (const auto& kernel_execs : node->op_execs()) { for (const auto& kernel_execs : node->op_execs(step_)) {
const string& device = kernel_execs.first; const string& device = kernel_execs.first;
const std::vector<std::pair<int64, int64>>& execs = kernel_execs.second; const std::vector<std::pair<int64, int64>>& execs = kernel_execs.second;

View File

@ -101,9 +101,9 @@ class MemoryTracker {
std::map<int64, int64> allocator_stats; std::map<int64, int64> allocator_stats;
}; };
void TrackNode(GraphNode* node); void TrackNode(int64 step, GraphNode* node);
void TrackNodeConnection(GraphNode* node, GraphNode* src); void TrackNodeConnection(int64 step, GraphNode* node, GraphNode* src);
const std::map<string, Device>& devices() const { return devices_; } const std::map<string, Device>& devices() const { return devices_; }
@ -113,16 +113,25 @@ class MemoryTracker {
class Timeline { class Timeline {
public: public:
Timeline(const string& outfile) : outfile_(outfile) {} Timeline(int64 step, const string& outfile)
: step_(step), outfile_(outfile) {}
~Timeline() {} ~Timeline() {}
void GenerateGraphTimeline(const GraphNode* gnode, int64 step() const { return step_; }
const MemoryTracker& memory_tracker); void SetStep(int64 step) { step_ = step; }
void GenerateGraphTimeline(const GraphNode* gnode);
void GenerateScopeTimeline(const ScopeNode* node); void GenerateScopeTimeline(const ScopeNode* node);
void GenerateCodeTimeline(const CodeNode* node); void GenerateCodeTimeline(const CodeNode* node);
void TrackNode(GraphNode* node) { mem_tracker_.TrackNode(step_, node); }
void TrackNodeConnection(GraphNode* node, GraphNode* src) {
mem_tracker_.TrackNodeConnection(step_, node, src);
}
private: private:
void OutputTimeline(); void OutputTimeline();
@ -162,9 +171,11 @@ class Timeline {
int64 AllocatePID(); int64 AllocatePID();
int64 step_;
const string outfile_; const string outfile_;
int64 next_pid_ = 0; int64 next_pid_ = 0;
int64 allocator_pid_ = -1; int64 allocator_pid_ = -1;
MemoryTracker mem_tracker_;
ChromeTraceFormatter chrome_formatter_; ChromeTraceFormatter chrome_formatter_;
std::map<string, int64> device_pids_; std::map<string, int64> device_pids_;

View File

@ -60,7 +60,7 @@ class TFProfTimelineTest : public ::testing::Test {
// manually check it's correct // manually check it's correct
TEST_F(TFProfTimelineTest, GraphView) { TEST_F(TFProfTimelineTest, GraphView) {
string dump_file = io::JoinPath(testing::TmpDir(), "dump"); string dump_file = io::JoinPath(testing::TmpDir(), "dump");
Options opts(10000, 0, 0, 0, 0, 0, "name", {".*"}, // accout_type_regexes Options opts(10000, 0, 0, 0, 0, 0, 0, "name", {".*"}, // accout_type_regexes
{".*"}, {""}, {".*"}, {""}, false, {".*"}, {""}, {".*"}, {""}, false,
{"params", "bytes", "micros", "float_ops"}, "timeline", {"params", "bytes", "micros", "float_ops"}, "timeline",
{{"outfile", dump_file}}); {{"outfile", dump_file}});
@ -73,7 +73,7 @@ TEST_F(TFProfTimelineTest, GraphView) {
TEST_F(TFProfTimelineTest, ScopeView) { TEST_F(TFProfTimelineTest, ScopeView) {
string dump_file = io::JoinPath(testing::TmpDir(), "dump"); string dump_file = io::JoinPath(testing::TmpDir(), "dump");
Options opts(5, 0, 0, 0, 0, 0, "name", {".*"}, // accout_type_regexes Options opts(5, 0, 0, 0, 0, 0, 0, "name", {".*"}, // accout_type_regexes
{".*"}, {""}, {".*"}, {""}, false, {".*"}, {""}, {".*"}, {""}, false,
{"params", "bytes", "micros", "float_ops"}, "timeline", {"params", "bytes", "micros", "float_ops"}, "timeline",
{{"outfile", dump_file}}); {{"outfile", dump_file}});

View File

@ -176,6 +176,12 @@ tensorflow::Status ParseCmdLine(const string& line, string* cmd,
} }
++i; ++i;
} else if (pieces[i] == tensorflow::tfprof::kOptions[6]) { } else if (pieces[i] == tensorflow::tfprof::kOptions[6]) {
if (pieces.size() <= i + 1 ||
!strings::safe_strto64(pieces[i + 1], &opts->step)) {
return ReturnError(pieces, i);
}
++i;
} else if (pieces[i] == tensorflow::tfprof::kOptions[7]) {
if (pieces.size() <= i + 1) { if (pieces.size() <= i + 1) {
return ReturnError(pieces, i); return ReturnError(pieces, i);
} }
@ -187,42 +193,42 @@ tensorflow::Status ParseCmdLine(const string& line, string* cmd,
} }
opts->order_by = *order_by; opts->order_by = *order_by;
++i; ++i;
} else if (pieces[i] == tensorflow::tfprof::kOptions[7]) { } else if (pieces[i] == tensorflow::tfprof::kOptions[8]) {
if (pieces.size() <= i + 1) { if (pieces.size() <= i + 1) {
return ReturnError(pieces, i); return ReturnError(pieces, i);
} }
opts->account_type_regexes = str_util::Split(StripQuote(pieces[i + 1]), opts->account_type_regexes = str_util::Split(StripQuote(pieces[i + 1]),
',', str_util::SkipEmpty()); ',', str_util::SkipEmpty());
++i; ++i;
} else if (pieces[i] == tensorflow::tfprof::kOptions[8]) { } else if (pieces[i] == tensorflow::tfprof::kOptions[9]) {
if (pieces.size() <= i + 1) { if (pieces.size() <= i + 1) {
return ReturnError(pieces, i); return ReturnError(pieces, i);
} }
opts->start_name_regexes = str_util::Split(StripQuote(pieces[i + 1]), ',', opts->start_name_regexes = str_util::Split(StripQuote(pieces[i + 1]), ',',
str_util::SkipEmpty()); str_util::SkipEmpty());
++i; ++i;
} else if (pieces[i] == tensorflow::tfprof::kOptions[9]) { } else if (pieces[i] == tensorflow::tfprof::kOptions[10]) {
if (pieces.size() <= i + 1) { if (pieces.size() <= i + 1) {
return ReturnError(pieces, i); return ReturnError(pieces, i);
} }
opts->trim_name_regexes = str_util::Split(StripQuote(pieces[i + 1]), ',', opts->trim_name_regexes = str_util::Split(StripQuote(pieces[i + 1]), ',',
str_util::SkipEmpty()); str_util::SkipEmpty());
++i; ++i;
} else if (pieces[i] == tensorflow::tfprof::kOptions[10]) { } else if (pieces[i] == tensorflow::tfprof::kOptions[11]) {
if (pieces.size() <= i + 1) { if (pieces.size() <= i + 1) {
return ReturnError(pieces, i); return ReturnError(pieces, i);
} }
opts->show_name_regexes = str_util::Split(StripQuote(pieces[i + 1]), ',', opts->show_name_regexes = str_util::Split(StripQuote(pieces[i + 1]), ',',
str_util::SkipEmpty()); str_util::SkipEmpty());
++i; ++i;
} else if (pieces[i] == tensorflow::tfprof::kOptions[11]) { } else if (pieces[i] == tensorflow::tfprof::kOptions[12]) {
if (pieces.size() <= i + 1) { if (pieces.size() <= i + 1) {
return ReturnError(pieces, i); return ReturnError(pieces, i);
} }
opts->hide_name_regexes = str_util::Split(StripQuote(pieces[i + 1]), ',', opts->hide_name_regexes = str_util::Split(StripQuote(pieces[i + 1]), ',',
str_util::SkipEmpty()); str_util::SkipEmpty());
++i; ++i;
} else if (pieces[i] == tensorflow::tfprof::kOptions[12]) { } else if (pieces[i] == tensorflow::tfprof::kOptions[13]) {
if ((pieces.size() > i + 1 && pieces[i + 1].find("-") == 0) || if ((pieces.size() > i + 1 && pieces[i + 1].find("-") == 0) ||
pieces.size() == i + 1) { pieces.size() == i + 1) {
opts->account_displayed_op_only = true; opts->account_displayed_op_only = true;
@ -232,7 +238,7 @@ tensorflow::Status ParseCmdLine(const string& line, string* cmd,
} else { } else {
++i; ++i;
} }
} else if (pieces[i] == tensorflow::tfprof::kOptions[13]) { } else if (pieces[i] == tensorflow::tfprof::kOptions[14]) {
if (pieces.size() <= i + 1) { if (pieces.size() <= i + 1) {
return ReturnError(pieces, i); return ReturnError(pieces, i);
} }
@ -249,7 +255,7 @@ tensorflow::Status ParseCmdLine(const string& line, string* cmd,
} }
opts->select = requested_set; opts->select = requested_set;
++i; ++i;
} else if (pieces[i] == tensorflow::tfprof::kOptions[14]) { } else if (pieces[i] == tensorflow::tfprof::kOptions[15]) {
if (pieces.size() <= i + 1) { if (pieces.size() <= i + 1) {
return ReturnError(pieces, i); return ReturnError(pieces, i);
} }
@ -291,6 +297,10 @@ void PrintHelp() {
"float operations. Only available if an op has " "float operations. Only available if an op has "
"op.RegisterStatistics() defined and OpLog is " "op.RegisterStatistics() defined and OpLog is "
"provided\n\n" "provided\n\n"
" -min_occurrence: Show the op types that are at least used this number "
"of times. Only available in op view.\n\n"
" -step: Show the stats of a step when multiple steps of "
"RunMetadata were added. By default (-1), show the average of all steps."
" -order_by: Order the results by [name|depth|bytes|micros|params|" " -order_by: Order the results by [name|depth|bytes|micros|params|"
"float_ops]\n\n" "float_ops]\n\n"
" -account_type_regexes: Account and display the ops whose types match " " -account_type_regexes: Account and display the ops whose types match "

View File

@ -75,6 +75,7 @@ int main(int argc, char** argv) {
tensorflow::int64 FLAGS_min_params = 0; tensorflow::int64 FLAGS_min_params = 0;
tensorflow::int64 FLAGS_min_float_ops = 0; tensorflow::int64 FLAGS_min_float_ops = 0;
tensorflow::int64 FLAGS_min_occurrence = 0; tensorflow::int64 FLAGS_min_occurrence = 0;
tensorflow::int64 FLAGS_step = -1;
tensorflow::string FLAGS_order_by = "name"; tensorflow::string FLAGS_order_by = "name";
tensorflow::string FLAGS_account_type_regexes = ".*"; tensorflow::string FLAGS_account_type_regexes = ".*";
tensorflow::string FLAGS_start_name_regexes = ".*"; tensorflow::string FLAGS_start_name_regexes = ".*";
@ -92,7 +93,8 @@ int main(int argc, char** argv) {
tensorflow::Flag("graph_path", &FLAGS_graph_path, tensorflow::Flag("graph_path", &FLAGS_graph_path,
"GraphDef proto text file name"), "GraphDef proto text file name"),
tensorflow::Flag("run_meta_path", &FLAGS_run_meta_path, tensorflow::Flag("run_meta_path", &FLAGS_run_meta_path,
"RunMetadata proto binary file name"), "Comma-separated list of RunMetadata proto binary "
"files. Each file is given step number 0,1,2,etc"),
tensorflow::Flag("op_log_path", &FLAGS_op_log_path, tensorflow::Flag("op_log_path", &FLAGS_op_log_path,
"tensorflow::tfprof::OpLog proto binary file name"), "tensorflow::tfprof::OpLog proto binary file name"),
tensorflow::Flag("checkpoint_path", &FLAGS_checkpoint_path, tensorflow::Flag("checkpoint_path", &FLAGS_checkpoint_path,
@ -104,6 +106,8 @@ int main(int argc, char** argv) {
tensorflow::Flag("min_float_ops", &FLAGS_min_float_ops, "min float ops"), tensorflow::Flag("min_float_ops", &FLAGS_min_float_ops, "min float ops"),
tensorflow::Flag("min_occurrence", &FLAGS_min_occurrence, tensorflow::Flag("min_occurrence", &FLAGS_min_occurrence,
"min occurrence"), "min occurrence"),
tensorflow::Flag("step", &FLAGS_step,
"The stats of which step to use. By default average"),
tensorflow::Flag("order_by", &FLAGS_order_by, "order by"), tensorflow::Flag("order_by", &FLAGS_order_by, "order by"),
tensorflow::Flag("account_type_regexes", &FLAGS_start_name_regexes, tensorflow::Flag("account_type_regexes", &FLAGS_start_name_regexes,
"start name regexes"), "start name regexes"),
@ -181,18 +185,6 @@ int main(int argc, char** argv) {
TF_CHECK_OK(tensorflow::tfprof::ReadGraphDef(tensorflow::Env::Default(), TF_CHECK_OK(tensorflow::tfprof::ReadGraphDef(tensorflow::Env::Default(),
FLAGS_graph_path, graph.get())); FLAGS_graph_path, graph.get()));
std::unique_ptr<tensorflow::RunMetadata> run_meta(
new tensorflow::RunMetadata());
if (!FLAGS_run_meta_path.empty()) {
s = ReadBinaryProto(tensorflow::Env::Default(), FLAGS_run_meta_path,
run_meta.get());
if (!s.ok()) {
fprintf(stderr, "Failed to read run_meta_path: %s\n",
s.ToString().c_str());
return 1;
}
}
std::unique_ptr<tensorflow::tfprof::OpLog> op_log( std::unique_ptr<tensorflow::tfprof::OpLog> op_log(
new tensorflow::tfprof::OpLog()); new tensorflow::tfprof::OpLog());
if (!FLAGS_op_log_path.empty()) { if (!FLAGS_op_log_path.empty()) {
@ -222,12 +214,27 @@ int main(int argc, char** argv) {
TF_DeleteStatus(status); TF_DeleteStatus(status);
} }
tensorflow::tfprof::TFStats tf_stat(std::move(graph), std::move(run_meta), tensorflow::tfprof::TFStats tf_stat(
std::move(op_log), std::move(graph), nullptr, std::move(op_log), std::move(ckpt_reader));
std::move(ckpt_reader));
std::vector<string> run_meta_files =
Split(FLAGS_run_meta_path, ',', tensorflow::str_util::SkipEmpty());
for (int i = 0; i < run_meta_files.size(); ++i) {
std::unique_ptr<tensorflow::RunMetadata> run_meta(
new tensorflow::RunMetadata());
s = ReadBinaryProto(tensorflow::Env::Default(), run_meta_files[i],
run_meta.get());
if (!s.ok()) {
fprintf(stderr, "Failed to read run_meta_path %s. Status: %s\n",
run_meta_files[i].c_str(), s.ToString().c_str());
return 1;
}
tf_stat.ParseRunMeta(i, std::move(run_meta));
}
tensorflow::tfprof::Options opts( tensorflow::tfprof::Options opts(
FLAGS_max_depth, FLAGS_min_bytes, FLAGS_min_micros, FLAGS_min_params, FLAGS_max_depth, FLAGS_min_bytes, FLAGS_min_micros, FLAGS_min_params,
FLAGS_min_float_ops, FLAGS_min_occurrence, FLAGS_order_by, FLAGS_min_float_ops, FLAGS_min_occurrence, FLAGS_step, FLAGS_order_by,
account_type_regexes, start_name_regexes, trim_name_regexes, account_type_regexes, start_name_regexes, trim_name_regexes,
show_name_regexes, hide_name_regexes, FLAGS_account_displayed_op_only, show_name_regexes, hide_name_regexes, FLAGS_account_displayed_op_only,
select, output_type, output_options); select, output_type, output_options);

View File

@ -11,6 +11,7 @@ message OptionsProto {
optional int64 min_params = 4; optional int64 min_params = 4;
optional int64 min_float_ops = 5; optional int64 min_float_ops = 5;
optional int64 min_occurrence = 17; optional int64 min_occurrence = 17;
optional int64 step = 18 [default = -1];
optional string order_by = 7; optional string order_by = 7;
repeated string account_type_regexes = 8; repeated string account_type_regexes = 8;