Extend tfprof to associate op stats with Python codes.

It's backward compatible. Stats of a source code line
are aggregated from all ops created by that line.

A example.
_TFProfRoot (0us/22.44ms)
  model_analyzer_test.py:149:run_filename_as_m...:none (0us/22.44ms)
    model_analyzer_test.py:33:_run_code_in_main:none (0us/22.44ms)
      model_analyzer_test.py:208:<module>:test.main() (0us/22.44ms)
        model_analyzer_test.py:132:testComplexCodeView:x = lib.BuildFull... (0us/22.44ms)
          model_analyzer_testlib.py:63:BuildFullModel:return sgd_op.min... (0us/21.83ms)
          model_analyzer_testlib.py:54:BuildFullModel:seq.append(array_... (0us/254us)
            model_analyzer_testlib.py:42:BuildSmallModel:x = nn_ops.conv2d... (0us/134us)
            ...
          model_analyzer_testlib.py:61:BuildFullModel:loss = nn_ops.l2_... (0us/28us)
        model_analyzer_test.py:134:testComplexCodeView:sess.run(variable... (0us/0us)
Change: 155393864
This commit is contained in:
A. Unique TensorFlower 2017-05-08 09:20:04 -08:00 committed by TensorFlower Gardener
parent ec8ffb9eaf
commit 697f34ca82
31 changed files with 1325 additions and 160 deletions

View File

@ -11,7 +11,12 @@ Consultants: Jon Shlens, Pete Warden
1. Measure model parameters, float operations, tensor shapes. 1. Measure model parameters, float operations, tensor shapes.
2. Measure op execution times, requested memory size and device placement. 2. Measure op execution times, requested memory size and device placement.
3. Inspect checkpoint tensors' shapes and their values. 3. Inspect checkpoint tensors' shapes and their values.
4. Explore model based on name scope or graph structure. 4. 3 ways to view and explore TensorFlow model profiles
* Organize by Python code call stack.
* Organize by TensorFlow operation name scope hierarchies.
* Organize by TensorFlow operation inputs/outputs graph.
5. Selectively grouping/filtering/accounting/ordering ops. 5. Selectively grouping/filtering/accounting/ordering ops.
tfprof can be used as Python API, Interactive CLI and One-shot Script. tfprof can be used as Python API, Interactive CLI and One-shot Script.
@ -28,7 +33,8 @@ param_stats = tf.contrib.tfprof.model_analyzer.print_model_analysis(
tfprof_options=tf.contrib.tfprof.model_analyzer. tfprof_options=tf.contrib.tfprof.model_analyzer.
TRAINABLE_VARS_PARAMS_STAT_OPTIONS) TRAINABLE_VARS_PARAMS_STAT_OPTIONS)
# param_stats is tensorflow.tfprof.TFProfNode proto. It organize the statistics # param_stats is tensorflow.tfprof.TFGraphNodeProto proto.
# It organize the statistics
# of each graph node in tree scructure. Let's print the root below. # of each graph node in tree scructure. Let's print the root below.
sys.stdout.write('total_params: %d\n' % param_stats.total_parameters) sys.stdout.write('total_params: %d\n' % param_stats.total_parameters)
``` ```

View File

@ -21,16 +21,34 @@ py_test(
name = "model_analyzer_test", name = "model_analyzer_test",
srcs = ["model_analyzer_test.py"], srcs = ["model_analyzer_test.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
tags = ["no_pip"],
deps = [ deps = [
":model_analyzer", ":model_analyzer",
"//tensorflow/core:protos_all_py", ":model_analyzer_testlib",
"//tensorflow/python:array_ops",
"//tensorflow/python:client", "//tensorflow/python:client",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform",
"//tensorflow/python:variables",
],
)
py_library(
name = "model_analyzer_testlib",
srcs = ["model_analyzer_testlib.py"],
srcs_version = "PY2AND3",
deps = [
":model_analyzer",
"//tensorflow/contrib/rnn:rnn_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:init_ops", "//tensorflow/python:init_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops",
"//tensorflow/python:platform", "//tensorflow/python:platform",
"//tensorflow/python:rnn",
"//tensorflow/python:training",
"//tensorflow/python:variable_scope", "//tensorflow/python:variable_scope",
"//tensorflow/python:variables", "//tensorflow/python:variables",
], ],

View File

@ -123,7 +123,7 @@ def print_model_analysis(graph,
"""Print model statistics. """Print model statistics.
Prints the model statistics to stdout. Also returns the results Prints the model statistics to stdout. Also returns the results
in a TFProfNode proto. See go/tfprof or run tfprof tool: in a TFGraphNodeProto proto. See go/tfprof or run tfprof tool:
'bazel run third_party/tensorflow/tools/tfprof help' 'bazel run third_party/tensorflow/tools/tfprof help'
Examples: Examples:
@ -142,15 +142,19 @@ def print_model_analysis(graph,
'micros' and 'bytes'. 'micros' and 'bytes'.
op_log: tensorflow::tfprof::OpLog proto. users can use this proto to op_log: tensorflow::tfprof::OpLog proto. users can use this proto to
group together ops and use a op_type to select the group. group together ops and use a op_type to select the group.
tfprof_cmd: string. Either 'scope' or 'graph'. 'scope' view organize tfprof_cmd: string. Either 'scope', 'graph', 'code'.
ops using their name scopes. 'graph' view organize ops using 'scope' view organize outputs using ops' name scope.
their graph inputs. 'graph' view organize outputs using op's inputs/outputs.
'code' view organize outputs using Python call stack.
tfprof_options: See 'tfprof help' for details. tfprof_options: See 'tfprof help' for details.
Returns: Returns:
TFProfNode proto. Side effect: a formatted output to stdout. If tfprof_cmd is 'scope' or 'graph', returns TFGraphNodeProto proto.
If tfprof_cmd is 'code', returns TFCodeNodeProto proto.
Side effect: a formatted output to stdout.
""" """
# pylint: disable=protected-access # pylint: disable=protected-access
op_log = tfprof_logger._merge_default_with_oplog(graph, op_log, run_meta) op_log = tfprof_logger._merge_default_with_oplog(
graph, op_log, run_meta, add_trace=tfprof_cmd == 'code')
# pylint: enable=protected-access # pylint: enable=protected-access
opts = tfprof_options_pb2.OptionsProto() opts = tfprof_options_pb2.OptionsProto()
opts.max_depth = tfprof_options['max_depth'] opts.max_depth = tfprof_options['max_depth']
@ -178,11 +182,24 @@ def print_model_analysis(graph,
opts.dump_to_file = tfprof_options['dump_to_file'] opts.dump_to_file = tfprof_options['dump_to_file']
run_meta_str = run_meta.SerializeToString() if run_meta else b'' run_meta_str = run_meta.SerializeToString() if run_meta else b''
op_log_str = op_log.SerializeToString() if op_log else b''
tfprof_node = tfprof_output_pb2.TFProfNode() if tfprof_cmd == 'code':
tfprof_node.ParseFromString( tfprof_node = tfprof_output_pb2.TFCodeNodeProto()
print_mdl.PrintModelAnalysis( tfprof_node.ParseFromString(
graph.as_graph_def().SerializeToString(), run_meta_str, op_log_str, print_mdl.PrintModelAnalysis(
tfprof_cmd.encode('utf-8'), opts.SerializeToString())) graph.as_graph_def().SerializeToString(),
run_meta_str,
op_log.SerializeToString(),
tfprof_cmd.encode('utf-8'),
opts.SerializeToString()))
else:
tfprof_node = tfprof_output_pb2.TFGraphNodeProto()
tfprof_node.ParseFromString(
print_mdl.PrintModelAnalysis(
graph.as_graph_def().SerializeToString(),
run_meta_str,
op_log.SerializeToString(),
tfprof_cmd.encode('utf-8'),
opts.SerializeToString()))
return tfprof_node return tfprof_node

View File

@ -18,49 +18,27 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile from tensorflow.python.platform import gfile
from tensorflow.python.platform import test from tensorflow.python.platform import test
# XXX: this depends on pywrap_tensorflow and must come later # XXX: this depends on pywrap_tensorflow and must come later
from tensorflow.contrib.tfprof.python.tools.tfprof import model_analyzer from tensorflow.contrib.tfprof.python.tools.tfprof import model_analyzer
from tensorflow.contrib.tfprof.python.tools.tfprof import model_analyzer_testlib as lib
class PrintModelAnalysisTest(test.TestCase): class PrintModelAnalysisTest(test.TestCase):
def _BuildSmallModel(self):
image = array_ops.zeros([2, 6, 6, 3])
_ = variable_scope.get_variable(
'ScalarW', [],
dtypes.float32,
initializer=init_ops.random_normal_initializer(stddev=0.001))
kernel = variable_scope.get_variable(
'DW', [3, 3, 3, 6],
dtypes.float32,
initializer=init_ops.random_normal_initializer(stddev=0.001))
x = nn_ops.conv2d(image, kernel, [1, 2, 2, 1], padding='SAME')
kernel = variable_scope.get_variable(
'DW2', [2, 2, 6, 12],
dtypes.float32,
initializer=init_ops.random_normal_initializer(stddev=0.001))
x = nn_ops.conv2d(x, kernel, [1, 2, 2, 1], padding='SAME')
return x
def testDumpToFile(self): def testDumpToFile(self):
ops.reset_default_graph()
opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS
opts['dump_to_file'] = os.path.join(test.get_temp_dir(), 'dump') opts['dump_to_file'] = os.path.join(test.get_temp_dir(), 'dump')
with session.Session() as sess, ops.device('/cpu:0'): with session.Session() as sess, ops.device('/cpu:0'):
_ = self._BuildSmallModel() _ = lib.BuildSmallModel()
model_analyzer.print_model_analysis(sess.graph, tfprof_options=opts) model_analyzer.print_model_analysis(sess.graph, tfprof_options=opts)
with gfile.Open(opts['dump_to_file'], 'r') as f: with gfile.Open(opts['dump_to_file'], 'r') as f:
@ -71,6 +49,7 @@ class PrintModelAnalysisTest(test.TestCase):
f.read()) f.read())
def testSelectEverything(self): def testSelectEverything(self):
ops.reset_default_graph()
opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS
opts['dump_to_file'] = os.path.join(test.get_temp_dir(), 'dump') opts['dump_to_file'] = os.path.join(test.get_temp_dir(), 'dump')
opts['account_type_regexes'] = ['.*'] opts['account_type_regexes'] = ['.*']
@ -78,8 +57,10 @@ class PrintModelAnalysisTest(test.TestCase):
'bytes', 'params', 'float_ops', 'num_hidden_ops', 'device', 'op_types' 'bytes', 'params', 'float_ops', 'num_hidden_ops', 'device', 'op_types'
] ]
with session.Session() as sess, ops.device('/cpu:0'): config = config_pb2.ConfigProto(
x = self._BuildSmallModel() graph_options=config_pb2.GraphOptions(build_cost_model=1))
with session.Session(config=config) as sess, ops.device('/cpu:0'):
x = lib.BuildSmallModel()
sess.run(variables.global_variables_initializer()) sess.run(variables.global_variables_initializer())
run_meta = config_pb2.RunMetadata() run_meta = config_pb2.RunMetadata()
@ -98,6 +79,118 @@ class PrintModelAnalysisTest(test.TestCase):
f.read()) f.read())
# pylint: enable=line-too-long # pylint: enable=line-too-long
def testSimpleCodeView(self):
ops.reset_default_graph()
opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS.copy()
opts['dump_to_file'] = os.path.join(test.get_temp_dir(), 'dump')
opts['account_type_regexes'] = ['.*']
opts['show_name_regexes'] = ['.*model_analyzer_testlib.*']
opts['account_displayed_op_only'] = False
# TODO(xpan): Test 'micros'. Since the execution time changes each run,
# it's a bit difficult to test it now.
opts['select'] = [
'bytes', 'params', 'float_ops', 'num_hidden_ops', 'device',
]
config = config_pb2.ConfigProto(
graph_options=config_pb2.GraphOptions(build_cost_model=1))
with session.Session(config=config) as sess, ops.device('/cpu:0'):
x = lib.BuildSmallModel()
sess.run(variables.global_variables_initializer())
run_meta = config_pb2.RunMetadata()
_ = sess.run(x,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE),
run_metadata=run_meta)
model_analyzer.print_model_analysis(
sess.graph, run_meta, tfprof_cmd='code', tfprof_options=opts)
with gfile.Open(opts['dump_to_file'], 'r') as f:
# pylint: disable=line-too-long
self.assertEqual(
'_TFProfRoot (0/451 params, 0/10.44k flops, 0B/5.28KB)\n model_analyzer_testlib.py:33:BuildSmallModel:image = array_ops... (0/0 params, 0/0 flops, 0B/864B)\n model_analyzer_testlib.py:37:BuildSmallModel:initializer=init_... (0/1 params, 0/0 flops, 0B/0B)\n model_analyzer_testlib.py:41:BuildSmallModel:initializer=init_... (0/162 params, 0/0 flops, 0B/1.30KB)\n model_analyzer_testlib.py:42:BuildSmallModel:x = nn_ops.conv2d... (0/0 params, 0/5.83k flops, 0B/432B)\n model_analyzer_testlib.py:46:BuildSmallModel:initializer=init_... (0/288 params, 0/0 flops, 0B/2.30KB)\n model_analyzer_testlib.py:47:BuildSmallModel:x = nn_ops.conv2d... (0/0 params, 0/4.61k flops, 0B/384B)\n',
f.read())
# pylint: enable=line-too-long
def testComplexCodeView(self):
ops.reset_default_graph()
opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS.copy()
opts['dump_to_file'] = os.path.join(test.get_temp_dir(), 'dump')
opts['account_type_regexes'] = ['.*']
opts['show_name_regexes'] = ['.*model_analyzer_testlib.py.*']
opts['account_displayed_op_only'] = False
opts['select'] = ['params', 'float_ops']
config = config_pb2.ConfigProto(
graph_options=config_pb2.GraphOptions(build_cost_model=1))
with session.Session(config=config) as sess, ops.device('/cpu:0'):
x = lib.BuildFullModel()
sess.run(variables.global_variables_initializer())
run_meta = config_pb2.RunMetadata()
_ = sess.run(x,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE),
run_metadata=run_meta)
tfprof_node = model_analyzer.print_model_analysis(
sess.graph, run_meta, tfprof_cmd='code', tfprof_options=opts)
# pylint: disable=line-too-long
with gfile.Open(opts['dump_to_file'], 'r') as f:
self.assertEqual(
'_TFProfRoot (0/2.84k params, 0/54.08k flops)\n model_analyzer_testlib.py:56:BuildFullModel:seq.append(array_... (0/1.80k params, 0/41.76k flops)\n model_analyzer_testlib.py:33:BuildSmallModel:image = array_ops... (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:37:BuildSmallModel:initializer=init_... (0/4 params, 0/0 flops)\n model_analyzer_testlib.py:41:BuildSmallModel:initializer=init_... (0/648 params, 0/0 flops)\n model_analyzer_testlib.py:42:BuildSmallModel:x = nn_ops.conv2d... (0/0 params, 0/23.33k flops)\n model_analyzer_testlib.py:46:BuildSmallModel:initializer=init_... (0/1.15k params, 0/0 flops)\n model_analyzer_testlib.py:47:BuildSmallModel:x = nn_ops.conv2d... (0/0 params, 0/18.43k flops)\n model_analyzer_testlib.py:60:BuildFullModel:cell, array_ops.c... (0/1.04k params, 0/4.13k flops)\n model_analyzer_testlib.py:62:BuildFullModel:target = array_op... (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:63:BuildFullModel:loss = nn_ops.l2_... (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:65:BuildFullModel:return sgd_op.min... (0/0 params, 0/8.19k flops)\n',
f.read())
self.assertLess(0, tfprof_node.total_exec_micros)
self.assertEqual(2844, tfprof_node.total_parameters)
self.assertEqual(54080, tfprof_node.total_float_ops)
self.assertEqual(5, len(tfprof_node.children))
self.assertEqual('_TFProfRoot', tfprof_node.name)
self.assertEqual('model_analyzer_testlib.py:56:BuildFullModel:seq.append(array_...',
tfprof_node.children[0].name)
self.assertEqual('model_analyzer_testlib.py:60:BuildFullModel:cell, array_ops.c...',
tfprof_node.children[1].name)
self.assertEqual('model_analyzer_testlib.py:62:BuildFullModel:target = array_op...',
tfprof_node.children[2].name)
self.assertEqual('model_analyzer_testlib.py:63:BuildFullModel:loss = nn_ops.l2_...',
tfprof_node.children[3].name)
self.assertEqual('model_analyzer_testlib.py:65:BuildFullModel:return sgd_op.min...',
tfprof_node.children[4].name)
# pylint: enable=line-too-long
def testCodeViewLeafGraphNode(self):
ops.reset_default_graph()
opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS.copy()
opts['account_type_regexes'] = ['.*']
opts['account_displayed_op_only'] = False
opts['select'] = [
'bytes', 'params', 'float_ops', 'num_hidden_ops', 'device'
]
config = config_pb2.ConfigProto(
graph_options=config_pb2.GraphOptions(build_cost_model=1))
with session.Session(config=config) as sess, ops.device('/cpu:0'):
x = lib.BuildSmallModel()
sess.run(variables.global_variables_initializer())
run_meta = config_pb2.RunMetadata()
_ = sess.run(x,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE),
run_metadata=run_meta)
tfprof_node = model_analyzer.print_model_analysis(
sess.graph, run_meta, tfprof_cmd='code', tfprof_options=opts)
leaf = tfprof_node
while leaf.children:
self.assertEqual(0, len(leaf.graph_nodes))
leaf = leaf.children[0]
self.assertEqual(1, len(leaf.graph_nodes))
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()

