Extend tfprof to associate op stats with Python codes.
It's backward compatible. Stats of a source code line are aggregated from all ops created by that line. A example. _TFProfRoot (0us/22.44ms) model_analyzer_test.py:149:run_filename_as_m...:none (0us/22.44ms) model_analyzer_test.py:33:_run_code_in_main:none (0us/22.44ms) model_analyzer_test.py:208:<module>:test.main() (0us/22.44ms) model_analyzer_test.py:132:testComplexCodeView:x = lib.BuildFull... (0us/22.44ms) model_analyzer_testlib.py:63:BuildFullModel:return sgd_op.min... (0us/21.83ms) model_analyzer_testlib.py:54:BuildFullModel:seq.append(array_... (0us/254us) model_analyzer_testlib.py:42:BuildSmallModel:x = nn_ops.conv2d... (0us/134us) ... model_analyzer_testlib.py:61:BuildFullModel:loss = nn_ops.l2_... (0us/28us) model_analyzer_test.py:134:testComplexCodeView:sess.run(variable... (0us/0us) Change: 155393864
This commit is contained in:
parent
ec8ffb9eaf
commit
697f34ca82
tensorflow
contrib/tfprof
README.md
python/tools/tfprof
tools/tfprof
README.md
internal
BUILDprint_model_analysis.cctfprof_code.cctfprof_code.htfprof_graph.cctfprof_graph.htfprof_node.cctfprof_node.htfprof_options.htfprof_scope.cctfprof_scope.htfprof_show.cctfprof_show.htfprof_show_code.cctfprof_show_code.htfprof_stats.cctfprof_stats.htfprof_stats_test.cctfprof_tensor_test.cc
tfprof_log.prototfprof_main.cctfprof_options.prototfprof_output.proto@ -11,7 +11,12 @@ Consultants: Jon Shlens, Pete Warden
|
||||
1. Measure model parameters, float operations, tensor shapes.
|
||||
2. Measure op execution times, requested memory size and device placement.
|
||||
3. Inspect checkpoint tensors' shapes and their values.
|
||||
4. Explore model based on name scope or graph structure.
|
||||
4. 3 ways to view and explore TensorFlow model profiles
|
||||
|
||||
* Organize by Python code call stack.
|
||||
* Organize by TensorFlow operation name scope hierarchies.
|
||||
* Organize by TensorFlow operation inputs/outputs graph.
|
||||
|
||||
5. Selectively grouping/filtering/accounting/ordering ops.
|
||||
|
||||
tfprof can be used as Python API, Interactive CLI and One-shot Script.
|
||||
@ -28,7 +33,8 @@ param_stats = tf.contrib.tfprof.model_analyzer.print_model_analysis(
|
||||
tfprof_options=tf.contrib.tfprof.model_analyzer.
|
||||
TRAINABLE_VARS_PARAMS_STAT_OPTIONS)
|
||||
|
||||
# param_stats is tensorflow.tfprof.TFProfNode proto. It organize the statistics
|
||||
# param_stats is tensorflow.tfprof.TFGraphNodeProto proto.
|
||||
# It organize the statistics
|
||||
# of each graph node in tree scructure. Let's print the root below.
|
||||
sys.stdout.write('total_params: %d\n' % param_stats.total_parameters)
|
||||
```
|
||||
|
@ -21,16 +21,34 @@ py_test(
|
||||
name = "model_analyzer_test",
|
||||
srcs = ["model_analyzer_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_pip"],
|
||||
deps = [
|
||||
":model_analyzer",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
":model_analyzer_testlib",
|
||||
"//tensorflow/python:client",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:variables",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "model_analyzer_testlib",
|
||||
srcs = ["model_analyzer_testlib.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":model_analyzer",
|
||||
"//tensorflow/contrib/rnn:rnn_py",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:nn_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:rnn",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
],
|
||||
|
@ -123,7 +123,7 @@ def print_model_analysis(graph,
|
||||
"""Print model statistics.
|
||||
|
||||
Prints the model statistics to stdout. Also returns the results
|
||||
in a TFProfNode proto. See go/tfprof or run tfprof tool:
|
||||
in a TFGraphNodeProto proto. See go/tfprof or run tfprof tool:
|
||||
'bazel run third_party/tensorflow/tools/tfprof help'
|
||||
|
||||
Examples:
|
||||
@ -142,15 +142,19 @@ def print_model_analysis(graph,
|
||||
'micros' and 'bytes'.
|
||||
op_log: tensorflow::tfprof::OpLog proto. users can use this proto to
|
||||
group together ops and use a op_type to select the group.
|
||||
tfprof_cmd: string. Either 'scope' or 'graph'. 'scope' view organize
|
||||
ops using their name scopes. 'graph' view organize ops using
|
||||
their graph inputs.
|
||||
tfprof_cmd: string. Either 'scope', 'graph', 'code'.
|
||||
'scope' view organize outputs using ops' name scope.
|
||||
'graph' view organize outputs using op's inputs/outputs.
|
||||
'code' view organize outputs using Python call stack.
|
||||
tfprof_options: See 'tfprof help' for details.
|
||||
Returns:
|
||||
TFProfNode proto. Side effect: a formatted output to stdout.
|
||||
If tfprof_cmd is 'scope' or 'graph', returns TFGraphNodeProto proto.
|
||||
If tfprof_cmd is 'code', returns TFCodeNodeProto proto.
|
||||
Side effect: a formatted output to stdout.
|
||||
"""
|
||||
# pylint: disable=protected-access
|
||||
op_log = tfprof_logger._merge_default_with_oplog(graph, op_log, run_meta)
|
||||
op_log = tfprof_logger._merge_default_with_oplog(
|
||||
graph, op_log, run_meta, add_trace=tfprof_cmd == 'code')
|
||||
# pylint: enable=protected-access
|
||||
opts = tfprof_options_pb2.OptionsProto()
|
||||
opts.max_depth = tfprof_options['max_depth']
|
||||
@ -178,11 +182,24 @@ def print_model_analysis(graph,
|
||||
opts.dump_to_file = tfprof_options['dump_to_file']
|
||||
|
||||
run_meta_str = run_meta.SerializeToString() if run_meta else b''
|
||||
op_log_str = op_log.SerializeToString() if op_log else b''
|
||||
|
||||
tfprof_node = tfprof_output_pb2.TFProfNode()
|
||||
tfprof_node.ParseFromString(
|
||||
print_mdl.PrintModelAnalysis(
|
||||
graph.as_graph_def().SerializeToString(), run_meta_str, op_log_str,
|
||||
tfprof_cmd.encode('utf-8'), opts.SerializeToString()))
|
||||
if tfprof_cmd == 'code':
|
||||
tfprof_node = tfprof_output_pb2.TFCodeNodeProto()
|
||||
tfprof_node.ParseFromString(
|
||||
print_mdl.PrintModelAnalysis(
|
||||
graph.as_graph_def().SerializeToString(),
|
||||
run_meta_str,
|
||||
op_log.SerializeToString(),
|
||||
tfprof_cmd.encode('utf-8'),
|
||||
opts.SerializeToString()))
|
||||
else:
|
||||
tfprof_node = tfprof_output_pb2.TFGraphNodeProto()
|
||||
tfprof_node.ParseFromString(
|
||||
print_mdl.PrintModelAnalysis(
|
||||
graph.as_graph_def().SerializeToString(),
|
||||
run_meta_str,
|
||||
op_log.SerializeToString(),
|
||||
tfprof_cmd.encode('utf-8'),
|
||||
opts.SerializeToString()))
|
||||
|
||||
return tfprof_node
|
||||
|
@ -18,49 +18,27 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
# XXX: this depends on pywrap_tensorflow and must come later
|
||||
from tensorflow.contrib.tfprof.python.tools.tfprof import model_analyzer
|
||||
from tensorflow.contrib.tfprof.python.tools.tfprof import model_analyzer_testlib as lib
|
||||
|
||||
|
||||
class PrintModelAnalysisTest(test.TestCase):
|
||||
|
||||
def _BuildSmallModel(self):
|
||||
image = array_ops.zeros([2, 6, 6, 3])
|
||||
_ = variable_scope.get_variable(
|
||||
'ScalarW', [],
|
||||
dtypes.float32,
|
||||
initializer=init_ops.random_normal_initializer(stddev=0.001))
|
||||
kernel = variable_scope.get_variable(
|
||||
'DW', [3, 3, 3, 6],
|
||||
dtypes.float32,
|
||||
initializer=init_ops.random_normal_initializer(stddev=0.001))
|
||||
x = nn_ops.conv2d(image, kernel, [1, 2, 2, 1], padding='SAME')
|
||||
kernel = variable_scope.get_variable(
|
||||
'DW2', [2, 2, 6, 12],
|
||||
dtypes.float32,
|
||||
initializer=init_ops.random_normal_initializer(stddev=0.001))
|
||||
x = nn_ops.conv2d(x, kernel, [1, 2, 2, 1], padding='SAME')
|
||||
return x
|
||||
|
||||
def testDumpToFile(self):
|
||||
ops.reset_default_graph()
|
||||
opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS
|
||||
opts['dump_to_file'] = os.path.join(test.get_temp_dir(), 'dump')
|
||||
|
||||
with session.Session() as sess, ops.device('/cpu:0'):
|
||||
_ = self._BuildSmallModel()
|
||||
_ = lib.BuildSmallModel()
|
||||
model_analyzer.print_model_analysis(sess.graph, tfprof_options=opts)
|
||||
|
||||
with gfile.Open(opts['dump_to_file'], 'r') as f:
|
||||
@ -71,6 +49,7 @@ class PrintModelAnalysisTest(test.TestCase):
|
||||
f.read())
|
||||
|
||||
def testSelectEverything(self):
|
||||
ops.reset_default_graph()
|
||||
opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS
|
||||
opts['dump_to_file'] = os.path.join(test.get_temp_dir(), 'dump')
|
||||
opts['account_type_regexes'] = ['.*']
|
||||
@ -78,8 +57,10 @@ class PrintModelAnalysisTest(test.TestCase):
|
||||
'bytes', 'params', 'float_ops', 'num_hidden_ops', 'device', 'op_types'
|
||||
]
|
||||
|
||||
with session.Session() as sess, ops.device('/cpu:0'):
|
||||
x = self._BuildSmallModel()
|
||||
config = config_pb2.ConfigProto(
|
||||
graph_options=config_pb2.GraphOptions(build_cost_model=1))
|
||||
with session.Session(config=config) as sess, ops.device('/cpu:0'):
|
||||
x = lib.BuildSmallModel()
|
||||
|
||||
sess.run(variables.global_variables_initializer())
|
||||
run_meta = config_pb2.RunMetadata()
|
||||
@ -98,6 +79,118 @@ class PrintModelAnalysisTest(test.TestCase):
|
||||
f.read())
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
def testSimpleCodeView(self):
|
||||
ops.reset_default_graph()
|
||||
opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS.copy()
|
||||
opts['dump_to_file'] = os.path.join(test.get_temp_dir(), 'dump')
|
||||
opts['account_type_regexes'] = ['.*']
|
||||
opts['show_name_regexes'] = ['.*model_analyzer_testlib.*']
|
||||
opts['account_displayed_op_only'] = False
|
||||
# TODO(xpan): Test 'micros'. Since the execution time changes each run,
|
||||
# it's a bit difficult to test it now.
|
||||
opts['select'] = [
|
||||
'bytes', 'params', 'float_ops', 'num_hidden_ops', 'device',
|
||||
]
|
||||
|
||||
config = config_pb2.ConfigProto(
|
||||
graph_options=config_pb2.GraphOptions(build_cost_model=1))
|
||||
with session.Session(config=config) as sess, ops.device('/cpu:0'):
|
||||
x = lib.BuildSmallModel()
|
||||
|
||||
sess.run(variables.global_variables_initializer())
|
||||
run_meta = config_pb2.RunMetadata()
|
||||
_ = sess.run(x,
|
||||
options=config_pb2.RunOptions(
|
||||
trace_level=config_pb2.RunOptions.FULL_TRACE),
|
||||
run_metadata=run_meta)
|
||||
|
||||
model_analyzer.print_model_analysis(
|
||||
sess.graph, run_meta, tfprof_cmd='code', tfprof_options=opts)
|
||||
|
||||
with gfile.Open(opts['dump_to_file'], 'r') as f:
|
||||
# pylint: disable=line-too-long
|
||||
self.assertEqual(
|
||||
'_TFProfRoot (0/451 params, 0/10.44k flops, 0B/5.28KB)\n model_analyzer_testlib.py:33:BuildSmallModel:image = array_ops... (0/0 params, 0/0 flops, 0B/864B)\n model_analyzer_testlib.py:37:BuildSmallModel:initializer=init_... (0/1 params, 0/0 flops, 0B/0B)\n model_analyzer_testlib.py:41:BuildSmallModel:initializer=init_... (0/162 params, 0/0 flops, 0B/1.30KB)\n model_analyzer_testlib.py:42:BuildSmallModel:x = nn_ops.conv2d... (0/0 params, 0/5.83k flops, 0B/432B)\n model_analyzer_testlib.py:46:BuildSmallModel:initializer=init_... (0/288 params, 0/0 flops, 0B/2.30KB)\n model_analyzer_testlib.py:47:BuildSmallModel:x = nn_ops.conv2d... (0/0 params, 0/4.61k flops, 0B/384B)\n',
|
||||
f.read())
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
def testComplexCodeView(self):
|
||||
ops.reset_default_graph()
|
||||
opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS.copy()
|
||||
opts['dump_to_file'] = os.path.join(test.get_temp_dir(), 'dump')
|
||||
opts['account_type_regexes'] = ['.*']
|
||||
opts['show_name_regexes'] = ['.*model_analyzer_testlib.py.*']
|
||||
opts['account_displayed_op_only'] = False
|
||||
opts['select'] = ['params', 'float_ops']
|
||||
|
||||
config = config_pb2.ConfigProto(
|
||||
graph_options=config_pb2.GraphOptions(build_cost_model=1))
|
||||
with session.Session(config=config) as sess, ops.device('/cpu:0'):
|
||||
x = lib.BuildFullModel()
|
||||
|
||||
sess.run(variables.global_variables_initializer())
|
||||
run_meta = config_pb2.RunMetadata()
|
||||
_ = sess.run(x,
|
||||
options=config_pb2.RunOptions(
|
||||
trace_level=config_pb2.RunOptions.FULL_TRACE),
|
||||
run_metadata=run_meta)
|
||||
|
||||
tfprof_node = model_analyzer.print_model_analysis(
|
||||
sess.graph, run_meta, tfprof_cmd='code', tfprof_options=opts)
|
||||
|
||||
# pylint: disable=line-too-long
|
||||
with gfile.Open(opts['dump_to_file'], 'r') as f:
|
||||
self.assertEqual(
|
||||
'_TFProfRoot (0/2.84k params, 0/54.08k flops)\n model_analyzer_testlib.py:56:BuildFullModel:seq.append(array_... (0/1.80k params, 0/41.76k flops)\n model_analyzer_testlib.py:33:BuildSmallModel:image = array_ops... (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:37:BuildSmallModel:initializer=init_... (0/4 params, 0/0 flops)\n model_analyzer_testlib.py:41:BuildSmallModel:initializer=init_... (0/648 params, 0/0 flops)\n model_analyzer_testlib.py:42:BuildSmallModel:x = nn_ops.conv2d... (0/0 params, 0/23.33k flops)\n model_analyzer_testlib.py:46:BuildSmallModel:initializer=init_... (0/1.15k params, 0/0 flops)\n model_analyzer_testlib.py:47:BuildSmallModel:x = nn_ops.conv2d... (0/0 params, 0/18.43k flops)\n model_analyzer_testlib.py:60:BuildFullModel:cell, array_ops.c... (0/1.04k params, 0/4.13k flops)\n model_analyzer_testlib.py:62:BuildFullModel:target = array_op... (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:63:BuildFullModel:loss = nn_ops.l2_... (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:65:BuildFullModel:return sgd_op.min... (0/0 params, 0/8.19k flops)\n',
|
||||
f.read())
|
||||
|
||||
self.assertLess(0, tfprof_node.total_exec_micros)
|
||||
self.assertEqual(2844, tfprof_node.total_parameters)
|
||||
self.assertEqual(54080, tfprof_node.total_float_ops)
|
||||
self.assertEqual(5, len(tfprof_node.children))
|
||||
self.assertEqual('_TFProfRoot', tfprof_node.name)
|
||||
self.assertEqual('model_analyzer_testlib.py:56:BuildFullModel:seq.append(array_...',
|
||||
tfprof_node.children[0].name)
|
||||
self.assertEqual('model_analyzer_testlib.py:60:BuildFullModel:cell, array_ops.c...',
|
||||
tfprof_node.children[1].name)
|
||||
self.assertEqual('model_analyzer_testlib.py:62:BuildFullModel:target = array_op...',
|
||||
tfprof_node.children[2].name)
|
||||
self.assertEqual('model_analyzer_testlib.py:63:BuildFullModel:loss = nn_ops.l2_...',
|
||||
tfprof_node.children[3].name)
|
||||
self.assertEqual('model_analyzer_testlib.py:65:BuildFullModel:return sgd_op.min...',
|
||||
tfprof_node.children[4].name)
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
def testCodeViewLeafGraphNode(self):
|
||||
ops.reset_default_graph()
|
||||
opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS.copy()
|
||||
opts['account_type_regexes'] = ['.*']
|
||||
opts['account_displayed_op_only'] = False
|
||||
opts['select'] = [
|
||||
'bytes', 'params', 'float_ops', 'num_hidden_ops', 'device'
|
||||
]
|
||||
|
||||
config = config_pb2.ConfigProto(
|
||||
graph_options=config_pb2.GraphOptions(build_cost_model=1))
|
||||
with session.Session(config=config) as sess, ops.device('/cpu:0'):
|
||||
x = lib.BuildSmallModel()
|
||||
|
||||
sess.run(variables.global_variables_initializer())
|
||||
run_meta = config_pb2.RunMetadata()
|
||||
_ = sess.run(x,
|
||||
options=config_pb2.RunOptions(
|
||||
trace_level=config_pb2.RunOptions.FULL_TRACE),
|
||||
run_metadata=run_meta)
|
||||
|
||||
tfprof_node = model_analyzer.print_model_analysis(
|
||||
sess.graph, run_meta, tfprof_cmd='code', tfprof_options=opts)
|
||||
|
||||
leaf = tfprof_node
|
||||
while leaf.children:
|
||||
self.assertEqual(0, len(leaf.graph_nodes))
|
||||
leaf = leaf.children[0]
|
||||
self.assertEqual(1, len(leaf.graph_nodes))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -0,0 +1,67 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""A test lib that defines some models."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import BasicRNNCell
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import rnn
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.training import gradient_descent
|
||||
|
||||
|
||||
def BuildSmallModel():
|
||||
"""Build a small forward conv model."""
|
||||
image = array_ops.zeros([2, 6, 6, 3])
|
||||
_ = variable_scope.get_variable(
|
||||
'ScalarW', [],
|
||||
dtypes.float32,
|
||||
initializer=init_ops.random_normal_initializer(stddev=0.001))
|
||||
kernel = variable_scope.get_variable(
|
||||
'DW', [3, 3, 3, 6],
|
||||
dtypes.float32,
|
||||
initializer=init_ops.random_normal_initializer(stddev=0.001))
|
||||
x = nn_ops.conv2d(image, kernel, [1, 2, 2, 1], padding='SAME')
|
||||
kernel = variable_scope.get_variable(
|
||||
'DW2', [2, 2, 6, 12],
|
||||
dtypes.float32,
|
||||
initializer=init_ops.random_normal_initializer(stddev=0.001))
|
||||
x = nn_ops.conv2d(x, kernel, [1, 2, 2, 1], padding='SAME')
|
||||
return x
|
||||
|
||||
|
||||
def BuildFullModel():
|
||||
"""Build the full model with conv,rnn,opt."""
|
||||
seq = []
|
||||
for i in range(4):
|
||||
with variable_scope.variable_scope('inp_%d' % i):
|
||||
seq.append(array_ops.reshape(BuildSmallModel(), [2, 1, -1]))
|
||||
|
||||
cell = BasicRNNCell(16, 48)
|
||||
out = rnn.dynamic_rnn(
|
||||
cell, array_ops.concat(seq, axis=1), dtype=dtypes.float32)[0]
|
||||
|
||||
target = array_ops.ones_like(out)
|
||||
loss = nn_ops.l2_loss(math_ops.reduce_mean(target - out))
|
||||
sgd_op = gradient_descent.GradientDescentOptimizer(1e-2)
|
||||
return sgd_op.minimize(loss)
|
||||
|
||||
|
@ -96,12 +96,13 @@ class PrintModelAnalysisTest(test.TestCase):
|
||||
|
||||
with session.Session() as sess, ops.device('/cpu:0'):
|
||||
_ = self._BuildSmallModel()
|
||||
tfprof_pb = tfprof_output_pb2.TFProfNode()
|
||||
tfprof_pb = tfprof_output_pb2.TFGraphNodeProto()
|
||||
tfprof_pb.ParseFromString(
|
||||
print_mdl.PrintModelAnalysis(sess.graph.as_graph_def(
|
||||
).SerializeToString(), b'', b'', b'scope', opts.SerializeToString()))
|
||||
print_mdl.PrintModelAnalysis(
|
||||
sess.graph.as_graph_def().SerializeToString(),
|
||||
b'', b'', b'scope', opts.SerializeToString()))
|
||||
|
||||
expected_pb = tfprof_output_pb2.TFProfNode()
|
||||
expected_pb = tfprof_output_pb2.TFGraphNodeProto()
|
||||
text_format.Merge(r"""name: "_TFProfRoot"
|
||||
exec_micros: 0
|
||||
requested_bytes: 0
|
||||
|
@ -62,12 +62,13 @@ def _fill_missing_graph_shape(graph, run_meta):
|
||||
return graph
|
||||
|
||||
|
||||
def _get_logged_ops(graph, run_meta=None):
|
||||
def _get_logged_ops(graph, run_meta=None, add_trace=False):
|
||||
"""Extract trainable model parameters and FLOPs for ops from a Graph.
|
||||
|
||||
Args:
|
||||
graph: tf.Graph.
|
||||
run_meta: RunMetadata proto used to complete shape information.
|
||||
add_trace: Whether to add op trace information.
|
||||
Returns:
|
||||
logged_ops: dict mapping from op_name to OpLogEntry.
|
||||
"""
|
||||
@ -76,21 +77,32 @@ def _get_logged_ops(graph, run_meta=None):
|
||||
|
||||
op_missing_shape = 0
|
||||
logged_ops = {}
|
||||
graph_def = graph.as_graph_def()
|
||||
for node in graph_def.node:
|
||||
for op in graph.get_operations():
|
||||
try:
|
||||
stats = ops.get_stats_for_node_def(graph, node, REGISTERED_FLOP_STATS)
|
||||
stats = ops.get_stats_for_node_def(
|
||||
graph, op.node_def, REGISTERED_FLOP_STATS)
|
||||
except ValueError:
|
||||
# Catch Exception When shape is incomplete. Skip it.
|
||||
op_missing_shape += 1
|
||||
stats = None
|
||||
|
||||
if not stats or not stats.value:
|
||||
continue
|
||||
if node.name not in logged_ops:
|
||||
entry = tfprof_log_pb2.OpLogEntry()
|
||||
entry.name = node.name
|
||||
entry = tfprof_log_pb2.OpLogEntry()
|
||||
entry.name = op.name
|
||||
add_entry = False
|
||||
if stats and stats.value:
|
||||
entry.float_ops = int(stats.value)
|
||||
add_entry = True
|
||||
|
||||
if add_trace:
|
||||
for tb in op.traceback:
|
||||
trace = entry.code_def.traces.add()
|
||||
trace.file = tb[0] if tb[0] else 'none'
|
||||
trace.lineno = tb[1] if tb[1] else -1
|
||||
trace.function = tb[2] if tb[2] else 'none'
|
||||
trace.line = tb[3] if tb[3] else 'none'
|
||||
add_entry = True
|
||||
|
||||
if add_entry:
|
||||
logged_ops[entry.name] = entry
|
||||
|
||||
for v in graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES):
|
||||
@ -108,18 +120,21 @@ def _get_logged_ops(graph, run_meta=None):
|
||||
return logged_ops
|
||||
|
||||
|
||||
def _merge_default_with_oplog(graph, op_log=None, run_meta=None):
|
||||
def _merge_default_with_oplog(graph, op_log=None,
|
||||
run_meta=None,
|
||||
add_trace=False):
|
||||
"""Merge the tfprof default extra info with caller's op_log.
|
||||
|
||||
Args:
|
||||
graph: tf.Graph.
|
||||
op_log: OpLog proto.
|
||||
run_meta: RunMetadata proto used to complete shape information.
|
||||
add_trace: Whether to add op trace information.
|
||||
Returns:
|
||||
tmp_op_log: Merged OpLog proto.
|
||||
"""
|
||||
tmp_op_log = tfprof_log_pb2.OpLog()
|
||||
logged_ops = _get_logged_ops(graph, run_meta)
|
||||
logged_ops = _get_logged_ops(graph, run_meta, add_trace=add_trace)
|
||||
if not op_log:
|
||||
tmp_op_log.log_entries.extend(logged_ops.values())
|
||||
else:
|
||||
@ -131,13 +146,16 @@ def _merge_default_with_oplog(graph, op_log=None, run_meta=None):
|
||||
all_ops[op_name].types.extend(entry.types)
|
||||
if entry.float_ops > 0 and all_ops[op_name].float_ops == 0:
|
||||
all_ops[op_name].float_ops = entry.float_ops
|
||||
if entry.code_def.traces and not all_ops[op_name].code_def.traces:
|
||||
all_ops[op_name].code_def.MergeFrom(entry.code_def)
|
||||
else:
|
||||
all_ops[op_name] = entry
|
||||
tmp_op_log.log_entries.extend(all_ops.values())
|
||||
return tmp_op_log
|
||||
|
||||
|
||||
def write_op_log(graph, log_dir, op_log=None, run_meta=None):
|
||||
def write_op_log(graph, log_dir, op_log=None, run_meta=None,
|
||||
add_trace=False):
|
||||
"""Log provided 'op_log', and add additional model information below.
|
||||
|
||||
The API also assigns ops in tf.trainable_variables() an op type called
|
||||
@ -154,8 +172,9 @@ def write_op_log(graph, log_dir, op_log=None, run_meta=None):
|
||||
one is created.
|
||||
run_meta: (Optional) RunMetadata proto that helps flops computation using
|
||||
run time shape information.
|
||||
add_trace: Whether to add op trace information. Used to support "code" view.
|
||||
"""
|
||||
op_log = _merge_default_with_oplog(graph, op_log, run_meta)
|
||||
op_log = _merge_default_with_oplog(graph, op_log, run_meta, add_trace)
|
||||
|
||||
with gfile.Open(os.path.join(log_dir, 'tfprof_log'), 'w') as log:
|
||||
log.write(op_log.SerializeToString())
|
||||
|
@ -10,12 +10,17 @@ Consultants: Jon Shlens, Pete Warden
|
||||
1. Measure model parameters, float operations, tensor shapes.
|
||||
2. Measure op execution times, requested memory size and device placement.
|
||||
3. Inspect checkpoint tensors' shapes and their values.
|
||||
4. Explore model based on name scope or graph structure.
|
||||
4. 3 ways to view and explore TensorFlow model profiles
|
||||
|
||||
* Organize by Python code call stack.
|
||||
* Organize by TensorFlow operation name scope hierarchies.
|
||||
* Organize by TensorFlow operation inputs/outputs graph.
|
||||
|
||||
5. Selectively grouping/filtering/accounting/ordering ops.
|
||||
|
||||
[Python API Tutorials](#python-api-tutorials): It can be called directly from
|
||||
Python codes. Results are either printed
|
||||
to stdout or dumped to file. tensorflow.tfprof.TFProfNode proto is returned from
|
||||
to stdout or dumped to file. tensorflow.tfprof.TFGraphNodeProto proto is returned from
|
||||
the API to allow users to perform further analysis.
|
||||
|
||||
[CLI Tutorials](#cli-tutorials):
|
||||
@ -33,13 +38,23 @@ tfprof is part of TensorFlow core. Simply ```import tensorflow as tf```.
|
||||
### Examine the shapes and sizes of all trainable Variables.
|
||||
```python
|
||||
# Print trainable variable parameter statistics to stdout.
|
||||
# By default, statistics are associated with each graph node.
|
||||
param_stats = tf.contrib.tfprof.model_analyzer.print_model_analysis(
|
||||
tf.get_default_graph(),
|
||||
tfprof_options=tf.contrib.tfprof.model_analyzer.
|
||||
TRAINABLE_VARS_PARAMS_STAT_OPTIONS)
|
||||
|
||||
# param_stats is tensorflow.tfprof.TFProfNode proto. It organize the statistics
|
||||
# of each graph node in tree scructure. Let's print the root below.
|
||||
|
||||
# Set tfprof_cmd='code' to associate statistics with Python codes.
|
||||
opts = tf.contrib.tfprof.model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS
|
||||
opts['show_name_regexes'] = ['.*my_code1.py.*', '.*my_code2.py.*']
|
||||
param_stats = tf.contrib.tfprof.model_analyzer.print_model_analysis(
|
||||
tf.get_default_graph(),
|
||||
tfprof_cmd='code'
|
||||
tfprof_options=opts)
|
||||
|
||||
# param_stats is tensorflow.tfprof.TFGraphNodeProto proto.
|
||||
# Let's print the root below.
|
||||
sys.stdout.write('total_params: %d\n' % param_stats.total_parameters)
|
||||
```
|
||||
|
||||
@ -84,8 +99,20 @@ Finally, you may run `print_model_analysis` to explore the timing and memory
|
||||
demands of the model.
|
||||
|
||||
``` python
|
||||
# See model_analyzer_test.py for more examples.
|
||||
#
|
||||
# Print to stdout an analysis of the memory usage and the timing information
|
||||
# from running the graph broken down by operations.
|
||||
# broken down by python codes.
|
||||
opts = tf.contrib.tfprof.model_analyzer.PRINT_ALL_TIMING_MEMORY.copy()
|
||||
opts['show_name_regexes'] = ['.*my_code.py.*']
|
||||
tf.contrib.tfprof.model_analyzer.print_model_analysis(
|
||||
tf.get_default_graph(),
|
||||
run_meta=run_metadata,
|
||||
tfprof_cmd='code',
|
||||
tfprof_options=opts)
|
||||
|
||||
# Print to stdout an analysis of the memory usage and the timing information
|
||||
# broken down by operations.
|
||||
tf.contrib.tfprof.model_analyzer.print_model_analysis(
|
||||
tf.get_default_graph(),
|
||||
run_meta=run_metadata,
|
||||
@ -138,9 +165,9 @@ bazel-bin/tensorflow/tools/tfprof/tfprof \
|
||||
--run_meta_path=run_meta \
|
||||
--checkpoint_path=model.ckpt
|
||||
#
|
||||
# tfprof_log is used to define customized op types and float ops.
|
||||
# tfprof_log is used to define customized op types, float ops and code traces.
|
||||
# Use tfprof_logger.write_op_log() to create tfprof_log.
|
||||
# See 11) in Examples section on generating tfprof_log file.
|
||||
# See 12) in Examples section on generating tfprof_log file.
|
||||
bazel-bin/tensorflow/tools/tfprof/tfprof \
|
||||
--graph_path=graph.pbtxt \
|
||||
--run_meta_path=run_meta \
|
||||
@ -174,7 +201,28 @@ tfprof>
|
||||
-dump_to_file
|
||||
```
|
||||
|
||||
3) I want to see the `BatchNorm`'s gamma value in checkpoint.
|
||||
3) I want to see which line of my python codes costs most time!
|
||||
|
||||
```shell
|
||||
# Requires --graph_path --op_log_path
|
||||
tfprof> code -max_depth 1000 -show_name_regexes .*model_analyzer.*py.* -select micros -account_type_regexes .* -order_by micros
|
||||
_TFProfRoot (0us/22.44ms)
|
||||
model_analyzer_test.py:149:run_filename_as_m...:none (0us/22.44ms)
|
||||
model_analyzer_test.py:33:_run_code_in_main:none (0us/22.44ms)
|
||||
model_analyzer_test.py:208:<module>:test.main() (0us/22.44ms)
|
||||
model_analyzer_test.py:132:testComplexCodeView:x = lib.BuildFull... (0us/22.44ms)
|
||||
model_analyzer_testlib.py:63:BuildFullModel:return sgd_op.min... (0us/21.83ms)
|
||||
model_analyzer_testlib.py:58:BuildFullModel:cell, array_ops.c... (0us/333us)
|
||||
model_analyzer_testlib.py:54:BuildFullModel:seq.append(array_... (0us/254us)
|
||||
model_analyzer_testlib.py:42:BuildSmallModel:x = nn_ops.conv2d... (0us/134us)
|
||||
model_analyzer_testlib.py:46:BuildSmallModel:initializer=init_... (0us/40us)
|
||||
...
|
||||
model_analyzer_testlib.py:61:BuildFullModel:loss = nn_ops.l2_... (0us/28us)
|
||||
model_analyzer_testlib.py:60:BuildFullModel:target = array_op... (0us/0us)
|
||||
model_analyzer_test.py:134:testComplexCodeView:sess.run(variable... (0us/0us)
|
||||
```
|
||||
|
||||
4) I want to see the `BatchNorm`'s gamma value in checkpoint.
|
||||
|
||||
```shell
|
||||
# Requires --graph_path, --checkpoint_path.
|
||||
@ -186,7 +234,7 @@ _TFProfRoot ()
|
||||
[1.57 1.83 1.30 1.25 1.59 1.14 1.26 0.82 1.19 1.10 1.48 1.01 0.82 1.23 1.21 1.14 ],
|
||||
```
|
||||
|
||||
4) I want to see my checkpoint tensors shape and number of parameters.
|
||||
5) I want to see my checkpoint tensors shape and number of parameters.
|
||||
|
||||
```shell
|
||||
# Requires --graph_path, --checkpoint_path.
|
||||
@ -205,7 +253,7 @@ _TFProfRoot (--/930.58k params)
|
||||
unit_last/final_bn/moving_variance (64, 64/64 params)
|
||||
```
|
||||
|
||||
5) I defined an op named ‘cost’ to calculate the loss. I want to know what ops
|
||||
6) I defined an op named ‘cost’ to calculate the loss. I want to know what ops
|
||||
it depends on take a long time to run. Hint: Use the ‘graph’ command to explore
|
||||
graph dependencies.
|
||||
|
||||
@ -221,7 +269,7 @@ _TFProfRoot (0us/3.61sec)
|
||||
unit_3_3/sub2/conv2/Conv2D (10.26ms/3.60sec)
|
||||
```
|
||||
|
||||
6) I want to know the expensive operations during the back propagation.
|
||||
7) I want to know the expensive operations during the back propagation.
|
||||
Hint: tensorflow prepend ‘gradient’ to your defined name scopes. Use the ‘scope’
|
||||
command to explore based on name scope hierarchies.
|
||||
|
||||
@ -238,7 +286,7 @@ _TFProfRoot (0us/2.29sec)
|
||||
...
|
||||
```
|
||||
|
||||
7) Show the number of float operations in the model.
|
||||
8) Show the number of float operations in the model.
|
||||
Note: float operations calculation depends on
|
||||
1) op.RegisterStatistics. If an op doesn’t
|
||||
have RegisterStatistics defined, its float operations cannot be counted.
|
||||
@ -263,7 +311,7 @@ _TFProfRoot (0/17.63b flops)
|
||||
...
|
||||
```
|
||||
|
||||
8) Show the number of parameters of all `tf.trainable_variables()` in the model.
|
||||
9) Show the number of parameters of all `tf.trainable_variables()` in the model.
|
||||
|
||||
```shell
|
||||
# Requires --graph_path --op_log_path.
|
||||
@ -283,7 +331,7 @@ generated by write_op_log() Python API. write_op_log() help users create some
|
||||
common op types implicitly. Users can define their own op types and log it
|
||||
through the write_op_log() API.
|
||||
|
||||
9) What if I’m lazy and don’t want to define op type? I have given my ops
|
||||
109) What if I’m lazy and don’t want to define op type? I have given my ops
|
||||
well-defined names in my model’s code. And want to use names to select a group
|
||||
of ops. Let’s try it!
|
||||
|
||||
@ -301,7 +349,7 @@ in terminal. Otherwise, tfprof accounts all ops matched by
|
||||
`-account_type_regexes` recursively even if they are hidden due to some
|
||||
options such as -max_depth.
|
||||
|
||||
10) TensorFlow has built-in op types. For example, built-in op type `Variable`
|
||||
11) TensorFlow has built-in op types. For example, built-in op type `Variable`
|
||||
seems to include `Variable's` created by your model. However, be careful when
|
||||
depending on it because TensorFlow creates extra `Variable` ops implicitly and
|
||||
the implicitly created ops can have the same prefix as the `Variable's` you
|
||||
@ -327,7 +375,7 @@ _TFProfRoot (--/930.58k params)
|
||||
```
|
||||
|
||||
|
||||
11) A example of defining extra op type for ops using `OpLog`
|
||||
12) A example of defining extra op type for ops using `OpLog`
|
||||
|
||||
First, in Python code, create an `OpLog` proto and add op type
|
||||
information to it:
|
||||
|
@ -15,6 +15,7 @@ cc_library(
|
||||
srcs = ["tfprof_stats.cc"],
|
||||
hdrs = ["tfprof_stats.h"],
|
||||
deps = [
|
||||
":tfprof_code",
|
||||
":tfprof_graph",
|
||||
":tfprof_node",
|
||||
":tfprof_options",
|
||||
@ -61,6 +62,27 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfprof_code",
|
||||
srcs = ["tfprof_code.cc"],
|
||||
hdrs = ["tfprof_code.h"],
|
||||
deps = [
|
||||
":tfprof_constants",
|
||||
":tfprof_node",
|
||||
":tfprof_options",
|
||||
":tfprof_show_code",
|
||||
":tfprof_tensor",
|
||||
":tfprof_utils",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:checkpoint_reader",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:regexp_internal",
|
||||
"//tensorflow/tools/tfprof:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfprof_graph",
|
||||
srcs = ["tfprof_graph.cc"],
|
||||
@ -98,6 +120,26 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfprof_show_code",
|
||||
srcs = ["tfprof_show_code.cc"],
|
||||
hdrs = ["tfprof_show_code.h"],
|
||||
deps = [
|
||||
":tfprof_constants",
|
||||
":tfprof_node",
|
||||
":tfprof_options",
|
||||
":tfprof_scope",
|
||||
":tfprof_show",
|
||||
":tfprof_tensor",
|
||||
":tfprof_utils",
|
||||
"//tensorflow/c:checkpoint_reader",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:regexp_internal",
|
||||
"//tensorflow/tools/tfprof:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "tfprof_show_test",
|
||||
srcs = ["tfprof_show_test.cc"],
|
||||
|
@ -40,13 +40,13 @@ string PrintModelAnalysis(const string* graph, const string* run_meta,
|
||||
graph_ptr->ParseFromString(*graph);
|
||||
|
||||
std::unique_ptr<RunMetadata> run_meta_ptr;
|
||||
if (run_meta) {
|
||||
if (run_meta && !run_meta->empty()) {
|
||||
run_meta_ptr.reset(new RunMetadata());
|
||||
run_meta_ptr->ParseFromString(*run_meta);
|
||||
}
|
||||
|
||||
std::unique_ptr<OpLog> op_log_ptr;
|
||||
if (op_log) {
|
||||
if (op_log && !op_log->empty()) {
|
||||
op_log_ptr.reset(new OpLog());
|
||||
op_log_ptr->ParseFromString(*op_log);
|
||||
}
|
||||
@ -58,16 +58,27 @@ string PrintModelAnalysis(const string* graph, const string* run_meta,
|
||||
|
||||
Options opts = Options::FromProtoStr(*options);
|
||||
|
||||
// TODO(xpan): We should have dump_to_file/print_stdout/etc to control
|
||||
// side-effects independently instead of one controlling the other.
|
||||
if (opts.dump_to_file.empty()) {
|
||||
printf("\n=========================Options=============================\n");
|
||||
printf("%s", opts.ToString().c_str());
|
||||
printf("\n==================Model Analysis Report======================\n");
|
||||
TFProfNode root(tf_stats.PrintGraph(*command, opts));
|
||||
string ret = "";
|
||||
if (*command == kCmds[2]) {
|
||||
ret = tf_stats.PrintCode(opts).SerializeAsString();
|
||||
} else {
|
||||
ret = tf_stats.PrintGraph(*command, opts).SerializeAsString();
|
||||
}
|
||||
printf("\n======================End of Report==========================\n");
|
||||
fflush(stdout);
|
||||
return root.SerializeAsString();
|
||||
return ret;
|
||||
}
|
||||
if (*command == kCmds[2]) {
|
||||
return tf_stats.PrintCode(opts).SerializeAsString();
|
||||
} else {
|
||||
return tf_stats.PrintGraph(*command, opts).SerializeAsString();
|
||||
}
|
||||
return tf_stats.PrintGraph(*command, opts).SerializeAsString();
|
||||
}
|
||||
} // namespace tfprof
|
||||
} // namespace tensorflow
|
||||
|
215
tensorflow/tools/tfprof/internal/tfprof_code.cc
Normal file
215
tensorflow/tools/tfprof/internal/tfprof_code.cc
Normal file
@ -0,0 +1,215 @@
|
||||
/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/tools/tfprof/internal/tfprof_code.h"
|
||||
|
||||
#include <stdio.h>
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/platform/regexp.h"
|
||||
#include "tensorflow/tools/tfprof/internal/tfprof_constants.h"
|
||||
#include "tensorflow/tools/tfprof/internal/tfprof_tensor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfprof {
|
||||
namespace {
|
||||
// Convert to Trace proto into a short readable string.
|
||||
string GetTraceString(const CodeDef::Trace& trace) {
|
||||
string ntrace = "";
|
||||
if (trace.file().find_last_of('/') != trace.file().npos) {
|
||||
ntrace += trace.file().substr(trace.file().find_last_of('/') + 1);
|
||||
} else {
|
||||
ntrace += trace.file();
|
||||
}
|
||||
ntrace += strings::StrCat(":", trace.lineno());
|
||||
if (trace.function().length() < 20) {
|
||||
ntrace += ":" + trace.function();
|
||||
} else {
|
||||
ntrace += ":" + trace.function().substr(0, 17) + "...";
|
||||
}
|
||||
if (trace.line().length() < 20) {
|
||||
ntrace += ":" + trace.line();
|
||||
} else {
|
||||
ntrace += ":" + trace.line().substr(0, 17) + "...";
|
||||
}
|
||||
return ntrace;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void TFCode::AddNode(TFGraphNode* node) {
|
||||
if (!node->code()) {
|
||||
return;
|
||||
}
|
||||
TFCodeNode* pre_trace_node = nullptr;
|
||||
for (int i = 0; i < node->code()->traces_size(); ++i) {
|
||||
// Unlike op name, which is globally unique, trace name is only unique
|
||||
// w.r.t. it's parent.
|
||||
const string& trace = GetTraceString(node->code()->traces(i));
|
||||
if (i == 0) {
|
||||
if (!trace_root_) {
|
||||
trace_root_.reset(new TFCodeNode(trace));
|
||||
}
|
||||
CHECK(trace_root_->name() == trace) << "Different trace root";
|
||||
pre_trace_node = trace_root_.get();
|
||||
continue;
|
||||
}
|
||||
pre_trace_node->AddChildren(trace);
|
||||
TFCodeNode* trace_node = pre_trace_node->children()[trace].get();
|
||||
|
||||
if (i == node->code()->traces_size() - 1) {
|
||||
trace_node->AddGraphNode(node);
|
||||
}
|
||||
pre_trace_node = trace_node;
|
||||
}
|
||||
}
|
||||
|
||||
void TFCode::Build() {
|
||||
if (!trace_root_) {
|
||||
return;
|
||||
}
|
||||
code_root_ = BuildCodeNodes(trace_root_.get());
|
||||
}
|
||||
|
||||
CodeNode* TFCode::BuildCodeNodes(TFCodeNode* root) {
|
||||
auto code_root = std::unique_ptr<CodeNode>(new CodeNode(root));
|
||||
CodeNode* code_root_ptr = code_root.get();
|
||||
code_nodes_.insert(std::move(code_root));
|
||||
|
||||
for (auto it = root->children().cbegin(); it != root->children().cend();
|
||||
++it) {
|
||||
code_root_ptr->children.push_back(BuildCodeNodes(it->second.get()));
|
||||
}
|
||||
return code_root_ptr;
|
||||
}
|
||||
|
||||
const ShowCodeNode* TFCode::ShowInternal(const Options& opts) {
|
||||
// Search from roots recursively to find start node, if start_name_regexes
|
||||
// is specified.
|
||||
tfprof_trace_root_.reset(new TFCodeNode(kTFProfRoot));
|
||||
tfprof_code_root_.reset(new CodeNode(tfprof_trace_root_.get()));
|
||||
if (!code_root_) {
|
||||
return tfprof_code_root_.get();
|
||||
}
|
||||
|
||||
std::vector<CodeNode*> roots = {code_root_};
|
||||
if (opts.start_name_regexes.size() != 1 ||
|
||||
opts.start_name_regexes[0] != ".*") {
|
||||
roots = SearchRoot(roots, opts.start_name_regexes);
|
||||
}
|
||||
|
||||
tfprof_code_root_->children.assign(roots.begin(), roots.end());
|
||||
Account({tfprof_code_root_.get()}, opts);
|
||||
|
||||
return PrintScope({tfprof_code_root_.get()}, opts, 1, 0)[0];
|
||||
}
|
||||
|
||||
std::vector<CodeNode*> TFCode::SearchRoot(std::vector<CodeNode*> roots,
|
||||
const std::vector<string>& regexes) {
|
||||
std::vector<CodeNode*> res;
|
||||
if (roots.empty()) {
|
||||
return res;
|
||||
}
|
||||
for (CodeNode* root : roots) {
|
||||
bool match_start_node = false;
|
||||
for (const string& regex : regexes) {
|
||||
if (RE2::FullMatch(root->name(), regex)) {
|
||||
res.push_back(root);
|
||||
match_start_node = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match_start_node) {
|
||||
// Found a start node at this branch, no need to continue.
|
||||
continue;
|
||||
}
|
||||
std::vector<CodeNode*> nroots = SearchRoot(root->children, regexes);
|
||||
res.insert(res.end(), nroots.begin(), nroots.end());
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<CodeNode*> TFCode::PrintScope(const std::vector<CodeNode*> roots,
|
||||
const Options& opts, int depth,
|
||||
int last_ident) {
|
||||
std::vector<CodeNode*> show_nodes;
|
||||
|
||||
for (CodeNode* node : roots) {
|
||||
int nlast_ident = last_ident;
|
||||
bool show = ShouldShow(node, opts, depth);
|
||||
if (show) {
|
||||
node->formatted_str.clear();
|
||||
if (opts.account_displayed_op_only) {
|
||||
node->ResetTotalStats();
|
||||
node->AddSelfToTotalStats();
|
||||
}
|
||||
nlast_ident += 2;
|
||||
}
|
||||
|
||||
std::vector<CodeNode*> show_cnodes;
|
||||
if (!ShouldTrim(node, opts.trim_name_regexes)) {
|
||||
show_cnodes = PrintScope(node->children, opts, depth + 1, nlast_ident);
|
||||
}
|
||||
if (show) {
|
||||
show_cnodes = SortNodes(show_cnodes, opts);
|
||||
string children_str;
|
||||
for (CodeNode* sc : show_cnodes) {
|
||||
children_str += sc->formatted_str;
|
||||
node->mutable_proto()->add_children()->MergeFrom(sc->proto());
|
||||
if (opts.account_displayed_op_only) {
|
||||
node->AggregateTotalStats(sc);
|
||||
}
|
||||
}
|
||||
|
||||
node->formatted_str =
|
||||
strings::Printf("%s%s\n", string(last_ident, ' ').c_str(),
|
||||
node->Format(opts).c_str());
|
||||
|
||||
if (opts.select.find(kShown[5]) != opts.select.end()) {
|
||||
fprintf(stderr, "code view has no tensor value to show\n");
|
||||
}
|
||||
|
||||
node->formatted_str += children_str;
|
||||
show_nodes.push_back(node);
|
||||
} else {
|
||||
show_nodes.insert(show_nodes.end(), show_cnodes.begin(),
|
||||
show_cnodes.end());
|
||||
}
|
||||
}
|
||||
return show_nodes;
|
||||
}
|
||||
|
||||
void TFCode::Account(const std::vector<CodeNode*>& roots, const Options& opts) {
|
||||
if (roots.empty()) return;
|
||||
|
||||
for (CodeNode* node : roots) {
|
||||
node->ResetTotalStats();
|
||||
Account(node->children, opts);
|
||||
|
||||
node->account = ShouldAccount(node, opts);
|
||||
if (node->account) {
|
||||
node->AddSelfToTotalStats();
|
||||
}
|
||||
for (CodeNode* c : node->children) {
|
||||
node->AggregateTotalStats(c);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace tfprof
|
||||
} // namespace tensorflow
|
88
tensorflow/tools/tfprof/internal/tfprof_code.h
Normal file
88
tensorflow/tools/tfprof/internal/tfprof_code.h
Normal file
@ -0,0 +1,88 @@
|
||||
/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Build a tree structure based on the TensorFlow model's python code stacks.
|
||||
// Stats are aggregated from descendants from ancestors.
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_TOOLS_TFPROF_INTERNAL_TFPROF_CODE_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_TOOLS_TFPROF_INTERNAL_TFPROF_CODE_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/checkpoint_reader.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/tools/tfprof/internal/tfprof_node.h"
|
||||
#include "tensorflow/tools/tfprof/internal/tfprof_options.h"
|
||||
#include "tensorflow/tools/tfprof/internal/tfprof_show_code.h"
|
||||
#include "tensorflow/tools/tfprof/internal/tfprof_utils.h"
|
||||
#include "tensorflow/tools/tfprof/tfprof_log.pb.h"
|
||||
#include "tensorflow/tools/tfprof/tfprof_output.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfprof {
|
||||
|
||||
class CodeNode : public ShowCodeNode {
|
||||
public:
|
||||
explicit CodeNode(const TFCodeNode* node) : ShowCodeNode(node) {}
|
||||
~CodeNode() override {}
|
||||
|
||||
void AggregateTotalStats(CodeNode* node) {
|
||||
ShowCodeNode::AggregateTotalStats(node);
|
||||
}
|
||||
|
||||
void AddSelfToTotalStats() { ShowCodeNode::AddSelfToTotalStats(); }
|
||||
|
||||
void ResetTotalStats() { ShowCodeNode::ResetTotalStats(); }
|
||||
|
||||
std::vector<CodeNode*> children;
|
||||
};
|
||||
|
||||
class TFCode : public TFShowCode {
|
||||
public:
|
||||
explicit TFCode() : code_root_(nullptr), trace_root_(nullptr) {}
|
||||
~TFCode() override {}
|
||||
|
||||
void AddNode(TFGraphNode* node) override;
|
||||
|
||||
void Build() override;
|
||||
|
||||
private:
|
||||
CodeNode* BuildCodeNodes(TFCodeNode* root);
|
||||
|
||||
const ShowCodeNode* ShowInternal(const Options& opts) override;
|
||||
|
||||
std::vector<CodeNode*> SearchRoot(std::vector<CodeNode*> roots,
|
||||
const std::vector<string>& regexes);
|
||||
|
||||
std::vector<CodeNode*> PrintScope(const std::vector<CodeNode*> roots,
|
||||
const Options& opts, int depth,
|
||||
int last_ident);
|
||||
|
||||
void Account(const std::vector<CodeNode*>& roots, const Options& opts);
|
||||
|
||||
CodeNode* code_root_;
|
||||
std::unique_ptr<TFCodeNode> trace_root_;
|
||||
std::unique_ptr<TFCodeNode> tfprof_trace_root_;
|
||||
std::unique_ptr<CodeNode> tfprof_code_root_;
|
||||
std::set<std::unique_ptr<CodeNode>> code_nodes_;
|
||||
};
|
||||
} // namespace tfprof
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_TOOLS_TFPROF_INTERNAL_TFPROF_CODE_H_
|
@ -31,14 +31,14 @@ GraphNode* TFGraph::CreateParentNode(const string& name) {
|
||||
node_defs_.back()->set_name(name);
|
||||
node_defs_.back()->set_op(kTFGraphParent);
|
||||
parent_nodes_[name] =
|
||||
std::unique_ptr<TFNode>(new TFNode(node_defs_.back().get()));
|
||||
std::unique_ptr<TFGraphNode>(new TFGraphNode(node_defs_.back().get()));
|
||||
nodes_map_[name] =
|
||||
std::unique_ptr<GraphNode>(new GraphNode(parent_nodes_[name].get()));
|
||||
return nodes_map_[name].get();
|
||||
}
|
||||
|
||||
void TFGraph::AddNode(TFNode* node) {
|
||||
string name = node->node_def()->name();
|
||||
void TFGraph::AddNode(TFGraphNode* node) {
|
||||
string name = node->name();
|
||||
nodes_map_[name] = std::unique_ptr<GraphNode>(new GraphNode(node));
|
||||
}
|
||||
|
||||
@ -49,7 +49,7 @@ void TFGraph::Build() {
|
||||
// Filter out the root nodes (node not input of any other node).
|
||||
for (auto it = nodes_map_.begin(); it != nodes_map_.end(); it++) {
|
||||
GraphNode* node = it->second.get();
|
||||
const std::map<string, TFNode*>& inputs = node->node->inputs();
|
||||
const std::map<string, TFGraphNode*>& inputs = node->node->inputs();
|
||||
for (auto inputs_it = inputs.cbegin(); inputs_it != inputs.cend();
|
||||
inputs_it++) {
|
||||
nonroots.insert(inputs_it->first);
|
||||
|
@ -39,7 +39,7 @@ namespace tensorflow {
|
||||
namespace tfprof {
|
||||
class GraphNode : public ShowNode {
|
||||
public:
|
||||
explicit GraphNode(TFNode* node) : ShowNode(node) {
|
||||
explicit GraphNode(TFGraphNode* node) : ShowNode(node) {
|
||||
mutable_proto()->set_inputs(node->inputs().size());
|
||||
mutable_proto()->set_total_inputs(0);
|
||||
}
|
||||
@ -72,7 +72,7 @@ class TFGraph : public TFShow {
|
||||
: TFShow(ckpt_reader) {}
|
||||
~TFGraph() override {}
|
||||
|
||||
void AddNode(TFNode* node) override;
|
||||
void AddNode(TFGraphNode* node) override;
|
||||
|
||||
void Build() override;
|
||||
|
||||
@ -99,14 +99,14 @@ class TFGraph : public TFShow {
|
||||
std::vector<GraphNode*> GenerateGraphDot(
|
||||
GraphNode* root, GraphNode* last_shown, const Options& opts, int depth,
|
||||
int hidden, std::set<string>* declared_nodes,
|
||||
std::set<string>* declared_edges, TFProfNode* parent);
|
||||
std::set<string>* declared_edges, TFGraphNodeProto* parent);
|
||||
|
||||
void Account(const std::vector<GraphNode*>& roots, const Options& opts,
|
||||
std::map<string, int64>* visits);
|
||||
|
||||
std::vector<GraphNode*> roots_;
|
||||
std::vector<std::unique_ptr<NodeDef>> node_defs_;
|
||||
std::map<string, std::unique_ptr<TFNode>> parent_nodes_;
|
||||
std::map<string, std::unique_ptr<TFGraphNode>> parent_nodes_;
|
||||
std::map<string, std::unique_ptr<GraphNode>> nodes_map_;
|
||||
};
|
||||
|
||||
|
@ -20,7 +20,8 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfprof {
|
||||
void TFNode::AddStepStat(const string& device, const NodeExecStats* step_stat) {
|
||||
void TFGraphNode::AddStepStat(const string& device,
|
||||
const NodeExecStats* step_stat) {
|
||||
if (!device.empty()) {
|
||||
// This might override device from GraphDef.
|
||||
device_ = device;
|
||||
@ -44,7 +45,7 @@ void TFNode::AddStepStat(const string& device, const NodeExecStats* step_stat) {
|
||||
}
|
||||
}
|
||||
|
||||
void TFNode::AddNodeStat(const CostGraphDef::Node* cost_node) {
|
||||
void TFGraphNode::AddNodeStat(const CostGraphDef::Node* cost_node) {
|
||||
kernel_compute_micros_ = cost_node->compute_cost();
|
||||
}
|
||||
} // namespace tfprof
|
||||
|
@ -30,14 +30,16 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/tools/tfprof/internal/tfprof_options.h"
|
||||
#include "tensorflow/tools/tfprof/tfprof_log.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfprof {
|
||||
|
||||
class TFNode {
|
||||
class TFGraphNode {
|
||||
public:
|
||||
TFNode(const NodeDef* node)
|
||||
TFGraphNode(const NodeDef* node)
|
||||
: node_(node),
|
||||
code_(nullptr),
|
||||
step_stat_(nullptr),
|
||||
op_start_micros_(0),
|
||||
op_schedule_micros_(0),
|
||||
@ -70,9 +72,9 @@ class TFNode {
|
||||
device_ = node->device();
|
||||
}
|
||||
|
||||
TFNode() : TFNode(nullptr) {}
|
||||
TFGraphNode() : TFGraphNode(nullptr) {}
|
||||
|
||||
void AddInput(TFNode* input) { inputs_[input->node_def()->name()] = input; }
|
||||
void AddInput(TFGraphNode* input) { inputs_[input->name()] = input; }
|
||||
|
||||
void AddOpType(const string& op_type) { op_types_.insert(op_type); }
|
||||
|
||||
@ -83,27 +85,32 @@ class TFNode {
|
||||
|
||||
void AddFloatOps(int64 float_ops) { float_ops_ = float_ops; }
|
||||
|
||||
void AddCode(const CodeDef* code) { code_ = code; }
|
||||
|
||||
const string& name() const { return node_->name(); }
|
||||
const NodeDef* node_def() { return node_; }
|
||||
const std::map<string, TFNode*>& inputs() { return inputs_; }
|
||||
const std::map<string, TFGraphNode*>& inputs() const { return inputs_; }
|
||||
int64 op_start_micros() { return op_start_micros_; }
|
||||
// This is time spent in Op::Compute(), which is GPU kernel schedule time.
|
||||
// Currently not used.
|
||||
int64 op_schedule_micros() { return op_schedule_micros_; }
|
||||
// This is time spent in kernel execution.
|
||||
int64 kernel_compute_micros() { return kernel_compute_micros_; }
|
||||
int64 kernel_compute_micros() const { return kernel_compute_micros_; }
|
||||
int64 all_spent_micros() { return all_spent_micros_; }
|
||||
int64 requested_byptes() { return requested_bytes_; }
|
||||
int64 float_ops() { return float_ops_; }
|
||||
string device() { return device_; }
|
||||
const std::set<string>& op_types() { return op_types_; }
|
||||
int64 requested_bytes() const { return requested_bytes_; }
|
||||
int64 float_ops() const { return float_ops_; }
|
||||
const CodeDef* code() { return code_; }
|
||||
string device() const { return device_; }
|
||||
const std::set<string>& op_types() const { return op_types_; }
|
||||
|
||||
const std::vector<int64>& shape() { return shape_; }
|
||||
const std::vector<int64>& shape() const { return shape_; }
|
||||
|
||||
private:
|
||||
void update_shape(const std::vector<int64>& shape) { shape_ = shape; }
|
||||
|
||||
std::map<string, TFNode*> inputs_;
|
||||
std::map<string, TFGraphNode*> inputs_;
|
||||
const NodeDef* node_;
|
||||
const CodeDef* code_;
|
||||
const NodeExecStats* step_stat_;
|
||||
|
||||
std::vector<int64> shape_;
|
||||
@ -117,6 +124,71 @@ class TFNode {
|
||||
int64 float_ops_;
|
||||
};
|
||||
|
||||
class TFCodeNode {
|
||||
public:
|
||||
TFCodeNode(const string& trace)
|
||||
: trace_(trace),
|
||||
kernel_compute_micros_(0),
|
||||
requested_bytes_(0),
|
||||
float_ops_(0) {}
|
||||
|
||||
void AddGraphNode(const TFGraphNode* node) {
|
||||
if (nodes_.find(node->name()) != nodes_.end()) {
|
||||
return;
|
||||
}
|
||||
nodes_[node->name()] = node;
|
||||
|
||||
kernel_compute_micros_ += node->kernel_compute_micros();
|
||||
requested_bytes_ += node->requested_bytes();
|
||||
float_ops_ += node->float_ops();
|
||||
op_types_.insert(node->op_types().begin(), node->op_types().end());
|
||||
if (node->shape().size() > 0) {
|
||||
shapes_.push_back(node->shape());
|
||||
}
|
||||
if (!node->device().empty()) {
|
||||
devices_.insert(node->device());
|
||||
}
|
||||
}
|
||||
const std::map<string, const TFGraphNode*>& graph_nodes() const {
|
||||
return nodes_;
|
||||
}
|
||||
|
||||
void AddChildren(const string& trace) {
|
||||
if (children_.find(trace) != children_.end()) {
|
||||
return;
|
||||
}
|
||||
children_[trace].reset(new TFCodeNode(trace));
|
||||
}
|
||||
std::map<string, std::unique_ptr<TFCodeNode>>& children() {
|
||||
return children_;
|
||||
}
|
||||
|
||||
const string& name() const { return trace_; }
|
||||
|
||||
int64 kernel_compute_micros() const { return kernel_compute_micros_; }
|
||||
|
||||
int64 requested_bytes() const { return requested_bytes_; }
|
||||
|
||||
int64 float_ops() const { return float_ops_; }
|
||||
|
||||
const std::set<string>& devices() const { return devices_; }
|
||||
|
||||
const std::set<string>& op_types() const { return op_types_; }
|
||||
|
||||
const std::vector<std::vector<int64>>& shapes() const { return shapes_; }
|
||||
|
||||
private:
|
||||
const string trace_;
|
||||
std::set<string> op_types_;
|
||||
int64 kernel_compute_micros_;
|
||||
int64 requested_bytes_;
|
||||
int64 float_ops_;
|
||||
|
||||
std::set<string> devices_;
|
||||
std::vector<std::vector<int64>> shapes_;
|
||||
std::map<string, const TFGraphNode*> nodes_;
|
||||
std::map<string, std::unique_ptr<TFCodeNode>> children_;
|
||||
};
|
||||
} // namespace tfprof
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -55,7 +55,7 @@ static const char* const kShown[] = {
|
||||
};
|
||||
|
||||
static const char* const kCmds[] = {
|
||||
"scope", "graph", "set", "help",
|
||||
"scope", "graph", "code", "set", "help",
|
||||
};
|
||||
|
||||
struct Options {
|
||||
|
@ -35,15 +35,15 @@ ScopeNode* TFScope::CreateParentNode(const string& name) {
|
||||
node_defs_.back()->set_name(name);
|
||||
node_defs_.back()->set_op(kTFScopeParent);
|
||||
parent_nodes_[name] =
|
||||
std::unique_ptr<TFNode>(new TFNode(node_defs_.back().get()));
|
||||
std::unique_ptr<TFGraphNode>(new TFGraphNode(node_defs_.back().get()));
|
||||
nodes_map_[name] =
|
||||
std::unique_ptr<ScopeNode>(new ScopeNode(parent_nodes_[name].get()));
|
||||
return nodes_map_[name].get();
|
||||
}
|
||||
|
||||
void TFScope::AddNode(TFNode* node) {
|
||||
string name = node->node_def()->name();
|
||||
if (nodes_map_.find(node->node_def()->name()) == nodes_map_.end()) {
|
||||
void TFScope::AddNode(TFGraphNode* node) {
|
||||
string name = node->name();
|
||||
if (nodes_map_.find(node->name()) == nodes_map_.end()) {
|
||||
nodes_map_[name] = std::unique_ptr<ScopeNode>(new ScopeNode(node));
|
||||
}
|
||||
|
||||
|
@ -39,7 +39,7 @@ namespace tfprof {
|
||||
|
||||
class ScopeNode : public ShowNode {
|
||||
public:
|
||||
explicit ScopeNode(TFNode* node) : ShowNode(node) {}
|
||||
explicit ScopeNode(const TFGraphNode* node) : ShowNode(node) {}
|
||||
~ScopeNode() override {}
|
||||
|
||||
void AggregateTotalStats(ScopeNode* node) {
|
||||
@ -59,7 +59,7 @@ class TFScope : public TFShow {
|
||||
: TFShow(ckpt_reader) {}
|
||||
~TFScope() override {}
|
||||
|
||||
void AddNode(TFNode* node) override;
|
||||
void AddNode(TFGraphNode* node) override;
|
||||
|
||||
void Build() override;
|
||||
|
||||
@ -79,7 +79,7 @@ class TFScope : public TFShow {
|
||||
|
||||
std::vector<ScopeNode*> roots_;
|
||||
std::vector<std::unique_ptr<NodeDef>> node_defs_;
|
||||
std::map<string, std::unique_ptr<TFNode>> parent_nodes_;
|
||||
std::map<string, std::unique_ptr<TFGraphNode>> parent_nodes_;
|
||||
std::map<string, std::unique_ptr<ScopeNode>> nodes_map_;
|
||||
};
|
||||
} // namespace tfprof
|
||||
|
@ -25,13 +25,13 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfprof {
|
||||
ShowNode::ShowNode(TFNode* node) : node(node), account(true) {
|
||||
ShowNode::ShowNode(const TFGraphNode* node) : node(node), account(true) {
|
||||
mutable_proto()->set_name(name());
|
||||
if (!node->device().empty()) {
|
||||
mutable_proto()->set_device(node->device());
|
||||
}
|
||||
mutable_proto()->set_exec_micros(node->kernel_compute_micros());
|
||||
mutable_proto()->set_requested_bytes(node->requested_byptes());
|
||||
mutable_proto()->set_requested_bytes(node->requested_bytes());
|
||||
mutable_proto()->set_float_ops(node->float_ops());
|
||||
|
||||
if (!node->shape().empty()) {
|
||||
@ -119,12 +119,12 @@ string ShowNode::FormatMeta(const Options& opts) {
|
||||
return str_util::Join(info, ", ");
|
||||
}
|
||||
|
||||
TFProfNode* ShowNode::mutable_proto() { return &proto_; }
|
||||
TFGraphNodeProto* ShowNode::mutable_proto() { return &proto_; }
|
||||
|
||||
const TFProfNode& ShowNode::proto() const { return proto_; }
|
||||
const TFGraphNodeProto& ShowNode::proto() const { return proto_; }
|
||||
|
||||
void ShowNode::AggregateTotalStats(ShowNode* node) {
|
||||
TFProfNode* node_pb = node->mutable_proto();
|
||||
TFGraphNodeProto* node_pb = node->mutable_proto();
|
||||
mutable_proto()->set_total_exec_micros(proto().total_exec_micros() +
|
||||
node_pb->total_exec_micros());
|
||||
mutable_proto()->set_total_requested_bytes(proto().total_requested_bytes() +
|
||||
@ -151,9 +151,10 @@ void ShowNode::ResetTotalStats() {
|
||||
mutable_proto()->set_total_requested_bytes(0);
|
||||
mutable_proto()->set_total_parameters(0);
|
||||
mutable_proto()->set_total_float_ops(0);
|
||||
mutable_proto()->mutable_children()->Clear();
|
||||
}
|
||||
|
||||
const TFProfNode& TFShow::Show(const Options& opts) {
|
||||
const TFGraphNodeProto& TFShow::Show(const Options& opts) {
|
||||
const ShowNode* root = ShowInternal(opts);
|
||||
if (opts.dump_to_file.empty()) {
|
||||
printf("%s", root->formatted_str.c_str());
|
||||
|
@ -37,18 +37,18 @@ namespace tensorflow {
|
||||
namespace tfprof {
|
||||
class ShowNode {
|
||||
public:
|
||||
explicit ShowNode(TFNode* node);
|
||||
explicit ShowNode(const TFGraphNode* node);
|
||||
virtual ~ShowNode() {}
|
||||
|
||||
const string& name() const { return node->node_def()->name(); }
|
||||
TFProfNode* mutable_proto();
|
||||
const TFProfNode& proto() const;
|
||||
const string& name() const { return node->name(); }
|
||||
TFGraphNodeProto* mutable_proto();
|
||||
const TFGraphNodeProto& proto() const;
|
||||
|
||||
string Format(const Options& opts);
|
||||
|
||||
string FormatMeta(const Options& opts);
|
||||
|
||||
TFNode* node;
|
||||
const TFGraphNode* node;
|
||||
bool account;
|
||||
string formatted_str;
|
||||
|
||||
@ -59,7 +59,7 @@ class ShowNode {
|
||||
|
||||
void ResetTotalStats();
|
||||
|
||||
TFProfNode proto_;
|
||||
TFGraphNodeProto proto_;
|
||||
};
|
||||
|
||||
class TFShow {
|
||||
@ -67,9 +67,9 @@ class TFShow {
|
||||
explicit TFShow(checkpoint::CheckpointReader* ckpt_reader)
|
||||
: ckpt_reader_(ckpt_reader) {}
|
||||
virtual ~TFShow() {}
|
||||
virtual void AddNode(TFNode* node) = 0;
|
||||
virtual void AddNode(TFGraphNode* node) = 0;
|
||||
virtual void Build() = 0;
|
||||
const TFProfNode& Show(const Options& opts);
|
||||
const TFGraphNodeProto& Show(const Options& opts);
|
||||
|
||||
protected:
|
||||
virtual const ShowNode* ShowInternal(const Options& opts) = 0;
|
||||
|
273
tensorflow/tools/tfprof/internal/tfprof_show_code.cc
Normal file
273
tensorflow/tools/tfprof/internal/tfprof_show_code.cc
Normal file
@ -0,0 +1,273 @@
|
||||
/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/tools/tfprof/internal/tfprof_show_code.h"
|
||||
|
||||
#include <memory>
|
||||
#include <set>
|
||||
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/regexp.h"
|
||||
#include "tensorflow/tools/tfprof/internal/tfprof_scope.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfprof {
|
||||
ShowCodeNode::ShowCodeNode(const TFCodeNode* node) : node(node), account(true) {
|
||||
std::vector<ScopeNode> snodes;
|
||||
for (auto it : node->graph_nodes()) {
|
||||
ScopeNode snode(it.second);
|
||||
snodes.push_back(snode);
|
||||
snodes[snodes.size() - 1].AddSelfToTotalStats();
|
||||
*mutable_proto()->mutable_graph_nodes()->Add() =
|
||||
snodes[snodes.size() - 1].proto();
|
||||
}
|
||||
|
||||
mutable_proto()->set_name(name());
|
||||
mutable_proto()->set_exec_micros(node->kernel_compute_micros());
|
||||
mutable_proto()->set_requested_bytes(node->requested_bytes());
|
||||
mutable_proto()->set_float_ops(node->float_ops());
|
||||
|
||||
if (!node->shapes().empty()) {
|
||||
for (const std::vector<int64>& shape : node->shapes()) {
|
||||
int64 params = 1;
|
||||
bool complete_shape = true;
|
||||
for (int64 d : shape) {
|
||||
// Sometimes parameters could be <0 when a dim is unknown.
|
||||
if (d < 0) {
|
||||
complete_shape = false;
|
||||
break;
|
||||
}
|
||||
params *= d;
|
||||
}
|
||||
if (complete_shape) {
|
||||
mutable_proto()->set_parameters(proto().parameters() + params);
|
||||
} else {
|
||||
fprintf(stderr, "Incomplete shape.");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
string ShowCodeNode::Format(const Options& opts) {
|
||||
if (opts.select.empty()) {
|
||||
return name();
|
||||
}
|
||||
return strings::Printf("%s (%s)", name().c_str(), FormatMeta(opts).c_str());
|
||||
}
|
||||
|
||||
string ShowCodeNode::FormatMeta(const Options& opts) {
|
||||
std::vector<string> info;
|
||||
std::vector<string> shapes;
|
||||
if (opts.select.find(kShown[2]) != opts.select.end()) {
|
||||
for (const std::vector<int64>& shape : node->shapes()) {
|
||||
if (!shape.empty()) {
|
||||
shapes.push_back(FormatShapes(shape));
|
||||
}
|
||||
}
|
||||
if (!shapes.empty()) {
|
||||
info.push_back(str_util::Join(shapes, "|"));
|
||||
}
|
||||
string params = FormatNumber(proto().total_parameters()) + " params";
|
||||
if (account) {
|
||||
params = FormatNumber(proto().parameters()) + "/" + params;
|
||||
} else {
|
||||
params = "--/" + params;
|
||||
}
|
||||
info.push_back(params);
|
||||
}
|
||||
if (opts.select.find(kShown[3]) != opts.select.end()) {
|
||||
string fops = FormatNumber(proto().total_float_ops()) + " flops";
|
||||
if (account) {
|
||||
fops = FormatNumber(proto().float_ops()) + "/" + fops;
|
||||
} else {
|
||||
fops = "--/" + fops;
|
||||
}
|
||||
info.push_back(fops);
|
||||
}
|
||||
if (opts.select.find(kShown[0]) != opts.select.end()) {
|
||||
string memory = FormatMemory(proto().total_requested_bytes());
|
||||
if (account) {
|
||||
memory = FormatMemory(proto().requested_bytes()) + "/" + memory;
|
||||
|
||||
} else {
|
||||
memory = "--/" + memory;
|
||||
}
|
||||
info.push_back(memory);
|
||||
}
|
||||
if (opts.select.find(kShown[1]) != opts.select.end()) {
|
||||
string time = FormatTime(proto().total_exec_micros());
|
||||
if (account) {
|
||||
time = FormatTime(proto().exec_micros()) + "/" + time;
|
||||
} else {
|
||||
time = "--/" + time;
|
||||
}
|
||||
info.push_back(time);
|
||||
}
|
||||
if (opts.select.find(kShown[6]) != opts.select.end()) {
|
||||
if (!node->devices().empty()) {
|
||||
info.push_back(str_util::Join(node->devices(), "|"));
|
||||
}
|
||||
}
|
||||
if (opts.select.find(kShown[7]) != opts.select.end()) {
|
||||
std::set<string> op_types = node->op_types();
|
||||
// Device is considered a type.
|
||||
op_types.insert(node->devices().cbegin(), node->devices().cend());
|
||||
info.push_back(str_util::Join(op_types, "|"));
|
||||
}
|
||||
return str_util::Join(info, ", ");
|
||||
}
|
||||
|
||||
TFCodeNodeProto* ShowCodeNode::mutable_proto() { return &proto_; }
|
||||
|
||||
const TFCodeNodeProto& ShowCodeNode::proto() const { return proto_; }
|
||||
|
||||
void ShowCodeNode::AggregateTotalStats(ShowCodeNode* node) {
|
||||
TFCodeNodeProto* node_pb = node->mutable_proto();
|
||||
mutable_proto()->set_total_exec_micros(proto().total_exec_micros() +
|
||||
node_pb->total_exec_micros());
|
||||
mutable_proto()->set_total_requested_bytes(proto().total_requested_bytes() +
|
||||
node_pb->total_requested_bytes());
|
||||
mutable_proto()->set_total_parameters(proto().total_parameters() +
|
||||
node_pb->total_parameters());
|
||||
mutable_proto()->set_total_float_ops(proto().total_float_ops() +
|
||||
node_pb->total_float_ops());
|
||||
}
|
||||
|
||||
void ShowCodeNode::AddSelfToTotalStats() {
|
||||
mutable_proto()->set_total_exec_micros(proto().total_exec_micros() +
|
||||
proto().exec_micros());
|
||||
mutable_proto()->set_total_requested_bytes(proto().total_requested_bytes() +
|
||||
proto().requested_bytes());
|
||||
mutable_proto()->set_total_parameters(proto().total_parameters() +
|
||||
proto().parameters());
|
||||
mutable_proto()->set_total_float_ops(proto().total_float_ops() +
|
||||
proto().float_ops());
|
||||
}
|
||||
|
||||
void ShowCodeNode::ResetTotalStats() {
|
||||
mutable_proto()->set_total_exec_micros(0);
|
||||
mutable_proto()->set_total_requested_bytes(0);
|
||||
mutable_proto()->set_total_parameters(0);
|
||||
mutable_proto()->set_total_float_ops(0);
|
||||
mutable_proto()->mutable_children()->Clear();
|
||||
}
|
||||
|
||||
const TFCodeNodeProto& TFShowCode::Show(const Options& opts) {
|
||||
const ShowCodeNode* root = ShowInternal(opts);
|
||||
if (opts.dump_to_file.empty()) {
|
||||
printf("%s", root->formatted_str.c_str());
|
||||
fflush(stdout);
|
||||
} else {
|
||||
Status s = WriteStringToFile(Env::Default(), opts.dump_to_file,
|
||||
root->formatted_str);
|
||||
if (!s.ok()) {
|
||||
fprintf(stderr, "%s\n", s.ToString().c_str());
|
||||
}
|
||||
}
|
||||
return root->proto();
|
||||
}
|
||||
|
||||
bool TFShowCode::ShouldShow(ShowCodeNode* node, const Options& opts,
|
||||
int depth) {
|
||||
// Always show kTFProfRoot.
|
||||
if (node->name() == kTFProfRoot) return true;
|
||||
|
||||
if (!node->account) return false;
|
||||
// TODO(xpan): Think more carefully about node filtering in code view.
|
||||
// Unlike graph/scope view, which users want to see the exact leaf op.
|
||||
// In code view, users want to see the middle code traces they wrote.
|
||||
//
|
||||
// This is a subtle difference from scope/graph view. Usually mostly
|
||||
// want to see the middle code traces (i.e. their own codes.), instead
|
||||
// of the TensorFlow internal codes traces.
|
||||
if (node->proto().total_requested_bytes() < opts.min_bytes ||
|
||||
node->proto().total_exec_micros() < opts.min_micros ||
|
||||
node->proto().total_parameters() < opts.min_params ||
|
||||
node->proto().total_float_ops() < opts.min_float_ops ||
|
||||
depth > opts.max_depth || !ShouldShowIfExtra(node, opts, depth)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool show = false;
|
||||
if (opts.device_regexes.size() == 1 && opts.device_regexes[0] == ".*") {
|
||||
show = true;
|
||||
} else {
|
||||
for (const string& regex : opts.device_regexes) {
|
||||
for (const string& device : node->node->devices()) {
|
||||
if (RE2::FullMatch(device, regex)) {
|
||||
show = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (show) break;
|
||||
}
|
||||
}
|
||||
// Don't show if device_regexes don't cover it.
|
||||
if (!show) return false;
|
||||
|
||||
show = false;
|
||||
if (opts.show_name_regexes.size() == 1 && opts.show_name_regexes[0] == ".*") {
|
||||
show = true;
|
||||
} else {
|
||||
for (const string& regex : opts.show_name_regexes) {
|
||||
if (RE2::FullMatch(node->name(), regex)) {
|
||||
show = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Don't show if show_name_regexes don't cover it.
|
||||
if (!show) return false;
|
||||
// Don't show if hide_name_regexes cover it.
|
||||
for (const string& regex : opts.hide_name_regexes) {
|
||||
if (RE2::FullMatch(node->name(), regex)) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TFShowCode::ShouldTrim(ShowCodeNode* node,
|
||||
const std::vector<string>& regexes) {
|
||||
for (const string& regex : regexes) {
|
||||
if (RE2::FullMatch(node->name(), regex)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool TFShowCode::ShouldAccount(ShowCodeNode* node, const Options& opts) {
|
||||
if (opts.account_type_regexes.size() == 1 &&
|
||||
opts.account_type_regexes[0] == ".*") {
|
||||
return true;
|
||||
}
|
||||
for (const string& regex : opts.account_type_regexes) {
|
||||
for (const string& type : node->node->op_types()) {
|
||||
if (RE2::FullMatch(type, regex)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
for (const string& device : node->node->devices()) {
|
||||
if (RE2::FullMatch(device, regex)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace tfprof
|
||||
} // namespace tensorflow
|
126
tensorflow/tools/tfprof/internal/tfprof_show_code.h
Normal file
126
tensorflow/tools/tfprof/internal/tfprof_show_code.h
Normal file
@ -0,0 +1,126 @@
|
||||
/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Parent class and utilities for tfprof_graph and tfprof_scope.
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_TOOLS_TFPROF_INTERNAL_TFPROF_SHOW_CODE_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_TOOLS_TFPROF_INTERNAL_TFPROF_SHOW_CODE_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/checkpoint_reader.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/tools/tfprof/internal/tfprof_constants.h"
|
||||
#include "tensorflow/tools/tfprof/internal/tfprof_node.h"
|
||||
#include "tensorflow/tools/tfprof/internal/tfprof_options.h"
|
||||
#include "tensorflow/tools/tfprof/internal/tfprof_tensor.h"
|
||||
#include "tensorflow/tools/tfprof/internal/tfprof_utils.h"
|
||||
#include "tensorflow/tools/tfprof/tfprof_output.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfprof {
|
||||
class ShowCodeNode {
|
||||
public:
|
||||
explicit ShowCodeNode(const TFCodeNode* node);
|
||||
virtual ~ShowCodeNode() {}
|
||||
|
||||
const string& name() const { return node->name(); }
|
||||
TFCodeNodeProto* mutable_proto();
|
||||
const TFCodeNodeProto& proto() const;
|
||||
|
||||
string Format(const Options& opts);
|
||||
|
||||
string FormatMeta(const Options& opts);
|
||||
|
||||
const TFCodeNode* node;
|
||||
bool account;
|
||||
string formatted_str;
|
||||
|
||||
protected:
|
||||
void AggregateTotalStats(ShowCodeNode* node);
|
||||
|
||||
void AddSelfToTotalStats();
|
||||
|
||||
void ResetTotalStats();
|
||||
|
||||
TFCodeNodeProto proto_;
|
||||
};
|
||||
|
||||
class TFShowCode {
|
||||
public:
|
||||
explicit TFShowCode() {}
|
||||
virtual ~TFShowCode() {}
|
||||
virtual void AddNode(TFGraphNode* node) = 0;
|
||||
virtual void Build() = 0;
|
||||
const TFCodeNodeProto& Show(const Options& opts);
|
||||
|
||||
protected:
|
||||
virtual const ShowCodeNode* ShowInternal(const Options& opts) = 0;
|
||||
|
||||
bool LookUpCheckPoint(const string& name,
|
||||
std::unique_ptr<TFProfTensor>* tensor);
|
||||
|
||||
// Overridden by subclass if extra requirements need to be met.
|
||||
virtual bool ShouldShowIfExtra(ShowCodeNode* node, const Options& opts,
|
||||
int depth) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ShouldShow(ShowCodeNode* node, const Options& opts, int depth);
|
||||
|
||||
bool ShouldTrim(ShowCodeNode* node, const std::vector<string>& regexes);
|
||||
|
||||
bool ShouldAccount(ShowCodeNode* node, const Options& opts);
|
||||
|
||||
template <typename T>
|
||||
std::vector<T*> SortNodes(const std::vector<T*>& nodes, const Options& opts) {
|
||||
if (opts.order_by.empty() || nodes.empty()) {
|
||||
return nodes;
|
||||
}
|
||||
std::vector<T*> sorted_nodes = nodes;
|
||||
std::sort(sorted_nodes.begin(), sorted_nodes.end(),
|
||||
[&opts](const T* n1, const T* n2) {
|
||||
if (n1->name() == kTFProfRoot) return true;
|
||||
if (n2->name() == kTFProfRoot) return false;
|
||||
bool name_cmp = n1->name() < n2->name();
|
||||
if (opts.order_by == kOrderBy[0]) {
|
||||
return name_cmp;
|
||||
} else if (opts.order_by == kOrderBy[1]) {
|
||||
return n1->proto().total_requested_bytes() >
|
||||
n2->proto().total_requested_bytes();
|
||||
} else if (opts.order_by == kOrderBy[2]) {
|
||||
return n1->proto().total_exec_micros() >
|
||||
n2->proto().total_exec_micros();
|
||||
} else if (opts.order_by == kOrderBy[3]) {
|
||||
return n1->proto().total_parameters() >
|
||||
n2->proto().total_parameters();
|
||||
} else if (opts.order_by == kOrderBy[4]) {
|
||||
return n1->proto().total_float_ops() >
|
||||
n2->proto().total_float_ops();
|
||||
}
|
||||
return name_cmp;
|
||||
});
|
||||
return sorted_nodes;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tfprof
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_TOOLS_TFPROF_INTERNAL_TFPROF_SHOW_CODE_H_
|
@ -56,29 +56,38 @@ TFStats::TFStats(std::unique_ptr<GraphDef> graph,
|
||||
printf("Preparing Views...\n");
|
||||
scope_view_ = std::unique_ptr<TFScope>(new TFScope(ckpt_reader_.get()));
|
||||
graph_view_ = std::unique_ptr<TFGraph>(new TFGraph(ckpt_reader_.get()));
|
||||
code_view_ = std::unique_ptr<TFCode>(new TFCode());
|
||||
|
||||
for (auto it = nodes_map_.begin(); it != nodes_map_.end(); it++) {
|
||||
scope_view_->AddNode(&it->second);
|
||||
graph_view_->AddNode(&it->second);
|
||||
code_view_->AddNode(&it->second);
|
||||
}
|
||||
scope_view_->Build();
|
||||
graph_view_->Build();
|
||||
code_view_->Build();
|
||||
}
|
||||
|
||||
const TFProfNode& TFStats::PrintGraph(const string& cmd, const Options& opts) {
|
||||
const TFGraphNodeProto& TFStats::PrintGraph(const string& cmd,
|
||||
const Options& opts) {
|
||||
if (cmd == kCmds[0]) {
|
||||
return scope_view_->Show(opts);
|
||||
} else if (cmd == kCmds[1]) {
|
||||
return graph_view_->Show(opts);
|
||||
} else {
|
||||
fprintf(stderr, "Unknown command: %s\n", cmd.c_str());
|
||||
return empty_node_;
|
||||
return empty_graph_node_;
|
||||
}
|
||||
}
|
||||
|
||||
const TFCodeNodeProto& TFStats::PrintCode(const Options& opts) {
|
||||
return code_view_->Show(opts);
|
||||
}
|
||||
|
||||
void TFStats::ParseGraph() {
|
||||
for (const NodeDef& node : graph_->node()) {
|
||||
CHECK(nodes_map_.find(node.name()) == nodes_map_.end());
|
||||
nodes_map_[node.name()] = TFNode(&node);
|
||||
nodes_map_[node.name()] = TFGraphNode(&node);
|
||||
}
|
||||
for (auto it = nodes_map_.begin(); it != nodes_map_.end(); it++) {
|
||||
const NodeDef* node_def = it->second.node_def();
|
||||
@ -110,6 +119,9 @@ void TFStats::ParseOpLog() {
|
||||
if (entry.float_ops()) {
|
||||
node->second.AddFloatOps(entry.float_ops());
|
||||
}
|
||||
if (entry.has_code_def()) {
|
||||
node->second.AddCode(&entry.code_def());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -131,13 +143,14 @@ void TFStats::ParseRunMeta() {
|
||||
"Missing CostGraphDef in RunMetadata.\nMaybe you forget to"
|
||||
"set tf.ConfigProto(graph_options=tf.GraphOptions("
|
||||
"build_cost_model=1)) to Session()\n");
|
||||
}
|
||||
for (const auto& node_pb : run_meta_->cost_graph().node()) {
|
||||
auto node = nodes_map_.find(node_pb.name());
|
||||
if (node == nodes_map_.end()) {
|
||||
continue;
|
||||
} else {
|
||||
for (const auto& node_pb : run_meta_->cost_graph().node()) {
|
||||
auto node = nodes_map_.find(node_pb.name());
|
||||
if (node == nodes_map_.end()) {
|
||||
continue;
|
||||
}
|
||||
node->second.AddNodeStat(&node_pb);
|
||||
}
|
||||
node->second.AddNodeStat(&node_pb);
|
||||
}
|
||||
}
|
||||
} // namespace tfprof
|
||||
|
@ -35,6 +35,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
#include "tensorflow/tools/tfprof/internal/tfprof_code.h"
|
||||
#include "tensorflow/tools/tfprof/internal/tfprof_graph.h"
|
||||
#include "tensorflow/tools/tfprof/internal/tfprof_node.h"
|
||||
#include "tensorflow/tools/tfprof/internal/tfprof_options.h"
|
||||
@ -56,7 +57,8 @@ class TFStats {
|
||||
|
||||
// Prints the results to stdout. Also returns the printed output in
|
||||
// a proto.
|
||||
const TFProfNode& PrintGraph(const string& cmd, const Options& opts);
|
||||
const TFGraphNodeProto& PrintGraph(const string& cmd, const Options& opts);
|
||||
const TFCodeNodeProto& PrintCode(const Options& opts);
|
||||
|
||||
private:
|
||||
void ParseGraph();
|
||||
@ -67,13 +69,16 @@ class TFStats {
|
||||
|
||||
std::unique_ptr<TFScope> scope_view_;
|
||||
std::unique_ptr<TFGraph> graph_view_;
|
||||
std::unique_ptr<TFCode> code_view_;
|
||||
std::unique_ptr<GraphDef> graph_;
|
||||
std::unique_ptr<RunMetadata> run_meta_;
|
||||
std::unique_ptr<OpLog> op_log_;
|
||||
std::unique_ptr<checkpoint::CheckpointReader> ckpt_reader_;
|
||||
// Store TFNode instead of TFNode* to avoid large number of dynamic alloc.
|
||||
std::map<string, TFNode> nodes_map_;
|
||||
TFProfNode empty_node_;
|
||||
// Store TFGraphNode instead of TFGraphNode* to avoid large number of
|
||||
// dynamic alloc.
|
||||
std::map<string, TFGraphNode> nodes_map_;
|
||||
TFGraphNodeProto empty_graph_node_;
|
||||
TFCodeNodeProto empty_code_node_;
|
||||
};
|
||||
|
||||
} // namespace tfprof
|
||||
|
@ -76,9 +76,9 @@ TEST_F(TFProfStatsTest, CustomOpType) {
|
||||
{".*"}, {""}, {".*"}, {""}, false,
|
||||
{"params", "bytes", "micros", "float_ops", "num_hidden_ops"},
|
||||
false);
|
||||
const TFProfNode& root = tf_stats_->PrintGraph("scope", opts);
|
||||
const TFGraphNodeProto& root = tf_stats_->PrintGraph("scope", opts);
|
||||
|
||||
TFProfNode expected;
|
||||
TFGraphNodeProto expected;
|
||||
CHECK(protobuf::TextFormat::ParseFromString(
|
||||
"name: \"_TFProfRoot\"\nexec_micros: 0\nrequested_bytes: "
|
||||
"0\ntotal_exec_micros: 5\ntotal_requested_bytes: 1480\ntotal_parameters: "
|
||||
@ -108,9 +108,9 @@ TEST_F(TFProfStatsTest, CheckPointOpType) {
|
||||
3, 0, 0, 0, 0, {".*"}, "name", {kCkptVarType}, // accout_type_regexes
|
||||
{".*"}, {""}, {".*"}, {""}, false,
|
||||
{"params", "bytes", "micros", "float_ops", "num_hidden_ops"}, false);
|
||||
const TFProfNode& root = tf_stats_->PrintGraph("scope", opts);
|
||||
const TFGraphNodeProto& root = tf_stats_->PrintGraph("scope", opts);
|
||||
|
||||
TFProfNode expected;
|
||||
TFGraphNodeProto expected;
|
||||
CHECK(protobuf::TextFormat::ParseFromString(
|
||||
"name: \"_TFProfRoot\"\nexec_micros: 0\nrequested_bytes: "
|
||||
"0\ntotal_exec_micros: 5\ntotal_requested_bytes: 1480\ntotal_parameters: "
|
||||
@ -141,9 +141,9 @@ TEST_F(TFProfStatsTest, TestGraph) {
|
||||
{""}, {".*"}, {""}, false,
|
||||
{"params", "bytes", "micros", "float_ops", "num_hidden_ops"},
|
||||
false);
|
||||
const TFProfNode& root = tf_stats_->PrintGraph("graph", opts);
|
||||
const TFGraphNodeProto& root = tf_stats_->PrintGraph("graph", opts);
|
||||
|
||||
TFProfNode expected;
|
||||
TFGraphNodeProto expected;
|
||||
CHECK(protobuf::TextFormat::ParseFromString(
|
||||
"name: \"_TFProfRoot\"\nexec_micros: 0\nrequested_bytes: 0\ninputs: "
|
||||
"0\ntotal_exec_micros: 0\ntotal_requested_bytes: 0\ntotal_parameters: "
|
||||
@ -155,9 +155,9 @@ TEST_F(TFProfStatsTest, TestGraph) {
|
||||
TEST_F(TFProfStatsTest, TestFloatOps) {
|
||||
Options opts(10, 0, 0, 0, 1, {".*"}, "name", {".*"}, {".*"}, {""}, {".*"},
|
||||
{""}, false, {"float_ops"}, false);
|
||||
const TFProfNode& root = tf_stats_->PrintGraph("scope", opts);
|
||||
const TFGraphNodeProto& root = tf_stats_->PrintGraph("scope", opts);
|
||||
|
||||
TFProfNode expected;
|
||||
TFGraphNodeProto expected;
|
||||
CHECK(protobuf::TextFormat::ParseFromString(
|
||||
"name: \"_TFProfRoot\"\nexec_micros: 0\nrequested_bytes: "
|
||||
"0\ntotal_exec_micros: 96\ntotal_requested_bytes: "
|
||||
@ -187,9 +187,9 @@ TEST_F(TFProfStatsTest, TestAccountShownNameOnly) {
|
||||
{"unit_2_1.*DW"}, // show_name_regexes.
|
||||
{""}, true, // account_displayed_op_only.
|
||||
{"params"}, false);
|
||||
const TFProfNode& root = tf_stats_->PrintGraph("scope", opts);
|
||||
const TFGraphNodeProto& root = tf_stats_->PrintGraph("scope", opts);
|
||||
|
||||
TFProfNode expected;
|
||||
TFGraphNodeProto expected;
|
||||
CHECK(protobuf::TextFormat::ParseFromString(
|
||||
"name: \"_TFProfRoot\"\nexec_micros: 0\nrequested_bytes: "
|
||||
"0\ntotal_exec_micros: 0\ntotal_requested_bytes: 0\ntotal_parameters: "
|
||||
@ -203,8 +203,8 @@ TEST_F(TFProfStatsTest, TestShowTensorValue) {
|
||||
{"unit_1_0.*gamma"}, {""}, false,
|
||||
{"tensor_value"}, // Show tensor value from checkpoint.
|
||||
false);
|
||||
const TFProfNode& root = tf_stats_->PrintGraph("scope", opts);
|
||||
TFProfNode expected;
|
||||
const TFGraphNodeProto& root = tf_stats_->PrintGraph("scope", opts);
|
||||
TFGraphNodeProto expected;
|
||||
CHECK(protobuf::TextFormat::ParseFromString(
|
||||
"name: \"_TFProfRoot\"\nexec_micros: 0\nrequested_bytes: "
|
||||
"0\ntotal_exec_micros: 96\ntotal_requested_bytes: "
|
||||
|
@ -58,9 +58,9 @@ TEST_F(TFProfTensorTest, Basics) {
|
||||
Options opts(3, 0, 0, 0, 0, {".*"}, "name", {"VariableV2"}, {".*"}, {""},
|
||||
{".*"}, {""}, false, {"tensor_value"}, // show the tensor value.
|
||||
false);
|
||||
const TFProfNode& root = tf_stats_->PrintGraph("scope", opts);
|
||||
const TFGraphNodeProto& root = tf_stats_->PrintGraph("scope", opts);
|
||||
|
||||
TFProfNode expected;
|
||||
TFGraphNodeProto expected;
|
||||
CHECK(protobuf::TextFormat::ParseFromString(
|
||||
"name: \"_TFProfRoot\"\nexec_micros: 0\nrequested_bytes: "
|
||||
"0\ntotal_exec_micros: 0\ntotal_requested_bytes: 0\ntotal_parameters: "
|
||||
|
@ -2,6 +2,17 @@ syntax = "proto2";
|
||||
|
||||
package tensorflow.tfprof;
|
||||
|
||||
// It specifies the Python callstack that creates an op.
|
||||
message CodeDef {
|
||||
repeated Trace traces = 1;
|
||||
message Trace {
|
||||
optional string file = 1;
|
||||
optional int32 lineno = 2;
|
||||
optional string function = 3;
|
||||
optional string line = 4;
|
||||
}
|
||||
}
|
||||
|
||||
message OpLogEntry {
|
||||
// op name.
|
||||
optional string name = 1;
|
||||
@ -12,6 +23,8 @@ message OpLogEntry {
|
||||
// User can define extra op type information for an op. This allows the user
|
||||
// to select a group of ops precisely using op_type as a key.
|
||||
repeated string types = 3;
|
||||
// Used to support tfprof "code" view.
|
||||
optional CodeDef code_def = 4;
|
||||
}
|
||||
|
||||
message OpLog {
|
||||
|
@ -160,12 +160,13 @@ int main(int argc, char** argv) {
|
||||
"Profiling everything!\n");
|
||||
return 0;
|
||||
} else if (argc > 1) {
|
||||
if (tensorflow::string(argv[1]) == tensorflow::tfprof::kCmds[3]) {
|
||||
if (tensorflow::string(argv[1]) == tensorflow::tfprof::kCmds[4]) {
|
||||
tensorflow::tfprof::PrintHelp();
|
||||
return 0;
|
||||
}
|
||||
if (tensorflow::string(argv[1]) == tensorflow::tfprof::kCmds[0] ||
|
||||
tensorflow::string(argv[1]) == tensorflow::tfprof::kCmds[1]) {
|
||||
tensorflow::string(argv[1]) == tensorflow::tfprof::kCmds[1] ||
|
||||
tensorflow::string(argv[1]) == tensorflow::tfprof::kCmds[2]) {
|
||||
cmd = argv[1];
|
||||
}
|
||||
}
|
||||
@ -214,7 +215,10 @@ int main(int argc, char** argv) {
|
||||
hide_name_regexes, FLAGS_account_displayed_op_only, select, FLAGS_viz,
|
||||
FLAGS_dump_to_file);
|
||||
|
||||
if (!cmd.empty()) {
|
||||
if (cmd == tensorflow::tfprof::kCmds[2]) {
|
||||
tf_stat.PrintCode(opts);
|
||||
return 0;
|
||||
} else if (!cmd.empty()) {
|
||||
tf_stat.PrintGraph(cmd, opts);
|
||||
return 0;
|
||||
}
|
||||
@ -240,10 +244,12 @@ int main(int argc, char** argv) {
|
||||
fprintf(stderr, "E: %s\n", s.ToString().c_str());
|
||||
continue;
|
||||
}
|
||||
if (cmd == tensorflow::tfprof::kCmds[2]) {
|
||||
if (cmd == tensorflow::tfprof::kCmds[3]) {
|
||||
opts = new_opts;
|
||||
} else if (cmd == tensorflow::tfprof::kCmds[3]) {
|
||||
} else if (cmd == tensorflow::tfprof::kCmds[4]) {
|
||||
tensorflow::tfprof::PrintHelp();
|
||||
} else if (cmd == tensorflow::tfprof::kCmds[2]) {
|
||||
tf_stat.PrintCode(new_opts);
|
||||
} else {
|
||||
tf_stat.PrintGraph(cmd, new_opts);
|
||||
}
|
||||
|
@ -21,4 +21,4 @@ message OptionsProto {
|
||||
repeated string select = 14;
|
||||
optional bool viz = 15;
|
||||
optional string dump_to_file = 16;
|
||||
}
|
||||
}
|
||||
|
@ -14,7 +14,8 @@ message TFProfTensorProto {
|
||||
repeated string value_str = 4;
|
||||
}
|
||||
|
||||
message TFProfNode {
|
||||
// A node in TensorFlow graph. Used by scope/graph view.
|
||||
message TFGraphNodeProto {
|
||||
// op name.
|
||||
optional string name = 1;
|
||||
// tensor value restored from checkpoint.
|
||||
@ -45,5 +46,34 @@ message TFProfNode {
|
||||
repeated TensorShapeProto shapes = 11;
|
||||
// Descendants of the graph. The actual descendants depend on the data
|
||||
// structure used (scope, graph).
|
||||
repeated TFProfNode children = 12;
|
||||
repeated TFGraphNodeProto children = 12;
|
||||
}
|
||||
|
||||
// A node in TensorFlow Python call trace stack. Used by code view.
|
||||
message TFCodeNodeProto {
|
||||
// A trace in the trace stack.
|
||||
optional string name = 1;
|
||||
|
||||
// code execution time.
|
||||
optional int64 exec_micros = 2;
|
||||
// Total requested bytes by the code.
|
||||
optional int64 requested_bytes = 3;
|
||||
// Number of parameters if available.
|
||||
optional int64 parameters = 4;
|
||||
// Number of float operations.
|
||||
optional int64 float_ops = 5;
|
||||
|
||||
// The following are the aggregated stats from called descendents and the
|
||||
// trace itself. The actual descendants depend on the data structure used.
|
||||
optional int64 total_exec_micros = 6;
|
||||
optional int64 total_requested_bytes = 7;
|
||||
optional int64 total_parameters = 8;
|
||||
optional int64 total_float_ops = 9;
|
||||
|
||||
// A set of graph nodes created by the leaf of the call stack.
|
||||
// 'children' field should be empty if graph_nodes is non-empty.
|
||||
repeated TFGraphNodeProto graph_nodes = 10;
|
||||
// Descendants of the graph. The actual descendants depend on the data
|
||||
// structure used (scope, graph).
|
||||
repeated TFCodeNodeProto children = 11;
|
||||
}
|
Loading…
Reference in New Issue
Block a user