View File

@ -0,0 +1,67 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A test lib that defines some models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import BasicRNNCell
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import variable_scope
from tensorflow.python.training import gradient_descent
def BuildSmallModel():
"""Build a small forward conv model."""
image = array_ops.zeros([2, 6, 6, 3])
_ = variable_scope.get_variable(
'ScalarW', [],
dtypes.float32,
initializer=init_ops.random_normal_initializer(stddev=0.001))
kernel = variable_scope.get_variable(
'DW', [3, 3, 3, 6],
dtypes.float32,
initializer=init_ops.random_normal_initializer(stddev=0.001))
x = nn_ops.conv2d(image, kernel, [1, 2, 2, 1], padding='SAME')
kernel = variable_scope.get_variable(
'DW2', [2, 2, 6, 12],
dtypes.float32,
initializer=init_ops.random_normal_initializer(stddev=0.001))
x = nn_ops.conv2d(x, kernel, [1, 2, 2, 1], padding='SAME')
return x
def BuildFullModel():
"""Build the full model with conv,rnn,opt."""
seq = []
for i in range(4):
with variable_scope.variable_scope('inp_%d' % i):
seq.append(array_ops.reshape(BuildSmallModel(), [2, 1, -1]))
cell = BasicRNNCell(16, 48)
out = rnn.dynamic_rnn(
cell, array_ops.concat(seq, axis=1), dtype=dtypes.float32)[0]
target = array_ops.ones_like(out)
loss = nn_ops.l2_loss(math_ops.reduce_mean(target - out))
sgd_op = gradient_descent.GradientDescentOptimizer(1e-2)
return sgd_op.minimize(loss)

View File

@ -96,12 +96,13 @@ class PrintModelAnalysisTest(test.TestCase):
with session.Session() as sess, ops.device('/cpu:0'): with session.Session() as sess, ops.device('/cpu:0'):
_ = self._BuildSmallModel() _ = self._BuildSmallModel()
tfprof_pb = tfprof_output_pb2.TFProfNode() tfprof_pb = tfprof_output_pb2.TFGraphNodeProto()
tfprof_pb.ParseFromString( tfprof_pb.ParseFromString(
print_mdl.PrintModelAnalysis(sess.graph.as_graph_def( print_mdl.PrintModelAnalysis(
).SerializeToString(), b'', b'', b'scope', opts.SerializeToString())) sess.graph.as_graph_def().SerializeToString(),
b'', b'', b'scope', opts.SerializeToString()))
expected_pb = tfprof_output_pb2.TFProfNode() expected_pb = tfprof_output_pb2.TFGraphNodeProto()
text_format.Merge(r"""name: "_TFProfRoot" text_format.Merge(r"""name: "_TFProfRoot"
exec_micros: 0 exec_micros: 0
requested_bytes: 0 requested_bytes: 0

View File

@ -62,12 +62,13 @@ def _fill_missing_graph_shape(graph, run_meta):
return graph return graph
def _get_logged_ops(graph, run_meta=None): def _get_logged_ops(graph, run_meta=None, add_trace=False):
"""Extract trainable model parameters and FLOPs for ops from a Graph. """Extract trainable model parameters and FLOPs for ops from a Graph.
Args: Args:
graph: tf.Graph. graph: tf.Graph.
run_meta: RunMetadata proto used to complete shape information. run_meta: RunMetadata proto used to complete shape information.
add_trace: Whether to add op trace information.
Returns: Returns:
logged_ops: dict mapping from op_name to OpLogEntry. logged_ops: dict mapping from op_name to OpLogEntry.
""" """
@ -76,21 +77,32 @@ def _get_logged_ops(graph, run_meta=None):
op_missing_shape = 0 op_missing_shape = 0
logged_ops = {} logged_ops = {}
graph_def = graph.as_graph_def() for op in graph.get_operations():
for node in graph_def.node:
try: try:
stats = ops.get_stats_for_node_def(graph, node, REGISTERED_FLOP_STATS) stats = ops.get_stats_for_node_def(
graph, op.node_def, REGISTERED_FLOP_STATS)
except ValueError: except ValueError:
# Catch Exception When shape is incomplete. Skip it. # Catch Exception When shape is incomplete. Skip it.
op_missing_shape += 1 op_missing_shape += 1
stats = None stats = None
if not stats or not stats.value: entry = tfprof_log_pb2.OpLogEntry()
continue entry.name = op.name
if node.name not in logged_ops: add_entry = False
entry = tfprof_log_pb2.OpLogEntry() if stats and stats.value:
entry.name = node.name
entry.float_ops = int(stats.value) entry.float_ops = int(stats.value)
add_entry = True
if add_trace:
for tb in op.traceback:
trace = entry.code_def.traces.add()
trace.file = tb[0] if tb[0] else 'none'
trace.lineno = tb[1] if tb[1] else -1
trace.function = tb[2] if tb[2] else 'none'
trace.line = tb[3] if tb[3] else 'none'
add_entry = True
if add_entry:
logged_ops[entry.name] = entry logged_ops[entry.name] = entry
for v in graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES): for v in graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES):
@ -108,18 +120,21 @@ def _get_logged_ops(graph, run_meta=None):
return logged_ops return logged_ops
def _merge_default_with_oplog(graph, op_log=None, run_meta=None): def _merge_default_with_oplog(graph, op_log=None,
run_meta=None,
add_trace=False):
"""Merge the tfprof default extra info with caller's op_log. """Merge the tfprof default extra info with caller's op_log.
Args: Args:
graph: tf.Graph. graph: tf.Graph.
op_log: OpLog proto. op_log: OpLog proto.
run_meta: RunMetadata proto used to complete shape information. run_meta: RunMetadata proto used to complete shape information.
add_trace: Whether to add op trace information.
Returns: Returns:
tmp_op_log: Merged OpLog proto. tmp_op_log: Merged OpLog proto.
""" """
tmp_op_log = tfprof_log_pb2.OpLog() tmp_op_log = tfprof_log_pb2.OpLog()
logged_ops = _get_logged_ops(graph, run_meta) logged_ops = _get_logged_ops(graph, run_meta, add_trace=add_trace)
if not op_log: if not op_log:
tmp_op_log.log_entries.extend(logged_ops.values()) tmp_op_log.log_entries.extend(logged_ops.values())
else: else:
@ -131,13 +146,16 @@ def _merge_default_with_oplog(graph, op_log=None, run_meta=None):
all_ops[op_name].types.extend(entry.types) all_ops[op_name].types.extend(entry.types)
if entry.float_ops > 0 and all_ops[op_name].float_ops == 0: if entry.float_ops > 0 and all_ops[op_name].float_ops == 0:
all_ops[op_name].float_ops = entry.float_ops all_ops[op_name].float_ops = entry.float_ops
if entry.code_def.traces and not all_ops[op_name].code_def.traces:
all_ops[op_name].code_def.MergeFrom(entry.code_def)
else: else:
all_ops[op_name] = entry all_ops[op_name] = entry
tmp_op_log.log_entries.extend(all_ops.values()) tmp_op_log.log_entries.extend(all_ops.values())
return tmp_op_log return tmp_op_log
def write_op_log(graph, log_dir, op_log=None, run_meta=None): def write_op_log(graph, log_dir, op_log=None, run_meta=None,
add_trace=False):
"""Log provided 'op_log', and add additional model information below. """Log provided 'op_log', and add additional model information below.
The API also assigns ops in tf.trainable_variables() an op type called The API also assigns ops in tf.trainable_variables() an op type called
@ -154,8 +172,9 @@ def write_op_log(graph, log_dir, op_log=None, run_meta=None):
one is created. one is created.
run_meta: (Optional) RunMetadata proto that helps flops computation using run_meta: (Optional) RunMetadata proto that helps flops computation using
run time shape information. run time shape information.
add_trace: Whether to add op trace information. Used to support "code" view.
""" """
op_log = _merge_default_with_oplog(graph, op_log, run_meta) op_log = _merge_default_with_oplog(graph, op_log, run_meta, add_trace)
with gfile.Open(os.path.join(log_dir, 'tfprof_log'), 'w') as log: with gfile.Open(os.path.join(log_dir, 'tfprof_log'), 'w') as log:
log.write(op_log.SerializeToString()) log.write(op_log.SerializeToString())

View File

@ -10,12 +10,17 @@ Consultants: Jon Shlens, Pete Warden
1. Measure model parameters, float operations, tensor shapes. 1. Measure model parameters, float operations, tensor shapes.
2. Measure op execution times, requested memory size and device placement. 2. Measure op execution times, requested memory size and device placement.
3. Inspect checkpoint tensors' shapes and their values. 3. Inspect checkpoint tensors' shapes and their values.
4. Explore model based on name scope or graph structure. 4. 3 ways to view and explore TensorFlow model profiles
* Organize by Python code call stack.
* Organize by TensorFlow operation name scope hierarchies.
* Organize by TensorFlow operation inputs/outputs graph.
5. Selectively grouping/filtering/accounting/ordering ops. 5. Selectively grouping/filtering/accounting/ordering ops.
[Python API Tutorials](#python-api-tutorials): It can be called directly from [Python API Tutorials](#python-api-tutorials): It can be called directly from
Python codes. Results are either printed Python codes. Results are either printed
to stdout or dumped to file. tensorflow.tfprof.TFProfNode proto is returned from to stdout or dumped to file. tensorflow.tfprof.TFGraphNodeProto proto is returned from
the API to allow users to perform further analysis. the API to allow users to perform further analysis.
[CLI Tutorials](#cli-tutorials): [CLI Tutorials](#cli-tutorials):
@ -33,13 +38,23 @@ tfprof is part of TensorFlow core. Simply ```import tensorflow as tf```.
### Examine the shapes and sizes of all trainable Variables. ### Examine the shapes and sizes of all trainable Variables.
```python ```python
# Print trainable variable parameter statistics to stdout. # Print trainable variable parameter statistics to stdout.
# By default, statistics are associated with each graph node.
param_stats = tf.contrib.tfprof.model_analyzer.print_model_analysis( param_stats = tf.contrib.tfprof.model_analyzer.print_model_analysis(
tf.get_default_graph(), tf.get_default_graph(),
tfprof_options=tf.contrib.tfprof.model_analyzer. tfprof_options=tf.contrib.tfprof.model_analyzer.
TRAINABLE_VARS_PARAMS_STAT_OPTIONS) TRAINABLE_VARS_PARAMS_STAT_OPTIONS)
# param_stats is tensorflow.tfprof.TFProfNode proto. It organize the statistics
# of each graph node in tree scructure. Let's print the root below. # Set tfprof_cmd='code' to associate statistics with Python codes.
opts = tf.contrib.tfprof.model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS
opts['show_name_regexes'] = ['.*my_code1.py.*', '.*my_code2.py.*']
param_stats = tf.contrib.tfprof.model_analyzer.print_model_analysis(
tf.get_default_graph(),
tfprof_cmd='code'
tfprof_options=opts)
# param_stats is tensorflow.tfprof.TFGraphNodeProto proto.
# Let's print the root below.
sys.stdout.write('total_params: %d\n' % param_stats.total_parameters) sys.stdout.write('total_params: %d\n' % param_stats.total_parameters)
``` ```
@ -84,8 +99,20 @@ Finally, you may run `print_model_analysis` to explore the timing and memory
demands of the model. demands of the model.
``` python ``` python
# See model_analyzer_test.py for more examples.
#
# Print to stdout an analysis of the memory usage and the timing information # Print to stdout an analysis of the memory usage and the timing information
# from running the graph broken down by operations. # broken down by python codes.
opts = tf.contrib.tfprof.model_analyzer.PRINT_ALL_TIMING_MEMORY.copy()
opts['show_name_regexes'] = ['.*my_code.py.*']
tf.contrib.tfprof.model_analyzer.print_model_analysis(
tf.get_default_graph(),
run_meta=run_metadata,
tfprof_cmd='code',
tfprof_options=opts)
# Print to stdout an analysis of the memory usage and the timing information
# broken down by operations.
tf.contrib.tfprof.model_analyzer.print_model_analysis( tf.contrib.tfprof.model_analyzer.print_model_analysis(
tf.get_default_graph(), tf.get_default_graph(),
run_meta=run_metadata, run_meta=run_metadata,
@ -138,9 +165,9 @@ bazel-bin/tensorflow/tools/tfprof/tfprof \
--run_meta_path=run_meta \ --run_meta_path=run_meta \
--checkpoint_path=model.ckpt --checkpoint_path=model.ckpt
# #
# tfprof_log is used to define customized op types and float ops. # tfprof_log is used to define customized op types, float ops and code traces.
# Use tfprof_logger.write_op_log() to create tfprof_log. # Use tfprof_logger.write_op_log() to create tfprof_log.
# See 11) in Examples section on generating tfprof_log file. # See 12) in Examples section on generating tfprof_log file.
bazel-bin/tensorflow/tools/tfprof/tfprof \ bazel-bin/tensorflow/tools/tfprof/tfprof \
--graph_path=graph.pbtxt \ --graph_path=graph.pbtxt \
--run_meta_path=run_meta \ --run_meta_path=run_meta \
@ -174,7 +201,28 @@ tfprof>
-dump_to_file -dump_to_file
``` ```
3) I want to see the `BatchNorm`'s gamma value in checkpoint. 3) I want to see which line of my python codes costs most time!
```shell
# Requires --graph_path --op_log_path
tfprof> code -max_depth 1000 -show_name_regexes .*model_analyzer.*py.* -select micros -account_type_regexes .* -order_by micros
_TFProfRoot (0us/22.44ms)
model_analyzer_test.py:149:run_filename_as_m...:none (0us/22.44ms)
model_analyzer_test.py:33:_run_code_in_main:none (0us/22.44ms)
model_analyzer_test.py:208:<module>:test.main() (0us/22.44ms)
model_analyzer_test.py:132:testComplexCodeView:x = lib.BuildFull... (0us/22.44ms)
model_analyzer_testlib.py:63:BuildFullModel:return sgd_op.min... (0us/21.83ms)
model_analyzer_testlib.py:58:BuildFullModel:cell, array_ops.c... (0us/333us)
model_analyzer_testlib.py:54:BuildFullModel:seq.append(array_... (0us/254us)
model_analyzer_testlib.py:42:BuildSmallModel:x = nn_ops.conv2d... (0us/134us)
model_analyzer_testlib.py:46:BuildSmallModel:initializer=init_... (0us/40us)
...
model_analyzer_testlib.py:61:BuildFullModel:loss = nn_ops.l2_... (0us/28us)
model_analyzer_testlib.py:60:BuildFullModel:target = array_op... (0us/0us)
model_analyzer_test.py:134:testComplexCodeView:sess.run(variable... (0us/0us)
```
4) I want to see the `BatchNorm`'s gamma value in checkpoint.
```shell ```shell
# Requires --graph_path, --checkpoint_path. # Requires --graph_path, --checkpoint_path.
@ -186,7 +234,7 @@ _TFProfRoot ()
[1.57 1.83 1.30 1.25 1.59 1.14 1.26 0.82 1.19 1.10 1.48 1.01 0.82 1.23 1.21 1.14 ], [1.57 1.83 1.30 1.25 1.59 1.14 1.26 0.82 1.19 1.10 1.48 1.01 0.82 1.23 1.21 1.14 ],
``` ```
4) I want to see my checkpoint tensors shape and number of parameters. 5) I want to see my checkpoint tensors shape and number of parameters.
```shell ```shell
# Requires --graph_path, --checkpoint_path. # Requires --graph_path, --checkpoint_path.
@ -205,7 +253,7 @@ _TFProfRoot (--/930.58k params)
unit_last/final_bn/moving_variance (64, 64/64 params) unit_last/final_bn/moving_variance (64, 64/64 params)
``` ```
5) I defined an op named cost to calculate the loss. I want to know what ops 6) I defined an op named cost to calculate the loss. I want to know what ops
it depends on take a long time to run. Hint: Use the graph command to explore it depends on take a long time to run. Hint: Use the graph command to explore
graph dependencies. graph dependencies.
@ -221,7 +269,7 @@ _TFProfRoot (0us/3.61sec)
unit_3_3/sub2/conv2/Conv2D (10.26ms/3.60sec) unit_3_3/sub2/conv2/Conv2D (10.26ms/3.60sec)
``` ```
6) I want to know the expensive operations during the back propagation. 7) I want to know the expensive operations during the back propagation.
Hint: tensorflow prepend gradient to your defined name scopes. Use the scope Hint: tensorflow prepend gradient to your defined name scopes. Use the scope
command to explore based on name scope hierarchies. command to explore based on name scope hierarchies.
@ -238,7 +286,7 @@ _TFProfRoot (0us/2.29sec)
... ...
``` ```
7) Show the number of float operations in the model. 8) Show the number of float operations in the model.
Note: float operations calculation depends on Note: float operations calculation depends on
1) op.RegisterStatistics. If an op doesnt 1) op.RegisterStatistics. If an op doesnt
have RegisterStatistics defined, its float operations cannot be counted. have RegisterStatistics defined, its float operations cannot be counted.
@ -263,7 +311,7 @@ _TFProfRoot (0/17.63b flops)
... ...
``` ```
8) Show the number of parameters of all `tf.trainable_variables()` in the model. 9) Show the number of parameters of all `tf.trainable_variables()` in the model.
```shell ```shell
# Requires --graph_path --op_log_path. # Requires --graph_path --op_log_path.
@ -283,7 +331,7 @@ generated by write_op_log() Python API. write_op_log() help users create some
common op types implicitly. Users can define their own op types and log it common op types implicitly. Users can define their own op types and log it
through the write_op_log() API. through the write_op_log() API.
9) What if Im lazy and dont want to define op type? I have given my ops 109) What if Im lazy and dont want to define op type? I have given my ops
well-defined names in my models code. And want to use names to select a group well-defined names in my models code. And want to use names to select a group
of ops. Lets try it! of ops. Lets try it!
@ -301,7 +349,7 @@ in terminal. Otherwise, tfprof accounts all ops matched by
`-account_type_regexes` recursively even if they are hidden due to some `-account_type_regexes` recursively even if they are hidden due to some
options such as -max_depth. options such as -max_depth.
10) TensorFlow has built-in op types. For example, built-in op type `Variable` 11) TensorFlow has built-in op types. For example, built-in op type `Variable`
seems to include `Variable's` created by your model. However, be careful when seems to include `Variable's` created by your model. However, be careful when
depending on it because TensorFlow creates extra `Variable` ops implicitly and depending on it because TensorFlow creates extra `Variable` ops implicitly and
the implicitly created ops can have the same prefix as the `Variable's` you the implicitly created ops can have the same prefix as the `Variable's` you
@ -327,7 +375,7 @@ _TFProfRoot (--/930.58k params)
``` ```
11) A example of defining extra op type for ops using `OpLog` 12) A example of defining extra op type for ops using `OpLog`
First, in Python code, create an `OpLog` proto and add op type First, in Python code, create an `OpLog` proto and add op type
information to it: information to it:

View File

@ -15,6 +15,7 @@ cc_library(
srcs = ["tfprof_stats.cc"], srcs = ["tfprof_stats.cc"],
hdrs = ["tfprof_stats.h"], hdrs = ["tfprof_stats.h"],
deps = [ deps = [
":tfprof_code",
":tfprof_graph", ":tfprof_graph",
":tfprof_node", ":tfprof_node",
":tfprof_options", ":tfprof_options",
@ -61,6 +62,27 @@ cc_library(
], ],
) )
cc_library(
name = "tfprof_code",
srcs = ["tfprof_code.cc"],
hdrs = ["tfprof_code.h"],
deps = [
":tfprof_constants",
":tfprof_node",
":tfprof_options",
":tfprof_show_code",
":tfprof_tensor",
":tfprof_utils",
"//tensorflow/c:c_api",
"//tensorflow/c:checkpoint_reader",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:regexp_internal",
"//tensorflow/tools/tfprof:protos_all_cc",
],
)
cc_library( cc_library(
name = "tfprof_graph", name = "tfprof_graph",
srcs = ["tfprof_graph.cc"], srcs = ["tfprof_graph.cc"],
@ -98,6 +120,26 @@ cc_library(
], ],
) )
cc_library(
name = "tfprof_show_code",
srcs = ["tfprof_show_code.cc"],
hdrs = ["tfprof_show_code.h"],
deps = [
":tfprof_constants",
":tfprof_node",
":tfprof_options",
":tfprof_scope",
":tfprof_show",
":tfprof_tensor",
":tfprof_utils",
"//tensorflow/c:checkpoint_reader",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:regexp_internal",
"//tensorflow/tools/tfprof:protos_all_cc",
],
)
tf_cc_test( tf_cc_test(
name = "tfprof_show_test", name = "tfprof_show_test",
srcs = ["tfprof_show_test.cc"], srcs = ["tfprof_show_test.cc"],

View File

@ -40,13 +40,13 @@ string PrintModelAnalysis(const string* graph, const string* run_meta,
graph_ptr->ParseFromString(*graph); graph_ptr->ParseFromString(*graph);
std::unique_ptr<RunMetadata> run_meta_ptr; std::unique_ptr<RunMetadata> run_meta_ptr;
if (run_meta) { if (run_meta && !run_meta->empty()) {
run_meta_ptr.reset(new RunMetadata()); run_meta_ptr.reset(new RunMetadata());
run_meta_ptr->ParseFromString(*run_meta); run_meta_ptr->ParseFromString(*run_meta);
} }
std::unique_ptr<OpLog> op_log_ptr; std::unique_ptr<OpLog> op_log_ptr;
if (op_log) { if (op_log && !op_log->empty()) {
op_log_ptr.reset(new OpLog()); op_log_ptr.reset(new OpLog());
op_log_ptr->ParseFromString(*op_log); op_log_ptr->ParseFromString(*op_log);
} }
@ -58,16 +58,27 @@ string PrintModelAnalysis(const string* graph, const string* run_meta,
Options opts = Options::FromProtoStr(*options); Options opts = Options::FromProtoStr(*options);
// TODO(xpan): We should have dump_to_file/print_stdout/etc to control
// side-effects independently instead of one controlling the other.
if (opts.dump_to_file.empty()) { if (opts.dump_to_file.empty()) {
printf("\n=========================Options=============================\n"); printf("\n=========================Options=============================\n");
printf("%s", opts.ToString().c_str()); printf("%s", opts.ToString().c_str());
printf("\n==================Model Analysis Report======================\n"); printf("\n==================Model Analysis Report======================\n");
TFProfNode root(tf_stats.PrintGraph(*command, opts)); string ret = "";
if (*command == kCmds[2]) {
ret = tf_stats.PrintCode(opts).SerializeAsString();
} else {
ret = tf_stats.PrintGraph(*command, opts).SerializeAsString();
}
printf("\n======================End of Report==========================\n"); printf("\n======================End of Report==========================\n");
fflush(stdout); fflush(stdout);
return root.SerializeAsString(); return ret;
}
if (*command == kCmds[2]) {
return tf_stats.PrintCode(opts).SerializeAsString();
} else {
return tf_stats.PrintGraph(*command, opts).SerializeAsString();
} }
return tf_stats.PrintGraph(*command, opts).SerializeAsString();
} }
} // namespace tfprof } // namespace tfprof
} // namespace tensorflow } // namespace tensorflow

View File

@ -0,0 +1,215 @@
/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/tools/tfprof/internal/tfprof_code.h"
#include <stdio.h>
#include <utility>
#include "tensorflow/c/c_api.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/tools/tfprof/internal/tfprof_constants.h"
#include "tensorflow/tools/tfprof/internal/tfprof_tensor.h"
namespace tensorflow {
namespace tfprof {
namespace {
// Convert to Trace proto into a short readable string.
string GetTraceString(const CodeDef::Trace& trace) {
string ntrace = "";
if (trace.file().find_last_of('/') != trace.file().npos) {
ntrace += trace.file().substr(trace.file().find_last_of('/') + 1);
} else {
ntrace += trace.file();
}
ntrace += strings::StrCat(":", trace.lineno());
if (trace.function().length() < 20) {
ntrace += ":" + trace.function();
} else {
ntrace += ":" + trace.function().substr(0, 17) + "...";
}
if (trace.line().length() < 20) {
ntrace += ":" + trace.line();
} else {
ntrace += ":" + trace.line().substr(0, 17) + "...";
}
return ntrace;
}
} // namespace
void TFCode::AddNode(TFGraphNode* node) {
if (!node->code()) {
return;
}
TFCodeNode* pre_trace_node = nullptr;
for (int i = 0; i < node->code()->traces_size(); ++i) {
// Unlike op name, which is globally unique, trace name is only unique
// w.r.t. it's parent.
const string& trace = GetTraceString(node->code()->traces(i));
if (i == 0) {
if (!trace_root_) {
trace_root_.reset(new TFCodeNode(trace));
}
CHECK(trace_root_->name() == trace) << "Different trace root";
pre_trace_node = trace_root_.get();
continue;
}
pre_trace_node->AddChildren(trace);
TFCodeNode* trace_node = pre_trace_node->children()[trace].get();
if (i == node->code()->traces_size() - 1) {
trace_node->AddGraphNode(node);
}
pre_trace_node = trace_node;
}
}
void TFCode::Build() {
if (!trace_root_) {
return;
}
code_root_ = BuildCodeNodes(trace_root_.get());
}
CodeNode* TFCode::BuildCodeNodes(TFCodeNode* root) {
auto code_root = std::unique_ptr<CodeNode>(new CodeNode(root));
CodeNode* code_root_ptr = code_root.get();
code_nodes_.insert(std::move(code_root));
for (auto it = root->children().cbegin(); it != root->children().cend();
++it) {
code_root_ptr->children.push_back(BuildCodeNodes(it->second.get()));
}
return code_root_ptr;
}
const ShowCodeNode* TFCode::ShowInternal(const Options& opts) {
// Search from roots recursively to find start node, if start_name_regexes
// is specified.
tfprof_trace_root_.reset(new TFCodeNode(kTFProfRoot));
tfprof_code_root_.reset(new CodeNode(tfprof_trace_root_.get()));
if (!code_root_) {
return tfprof_code_root_.get();
}
std::vector<CodeNode*> roots = {code_root_};
if (opts.start_name_regexes.size() != 1 ||
opts.start_name_regexes[0] != ".*") {
roots = SearchRoot(roots, opts.start_name_regexes);
}
tfprof_code_root_->children.assign(roots.begin(), roots.end());
Account({tfprof_code_root_.get()}, opts);
return PrintScope({tfprof_code_root_.get()}, opts, 1, 0)[0];
}
std::vector<CodeNode*> TFCode::SearchRoot(std::vector<CodeNode*> roots,
const std::vector<string>& regexes) {
std::vector<CodeNode*> res;
if (roots.empty()) {
return res;
}
for (CodeNode* root : roots) {
bool match_start_node = false;
for (const string& regex : regexes) {
if (RE2::FullMatch(root->name(), regex)) {
res.push_back(root);
match_start_node = true;
break;
}
}
if (match_start_node) {
// Found a start node at this branch, no need to continue.
continue;
}
std::vector<CodeNode*> nroots = SearchRoot(root->children, regexes);
res.insert(res.end(), nroots.begin(), nroots.end());
}
return res;
}
std::vector<CodeNode*> TFCode::PrintScope(const std::vector<CodeNode*> roots,
const Options& opts, int depth,
int last_ident) {
std::vector<CodeNode*> show_nodes;
for (CodeNode* node : roots) {
int nlast_ident = last_ident;
bool show = ShouldShow(node, opts, depth);
if (show) {
node->formatted_str.clear();
if (opts.account_displayed_op_only) {
node->ResetTotalStats();
node->AddSelfToTotalStats();
}
nlast_ident += 2;
}
std::vector<CodeNode*> show_cnodes;
if (!ShouldTrim(node, opts.trim_name_regexes)) {
show_cnodes = PrintScope(node->children, opts, depth + 1, nlast_ident);
}
if (show) {
show_cnodes = SortNodes(show_cnodes, opts);
string children_str;
for (CodeNode* sc : show_cnodes) {
children_str += sc->formatted_str;
node->mutable_proto()->add_children()->MergeFrom(sc->proto());
if (opts.account_displayed_op_only) {
node->AggregateTotalStats(sc);
}
}
node->formatted_str =
strings::Printf("%s%s\n", string(last_ident, ' ').c_str(),
node->Format(opts).c_str());
if (opts.select.find(kShown[5]) != opts.select.end()) {
fprintf(stderr, "code view has no tensor value to show\n");
}
node->formatted_str += children_str;
show_nodes.push_back(node);
} else {
show_nodes.insert(show_nodes.end(), show_cnodes.begin(),
show_cnodes.end());
}
}
return show_nodes;
}
void TFCode::Account(const std::vector<CodeNode*>& roots, const Options& opts) {
if (roots.empty()) return;
for (CodeNode* node : roots) {
node->ResetTotalStats();
Account(node->children, opts);
node->account = ShouldAccount(node, opts);
if (node->account) {
node->AddSelfToTotalStats();
}
for (CodeNode* c : node->children) {
node->AggregateTotalStats(c);
}
}
}
} // namespace tfprof
} // namespace tensorflow

View File

@ -0,0 +1,88 @@
/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Build a tree structure based on the TensorFlow model's python code stacks.
// Stats are aggregated from descendants from ancestors.
#ifndef THIRD_PARTY_TENSORFLOW_TOOLS_TFPROF_INTERNAL_TFPROF_CODE_H_
#define THIRD_PARTY_TENSORFLOW_TOOLS_TFPROF_INTERNAL_TFPROF_CODE_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "tensorflow/c/checkpoint_reader.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/tools/tfprof/internal/tfprof_node.h"
#include "tensorflow/tools/tfprof/internal/tfprof_options.h"
#include "tensorflow/tools/tfprof/internal/tfprof_show_code.h"
#include "tensorflow/tools/tfprof/internal/tfprof_utils.h"
#include "tensorflow/tools/tfprof/tfprof_log.pb.h"
#include "tensorflow/tools/tfprof/tfprof_output.pb.h"
namespace tensorflow {
namespace tfprof {
class CodeNode : public ShowCodeNode {
public:
explicit CodeNode(const TFCodeNode* node) : ShowCodeNode(node) {}
~CodeNode() override {}
void AggregateTotalStats(CodeNode* node) {
ShowCodeNode::AggregateTotalStats(node);
}
void AddSelfToTotalStats() { ShowCodeNode::AddSelfToTotalStats(); }
void ResetTotalStats() { ShowCodeNode::ResetTotalStats(); }
std::vector<CodeNode*> children;
};
class TFCode : public TFShowCode {
public:
explicit TFCode() : code_root_(nullptr), trace_root_(nullptr) {}
~TFCode() override {}
void AddNode(TFGraphNode* node) override;
void Build() override;
private:
CodeNode* BuildCodeNodes(TFCodeNode* root);
const ShowCodeNode* ShowInternal(const Options& opts) override;
std::vector<CodeNode*> SearchRoot(std::vector<CodeNode*> roots,
const std::vector<string>& regexes);
std::vector<CodeNode*> PrintScope(const std::vector<CodeNode*> roots,
const Options& opts, int depth,
int last_ident);
void Account(const std::vector<CodeNode*>& roots, const Options& opts);
CodeNode* code_root_;
std::unique_ptr<TFCodeNode> trace_root_;
std::unique_ptr<TFCodeNode> tfprof_trace_root_;
std::unique_ptr<CodeNode> tfprof_code_root_;
std::set<std::unique_ptr<CodeNode>> code_nodes_;
};
} // namespace tfprof
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_TOOLS_TFPROF_INTERNAL_TFPROF_CODE_H_

View File

@ -31,14 +31,14 @@ GraphNode* TFGraph::CreateParentNode(const string& name) {
node_defs_.back()->set_name(name); node_defs_.back()->set_name(name);
node_defs_.back()->set_op(kTFGraphParent); node_defs_.back()->set_op(kTFGraphParent);
parent_nodes_[name] = parent_nodes_[name] =
std::unique_ptr<TFNode>(new TFNode(node_defs_.back().get())); std::unique_ptr<TFGraphNode>(new TFGraphNode(node_defs_.back().get()));
nodes_map_[name] = nodes_map_[name] =
std::unique_ptr<GraphNode>(new GraphNode(parent_nodes_[name].get())); std::unique_ptr<GraphNode>(new GraphNode(parent_nodes_[name].get()));
return nodes_map_[name].get(); return nodes_map_[name].get();
} }
void TFGraph::AddNode(TFNode* node) { void TFGraph::AddNode(TFGraphNode* node) {
string name = node->node_def()->name(); string name = node->name();
nodes_map_[name] = std::unique_ptr<GraphNode>(new GraphNode(node)); nodes_map_[name] = std::unique_ptr<GraphNode>(new GraphNode(node));
} }
@ -49,7 +49,7 @@ void TFGraph::Build() {
// Filter out the root nodes (node not input of any other node). // Filter out the root nodes (node not input of any other node).
for (auto it = nodes_map_.begin(); it != nodes_map_.end(); it++) { for (auto it = nodes_map_.begin(); it != nodes_map_.end(); it++) {
GraphNode* node = it->second.get(); GraphNode* node = it->second.get();
const std::map<string, TFNode*>& inputs = node->node->inputs(); const std::map<string, TFGraphNode*>& inputs = node->node->inputs();
for (auto inputs_it = inputs.cbegin(); inputs_it != inputs.cend(); for (auto inputs_it = inputs.cbegin(); inputs_it != inputs.cend();
inputs_it++) { inputs_it++) {
nonroots.insert(inputs_it->first); nonroots.insert(inputs_it->first);

View File

@ -39,7 +39,7 @@ namespace tensorflow {
namespace tfprof { namespace tfprof {
class GraphNode : public ShowNode { class GraphNode : public ShowNode {
public: public:
explicit GraphNode(TFNode* node) : ShowNode(node) { explicit GraphNode(TFGraphNode* node) : ShowNode(node) {
mutable_proto()->set_inputs(node->inputs().size()); mutable_proto()->set_inputs(node->inputs().size());
mutable_proto()->set_total_inputs(0); mutable_proto()->set_total_inputs(0);
} }
@ -72,7 +72,7 @@ class TFGraph : public TFShow {
: TFShow(ckpt_reader) {} : TFShow(ckpt_reader) {}
~TFGraph() override {} ~TFGraph() override {}
void AddNode(TFNode* node) override; void AddNode(TFGraphNode* node) override;
void Build() override; void Build() override;
@ -99,14 +99,14 @@ class TFGraph : public TFShow {
std::vector<GraphNode*> GenerateGraphDot( std::vector<GraphNode*> GenerateGraphDot(
GraphNode* root, GraphNode* last_shown, const Options& opts, int depth, GraphNode* root, GraphNode* last_shown, const Options& opts, int depth,
int hidden, std::set<string>* declared_nodes, int hidden, std::set<string>* declared_nodes,
std::set<string>* declared_edges, TFProfNode* parent); std::set<string>* declared_edges, TFGraphNodeProto* parent);
void Account(const std::vector<GraphNode*>& roots, const Options& opts, void Account(const std::vector<GraphNode*>& roots, const Options& opts,
std::map<string, int64>* visits); std::map<string, int64>* visits);
std::vector<GraphNode*> roots_; std::vector<GraphNode*> roots_;
std::vector<std::unique_ptr<NodeDef>> node_defs_; std::vector<std::unique_ptr<NodeDef>> node_defs_;
std::map<string, std::unique_ptr<TFNode>> parent_nodes_; std::map<string, std::unique_ptr<TFGraphNode>> parent_nodes_;
std::map<string, std::unique_ptr<GraphNode>> nodes_map_; std::map<string, std::unique_ptr<GraphNode>> nodes_map_;
}; };

View File

@ -20,7 +20,8 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace tfprof { namespace tfprof {
void TFNode::AddStepStat(const string& device, const NodeExecStats* step_stat) { void TFGraphNode::AddStepStat(const string& device,
const NodeExecStats* step_stat) {
if (!device.empty()) { if (!device.empty()) {
// This might override device from GraphDef. // This might override device from GraphDef.
device_ = device; device_ = device;
@ -44,7 +45,7 @@ void TFNode::AddStepStat(const string& device, const NodeExecStats* step_stat) {
} }
} }
void TFNode::AddNodeStat(const CostGraphDef::Node* cost_node) { void TFGraphNode::AddNodeStat(const CostGraphDef::Node* cost_node) {
kernel_compute_micros_ = cost_node->compute_cost(); kernel_compute_micros_ = cost_node->compute_cost();
} }
} // namespace tfprof } // namespace tfprof

View File

@ -30,14 +30,16 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/tools/tfprof/internal/tfprof_options.h" #include "tensorflow/tools/tfprof/internal/tfprof_options.h"
#include "tensorflow/tools/tfprof/tfprof_log.pb.h"
namespace tensorflow { namespace tensorflow {
namespace tfprof { namespace tfprof {
class TFNode { class TFGraphNode {
public: public:
TFNode(const NodeDef* node) TFGraphNode(const NodeDef* node)
: node_(node), : node_(node),
code_(nullptr),
step_stat_(nullptr), step_stat_(nullptr),
op_start_micros_(0), op_start_micros_(0),
op_schedule_micros_(0), op_schedule_micros_(0),
@ -70,9 +72,9 @@ class TFNode {
device_ = node->device(); device_ = node->device();
} }
TFNode() : TFNode(nullptr) {} TFGraphNode() : TFGraphNode(nullptr) {}
void AddInput(TFNode* input) { inputs_[input->node_def()->name()] = input; } void AddInput(TFGraphNode* input) { inputs_[input->name()] = input; }
void AddOpType(const string& op_type) { op_types_.insert(op_type); } void AddOpType(const string& op_type) { op_types_.insert(op_type); }
@ -83,27 +85,32 @@ class TFNode {
void AddFloatOps(int64 float_ops) { float_ops_ = float_ops; } void AddFloatOps(int64 float_ops) { float_ops_ = float_ops; }
void AddCode(const CodeDef* code) { code_ = code; }
const string& name() const { return node_->name(); }
const NodeDef* node_def() { return node_; } const NodeDef* node_def() { return node_; }
const std::map<string, TFNode*>& inputs() { return inputs_; } const std::map<string, TFGraphNode*>& inputs() const { return inputs_; }
int64 op_start_micros() { return op_start_micros_; } int64 op_start_micros() { return op_start_micros_; }
// This is time spent in Op::Compute(), which is GPU kernel schedule time. // This is time spent in Op::Compute(), which is GPU kernel schedule time.
// Currently not used. // Currently not used.
int64 op_schedule_micros() { return op_schedule_micros_; } int64 op_schedule_micros() { return op_schedule_micros_; }
// This is time spent in kernel execution. // This is time spent in kernel execution.
int64 kernel_compute_micros() { return kernel_compute_micros_; } int64 kernel_compute_micros() const { return kernel_compute_micros_; }
int64 all_spent_micros() { return all_spent_micros_; } int64 all_spent_micros() { return all_spent_micros_; }
int64 requested_byptes() { return requested_bytes_; } int64 requested_bytes() const { return requested_bytes_; }
int64 float_ops() { return float_ops_; } int64 float_ops() const { return float_ops_; }
string device() { return device_; } const CodeDef* code() { return code_; }
const std::set<string>& op_types() { return op_types_; } string device() const { return device_; }
const std::set<string>& op_types() const { return op_types_; }
const std::vector<int64>& shape() { return shape_; } const std::vector<int64>& shape() const { return shape_; }
private: private:
void update_shape(const std::vector<int64>& shape) { shape_ = shape; } void update_shape(const std::vector<int64>& shape) { shape_ = shape; }
std::map<string, TFNode*> inputs_; std::map<string, TFGraphNode*> inputs_;
const NodeDef* node_; const NodeDef* node_;
const CodeDef* code_;
const NodeExecStats* step_stat_; const NodeExecStats* step_stat_;
std::vector<int64> shape_; std::vector<int64> shape_;
@ -117,6 +124,71 @@ class TFNode {
int64 float_ops_; int64 float_ops_;
}; };
class TFCodeNode {
public:
TFCodeNode(const string& trace)
: trace_(trace),
kernel_compute_micros_(0),
requested_bytes_(0),
float_ops_(0) {}
void AddGraphNode(const TFGraphNode* node) {
if (nodes_.find(node->name()) != nodes_.end()) {
return;
}
nodes_[node->name()] = node;
kernel_compute_micros_ += node->kernel_compute_micros();
requested_bytes_ += node->requested_bytes();
float_ops_ += node->float_ops();
op_types_.insert(node->op_types().begin(), node->op_types().end());
if (node->shape().size() > 0) {
shapes_.push_back(node->shape());
}
if (!node->device().empty()) {
devices_.insert(node->device());
}
}
const std::map<string, const TFGraphNode*>& graph_nodes() const {
return nodes_;
}
void AddChildren(const string& trace) {
if (children_.find(trace) != children_.end()) {
return;
}
children_[trace].reset(new TFCodeNode(trace));
}
std::map<string, std::unique_ptr<TFCodeNode>>& children() {
return children_;
}
const string& name() const { return trace_; }
int64 kernel_compute_micros() const { return kernel_compute_micros_; }
int64 requested_bytes() const { return requested_bytes_; }
int64 float_ops() const { return float_ops_; }
const std::set<string>& devices() const { return devices_; }
const std::set<string>& op_types() const { return op_types_; }
const std::vector<std::vector<int64>>& shapes() const { return shapes_; }
private:
const string trace_;
std::set<string> op_types_;
int64 kernel_compute_micros_;
int64 requested_bytes_;
int64 float_ops_;
std::set<string> devices_;
std::vector<std::vector<int64>> shapes_;
std::map<string, const TFGraphNode*> nodes_;
std::map<string, std::unique_ptr<TFCodeNode>> children_;
};
} // namespace tfprof } // namespace tfprof
} // namespace tensorflow } // namespace tensorflow

View File

@ -55,7 +55,7 @@ static const char* const kShown[] = {
}; };
static const char* const kCmds[] = { static const char* const kCmds[] = {
"scope", "graph", "set", "help", "scope", "graph", "code", "set", "help",
}; };
struct Options { struct Options {

View File

@ -35,15 +35,15 @@ ScopeNode* TFScope::CreateParentNode(const string& name) {
node_defs_.back()->set_name(name); node_defs_.back()->set_name(name);
node_defs_.back()->set_op(kTFScopeParent); node_defs_.back()->set_op(kTFScopeParent);
parent_nodes_[name] = parent_nodes_[name] =
std::unique_ptr<TFNode>(new TFNode(node_defs_.back().get())); std::unique_ptr<TFGraphNode>(new TFGraphNode(node_defs_.back().get()));
nodes_map_[name] = nodes_map_[name] =
std::unique_ptr<ScopeNode>(new ScopeNode(parent_nodes_[name].get())); std::unique_ptr<ScopeNode>(new ScopeNode(parent_nodes_[name].get()));
return nodes_map_[name].get(); return nodes_map_[name].get();
} }
void TFScope::AddNode(TFNode* node) { void TFScope::AddNode(TFGraphNode* node) {
string name = node->node_def()->name(); string name = node->name();
if (nodes_map_.find(node->node_def()->name()) == nodes_map_.end()) { if (nodes_map_.find(node->name()) == nodes_map_.end()) {
nodes_map_[name] = std::unique_ptr<ScopeNode>(new ScopeNode(node)); nodes_map_[name] = std::unique_ptr<ScopeNode>(new ScopeNode(node));
} }

View File

@ -39,7 +39,7 @@ namespace tfprof {
class ScopeNode : public ShowNode { class ScopeNode : public ShowNode {
public: public:
explicit ScopeNode(TFNode* node) : ShowNode(node) {} explicit ScopeNode(const TFGraphNode* node) : ShowNode(node) {}
~ScopeNode() override {} ~ScopeNode() override {}
void AggregateTotalStats(ScopeNode* node) { void AggregateTotalStats(ScopeNode* node) {
@ -59,7 +59,7 @@ class TFScope : public TFShow {
: TFShow(ckpt_reader) {} : TFShow(ckpt_reader) {}
~TFScope() override {} ~TFScope() override {}
void AddNode(TFNode* node) override; void AddNode(TFGraphNode* node) override;
void Build() override; void Build() override;
@ -79,7 +79,7 @@ class TFScope : public TFShow {
std::vector<ScopeNode*> roots_; std::vector<ScopeNode*> roots_;
std::vector<std::unique_ptr<NodeDef>> node_defs_; std::vector<std::unique_ptr<NodeDef>> node_defs_;
std::map<string, std::unique_ptr<TFNode>> parent_nodes_; std::map<string, std::unique_ptr<TFGraphNode>> parent_nodes_;
std::map<string, std::unique_ptr<ScopeNode>> nodes_map_; std::map<string, std::unique_ptr<ScopeNode>> nodes_map_;
}; };
} // namespace tfprof } // namespace tfprof

View File

@ -25,13 +25,13 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace tfprof { namespace tfprof {
ShowNode::ShowNode(TFNode* node) : node(node), account(true) { ShowNode::ShowNode(const TFGraphNode* node) : node(node), account(true) {
mutable_proto()->set_name(name()); mutable_proto()->set_name(name());
if (!node->device().empty()) { if (!node->device().empty()) {
mutable_proto()->set_device(node->device()); mutable_proto()->set_device(node->device());
} }
mutable_proto()->set_exec_micros(node->kernel_compute_micros()); mutable_proto()->set_exec_micros(node->kernel_compute_micros());
mutable_proto()->set_requested_bytes(node->requested_byptes()); mutable_proto()->set_requested_bytes(node->requested_bytes());
mutable_proto()->set_float_ops(node->float_ops()); mutable_proto()->set_float_ops(node->float_ops());
if (!node->shape().empty()) { if (!node->shape().empty()) {
@ -119,12 +119,12 @@ string ShowNode::FormatMeta(const Options& opts) {
return str_util::Join(info, ", "); return str_util::Join(info, ", ");
} }
TFProfNode* ShowNode::mutable_proto() { return &proto_; } TFGraphNodeProto* ShowNode::mutable_proto() { return &proto_; }
const TFProfNode& ShowNode::proto() const { return proto_; } const TFGraphNodeProto& ShowNode::proto() const { return proto_; }
void ShowNode::AggregateTotalStats(ShowNode* node) { void ShowNode::AggregateTotalStats(ShowNode* node) {
TFProfNode* node_pb = node->mutable_proto(); TFGraphNodeProto* node_pb = node->mutable_proto();
mutable_proto()->set_total_exec_micros(proto().total_exec_micros() + mutable_proto()->set_total_exec_micros(proto().total_exec_micros() +
node_pb->total_exec_micros()); node_pb->total_exec_micros());
mutable_proto()->set_total_requested_bytes(proto().total_requested_bytes() + mutable_proto()->set_total_requested_bytes(proto().total_requested_bytes() +
@ -151,9 +151,10 @@ void ShowNode::ResetTotalStats() {
mutable_proto()->set_total_requested_bytes(0); mutable_proto()->set_total_requested_bytes(0);
mutable_proto()->set_total_parameters(0); mutable_proto()->set_total_parameters(0);
mutable_proto()->set_total_float_ops(0); mutable_proto()->set_total_float_ops(0);
mutable_proto()->mutable_children()->Clear();
} }
const TFProfNode& TFShow::Show(const Options& opts) { const TFGraphNodeProto& TFShow::Show(const Options& opts) {
const ShowNode* root = ShowInternal(opts); const ShowNode* root = ShowInternal(opts);
if (opts.dump_to_file.empty()) { if (opts.dump_to_file.empty()) {
printf("%s", root->formatted_str.c_str()); printf("%s", root->formatted_str.c_str());

View File

@ -37,18 +37,18 @@ namespace tensorflow {
namespace tfprof { namespace tfprof {
class ShowNode { class ShowNode {
public: public:
explicit ShowNode(TFNode* node); explicit ShowNode(const TFGraphNode* node);
virtual ~ShowNode() {} virtual ~ShowNode() {}
const string& name() const { return node->node_def()->name(); } const string& name() const { return node->name(); }
TFProfNode* mutable_proto(); TFGraphNodeProto* mutable_proto();
const TFProfNode& proto() const; const TFGraphNodeProto& proto() const;
string Format(const Options& opts); string Format(const Options& opts);
string FormatMeta(const Options& opts); string FormatMeta(const Options& opts);
TFNode* node; const TFGraphNode* node;
bool account; bool account;
string formatted_str; string formatted_str;
@ -59,7 +59,7 @@ class ShowNode {
void ResetTotalStats(); void ResetTotalStats();
TFProfNode proto_; TFGraphNodeProto proto_;
}; };
class TFShow { class TFShow {
@ -67,9 +67,9 @@ class TFShow {
explicit TFShow(checkpoint::CheckpointReader* ckpt_reader) explicit TFShow(checkpoint::CheckpointReader* ckpt_reader)
: ckpt_reader_(ckpt_reader) {} : ckpt_reader_(ckpt_reader) {}
virtual ~TFShow() {} virtual ~TFShow() {}
virtual void AddNode(TFNode* node) = 0; virtual void AddNode(TFGraphNode* node) = 0;
virtual void Build() = 0; virtual void Build() = 0;
const TFProfNode& Show(const Options& opts); const TFGraphNodeProto& Show(const Options& opts);
protected: protected:
virtual const ShowNode* ShowInternal(const Options& opts) = 0; virtual const ShowNode* ShowInternal(const Options& opts) = 0;

View File

@ -0,0 +1,273 @@
/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/tools/tfprof/internal/tfprof_show_code.h"
#include <memory>
#include <set>
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/tools/tfprof/internal/tfprof_scope.h"
namespace tensorflow {
namespace tfprof {
ShowCodeNode::ShowCodeNode(const TFCodeNode* node) : node(node), account(true) {
std::vector<ScopeNode> snodes;
for (auto it : node->graph_nodes()) {
ScopeNode snode(it.second);
snodes.push_back(snode);
snodes[snodes.size() - 1].AddSelfToTotalStats();
*mutable_proto()->mutable_graph_nodes()->Add() =
snodes[snodes.size() - 1].proto();
}
mutable_proto()->set_name(name());
mutable_proto()->set_exec_micros(node->kernel_compute_micros());
mutable_proto()->set_requested_bytes(node->requested_bytes());
mutable_proto()->set_float_ops(node->float_ops());
if (!node->shapes().empty()) {
for (const std::vector<int64>& shape : node->shapes()) {
int64 params = 1;
bool complete_shape = true;
for (int64 d : shape) {
// Sometimes parameters could be <0 when a dim is unknown.
if (d < 0) {
complete_shape = false;
break;
}
params *= d;
}
if (complete_shape) {
mutable_proto()->set_parameters(proto().parameters() + params);
} else {
fprintf(stderr, "Incomplete shape.");
}
}
}
}
string ShowCodeNode::Format(const Options& opts) {
if (opts.select.empty()) {
return name();
}
return strings::Printf("%s (%s)", name().c_str(), FormatMeta(opts).c_str());
}
string ShowCodeNode::FormatMeta(const Options& opts) {
std::vector<string> info;
std::vector<string> shapes;
if (opts.select.find(kShown[2]) != opts.select.end()) {
for (const std::vector<int64>& shape : node->shapes()) {
if (!shape.empty()) {
shapes.push_back(FormatShapes(shape));
}
}
if (!shapes.empty()) {
info.push_back(str_util::Join(shapes, "|"));
}
string params = FormatNumber(proto().total_parameters()) + " params";
if (account) {
params = FormatNumber(proto().parameters()) + "/" + params;
} else {
params = "--/" + params;
}
info.push_back(params);
}
if (opts.select.find(kShown[3]) != opts.select.end()) {
string fops = FormatNumber(proto().total_float_ops()) + " flops";
if (account) {
fops = FormatNumber(proto().float_ops()) + "/" + fops;
} else {
fops = "--/" + fops;
}
info.push_back(fops);
}
if (opts.select.find(kShown[0]) != opts.select.end()) {
string memory = FormatMemory(proto().total_requested_bytes());
if (account) {
memory = FormatMemory(proto().requested_bytes()) + "/" + memory;
} else {
memory = "--/" + memory;
}
info.push_back(memory);
}
if (opts.select.find(kShown[1]) != opts.select.end()) {
string time = FormatTime(proto().total_exec_micros());
if (account) {
time = FormatTime(proto().exec_micros()) + "/" + time;
} else {
time = "--/" + time;
}
info.push_back(time);
}
if (opts.select.find(kShown[6]) != opts.select.end()) {
if (!node->devices().empty()) {
info.push_back(str_util::Join(node->devices(), "|"));
}
}
if (opts.select.find(kShown[7]) != opts.select.end()) {
std::set<string> op_types = node->op_types();
// Device is considered a type.
op_types.insert(node->devices().cbegin(), node->devices().cend());
info.push_back(str_util::Join(op_types, "|"));
}
return str_util::Join(info, ", ");
}
TFCodeNodeProto* ShowCodeNode::mutable_proto() { return &proto_; }
const TFCodeNodeProto& ShowCodeNode::proto() const { return proto_; }
void ShowCodeNode::AggregateTotalStats(ShowCodeNode* node) {
TFCodeNodeProto* node_pb = node->mutable_proto();
mutable_proto()->set_total_exec_micros(proto().total_exec_micros() +
node_pb->total_exec_micros());
mutable_proto()->set_total_requested_bytes(proto().total_requested_bytes() +
node_pb->total_requested_bytes());
mutable_proto()->set_total_parameters(proto().total_parameters() +
node_pb->total_parameters());
mutable_proto()->set_total_float_ops(proto().total_float_ops() +
node_pb->total_float_ops());
}
void ShowCodeNode::AddSelfToTotalStats() {
mutable_proto()->set_total_exec_micros(proto().total_exec_micros() +
proto().exec_micros());
mutable_proto()->set_total_requested_bytes(proto().total_requested_bytes() +
proto().requested_bytes());
mutable_proto()->set_total_parameters(proto().total_parameters() +
proto().parameters());
mutable_proto()->set_total_float_ops(proto().total_float_ops() +
proto().float_ops());
}
void ShowCodeNode::ResetTotalStats() {
mutable_proto()->set_total_exec_micros(0);
mutable_proto()->set_total_requested_bytes(0);
mutable_proto()->set_total_parameters(0);
mutable_proto()->set_total_float_ops(0);
mutable_proto()->mutable_children()->Clear();
}
const TFCodeNodeProto& TFShowCode::Show(const Options& opts) {
const ShowCodeNode* root = ShowInternal(opts);
if (opts.dump_to_file.empty()) {
printf("%s", root->formatted_str.c_str());
fflush(stdout);
} else {
Status s = WriteStringToFile(Env::Default(), opts.dump_to_file,
root->formatted_str);
if (!s.ok()) {
fprintf(stderr, "%s\n", s.ToString().c_str());
}
}
return root->proto();
}
bool TFShowCode::ShouldShow(ShowCodeNode* node, const Options& opts,
int depth) {
// Always show kTFProfRoot.
if (node->name() == kTFProfRoot) return true;
if (!node->account) return false;
// TODO(xpan): Think more carefully about node filtering in code view.
// Unlike graph/scope view, which users want to see the exact leaf op.
// In code view, users want to see the middle code traces they wrote.
//
// This is a subtle difference from scope/graph view. Usually mostly
// want to see the middle code traces (i.e. their own codes.), instead
// of the TensorFlow internal codes traces.
if (node->proto().total_requested_bytes() < opts.min_bytes ||
node->proto().total_exec_micros() < opts.min_micros ||
node->proto().total_parameters() < opts.min_params ||
node->proto().total_float_ops() < opts.min_float_ops ||
depth > opts.max_depth || !ShouldShowIfExtra(node, opts, depth)) {
return false;
}
bool show = false;
if (opts.device_regexes.size() == 1 && opts.device_regexes[0] == ".*") {
show = true;
} else {
for (const string& regex : opts.device_regexes) {
for (const string& device : node->node->devices()) {
if (RE2::FullMatch(device, regex)) {
show = true;
break;
}
}
if (show) break;
}
}
// Don't show if device_regexes don't cover it.
if (!show) return false;
show = false;
if (opts.show_name_regexes.size() == 1 && opts.show_name_regexes[0] == ".*") {
show = true;
} else {
for (const string& regex : opts.show_name_regexes) {
if (RE2::FullMatch(node->name(), regex)) {
show = true;
break;
}
}
}
// Don't show if show_name_regexes don't cover it.
if (!show) return false;
// Don't show if hide_name_regexes cover it.
for (const string& regex : opts.hide_name_regexes) {
if (RE2::FullMatch(node->name(), regex)) return false;
}
return true;
}
bool TFShowCode::ShouldTrim(ShowCodeNode* node,
const std::vector<string>& regexes) {
for (const string& regex : regexes) {
if (RE2::FullMatch(node->name(), regex)) {
return true;
}
}
return false;
}
bool TFShowCode::ShouldAccount(ShowCodeNode* node, const Options& opts) {
if (opts.account_type_regexes.size() == 1 &&
opts.account_type_regexes[0] == ".*") {
return true;
}
for (const string& regex : opts.account_type_regexes) {
for (const string& type : node->node->op_types()) {
if (RE2::FullMatch(type, regex)) {
return true;
}
}
for (const string& device : node->node->devices()) {
if (RE2::FullMatch(device, regex)) {
return true;
}
}
}
return false;
}
} // namespace tfprof
} // namespace tensorflow

View File

@ -0,0 +1,126 @@
/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Parent class and utilities for tfprof_graph and tfprof_scope.
#ifndef THIRD_PARTY_TENSORFLOW_TOOLS_TFPROF_INTERNAL_TFPROF_SHOW_CODE_H_
#define THIRD_PARTY_TENSORFLOW_TOOLS_TFPROF_INTERNAL_TFPROF_SHOW_CODE_H_
#include <algorithm>
#include <string>
#include <vector>
#include "tensorflow/c/checkpoint_reader.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/tools/tfprof/internal/tfprof_constants.h"
#include "tensorflow/tools/tfprof/internal/tfprof_node.h"
#include "tensorflow/tools/tfprof/internal/tfprof_options.h"
#include "tensorflow/tools/tfprof/internal/tfprof_tensor.h"
#include "tensorflow/tools/tfprof/internal/tfprof_utils.h"
#include "tensorflow/tools/tfprof/tfprof_output.pb.h"
namespace tensorflow {
namespace tfprof {
class ShowCodeNode {
public:
explicit ShowCodeNode(const TFCodeNode* node);
virtual ~ShowCodeNode() {}
const string& name() const { return node->name(); }
TFCodeNodeProto* mutable_proto();
const TFCodeNodeProto& proto() const;
string Format(const Options& opts);
string FormatMeta(const Options& opts);
const TFCodeNode* node;
bool account;
string formatted_str;
protected:
void AggregateTotalStats(ShowCodeNode* node);
void AddSelfToTotalStats();
void ResetTotalStats();
TFCodeNodeProto proto_;
};
class TFShowCode {
public:
explicit TFShowCode() {}
virtual ~TFShowCode() {}
virtual void AddNode(TFGraphNode* node) = 0;
virtual void Build() = 0;
const TFCodeNodeProto& Show(const Options& opts);
protected:
virtual const ShowCodeNode* ShowInternal(const Options& opts) = 0;
bool LookUpCheckPoint(const string& name,
std::unique_ptr<TFProfTensor>* tensor);
// Overridden by subclass if extra requirements need to be met.
virtual bool ShouldShowIfExtra(ShowCodeNode* node, const Options& opts,
int depth) {
return true;
}
bool ShouldShow(ShowCodeNode* node, const Options& opts, int depth);
bool ShouldTrim(ShowCodeNode* node, const std::vector<string>& regexes);
bool ShouldAccount(ShowCodeNode* node, const Options& opts);
template <typename T>
std::vector<T*> SortNodes(const std::vector<T*>& nodes, const Options& opts) {
if (opts.order_by.empty() || nodes.empty()) {
return nodes;
}
std::vector<T*> sorted_nodes = nodes;
std::sort(sorted_nodes.begin(), sorted_nodes.end(),
[&opts](const T* n1, const T* n2) {
if (n1->name() == kTFProfRoot) return true;
if (n2->name() == kTFProfRoot) return false;
bool name_cmp = n1->name() < n2->name();
if (opts.order_by == kOrderBy[0]) {
return name_cmp;
} else if (opts.order_by == kOrderBy[1]) {
return n1->proto().total_requested_bytes() >
n2->proto().total_requested_bytes();
} else if (opts.order_by == kOrderBy[2]) {
return n1->proto().total_exec_micros() >
n2->proto().total_exec_micros();
} else if (opts.order_by == kOrderBy[3]) {
return n1->proto().total_parameters() >
n2->proto().total_parameters();
} else if (opts.order_by == kOrderBy[4]) {
return n1->proto().total_float_ops() >
n2->proto().total_float_ops();
}
return name_cmp;
});
return sorted_nodes;
}
};
} // namespace tfprof
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_TOOLS_TFPROF_INTERNAL_TFPROF_SHOW_CODE_H_

View File

@ -56,29 +56,38 @@ TFStats::TFStats(std::unique_ptr<GraphDef> graph,
printf("Preparing Views...\n"); printf("Preparing Views...\n");
scope_view_ = std::unique_ptr<TFScope>(new TFScope(ckpt_reader_.get())); scope_view_ = std::unique_ptr<TFScope>(new TFScope(ckpt_reader_.get()));
graph_view_ = std::unique_ptr<TFGraph>(new TFGraph(ckpt_reader_.get())); graph_view_ = std::unique_ptr<TFGraph>(new TFGraph(ckpt_reader_.get()));
code_view_ = std::unique_ptr<TFCode>(new TFCode());
for (auto it = nodes_map_.begin(); it != nodes_map_.end(); it++) { for (auto it = nodes_map_.begin(); it != nodes_map_.end(); it++) {
scope_view_->AddNode(&it->second); scope_view_->AddNode(&it->second);
graph_view_->AddNode(&it->second); graph_view_->AddNode(&it->second);
code_view_->AddNode(&it->second);
} }
scope_view_->Build(); scope_view_->Build();
graph_view_->Build(); graph_view_->Build();
code_view_->Build();
} }
const TFProfNode& TFStats::PrintGraph(const string& cmd, const Options& opts) { const TFGraphNodeProto& TFStats::PrintGraph(const string& cmd,
const Options& opts) {
if (cmd == kCmds[0]) { if (cmd == kCmds[0]) {
return scope_view_->Show(opts); return scope_view_->Show(opts);
} else if (cmd == kCmds[1]) { } else if (cmd == kCmds[1]) {
return graph_view_->Show(opts); return graph_view_->Show(opts);
} else { } else {
fprintf(stderr, "Unknown command: %s\n", cmd.c_str()); fprintf(stderr, "Unknown command: %s\n", cmd.c_str());
return empty_node_; return empty_graph_node_;
} }
} }
const TFCodeNodeProto& TFStats::PrintCode(const Options& opts) {
return code_view_->Show(opts);
}
void TFStats::ParseGraph() { void TFStats::ParseGraph() {
for (const NodeDef& node : graph_->node()) { for (const NodeDef& node : graph_->node()) {
CHECK(nodes_map_.find(node.name()) == nodes_map_.end()); CHECK(nodes_map_.find(node.name()) == nodes_map_.end());
nodes_map_[node.name()] = TFNode(&node); nodes_map_[node.name()] = TFGraphNode(&node);
} }
for (auto it = nodes_map_.begin(); it != nodes_map_.end(); it++) { for (auto it = nodes_map_.begin(); it != nodes_map_.end(); it++) {
const NodeDef* node_def = it->second.node_def(); const NodeDef* node_def = it->second.node_def();
@ -110,6 +119,9 @@ void TFStats::ParseOpLog() {
if (entry.float_ops()) { if (entry.float_ops()) {
node->second.AddFloatOps(entry.float_ops()); node->second.AddFloatOps(entry.float_ops());
} }
if (entry.has_code_def()) {
node->second.AddCode(&entry.code_def());
}
} }
} }
@ -131,13 +143,14 @@ void TFStats::ParseRunMeta() {
"Missing CostGraphDef in RunMetadata.\nMaybe you forget to" "Missing CostGraphDef in RunMetadata.\nMaybe you forget to"
"set tf.ConfigProto(graph_options=tf.GraphOptions(" "set tf.ConfigProto(graph_options=tf.GraphOptions("
"build_cost_model=1)) to Session()\n"); "build_cost_model=1)) to Session()\n");
} } else {
for (const auto& node_pb : run_meta_->cost_graph().node()) { for (const auto& node_pb : run_meta_->cost_graph().node()) {
auto node = nodes_map_.find(node_pb.name()); auto node = nodes_map_.find(node_pb.name());
if (node == nodes_map_.end()) { if (node == nodes_map_.end()) {
continue; continue;
}
node->second.AddNodeStat(&node_pb);
} }
node->second.AddNodeStat(&node_pb);
} }
} }
} // namespace tfprof } // namespace tfprof

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/tools/tfprof/internal/tfprof_code.h"
#include "tensorflow/tools/tfprof/internal/tfprof_graph.h" #include "tensorflow/tools/tfprof/internal/tfprof_graph.h"
#include "tensorflow/tools/tfprof/internal/tfprof_node.h" #include "tensorflow/tools/tfprof/internal/tfprof_node.h"
#include "tensorflow/tools/tfprof/internal/tfprof_options.h" #include "tensorflow/tools/tfprof/internal/tfprof_options.h"
@ -56,7 +57,8 @@ class TFStats {
// Prints the results to stdout. Also returns the printed output in // Prints the results to stdout. Also returns the printed output in
// a proto. // a proto.
const TFProfNode& PrintGraph(const string& cmd, const Options& opts); const TFGraphNodeProto& PrintGraph(const string& cmd, const Options& opts);
const TFCodeNodeProto& PrintCode(const Options& opts);
private: private:
void ParseGraph(); void ParseGraph();
@ -67,13 +69,16 @@ class TFStats {
std::unique_ptr<TFScope> scope_view_; std::unique_ptr<TFScope> scope_view_;
std::unique_ptr<TFGraph> graph_view_; std::unique_ptr<TFGraph> graph_view_;
std::unique_ptr<TFCode> code_view_;
std::unique_ptr<GraphDef> graph_; std::unique_ptr<GraphDef> graph_;
std::unique_ptr<RunMetadata> run_meta_; std::unique_ptr<RunMetadata> run_meta_;
std::unique_ptr<OpLog> op_log_; std::unique_ptr<OpLog> op_log_;
std::unique_ptr<checkpoint::CheckpointReader> ckpt_reader_; std::unique_ptr<checkpoint::CheckpointReader> ckpt_reader_;
// Store TFNode instead of TFNode* to avoid large number of dynamic alloc. // Store TFGraphNode instead of TFGraphNode* to avoid large number of
std::map<string, TFNode> nodes_map_; // dynamic alloc.
TFProfNode empty_node_; std::map<string, TFGraphNode> nodes_map_;
TFGraphNodeProto empty_graph_node_;
TFCodeNodeProto empty_code_node_;
}; };
} // namespace tfprof } // namespace tfprof

View File

@ -76,9 +76,9 @@ TEST_F(TFProfStatsTest, CustomOpType) {
{".*"}, {""}, {".*"}, {""}, false, {".*"}, {""}, {".*"}, {""}, false,
{"params", "bytes", "micros", "float_ops", "num_hidden_ops"}, {"params", "bytes", "micros", "float_ops", "num_hidden_ops"},
false); false);
const TFProfNode& root = tf_stats_->PrintGraph("scope", opts); const TFGraphNodeProto& root = tf_stats_->PrintGraph("scope", opts);
TFProfNode expected; TFGraphNodeProto expected;
CHECK(protobuf::TextFormat::ParseFromString( CHECK(protobuf::TextFormat::ParseFromString(
"name: \"_TFProfRoot\"\nexec_micros: 0\nrequested_bytes: " "name: \"_TFProfRoot\"\nexec_micros: 0\nrequested_bytes: "
"0\ntotal_exec_micros: 5\ntotal_requested_bytes: 1480\ntotal_parameters: " "0\ntotal_exec_micros: 5\ntotal_requested_bytes: 1480\ntotal_parameters: "
@ -108,9 +108,9 @@ TEST_F(TFProfStatsTest, CheckPointOpType) {
3, 0, 0, 0, 0, {".*"}, "name", {kCkptVarType}, // accout_type_regexes 3, 0, 0, 0, 0, {".*"}, "name", {kCkptVarType}, // accout_type_regexes
{".*"}, {""}, {".*"}, {""}, false, {".*"}, {""}, {".*"}, {""}, false,
{"params", "bytes", "micros", "float_ops", "num_hidden_ops"}, false); {"params", "bytes", "micros", "float_ops", "num_hidden_ops"}, false);
const TFProfNode& root = tf_stats_->PrintGraph("scope", opts); const TFGraphNodeProto& root = tf_stats_->PrintGraph("scope", opts);
TFProfNode expected; TFGraphNodeProto expected;
CHECK(protobuf::TextFormat::ParseFromString( CHECK(protobuf::TextFormat::ParseFromString(
"name: \"_TFProfRoot\"\nexec_micros: 0\nrequested_bytes: " "name: \"_TFProfRoot\"\nexec_micros: 0\nrequested_bytes: "
"0\ntotal_exec_micros: 5\ntotal_requested_bytes: 1480\ntotal_parameters: " "0\ntotal_exec_micros: 5\ntotal_requested_bytes: 1480\ntotal_parameters: "
@ -141,9 +141,9 @@ TEST_F(TFProfStatsTest, TestGraph) {
{""}, {".*"}, {""}, false, {""}, {".*"}, {""}, false,
{"params", "bytes", "micros", "float_ops", "num_hidden_ops"}, {"params", "bytes", "micros", "float_ops", "num_hidden_ops"},
false); false);
const TFProfNode& root = tf_stats_->PrintGraph("graph", opts); const TFGraphNodeProto& root = tf_stats_->PrintGraph("graph", opts);
TFProfNode expected; TFGraphNodeProto expected;
CHECK(protobuf::TextFormat::ParseFromString( CHECK(protobuf::TextFormat::ParseFromString(
"name: \"_TFProfRoot\"\nexec_micros: 0\nrequested_bytes: 0\ninputs: " "name: \"_TFProfRoot\"\nexec_micros: 0\nrequested_bytes: 0\ninputs: "
"0\ntotal_exec_micros: 0\ntotal_requested_bytes: 0\ntotal_parameters: " "0\ntotal_exec_micros: 0\ntotal_requested_bytes: 0\ntotal_parameters: "
@ -155,9 +155,9 @@ TEST_F(TFProfStatsTest, TestGraph) {
TEST_F(TFProfStatsTest, TestFloatOps) { TEST_F(TFProfStatsTest, TestFloatOps) {
Options opts(10, 0, 0, 0, 1, {".*"}, "name", {".*"}, {".*"}, {""}, {".*"}, Options opts(10, 0, 0, 0, 1, {".*"}, "name", {".*"}, {".*"}, {""}, {".*"},
{""}, false, {"float_ops"}, false); {""}, false, {"float_ops"}, false);
const TFProfNode& root = tf_stats_->PrintGraph("scope", opts); const TFGraphNodeProto& root = tf_stats_->PrintGraph("scope", opts);
TFProfNode expected; TFGraphNodeProto expected;
CHECK(protobuf::TextFormat::ParseFromString( CHECK(protobuf::TextFormat::ParseFromString(
"name: \"_TFProfRoot\"\nexec_micros: 0\nrequested_bytes: " "name: \"_TFProfRoot\"\nexec_micros: 0\nrequested_bytes: "
"0\ntotal_exec_micros: 96\ntotal_requested_bytes: " "0\ntotal_exec_micros: 96\ntotal_requested_bytes: "
@ -187,9 +187,9 @@ TEST_F(TFProfStatsTest, TestAccountShownNameOnly) {
{"unit_2_1.*DW"}, // show_name_regexes. {"unit_2_1.*DW"}, // show_name_regexes.
{""}, true, // account_displayed_op_only. {""}, true, // account_displayed_op_only.
{"params"}, false); {"params"}, false);
const TFProfNode& root = tf_stats_->PrintGraph("scope", opts); const TFGraphNodeProto& root = tf_stats_->PrintGraph("scope", opts);
TFProfNode expected; TFGraphNodeProto expected;
CHECK(protobuf::TextFormat::ParseFromString( CHECK(protobuf::TextFormat::ParseFromString(
"name: \"_TFProfRoot\"\nexec_micros: 0\nrequested_bytes: " "name: \"_TFProfRoot\"\nexec_micros: 0\nrequested_bytes: "
"0\ntotal_exec_micros: 0\ntotal_requested_bytes: 0\ntotal_parameters: " "0\ntotal_exec_micros: 0\ntotal_requested_bytes: 0\ntotal_parameters: "
@ -203,8 +203,8 @@ TEST_F(TFProfStatsTest, TestShowTensorValue) {
{"unit_1_0.*gamma"}, {""}, false, {"unit_1_0.*gamma"}, {""}, false,
{"tensor_value"}, // Show tensor value from checkpoint. {"tensor_value"}, // Show tensor value from checkpoint.
false); false);
const TFProfNode& root = tf_stats_->PrintGraph("scope", opts); const TFGraphNodeProto& root = tf_stats_->PrintGraph("scope", opts);
TFProfNode expected; TFGraphNodeProto expected;
CHECK(protobuf::TextFormat::ParseFromString( CHECK(protobuf::TextFormat::ParseFromString(
"name: \"_TFProfRoot\"\nexec_micros: 0\nrequested_bytes: " "name: \"_TFProfRoot\"\nexec_micros: 0\nrequested_bytes: "
"0\ntotal_exec_micros: 96\ntotal_requested_bytes: " "0\ntotal_exec_micros: 96\ntotal_requested_bytes: "

View File

@ -58,9 +58,9 @@ TEST_F(TFProfTensorTest, Basics) {
Options opts(3, 0, 0, 0, 0, {".*"}, "name", {"VariableV2"}, {".*"}, {""}, Options opts(3, 0, 0, 0, 0, {".*"}, "name", {"VariableV2"}, {".*"}, {""},
{".*"}, {""}, false, {"tensor_value"}, // show the tensor value. {".*"}, {""}, false, {"tensor_value"}, // show the tensor value.
false); false);
const TFProfNode& root = tf_stats_->PrintGraph("scope", opts); const TFGraphNodeProto& root = tf_stats_->PrintGraph("scope", opts);
TFProfNode expected; TFGraphNodeProto expected;
CHECK(protobuf::TextFormat::ParseFromString( CHECK(protobuf::TextFormat::ParseFromString(
"name: \"_TFProfRoot\"\nexec_micros: 0\nrequested_bytes: " "name: \"_TFProfRoot\"\nexec_micros: 0\nrequested_bytes: "
"0\ntotal_exec_micros: 0\ntotal_requested_bytes: 0\ntotal_parameters: " "0\ntotal_exec_micros: 0\ntotal_requested_bytes: 0\ntotal_parameters: "

View File

@ -2,6 +2,17 @@ syntax = "proto2";
package tensorflow.tfprof; package tensorflow.tfprof;
// It specifies the Python callstack that creates an op.
message CodeDef {
repeated Trace traces = 1;
message Trace {
optional string file = 1;
optional int32 lineno = 2;
optional string function = 3;
optional string line = 4;
}
}
message OpLogEntry { message OpLogEntry {
// op name. // op name.
optional string name = 1; optional string name = 1;
@ -12,6 +23,8 @@ message OpLogEntry {
// User can define extra op type information for an op. This allows the user // User can define extra op type information for an op. This allows the user
// to select a group of ops precisely using op_type as a key. // to select a group of ops precisely using op_type as a key.
repeated string types = 3; repeated string types = 3;
// Used to support tfprof "code" view.
optional CodeDef code_def = 4;
} }
message OpLog { message OpLog {

View File

@ -160,12 +160,13 @@ int main(int argc, char** argv) {
"Profiling everything!\n"); "Profiling everything!\n");
return 0; return 0;
} else if (argc > 1) { } else if (argc > 1) {
if (tensorflow::string(argv[1]) == tensorflow::tfprof::kCmds[3]) { if (tensorflow::string(argv[1]) == tensorflow::tfprof::kCmds[4]) {
tensorflow::tfprof::PrintHelp(); tensorflow::tfprof::PrintHelp();
return 0; return 0;
} }
if (tensorflow::string(argv[1]) == tensorflow::tfprof::kCmds[0] || if (tensorflow::string(argv[1]) == tensorflow::tfprof::kCmds[0] ||
tensorflow::string(argv[1]) == tensorflow::tfprof::kCmds[1]) { tensorflow::string(argv[1]) == tensorflow::tfprof::kCmds[1] ||
tensorflow::string(argv[1]) == tensorflow::tfprof::kCmds[2]) {
cmd = argv[1]; cmd = argv[1];
} }
} }
@ -214,7 +215,10 @@ int main(int argc, char** argv) {
hide_name_regexes, FLAGS_account_displayed_op_only, select, FLAGS_viz, hide_name_regexes, FLAGS_account_displayed_op_only, select, FLAGS_viz,
FLAGS_dump_to_file); FLAGS_dump_to_file);
if (!cmd.empty()) { if (cmd == tensorflow::tfprof::kCmds[2]) {
tf_stat.PrintCode(opts);
return 0;
} else if (!cmd.empty()) {
tf_stat.PrintGraph(cmd, opts); tf_stat.PrintGraph(cmd, opts);
return 0; return 0;
} }
@ -240,10 +244,12 @@ int main(int argc, char** argv) {
fprintf(stderr, "E: %s\n", s.ToString().c_str()); fprintf(stderr, "E: %s\n", s.ToString().c_str());
continue; continue;
} }
if (cmd == tensorflow::tfprof::kCmds[2]) { if (cmd == tensorflow::tfprof::kCmds[3]) {
opts = new_opts; opts = new_opts;
} else if (cmd == tensorflow::tfprof::kCmds[3]) { } else if (cmd == tensorflow::tfprof::kCmds[4]) {
tensorflow::tfprof::PrintHelp(); tensorflow::tfprof::PrintHelp();
} else if (cmd == tensorflow::tfprof::kCmds[2]) {
tf_stat.PrintCode(new_opts);
} else { } else {
tf_stat.PrintGraph(cmd, new_opts); tf_stat.PrintGraph(cmd, new_opts);
} }

View File

@ -21,4 +21,4 @@ message OptionsProto {
repeated string select = 14; repeated string select = 14;
optional bool viz = 15; optional bool viz = 15;
optional string dump_to_file = 16; optional string dump_to_file = 16;
} }

View File

@ -14,7 +14,8 @@ message TFProfTensorProto {
repeated string value_str = 4; repeated string value_str = 4;
} }
message TFProfNode { // A node in TensorFlow graph. Used by scope/graph view.
message TFGraphNodeProto {
// op name. // op name.
optional string name = 1; optional string name = 1;
// tensor value restored from checkpoint. // tensor value restored from checkpoint.
@ -45,5 +46,34 @@ message TFProfNode {
repeated TensorShapeProto shapes = 11; repeated TensorShapeProto shapes = 11;
// Descendants of the graph. The actual descendants depend on the data // Descendants of the graph. The actual descendants depend on the data
// structure used (scope, graph). // structure used (scope, graph).
repeated TFProfNode children = 12; repeated TFGraphNodeProto children = 12;
}
// A node in TensorFlow Python call trace stack. Used by code view.
message TFCodeNodeProto {
// A trace in the trace stack.
optional string name = 1;
// code execution time.
optional int64 exec_micros = 2;
// Total requested bytes by the code.
optional int64 requested_bytes = 3;
// Number of parameters if available.
optional int64 parameters = 4;
// Number of float operations.
optional int64 float_ops = 5;
// The following are the aggregated stats from called descendents and the
// trace itself. The actual descendants depend on the data structure used.
optional int64 total_exec_micros = 6;
optional int64 total_requested_bytes = 7;
optional int64 total_parameters = 8;
optional int64 total_float_ops = 9;
// A set of graph nodes created by the leaf of the call stack.
// 'children' field should be empty if graph_nodes is non-empty.
repeated TFGraphNodeProto graph_nodes = 10;
// Descendants of the graph. The actual descendants depend on the data
// structure used (scope, graph).
repeated TFCodeNodeProto children = 11;
} }