tfprof multi-step profiling.
This allows users to fill in RunMetadata across different steps. 1. It is useful for RL model which runs a subset of graph each step. 2. It also gets averages of multi-step stats. PiperOrigin-RevId: 157552388
This commit is contained in:
parent
7aac2395ce
commit
a7fff05e05
@ -33,6 +33,22 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "profiler_test",
|
||||
srcs = ["profiler_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_pip"],
|
||||
deps = [
|
||||
":model_analyzer",
|
||||
":model_analyzer_testlib",
|
||||
"//tensorflow/python:client",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:variables",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "model_analyzer_testlib",
|
||||
srcs = ["model_analyzer_testlib.py"],
|
||||
|
@ -112,6 +112,181 @@ PRINT_ALL_TIMING_MEMORY = {
|
||||
# pylint: enable=bad-continuation
|
||||
|
||||
|
||||
def _build_options(tfprof_options):
|
||||
"""Build tfprof.OptionsProto.
|
||||
|
||||
Args:
|
||||
tfprof_options: A dictionary of options.
|
||||
Returns:
|
||||
tfprof.OptionsProto.
|
||||
"""
|
||||
opts = tfprof_options_pb2.OptionsProto()
|
||||
opts.max_depth = tfprof_options.get('max_depth', 10)
|
||||
opts.min_bytes = tfprof_options.get('min_bytes', 0)
|
||||
opts.min_micros = tfprof_options.get('min_micros', 0)
|
||||
opts.min_params = tfprof_options.get('min_params', 0)
|
||||
opts.min_float_ops = tfprof_options.get('min_float_ops', 0)
|
||||
opts.min_occurrence = tfprof_options.get('min_occurrence', 0)
|
||||
|
||||
opts.step = tfprof_options.get('step', -1)
|
||||
|
||||
opts.order_by = tfprof_options.get('order_by', 'name')
|
||||
|
||||
for p in tfprof_options.get('account_type_regexes', []):
|
||||
opts.account_type_regexes.append(p)
|
||||
for p in tfprof_options.get('start_name_regexes', []):
|
||||
opts.start_name_regexes.append(p)
|
||||
for p in tfprof_options.get('trim_name_regexes', []):
|
||||
opts.trim_name_regexes.append(p)
|
||||
for p in tfprof_options.get('show_name_regexes', []):
|
||||
opts.show_name_regexes.append(p)
|
||||
for p in tfprof_options.get('hide_name_regexes', []):
|
||||
opts.hide_name_regexes.append(p)
|
||||
opts.account_displayed_op_only = tfprof_options.get(
|
||||
'account_displayed_op_only', False)
|
||||
|
||||
for p in tfprof_options.get('select', []):
|
||||
opts.select.append(p)
|
||||
|
||||
opts.output = tfprof_options.get('output', 'stdout')
|
||||
opts.dump_to_file = tfprof_options.get('dump_to_file', '')
|
||||
|
||||
return opts
|
||||
|
||||
|
||||
class Profiler(object):
|
||||
"""TensorFlow multi-step profiler.
|
||||
|
||||
See go/tfprof or README for details.
|
||||
|
||||
Typical use case:
|
||||
# Currently we are only allowed to create 1 profiler per process.
|
||||
profiler = Profile(sess.graph)
|
||||
|
||||
for i in xrange(total_steps):
|
||||
if i % 10000 == 0:
|
||||
run_meta = tf.RunMetadata()
|
||||
_ = sess.run(...,
|
||||
options=tf.RunOptions(
|
||||
trace_level=tf.RunOptions.FULL_TRACE),
|
||||
run_metadata=run_meta)
|
||||
profiler.add_step(i, run_meta)
|
||||
|
||||
# Profile the parameters of your model.
|
||||
profiler.profile_name_scope(options=TRAINABLE_VARS_PARAMS_STAT_OPTIONS)
|
||||
|
||||
# Or profile the timing of your model operations.
|
||||
opts = PRINT_ALL_TIMING_MEMORY.copy()
|
||||
opts['order_by'] = 'micros'
|
||||
opts['select'] = ['micros', 'occurrence']
|
||||
opts['max_depth'] = 20
|
||||
profiler.profile_operations(options=opts)
|
||||
|
||||
# Or you can generate a timeline:
|
||||
opts = PRINT_ALL_TIMING_MEMORY.copy()
|
||||
opts['output'] = 'timeline:outfile=' + filename
|
||||
opts['step'] = i
|
||||
profiler.profile_graph(options=opts)
|
||||
else:
|
||||
_ = sess.run(...)
|
||||
"""
|
||||
|
||||
def __init__(self, graph, op_log=None):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
graph: tf.Graph.
|
||||
op_log: optional. tensorflow::tfprof::OpLog proto. Used to define
|
||||
extra op types.
|
||||
"""
|
||||
self._graph = graph
|
||||
# pylint: disable=protected-access
|
||||
op_log = tfprof_logger._merge_default_with_oplog(
|
||||
self._graph, op_log=op_log)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
print_mdl.NewProfiler(
|
||||
self._graph.as_graph_def().SerializeToString(),
|
||||
op_log.SerializeToString())
|
||||
|
||||
def __del__(self):
|
||||
print_mdl.DeleteProfiler()
|
||||
|
||||
def add_step(self, step, run_meta):
|
||||
"""Add statistics of a step.
|
||||
|
||||
Args:
|
||||
step: A step uint64 used to identify the RunMetadata. Must be different
|
||||
across different AddStep() calls.
|
||||
run_meta: RunMetadata proto that contains statistics of a session run.
|
||||
"""
|
||||
# pylint: disable=protected-access
|
||||
op_log = tfprof_logger._merge_default_with_oplog(
|
||||
self._graph, run_meta=run_meta, add_trace=False,
|
||||
add_trainable_var=False)
|
||||
# pylint: enable=protected-access
|
||||
print_mdl.AddStep(
|
||||
step, run_meta.SerializeToString(), op_log.SerializeToString())
|
||||
|
||||
def profile_python_codes(self, options):
|
||||
"""Profile the statistics of the Python codes.
|
||||
|
||||
Hint: set options['show_name_regexes'] = ['.*my_code.py.*']
|
||||
|
||||
Args:
|
||||
options: A dict of profiler options.
|
||||
Returns:
|
||||
a TFMultiGraphNodeProto that records the results.
|
||||
"""
|
||||
opts = _build_options(options)
|
||||
tfprof_node = tfprof_output_pb2.TFMultiGraphNodeProto()
|
||||
tfprof_node.ParseFromString(
|
||||
print_mdl.Profile('code'.encode('utf-8'), opts.SerializeToString()))
|
||||
return tfprof_node
|
||||
|
||||
def profile_operations(self, options):
|
||||
"""Profile the statistics of the Operation types (e.g. MatMul, Conv2D).
|
||||
|
||||
Args:
|
||||
options: A dict of profiler options.
|
||||
Returns:
|
||||
a TFMultiGraphNodeProto that records the results.
|
||||
"""
|
||||
opts = _build_options(options)
|
||||
tfprof_node = tfprof_output_pb2.TFMultiGraphNodeProto()
|
||||
tfprof_node.ParseFromString(
|
||||
print_mdl.Profile('op'.encode('utf-8'), opts.SerializeToString()))
|
||||
return tfprof_node
|
||||
|
||||
def profile_name_scope(self, options):
|
||||
"""Profile the statistics of graph nodes, organized by name scope.
|
||||
|
||||
Args:
|
||||
options: A dict of profiler options.
|
||||
Returns:
|
||||
a TFGraphNodeProto that records the results.
|
||||
"""
|
||||
opts = _build_options(options)
|
||||
tfprof_node = tfprof_output_pb2.TFGraphNodeProto()
|
||||
tfprof_node.ParseFromString(
|
||||
print_mdl.Profile('scope'.encode('utf-8'), opts.SerializeToString()))
|
||||
return tfprof_node
|
||||
|
||||
def profile_graph(self, options):
|
||||
"""Profile the statistics of graph nodes, organized by dataflow graph.
|
||||
|
||||
Args:
|
||||
options: A dict of profiler options.
|
||||
Returns:
|
||||
a TFGraphNodeProto that records the results.
|
||||
"""
|
||||
opts = _build_options(options)
|
||||
tfprof_node = tfprof_output_pb2.TFGraphNodeProto()
|
||||
tfprof_node.ParseFromString(
|
||||
print_mdl.Profile('graph'.encode('utf-8'), opts.SerializeToString()))
|
||||
return tfprof_node
|
||||
|
||||
|
||||
def print_model_analysis(graph,
|
||||
run_meta=None,
|
||||
op_log=None,
|
||||
@ -145,33 +320,8 @@ def print_model_analysis(graph,
|
||||
op_log = tfprof_logger._merge_default_with_oplog(
|
||||
graph, op_log, run_meta, add_trace=tfprof_cmd == 'code')
|
||||
# pylint: enable=protected-access
|
||||
opts = tfprof_options_pb2.OptionsProto()
|
||||
opts.max_depth = tfprof_options['max_depth']
|
||||
opts.min_bytes = tfprof_options['min_bytes']
|
||||
opts.min_micros = tfprof_options['min_micros']
|
||||
opts.min_params = tfprof_options['min_params']
|
||||
opts.min_float_ops = tfprof_options['min_float_ops']
|
||||
if 'min_occurrence' in tfprof_options:
|
||||
opts.min_occurrence = tfprof_options['min_occurrence']
|
||||
else:
|
||||
opts.min_occurrence = 0
|
||||
|
||||
opts.order_by = tfprof_options['order_by']
|
||||
for p in tfprof_options['account_type_regexes']:
|
||||
opts.account_type_regexes.append(p)
|
||||
for p in tfprof_options['start_name_regexes']:
|
||||
opts.start_name_regexes.append(p)
|
||||
for p in tfprof_options['trim_name_regexes']:
|
||||
opts.trim_name_regexes.append(p)
|
||||
for p in tfprof_options['show_name_regexes']:
|
||||
opts.show_name_regexes.append(p)
|
||||
for p in tfprof_options['hide_name_regexes']:
|
||||
opts.hide_name_regexes.append(p)
|
||||
opts.account_displayed_op_only = tfprof_options['account_displayed_op_only']
|
||||
for p in tfprof_options['select']:
|
||||
opts.select.append(p)
|
||||
opts.output = tfprof_options['output']
|
||||
opts.dump_to_file = tfprof_options['dump_to_file']
|
||||
opts = _build_options(tfprof_options)
|
||||
|
||||
run_meta_str = run_meta.SerializeToString() if run_meta else b''
|
||||
|
||||
|
@ -199,6 +199,7 @@ class PrintModelAnalysisTest(test.TestCase):
|
||||
opts['output'] = 'timeline:outfile=' + outfile
|
||||
opts['account_type_regexes'] = ['.*']
|
||||
opts['max_depth'] = 100000
|
||||
opts['step'] = 0
|
||||
|
||||
with session.Session() as sess, ops.device('/cpu:0'):
|
||||
x = lib.BuildFullModel()
|
||||
|
@ -65,3 +65,33 @@ def BuildFullModel():
|
||||
loss = nn_ops.l2_loss(math_ops.reduce_mean(target - out))
|
||||
sgd_op = gradient_descent.GradientDescentOptimizer(1e-2)
|
||||
return sgd_op.minimize(loss)
|
||||
|
||||
|
||||
def BuildSplitableModel():
|
||||
"""Build a small model that can be run partially in each step."""
|
||||
image = array_ops.zeros([2, 6, 6, 3])
|
||||
|
||||
kernel1 = variable_scope.get_variable(
|
||||
'DW', [3, 3, 3, 6],
|
||||
dtypes.float32,
|
||||
initializer=init_ops.random_normal_initializer(stddev=0.001))
|
||||
r1 = nn_ops.conv2d(image, kernel1, [1, 2, 2, 1], padding='SAME')
|
||||
|
||||
kernel2 = variable_scope.get_variable(
|
||||
'DW2', [2, 3, 3, 6],
|
||||
dtypes.float32,
|
||||
initializer=init_ops.random_normal_initializer(stddev=0.001))
|
||||
r2 = nn_ops.conv2d(image, kernel2, [1, 2, 2, 1], padding='SAME')
|
||||
|
||||
r3 = r1 + r2
|
||||
return r1, r2, r3
|
||||
|
||||
|
||||
def SearchTFProfNode(node, name):
|
||||
"""Search a node in the tree."""
|
||||
if node.name == name:
|
||||
return node
|
||||
for c in node.children:
|
||||
r = SearchTFProfNode(c, name)
|
||||
if r: return r
|
||||
return None
|
||||
|
184
tensorflow/contrib/tfprof/python/tools/tfprof/profiler_test.py
Normal file
184
tensorflow/contrib/tfprof/python/tools/tfprof/profiler_test.py
Normal file
@ -0,0 +1,184 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
from tensorflow.contrib.tfprof.python.tools.tfprof import model_analyzer
|
||||
from tensorflow.contrib.tfprof.python.tools.tfprof import model_analyzer_testlib as lib
|
||||
|
||||
|
||||
class ProfilerTest(test.TestCase):
|
||||
|
||||
def testProfileBasic(self):
|
||||
ops.reset_default_graph()
|
||||
opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS.copy()
|
||||
opts['account_type_regexes'] = ['.*']
|
||||
opts['select'] = ['params', 'float_ops', 'micros', 'bytes',
|
||||
'device', 'op_types', 'occurrence']
|
||||
outfile = os.path.join(test.get_temp_dir(), 'dump')
|
||||
opts['output'] = 'file:outfile=' + outfile
|
||||
|
||||
# Test the output without run_meta.
|
||||
sess = session.Session()
|
||||
r = lib.BuildFullModel()
|
||||
sess.run(variables.global_variables_initializer())
|
||||
|
||||
profiler = model_analyzer.Profiler(sess.graph)
|
||||
profiler.profile_name_scope(opts)
|
||||
with gfile.Open(outfile, 'r') as f:
|
||||
profiler_str = f.read()
|
||||
|
||||
model_analyzer.print_model_analysis(
|
||||
sess.graph, tfprof_cmd='scope', tfprof_options=opts)
|
||||
with gfile.Open(outfile, 'r') as f:
|
||||
pma_str = f.read()
|
||||
self.assertEqual(pma_str, profiler_str)
|
||||
|
||||
# Test the output with run_meta.
|
||||
run_meta = config_pb2.RunMetadata()
|
||||
_ = sess.run(r,
|
||||
options=config_pb2.RunOptions(
|
||||
trace_level=config_pb2.RunOptions.FULL_TRACE),
|
||||
run_metadata=run_meta)
|
||||
|
||||
profiler.add_step(1, run_meta)
|
||||
profiler.profile_graph(opts)
|
||||
with gfile.Open(outfile, 'r') as f:
|
||||
profiler_str = f.read()
|
||||
|
||||
model_analyzer.print_model_analysis(
|
||||
sess.graph, tfprof_cmd='graph', run_meta=run_meta, tfprof_options=opts)
|
||||
with gfile.Open(outfile, 'r') as f:
|
||||
pma_str = f.read()
|
||||
self.assertEqual(pma_str, profiler_str)
|
||||
|
||||
profiler.profile_python_codes(opts)
|
||||
with gfile.Open(outfile, 'r') as f:
|
||||
profiler_str = f.read()
|
||||
|
||||
model_analyzer.print_model_analysis(
|
||||
sess.graph, tfprof_cmd='code', run_meta=run_meta, tfprof_options=opts)
|
||||
with gfile.Open(outfile, 'r') as f:
|
||||
pma_str = f.read()
|
||||
self.assertEqual(pma_str, profiler_str)
|
||||
|
||||
profiler.profile_operations(opts)
|
||||
with gfile.Open(outfile, 'r') as f:
|
||||
profiler_str = f.read()
|
||||
|
||||
model_analyzer.print_model_analysis(
|
||||
sess.graph, tfprof_cmd='op', run_meta=run_meta, tfprof_options=opts)
|
||||
with gfile.Open(outfile, 'r') as f:
|
||||
pma_str = f.read()
|
||||
self.assertEqual(pma_str, profiler_str)
|
||||
|
||||
# Test the output difference between multi-step profile and 1-step profile.
|
||||
_ = sess.run(r,
|
||||
options=config_pb2.RunOptions(
|
||||
trace_level=config_pb2.RunOptions.FULL_TRACE),
|
||||
run_metadata=run_meta)
|
||||
|
||||
profiler.add_step(2, run_meta)
|
||||
profiler.profile_name_scope(opts)
|
||||
with gfile.Open(outfile, 'r') as f:
|
||||
profiler_str = f.read()
|
||||
|
||||
model_analyzer.print_model_analysis(
|
||||
sess.graph, tfprof_cmd='scope', run_meta=run_meta, tfprof_options=opts)
|
||||
with gfile.Open(outfile, 'r') as f:
|
||||
pma_str = f.read()
|
||||
self.assertNotEqual(pma_str, profiler_str)
|
||||
|
||||
opts2 = opts.copy()
|
||||
opts2['select'] = ['params', 'float_ops']
|
||||
profiler.profile_name_scope(opts2)
|
||||
with gfile.Open(outfile, 'r') as f:
|
||||
profiler_str = f.read()
|
||||
|
||||
model_analyzer.print_model_analysis(
|
||||
sess.graph, tfprof_cmd='scope', run_meta=run_meta, tfprof_options=opts2)
|
||||
with gfile.Open(outfile, 'r') as f:
|
||||
pma_str = f.read()
|
||||
self.assertEqual(pma_str, profiler_str)
|
||||
|
||||
def testMultiStepProfile(self):
|
||||
ops.reset_default_graph()
|
||||
opts = model_analyzer.PRINT_ALL_TIMING_MEMORY.copy()
|
||||
opts['account_type_regexes'] = ['.*']
|
||||
|
||||
with session.Session() as sess, ops.device('/cpu:0'):
|
||||
r1, r2, r3 = lib.BuildSplitableModel()
|
||||
sess.run(variables.global_variables_initializer())
|
||||
|
||||
profiler = model_analyzer.Profiler(sess.graph)
|
||||
pb0 = profiler.profile_name_scope(opts)
|
||||
|
||||
run_meta = config_pb2.RunMetadata()
|
||||
_ = sess.run(r1,
|
||||
options=config_pb2.RunOptions(
|
||||
trace_level=config_pb2.RunOptions.FULL_TRACE),
|
||||
run_metadata=run_meta)
|
||||
profiler.add_step(1, run_meta)
|
||||
pb1 = profiler.profile_name_scope(opts)
|
||||
|
||||
self.assertNotEqual(lib.SearchTFProfNode(pb1, 'DW'), None)
|
||||
self.assertEqual(lib.SearchTFProfNode(pb1, 'DW2'), None)
|
||||
self.assertEqual(lib.SearchTFProfNode(pb1, 'add'), None)
|
||||
|
||||
run_meta2 = config_pb2.RunMetadata()
|
||||
_ = sess.run(r2,
|
||||
options=config_pb2.RunOptions(
|
||||
trace_level=config_pb2.RunOptions.FULL_TRACE),
|
||||
run_metadata=run_meta2)
|
||||
profiler.add_step(2, run_meta2)
|
||||
pb2 = profiler.profile_name_scope(opts)
|
||||
|
||||
self.assertNotEqual(lib.SearchTFProfNode(pb2, 'DW'), None)
|
||||
self.assertNotEqual(lib.SearchTFProfNode(pb2, 'DW2'), None)
|
||||
self.assertEqual(lib.SearchTFProfNode(pb2, 'add'), None)
|
||||
|
||||
run_meta3 = config_pb2.RunMetadata()
|
||||
_ = sess.run(r3,
|
||||
options=config_pb2.RunOptions(
|
||||
trace_level=config_pb2.RunOptions.FULL_TRACE),
|
||||
run_metadata=run_meta3)
|
||||
profiler.add_step(3, run_meta3)
|
||||
pb3 = profiler.profile_name_scope(opts)
|
||||
|
||||
self.assertNotEqual(lib.SearchTFProfNode(pb3, 'DW'), None)
|
||||
self.assertNotEqual(lib.SearchTFProfNode(pb3, 'DW2'), None)
|
||||
self.assertNotEqual(lib.SearchTFProfNode(pb3, 'add'), None)
|
||||
|
||||
self.assertEqual(lib.SearchTFProfNode(pb0, 'Conv2D'), None)
|
||||
self.assertGreater(lib.SearchTFProfNode(pb1, 'Conv2D').exec_micros, 0)
|
||||
self.assertEqual(lib.SearchTFProfNode(pb1, 'Conv2D_1'), None)
|
||||
self.assertGreater(lib.SearchTFProfNode(pb2, 'Conv2D_1').exec_micros, 0)
|
||||
self.assertEqual(lib.SearchTFProfNode(pb2, 'add'), None)
|
||||
self.assertGreater(lib.SearchTFProfNode(pb3, 'add').exec_micros, 0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -19,6 +19,8 @@ limitations under the License.
|
||||
%{
|
||||
#include "tensorflow/tools/tfprof/internal/print_model_analysis.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
|
||||
using tensorflow::int64;
|
||||
%}
|
||||
|
||||
%typemap(typecheck) const string & = char *;
|
||||
@ -37,6 +39,10 @@ limitations under the License.
|
||||
%unignore tensorflow;
|
||||
%unignore tensorflow::tfprof;
|
||||
%unignore tensorflow::tfprof::PrintModelAnalysis;
|
||||
%unignore tensorflow::tfprof::NewProfiler;
|
||||
%unignore tensorflow::tfprof::DeleteProfiler;
|
||||
%unignore tensorflow::tfprof::AddStep;
|
||||
%unignore tensorflow::tfprof::Profile;
|
||||
|
||||
%include "tensorflow/tools/tfprof/internal/print_model_analysis.h"
|
||||
|
||||
|
@ -62,13 +62,16 @@ def _fill_missing_graph_shape(graph, run_meta):
|
||||
return graph
|
||||
|
||||
|
||||
def _get_logged_ops(graph, run_meta=None, add_trace=True):
|
||||
def _get_logged_ops(graph, run_meta=None, add_trace=True,
|
||||
add_trainable_var=True):
|
||||
"""Extract trainable model parameters and FLOPs for ops from a Graph.
|
||||
|
||||
Args:
|
||||
graph: tf.Graph.
|
||||
run_meta: RunMetadata proto used to complete shape information.
|
||||
add_trace: Whether to add op trace information.
|
||||
add_trainable_var: Whether to assign tf.trainable_variables() op type
|
||||
'_trainable_variables'.
|
||||
Returns:
|
||||
logged_ops: dict mapping from op_name to OpLogEntry.
|
||||
"""
|
||||
@ -77,6 +80,7 @@ def _get_logged_ops(graph, run_meta=None, add_trace=True):
|
||||
|
||||
op_missing_shape = 0
|
||||
logged_ops = {}
|
||||
# TODO(xpan): Work with Profiler more efficiently.
|
||||
for op in graph.get_operations():
|
||||
try:
|
||||
stats = ops.get_stats_for_node_def(
|
||||
@ -105,23 +109,24 @@ def _get_logged_ops(graph, run_meta=None, add_trace=True):
|
||||
if add_entry:
|
||||
logged_ops[entry.name] = entry
|
||||
|
||||
for v in graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES):
|
||||
if v.op.name not in logged_ops:
|
||||
entry = tfprof_log_pb2.OpLogEntry()
|
||||
entry.name = v.op.name
|
||||
entry.types.append(TRAINABLE_VARIABLES)
|
||||
logged_ops[entry.name] = entry
|
||||
else:
|
||||
logged_ops[v.op.name].types.append(TRAINABLE_VARIABLES)
|
||||
if add_trainable_var:
|
||||
for v in graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES):
|
||||
if v.op.name not in logged_ops:
|
||||
entry = tfprof_log_pb2.OpLogEntry()
|
||||
entry.name = v.op.name
|
||||
entry.types.append(TRAINABLE_VARIABLES)
|
||||
logged_ops[entry.name] = entry
|
||||
else:
|
||||
logged_ops[v.op.name].types.append(TRAINABLE_VARIABLES)
|
||||
|
||||
if op_missing_shape > 0 and not run_meta:
|
||||
sys.stderr.write('%d ops no flops stats due to incomplete shapes. '
|
||||
'Consider passing run_meta to use run_time shapes.\n' %
|
||||
sys.stderr.write('%d ops no flops stats due to incomplete shapes.\n' %
|
||||
op_missing_shape)
|
||||
return logged_ops
|
||||
|
||||
|
||||
def _merge_default_with_oplog(graph, op_log=None, run_meta=None,
|
||||
add_trace=True):
|
||||
add_trace=True, add_trainable_var=True):
|
||||
"""Merge the tfprof default extra info with caller's op_log.
|
||||
|
||||
Args:
|
||||
@ -129,11 +134,15 @@ def _merge_default_with_oplog(graph, op_log=None, run_meta=None,
|
||||
op_log: OpLog proto.
|
||||
run_meta: RunMetadata proto used to complete shape information.
|
||||
add_trace: Whether to add op trace information.
|
||||
add_trainable_var: Whether to assign tf.trainable_variables() op type
|
||||
'_trainable_variables'.
|
||||
Returns:
|
||||
tmp_op_log: Merged OpLog proto.
|
||||
"""
|
||||
tmp_op_log = tfprof_log_pb2.OpLog()
|
||||
logged_ops = _get_logged_ops(graph, run_meta, add_trace=add_trace)
|
||||
logged_ops = _get_logged_ops(
|
||||
graph, run_meta, add_trace=add_trace, add_trainable_var=add_trainable_var)
|
||||
|
||||
if not op_log:
|
||||
tmp_op_log.log_entries.extend(logged_ops.values())
|
||||
else:
|
||||
|
@ -259,7 +259,8 @@ tfprof>
|
||||
-min_micros 0
|
||||
-min_params 0
|
||||
-min_float_ops 0
|
||||
-min_occurrence 0
|
||||
-min_occurrence 0
|
||||
-step -1
|
||||
-order_by name
|
||||
-account_type_regexes Variable,VariableV2
|
||||
-start_name_regexes .*
|
||||
@ -598,6 +599,8 @@ provides checkpointed tensors' values.
|
||||
|
||||
`-min_occurrence`: Show ops that appear at least this number of times. Only available in "op" view.
|
||||
|
||||
`-step`: Show the stats of the this step when multiple steps of RunMetadata were added. By default, show the average of all steps."
|
||||
|
||||
`-order_by`: Order the results by [name|depth|bytes|micros|params|float_ops|occurrence]
|
||||
|
||||
`-account_type_regexes`: Account and display the ops whose types match one of the type regexes specified. tfprof allow user to define extra op types for ops through tensorflow.tfprof.OpLog proto. regexes are comma-sperated.
|
||||
|
@ -30,6 +30,89 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfprof {
|
||||
namespace {
|
||||
TFStats* tf_stat = nullptr;
|
||||
|
||||
string RunProfile(const string& command, const string& options,
|
||||
TFStats* tf_stats) {
|
||||
Options opts;
|
||||
tensorflow::Status s = Options::FromProtoStr(options, &opts);
|
||||
if (!s.ok()) {
|
||||
fprintf(stderr, "%s\n", s.ToString().c_str());
|
||||
return "";
|
||||
}
|
||||
|
||||
if (opts.output_type == kOutput[1]) {
|
||||
printf("\n=========================Options=============================\n");
|
||||
printf("%s", opts.ToString().c_str());
|
||||
printf("\n==================Model Analysis Report======================\n");
|
||||
string ret = "";
|
||||
if (command == kCmds[2] || command == kCmds[3]) {
|
||||
ret = tf_stats->ShowMultiGraphNode(command, opts).SerializeAsString();
|
||||
} else if (command == kCmds[0] || command == kCmds[1]) {
|
||||
ret = tf_stats->ShowGraphNode(command, opts).SerializeAsString();
|
||||
} else {
|
||||
fprintf(stderr, "Unknown command: %s\n", command.c_str());
|
||||
}
|
||||
printf("\n======================End of Report==========================\n");
|
||||
fflush(stdout);
|
||||
return ret;
|
||||
}
|
||||
if (command == kCmds[2] || command == kCmds[3]) {
|
||||
return tf_stats->ShowMultiGraphNode(command, opts).SerializeAsString();
|
||||
} else if (command == kCmds[0] || command == kCmds[1]) {
|
||||
return tf_stats->ShowGraphNode(command, opts).SerializeAsString();
|
||||
} else {
|
||||
fprintf(stderr, "Unknown command: %s\n", command.c_str());
|
||||
return "";
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool NewProfiler(const string* graph, const string* op_log) {
|
||||
CHECK(!tf_stat) << "Currently only 1 living tfprof profiler is allowed";
|
||||
CHECK(graph) << "graph mustn't be null";
|
||||
std::unique_ptr<GraphDef> graph_ptr(new GraphDef());
|
||||
graph_ptr->ParseFromString(*graph);
|
||||
|
||||
std::unique_ptr<OpLog> op_log_ptr;
|
||||
if (op_log && !op_log->empty()) {
|
||||
op_log_ptr.reset(new OpLog());
|
||||
op_log_ptr->ParseFromString(*op_log);
|
||||
}
|
||||
tf_stat = new TFStats(std::move(graph_ptr), nullptr, std::move(op_log_ptr),
|
||||
nullptr);
|
||||
return true;
|
||||
}
|
||||
|
||||
void DeleteProfiler() {
|
||||
delete tf_stat;
|
||||
tf_stat = nullptr;
|
||||
}
|
||||
|
||||
void AddStep(int64 step, const string* run_meta, const string* op_log) {
|
||||
CHECK(tf_stat);
|
||||
CHECK(run_meta && !run_meta->empty());
|
||||
// TODO(xpan): Better error handling.
|
||||
std::unique_ptr<RunMetadata> run_meta_ptr(new RunMetadata());
|
||||
run_meta_ptr->ParseFromString(*run_meta);
|
||||
tf_stat->ParseRunMeta(step, std::move(run_meta_ptr));
|
||||
|
||||
std::unique_ptr<OpLog> op_log_ptr;
|
||||
if (op_log && !op_log->empty()) {
|
||||
op_log_ptr.reset(new OpLog());
|
||||
op_log_ptr->ParseFromString(*op_log);
|
||||
}
|
||||
tf_stat->ParseOpLog(std::move(op_log_ptr));
|
||||
}
|
||||
|
||||
string Profile(const string* command, const string* options) {
|
||||
CHECK(tf_stat);
|
||||
CHECK(command) << "command mustn't be null";
|
||||
CHECK(options) << "options mustn't be null";
|
||||
return RunProfile(*command, *options, tf_stat);
|
||||
}
|
||||
|
||||
string PrintModelAnalysis(const string* graph, const string* run_meta,
|
||||
const string* op_log, const string* command,
|
||||
const string* options) {
|
||||
@ -51,42 +134,13 @@ string PrintModelAnalysis(const string* graph, const string* run_meta,
|
||||
op_log_ptr->ParseFromString(*op_log);
|
||||
}
|
||||
|
||||
// TODO(xpan): Maybe need to init the checkpoint reader?
|
||||
std::unique_ptr<checkpoint::CheckpointReader> ckpt_reader;
|
||||
|
||||
TFStats tf_stats(std::move(graph_ptr), std::move(run_meta_ptr),
|
||||
std::move(op_log_ptr), std::move(ckpt_reader));
|
||||
|
||||
Options opts;
|
||||
tensorflow::Status s = Options::FromProtoStr(*options, &opts);
|
||||
if (!s.ok()) {
|
||||
fprintf(stderr, "%s\n", s.ToString().c_str());
|
||||
return "";
|
||||
}
|
||||
|
||||
if (opts.output_type == kOutput[1]) {
|
||||
printf("\n=========================Options=============================\n");
|
||||
printf("%s", opts.ToString().c_str());
|
||||
printf("\n==================Model Analysis Report======================\n");
|
||||
string ret = "";
|
||||
if (*command == kCmds[2] || *command == kCmds[3]) {
|
||||
ret = tf_stats.ShowMultiGraphNode(*command, opts).SerializeAsString();
|
||||
} else if (*command == kCmds[0] || *command == kCmds[1]) {
|
||||
ret = tf_stats.ShowGraphNode(*command, opts).SerializeAsString();
|
||||
} else {
|
||||
fprintf(stderr, "Unknown command: %s\n", command->c_str());
|
||||
}
|
||||
printf("\n======================End of Report==========================\n");
|
||||
fflush(stdout);
|
||||
return ret;
|
||||
}
|
||||
if (*command == kCmds[2] || *command == kCmds[3]) {
|
||||
return tf_stats.ShowMultiGraphNode(*command, opts).SerializeAsString();
|
||||
} else if (*command == kCmds[0] || *command == kCmds[1]) {
|
||||
return tf_stats.ShowGraphNode(*command, opts).SerializeAsString();
|
||||
} else {
|
||||
fprintf(stderr, "Unknown command: %s\n", command->c_str());
|
||||
return "";
|
||||
}
|
||||
return RunProfile(*command, *options, &tf_stats);
|
||||
}
|
||||
} // namespace tfprof
|
||||
} // namespace tensorflow
|
||||
|
@ -23,8 +23,19 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace tfprof {
|
||||
struct Options;
|
||||
// ***This API is only for swig. Don't user it directory!***
|
||||
//
|
||||
|
||||
// **********************
|
||||
// APIs in this file are only for swig.
|
||||
// Talk to xpan@ if you want to call it directly!
|
||||
// *********************
|
||||
bool NewProfiler(const string* graph, const string* op_log);
|
||||
|
||||
void DeleteProfiler();
|
||||
|
||||
void AddStep(int64 step, const string* run_meta, const string* op_log);
|
||||
|
||||
string Profile(const string* command, const string* options);
|
||||
|
||||
// Interface defined for Python API swig. Calls the tfprof core API.
|
||||
// 'graph', 'run_meta', 'op_log' are serialized GraphDef, RunMetadata,
|
||||
// OpLog strings, respectively.
|
||||
|
@ -53,14 +53,16 @@ string GetTraceString(const CodeDef::Trace& trace) {
|
||||
} // namespace
|
||||
|
||||
void TFCode::AddNode(TFGraphNode* node) {
|
||||
if (!node->code()) {
|
||||
if (node->code().traces_size() == 0) {
|
||||
return;
|
||||
}
|
||||
TFMultiGraphNode* pre_trace_node = nullptr;
|
||||
for (int i = 0; i < node->code()->traces_size(); ++i) {
|
||||
// TODO(xpan): Consider to release CodeDef after TFCode is built. It
|
||||
// takes a lot of memory.
|
||||
for (int i = 0; i < node->code().traces_size(); ++i) {
|
||||
// Unlike op name, which is globally unique, trace name is only unique
|
||||
// w.r.t. it's parent.
|
||||
const string& trace = GetTraceString(node->code()->traces(i));
|
||||
const string& trace = GetTraceString(node->code().traces(i));
|
||||
if (i == 0) {
|
||||
if (!trace_root_) {
|
||||
trace_root_.reset(new TFMultiGraphNode(trace));
|
||||
@ -72,7 +74,7 @@ void TFCode::AddNode(TFGraphNode* node) {
|
||||
pre_trace_node->AddChildren(trace);
|
||||
TFMultiGraphNode* trace_node = pre_trace_node->children().at(trace).get();
|
||||
|
||||
if (i == node->code()->traces_size() - 1) {
|
||||
if (i == node->code().traces_size() - 1) {
|
||||
trace_node->AddGraphNode(node);
|
||||
}
|
||||
pre_trace_node = trace_node;
|
||||
|
@ -70,12 +70,19 @@ void TFGraph::Build() {
|
||||
}
|
||||
|
||||
const ShowNode* TFGraph::ShowInternal(const Options& opts, Timeline* timeline) {
|
||||
root_->ResetTotalStats();
|
||||
root_->show_children.clear();
|
||||
if (timeline && timeline->step() < 0) {
|
||||
// TODO(xpan): Maybe pick a default step for users.
|
||||
fprintf(stderr,
|
||||
"Must specify -step option to generate timeline in graph view.\n");
|
||||
return root_;
|
||||
}
|
||||
// 1. Account and aggregate the stats based on the graph structure.
|
||||
// Returns a graph consists of accounted nodes.
|
||||
std::set<string> visits;
|
||||
std::vector<GraphNode*> roots = Account(root_->children, opts, &visits);
|
||||
root_->ResetTotalStats();
|
||||
root_->show_children.clear();
|
||||
std::vector<GraphNode*> roots =
|
||||
Account(root_->children, opts, timeline, &visits);
|
||||
for (GraphNode* n : roots) {
|
||||
root_->AggregateTotalStats(n);
|
||||
}
|
||||
@ -98,7 +105,7 @@ const ShowNode* TFGraph::ShowInternal(const Options& opts, Timeline* timeline) {
|
||||
Format(root->show_children, &root->formatted_str, root->mutable_proto());
|
||||
|
||||
if (timeline) {
|
||||
timeline->GenerateGraphTimeline(root, memory_tracker_);
|
||||
timeline->GenerateGraphTimeline(root);
|
||||
}
|
||||
return root;
|
||||
}
|
||||
@ -201,26 +208,28 @@ std::vector<GraphNode*> TFGraph::PrintGraph(const std::vector<GraphNode*> roots,
|
||||
|
||||
std::vector<GraphNode*> TFGraph::Account(const std::vector<GraphNode*>& roots,
|
||||
const Options& opts,
|
||||
Timeline* timeline,
|
||||
std::set<string>* visits) {
|
||||
std::vector<GraphNode*> act_nodes;
|
||||
for (GraphNode* node : roots) {
|
||||
if (visits->find(node->name()) != visits->end()) continue;
|
||||
visits->insert(node->name());
|
||||
// Depth-first.
|
||||
std::vector<GraphNode*> act_cnodes = Account(node->children, opts, visits);
|
||||
std::vector<GraphNode*> act_cnodes =
|
||||
Account(node->children, opts, timeline, visits);
|
||||
|
||||
node->account = ShouldAccount(node, opts);
|
||||
node->account = ReAccount(node, opts);
|
||||
if (node->account) {
|
||||
node->show_children.clear();
|
||||
node->ResetTotalStats();
|
||||
node->AddSelfToTotalStats();
|
||||
if (node->trackable) {
|
||||
memory_tracker_.TrackNode(node);
|
||||
if (timeline) {
|
||||
timeline->TrackNode(node);
|
||||
}
|
||||
// Aggregate its accounted children stats.
|
||||
for (GraphNode* c : act_cnodes) {
|
||||
if (node->trackable && c->trackable) {
|
||||
memory_tracker_.TrackNodeConnection(node, c);
|
||||
if (timeline) {
|
||||
timeline->TrackNodeConnection(node, c);
|
||||
}
|
||||
node->AggregateTotalStats(c);
|
||||
node->show_children.push_back(c);
|
||||
|
@ -70,7 +70,7 @@ class TFGraph : public TFShow {
|
||||
int last_ident, std::set<string>* visits);
|
||||
|
||||
std::vector<GraphNode*> Account(const std::vector<GraphNode*>& roots,
|
||||
const Options& opts,
|
||||
const Options& opts, Timeline* timeline,
|
||||
std::set<string>* visits);
|
||||
|
||||
void Format(const std::vector<GraphNode*> roots, string* display_str,
|
||||
|
@ -27,86 +27,31 @@ namespace tfprof {
|
||||
// For CPU, op_end_rel is the kernel time, while all_end_rel_micros includes
|
||||
// some post-processing.
|
||||
// Here, we only consider kernel time for simplicity.
|
||||
void TFGraphNode::AddStepStat(const string& device,
|
||||
const NodeExecStats* step_stat) {
|
||||
step_stat_ = step_stat;
|
||||
CHECK(step_stat_);
|
||||
|
||||
void TFGraphNode::AddStepStat(int64 step, const string& device,
|
||||
const NodeExecStats& step_stat) {
|
||||
string dev = str_util::Lowercase(device);
|
||||
|
||||
// TODO(xpan): Test it.
|
||||
if (RE2::FullMatch(dev, "/job:.*/replica:\\d+/task:\\d+/[a-z]+:\\d+")) {
|
||||
canonical_device_ = dev;
|
||||
// TODO(xpan): Support things other than gpu?
|
||||
host_device_ = StringReplace(dev, "gpu:\\d+", "cpu:0");
|
||||
AddOpType(canonical_device_);
|
||||
}
|
||||
|
||||
devices_.insert(dev);
|
||||
if (step_stat_->all_start_micros() > 0) {
|
||||
if (all_start_micros_ > 0) {
|
||||
all_start_micros_ =
|
||||
std::min(all_start_micros_,
|
||||
static_cast<int64>(step_stat_->all_start_micros()));
|
||||
if (!canonical_device_.empty()) {
|
||||
if (canonical_device_ != dev) {
|
||||
fprintf(stderr, "Unexpected: graph node changed device: %s->%s.\n",
|
||||
canonical_device_.c_str(), dev.c_str());
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
all_start_micros_ = step_stat_->all_start_micros();
|
||||
}
|
||||
int64 op_end_rel_micros = step_stat_->op_end_rel_micros();
|
||||
// Round quick execution to 1 micro to be semantically robust.
|
||||
if (op_end_rel_micros == 0) {
|
||||
++op_end_rel_micros;
|
||||
}
|
||||
latest_end_rel_micros_ =
|
||||
std::max(latest_end_rel_micros_, op_end_rel_micros);
|
||||
|
||||
op_execs_[dev].push_back(
|
||||
std::make_pair(step_stat_->all_start_micros(), op_end_rel_micros));
|
||||
|
||||
if (dev.find("stream") != dev.npos && dev.find("stream:all") == dev.npos) {
|
||||
gpu_kernel_execs_[dev].push_back(
|
||||
std::make_pair(step_stat_->all_start_micros(), op_end_rel_micros));
|
||||
canonical_device_ = dev;
|
||||
// TODO(xpan): Support things other than gpu?
|
||||
host_device_ = StringReplace(dev, "gpu:\\d+", "cpu:0");
|
||||
AddOpType(canonical_device_);
|
||||
}
|
||||
}
|
||||
|
||||
ExecStep& exec = execs_[step];
|
||||
exec.AddTimeStats(dev, step_stat);
|
||||
|
||||
if (dev == canonical_device_) {
|
||||
for (const auto& mem : step_stat_->memory()) {
|
||||
// TODO(xpan): Fix this hack. Currently the allocator name seems quite
|
||||
// ad-hoc.
|
||||
if (mem.allocator_name().find("GPU") == mem.allocator_name().npos) {
|
||||
continue;
|
||||
}
|
||||
if (dev == canonical_device_) {
|
||||
allocator_bytes_in_use_ =
|
||||
std::max(allocator_bytes_in_use_,
|
||||
static_cast<int64>(mem.allocator_bytes_in_use()));
|
||||
}
|
||||
}
|
||||
int64 total_output_bytes = 0;
|
||||
for (const auto& output : step_stat_->output()) {
|
||||
if (output.has_tensor_description() &&
|
||||
output.tensor_description().has_allocation_description()) {
|
||||
// TODO(xpan): Maybe allocated_bytes.
|
||||
int64 output_bytes = std::max(output.tensor_description()
|
||||
.allocation_description()
|
||||
.allocated_bytes(),
|
||||
output.tensor_description()
|
||||
.allocation_description()
|
||||
.requested_bytes());
|
||||
uint64 output_ptr =
|
||||
output.tensor_description().allocation_description().ptr();
|
||||
total_output_bytes += output_bytes;
|
||||
output_bytes_[output.slot()] = std::make_pair(output_bytes, output_ptr);
|
||||
}
|
||||
}
|
||||
if (step_stat_->has_memory_stats()) {
|
||||
host_temp_bytes_ += step_stat_->memory_stats().host_temp_memory_size();
|
||||
host_persistent_bytes_ +=
|
||||
step_stat_->memory_stats().host_persistent_memory_size();
|
||||
accelerator_temp_bytes_ +=
|
||||
step_stat_->memory_stats().device_temp_memory_size();
|
||||
accelerator_persistent_bytes_ +=
|
||||
step_stat_->memory_stats().device_persistent_memory_size();
|
||||
}
|
||||
requested_bytes_ = total_output_bytes;
|
||||
exec.AddMemoryStats(dev, step_stat);
|
||||
}
|
||||
}
|
||||
} // namespace tfprof
|
||||
|
@ -37,22 +37,162 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace tfprof {
|
||||
|
||||
class TFGraphNode {
|
||||
class ExecStep {
|
||||
public:
|
||||
TFGraphNode(const NodeDef* node)
|
||||
: node_(node),
|
||||
code_(nullptr),
|
||||
step_stat_(nullptr),
|
||||
all_start_micros_(0),
|
||||
ExecStep()
|
||||
: all_start_micros_(0),
|
||||
latest_end_rel_micros_(0),
|
||||
mem_initiated_(false),
|
||||
requested_bytes_(0),
|
||||
host_temp_bytes_(0),
|
||||
host_persistent_bytes_(0),
|
||||
accelerator_temp_bytes_(0),
|
||||
accelerator_persistent_bytes_(0),
|
||||
allocator_bytes_in_use_(0),
|
||||
float_ops_(0),
|
||||
op_(node->op()) {
|
||||
allocator_bytes_in_use_(0) {}
|
||||
|
||||
void AddTimeStats(const string& dev, const NodeExecStats& step_stat) {
|
||||
devices_.insert(dev);
|
||||
if (step_stat.all_start_micros() > 0) {
|
||||
if (all_start_micros_ > 0) {
|
||||
all_start_micros_ =
|
||||
std::min(all_start_micros_,
|
||||
static_cast<int64>(step_stat.all_start_micros()));
|
||||
} else {
|
||||
all_start_micros_ = step_stat.all_start_micros();
|
||||
}
|
||||
int64 op_end_rel_micros = step_stat.op_end_rel_micros();
|
||||
// Round quick execution to 1 micro to be semantically robust.
|
||||
if (op_end_rel_micros == 0) {
|
||||
++op_end_rel_micros;
|
||||
}
|
||||
latest_end_rel_micros_ =
|
||||
std::max(latest_end_rel_micros_, op_end_rel_micros);
|
||||
|
||||
op_execs_[dev].push_back(
|
||||
std::make_pair(step_stat.all_start_micros(), op_end_rel_micros));
|
||||
|
||||
if (dev.find("stream") != dev.npos &&
|
||||
dev.find("stream:all") == dev.npos) {
|
||||
gpu_kernel_execs_[dev].push_back(
|
||||
std::make_pair(step_stat.all_start_micros(), op_end_rel_micros));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AddMemoryStats(const string& dev, const NodeExecStats& step_stat) {
|
||||
if (mem_initiated_) {
|
||||
// fprintf(stderr, "Memory initiated twice on %s", dev.c_str());
|
||||
return;
|
||||
}
|
||||
mem_initiated_ = true;
|
||||
|
||||
for (const auto& mem : step_stat.memory()) {
|
||||
// TODO(xpan): Fix this hack. Currently the allocator name seems quite
|
||||
// ad-hoc.
|
||||
if (mem.allocator_name().find("GPU") == mem.allocator_name().npos) {
|
||||
continue;
|
||||
}
|
||||
allocator_bytes_in_use_ =
|
||||
std::max(allocator_bytes_in_use_,
|
||||
static_cast<int64>(mem.allocator_bytes_in_use()));
|
||||
}
|
||||
int64 total_output_bytes = 0;
|
||||
for (const auto& output : step_stat.output()) {
|
||||
if (output.has_tensor_description() &&
|
||||
output.tensor_description().has_allocation_description()) {
|
||||
// TODO(xpan): Maybe allocated_bytes.
|
||||
int64 output_bytes = std::max(output.tensor_description()
|
||||
.allocation_description()
|
||||
.allocated_bytes(),
|
||||
output.tensor_description()
|
||||
.allocation_description()
|
||||
.requested_bytes());
|
||||
uint64 output_ptr =
|
||||
output.tensor_description().allocation_description().ptr();
|
||||
total_output_bytes += output_bytes;
|
||||
output_bytes_[output.slot()] = std::make_pair(output_bytes, output_ptr);
|
||||
}
|
||||
}
|
||||
if (step_stat.has_memory_stats()) {
|
||||
host_temp_bytes_ += step_stat.memory_stats().host_temp_memory_size();
|
||||
host_persistent_bytes_ +=
|
||||
step_stat.memory_stats().host_persistent_memory_size();
|
||||
accelerator_temp_bytes_ +=
|
||||
step_stat.memory_stats().device_temp_memory_size();
|
||||
accelerator_persistent_bytes_ +=
|
||||
step_stat.memory_stats().device_persistent_memory_size();
|
||||
}
|
||||
requested_bytes_ = total_output_bytes;
|
||||
}
|
||||
|
||||
int64 exec_micros() const {
|
||||
int64 total = 0;
|
||||
for (const auto& execs : gpu_kernel_execs_) {
|
||||
for (const auto& exec : execs.second) {
|
||||
total += exec.second;
|
||||
}
|
||||
}
|
||||
if (total > 0) return total;
|
||||
|
||||
// If there is no gpu kernel time, fall back to assume it runs on cpu.
|
||||
// TODO(xpan): No way to track CPU async op timing accurately?
|
||||
for (const auto& execs : op_execs_) {
|
||||
for (const auto& exec : execs.second) {
|
||||
total += exec.second;
|
||||
}
|
||||
}
|
||||
return total;
|
||||
}
|
||||
|
||||
const std::map<string, std::vector<std::pair<int64, int64>>>& op_execs()
|
||||
const {
|
||||
return op_execs_;
|
||||
}
|
||||
int64 all_start_micros() const { return all_start_micros_; }
|
||||
int64 latest_end_rel_micros() const { return latest_end_rel_micros_; }
|
||||
|
||||
int64 requested_bytes() const { return requested_bytes_; }
|
||||
int64 accelerator_temp_bytes() const { return accelerator_temp_bytes_; }
|
||||
int64 host_temp_bytes() const { return host_temp_bytes_; }
|
||||
int64 accelerator_persistent_bytes() const {
|
||||
return accelerator_persistent_bytes_;
|
||||
}
|
||||
int64 host_persistent_bytes() const { return host_persistent_bytes_; }
|
||||
const std::map<int64, std::pair<int64, uint64>>& output_bytes() const {
|
||||
return output_bytes_;
|
||||
}
|
||||
int64 allocator_bytes_in_use() const { return allocator_bytes_in_use_; }
|
||||
|
||||
private:
|
||||
// The earliest/latest time including scheduling and kernel execution.
|
||||
int64 all_start_micros_;
|
||||
int64 latest_end_rel_micros_;
|
||||
// device -> vector of {op_start_micros, op_kernel_exec_micros} pairs.
|
||||
std::map<string, std::vector<std::pair<int64, int64>>> gpu_kernel_execs_;
|
||||
std::map<string, std::vector<std::pair<int64, int64>>> op_execs_;
|
||||
// All devices the op is associated with (e.g. gpu:0 (scheduling),
|
||||
// gpu:0:stream:xx (kernel exec), cpu:0 host)
|
||||
std::set<string> devices_;
|
||||
|
||||
bool mem_initiated_;
|
||||
// Total output bytes requested by the op.
|
||||
int64 requested_bytes_;
|
||||
// Total temporary bytes allocated and released by the op.
|
||||
int64 host_temp_bytes_;
|
||||
// Total persistent bytes (e.g. variable) allocated by the op.
|
||||
int64 host_persistent_bytes_;
|
||||
int64 accelerator_temp_bytes_;
|
||||
int64 accelerator_persistent_bytes_;
|
||||
// The total number of bytes currently allocated by the allocator if >0.
|
||||
int64 allocator_bytes_in_use_;
|
||||
// output_idx -> {output_bytes, memory_ptr}
|
||||
std::map<int64, std::pair<int64, uint64>> output_bytes_;
|
||||
};
|
||||
|
||||
class TFGraphNode {
|
||||
public:
|
||||
TFGraphNode(const NodeDef* node)
|
||||
: node_(node), float_ops_(0), op_(node->op()) {
|
||||
for (const auto& attr : node->attr()) {
|
||||
// TODO(xpan): Also consider _output_shapes.
|
||||
if (attr.first != "shape" || !attr.second.has_shape()) continue;
|
||||
@ -82,67 +222,123 @@ class TFGraphNode {
|
||||
|
||||
void AddOpType(const string& op_type) { op_types_.insert(op_type); }
|
||||
|
||||
void AddStepStat(const string& device, const NodeExecStats* step_stat);
|
||||
void AddStepStat(int64 step, const string& device,
|
||||
const NodeExecStats& step_stat);
|
||||
|
||||
void AddFloatOps(int64 float_ops) { float_ops_ = float_ops; }
|
||||
|
||||
void AddCode(const CodeDef* code) { code_ = code; }
|
||||
// TODO(xpan): This could take a lot of memory.
|
||||
void AddCode(const CodeDef& code) { code_.MergeFrom(code); }
|
||||
|
||||
const string& name() const { return node_->name(); }
|
||||
const string& op() const { return op_; }
|
||||
const NodeDef* node_def() { return node_; }
|
||||
|
||||
const NodeExecStats* step_stats() const { return step_stat_; }
|
||||
bool trackable(int64 step) const {
|
||||
auto exec = execs_.find(step);
|
||||
if (exec == execs_.end()) return false;
|
||||
|
||||
if (exec->second.all_start_micros() == 0) return false;
|
||||
if (canonical_device_.empty() || host_device_.empty()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
const std::map<string, TFGraphNode*>& inputs() const { return inputs_; }
|
||||
const std::map<string, int64>& output_idx() const { return output_idx_; }
|
||||
|
||||
// This is time spent in kernel execution.
|
||||
int64 kernel_exec_micros() const {
|
||||
if (!step_stat_) return 0;
|
||||
int64 total = 0;
|
||||
for (const auto& execs : gpu_kernel_execs_) {
|
||||
for (const auto& exec : execs.second) {
|
||||
total += exec.second;
|
||||
}
|
||||
int64 kernel_exec_micros(int64 step) const {
|
||||
if (execs_.empty()) {
|
||||
return 0;
|
||||
}
|
||||
if (total > 0) return total;
|
||||
|
||||
// If there is no gpu kernel time, fall back to assume it runs on cpu.
|
||||
for (const auto& execs : op_execs_) {
|
||||
for (const auto& exec : execs.second) {
|
||||
total += exec.second;
|
||||
}
|
||||
if (step >= 0) {
|
||||
auto exec = execs_.find(step);
|
||||
CHECK(exec != execs_.end());
|
||||
return exec->second.exec_micros();
|
||||
}
|
||||
return total;
|
||||
|
||||
int64 total_micros = 0;
|
||||
for (const auto& exec : execs_) {
|
||||
total_micros += exec.second.exec_micros();
|
||||
}
|
||||
return total_micros / execs_.size();
|
||||
}
|
||||
|
||||
int64 all_start_micros() const { return all_start_micros_; }
|
||||
int64 latest_end_rel_micros() const { return latest_end_rel_micros_; }
|
||||
const std::map<string, std::vector<std::pair<int64, int64>>>& op_execs()
|
||||
const {
|
||||
return op_execs_;
|
||||
int64 requested_bytes(int64 step) const {
|
||||
if (execs_.empty()) {
|
||||
return 0;
|
||||
}
|
||||
if (step >= 0) {
|
||||
auto exec = execs_.find(step);
|
||||
CHECK(exec != execs_.end()) << "unknown step " << step;
|
||||
return exec->second.requested_bytes();
|
||||
}
|
||||
|
||||
int64 requested_bytes = 0;
|
||||
for (const auto& exec : execs_) {
|
||||
requested_bytes += exec.second.requested_bytes();
|
||||
}
|
||||
return requested_bytes / execs_.size();
|
||||
}
|
||||
|
||||
int64 requested_bytes() const { return requested_bytes_; }
|
||||
int64 accelerator_temp_bytes() const { return accelerator_temp_bytes_; }
|
||||
int64 host_temp_bytes() const { return host_temp_bytes_; }
|
||||
int64 accelerator_persistent_bytes() const {
|
||||
return accelerator_persistent_bytes_;
|
||||
int64 all_start_micros(int64 step) const {
|
||||
auto exec = execs_.find(step);
|
||||
CHECK(exec != execs_.end()) << "unknown step " << step;
|
||||
return exec->second.all_start_micros();
|
||||
}
|
||||
int64 host_persistent_bytes() const { return host_persistent_bytes_; }
|
||||
const std::map<int64, std::pair<int64, uint64>>& output_bytes() const {
|
||||
return output_bytes_;
|
||||
|
||||
int64 latest_end_rel_micros(int64 step) const {
|
||||
auto exec = execs_.find(step);
|
||||
CHECK(exec != execs_.end()) << "unknown step " << step;
|
||||
return exec->second.latest_end_rel_micros();
|
||||
}
|
||||
|
||||
const std::map<string, std::vector<std::pair<int64, int64>>>& op_execs(
|
||||
int64 step) const {
|
||||
auto exec = execs_.find(step);
|
||||
CHECK(exec != execs_.end()) << "unknown step " << step;
|
||||
return exec->second.op_execs();
|
||||
}
|
||||
|
||||
int64 accelerator_temp_bytes(int64 step) const {
|
||||
auto exec = execs_.find(step);
|
||||
CHECK(exec != execs_.end()) << "unknown step " << step;
|
||||
return exec->second.accelerator_temp_bytes();
|
||||
}
|
||||
int64 host_temp_bytes(int64 step) const {
|
||||
auto exec = execs_.find(step);
|
||||
CHECK(exec != execs_.end()) << "unknown step " << step;
|
||||
return exec->second.host_temp_bytes();
|
||||
}
|
||||
int64 accelerator_persistent_bytes(int64 step) const {
|
||||
auto exec = execs_.find(step);
|
||||
CHECK(exec != execs_.end()) << "unknown step " << step;
|
||||
return exec->second.accelerator_persistent_bytes();
|
||||
}
|
||||
int64 host_persistent_bytes(int64 step) const {
|
||||
auto exec = execs_.find(step);
|
||||
CHECK(exec != execs_.end()) << "unknown step " << step;
|
||||
return exec->second.host_persistent_bytes();
|
||||
}
|
||||
const std::map<int64, std::pair<int64, uint64>>& output_bytes(
|
||||
int64 step) const {
|
||||
auto exec = execs_.find(step);
|
||||
CHECK(exec != execs_.end()) << "unknown step " << step;
|
||||
return exec->second.output_bytes();
|
||||
}
|
||||
int64 allocator_bytes_in_use(int64 step) const {
|
||||
auto exec = execs_.find(step);
|
||||
CHECK(exec != execs_.end()) << "unknown step " << step;
|
||||
return exec->second.allocator_bytes_in_use();
|
||||
}
|
||||
int64 allocator_bytes_in_use() const { return allocator_bytes_in_use_; }
|
||||
|
||||
int64 float_ops() const { return float_ops_; }
|
||||
const CodeDef* code() { return code_; }
|
||||
const CodeDef& code() { return code_; }
|
||||
string canonical_device() const { return canonical_device_; }
|
||||
string host_device() const { return host_device_; }
|
||||
std::set<string> devices() const { return devices_; }
|
||||
const std::set<string>& op_types() const { return op_types_; }
|
||||
|
||||
const std::vector<int64>& shape() const { return shape_; }
|
||||
|
||||
private:
|
||||
@ -152,39 +348,18 @@ class TFGraphNode {
|
||||
std::map<string, int64> output_idx_;
|
||||
|
||||
const NodeDef* node_;
|
||||
const CodeDef* code_;
|
||||
const NodeExecStats* step_stat_;
|
||||
|
||||
CodeDef code_;
|
||||
|
||||
std::vector<int64> shape_;
|
||||
std::set<string> op_types_;
|
||||
|
||||
// The earliest/latest time including scheduling and kernel execution.
|
||||
int64 all_start_micros_;
|
||||
int64 latest_end_rel_micros_;
|
||||
// device -> vector of {op_start_micros, op_kernel_exec_micros} pairs.
|
||||
std::map<string, std::vector<std::pair<int64, int64>>> gpu_kernel_execs_;
|
||||
std::map<string, std::vector<std::pair<int64, int64>>> op_execs_;
|
||||
std::map<int64, ExecStep> execs_;
|
||||
|
||||
// /j:#/t:#/r:#/device:#. A canonical device name without extra suffix.
|
||||
string canonical_device_;
|
||||
// The host device name.
|
||||
string host_device_;
|
||||
// All devices the op is associated with (e.g. gpu:0 (scheduling),
|
||||
// gpu:0:stream:xx (kernel exec), cpu:0 host)
|
||||
std::set<string> devices_;
|
||||
|
||||
// Total output bytes requested by the op.
|
||||
int64 requested_bytes_;
|
||||
// Total temporary bytes allocated and released by the op.
|
||||
int64 host_temp_bytes_;
|
||||
// Total persistent bytes (e.g. variable) allocated by the op.
|
||||
int64 host_persistent_bytes_;
|
||||
int64 accelerator_temp_bytes_;
|
||||
int64 accelerator_persistent_bytes_;
|
||||
// The total number of bytes currently allocated by the allocator if >0.
|
||||
int64 allocator_bytes_in_use_;
|
||||
// output_idx -> {output_bytes, memory_ptr}
|
||||
std::map<int64, std::pair<int64, uint64>> output_bytes_;
|
||||
|
||||
int64 float_ops_;
|
||||
|
||||
@ -199,7 +374,7 @@ class TFMultiGraphNode {
|
||||
requested_bytes_(0),
|
||||
float_ops_(0) {}
|
||||
|
||||
bool SnapshotNodes(const std::vector<string>& type_regexes) {
|
||||
bool SnapshotNodes(int64 step, const std::vector<string>& type_regexes) {
|
||||
kernel_exec_micros_ = 0;
|
||||
requested_bytes_ = 0;
|
||||
float_ops_ = 0;
|
||||
@ -208,30 +383,23 @@ class TFMultiGraphNode {
|
||||
devices_.clear();
|
||||
snapshot_nodes_.clear();
|
||||
|
||||
std::map<string, std::vector<const TFGraphNode*>> nodes =
|
||||
pick_nodes(type_regexes);
|
||||
std::vector<const TFGraphNode*> nodes = pick_nodes(type_regexes);
|
||||
|
||||
if (nodes.empty()) {
|
||||
return (type_regexes.size() == 1 && type_regexes[0] == ".*");
|
||||
}
|
||||
|
||||
std::set<string> visits;
|
||||
for (const auto& entry : nodes) {
|
||||
op_types_.insert(entry.first);
|
||||
for (const TFGraphNode* node : nodes) {
|
||||
op_types_.insert(node->op_types().begin(), node->op_types().end());
|
||||
|
||||
for (const TFGraphNode* node : entry.second) {
|
||||
if (visits.find(node->name()) != visits.end()) continue;
|
||||
visits.insert(node->name());
|
||||
|
||||
kernel_exec_micros_ += node->kernel_exec_micros();
|
||||
requested_bytes_ += node->requested_bytes();
|
||||
float_ops_ += node->float_ops();
|
||||
if (node->shape().size() > 0) {
|
||||
shapes_.push_back(node->shape());
|
||||
}
|
||||
devices_.insert(node->canonical_device());
|
||||
snapshot_nodes_[node->name()] = node;
|
||||
kernel_exec_micros_ += node->kernel_exec_micros(step);
|
||||
requested_bytes_ += node->requested_bytes(step);
|
||||
float_ops_ += node->float_ops();
|
||||
if (node->shape().size() > 0) {
|
||||
shapes_.push_back(node->shape());
|
||||
}
|
||||
devices_.insert(node->canonical_device());
|
||||
snapshot_nodes_[node->name()] = node;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
@ -241,9 +409,6 @@ class TFMultiGraphNode {
|
||||
return;
|
||||
}
|
||||
nodes_[node->name()] = node;
|
||||
for (const string& type : node->op_types()) {
|
||||
nodes_by_type_[type].push_back(node);
|
||||
}
|
||||
}
|
||||
|
||||
const std::map<string, const TFGraphNode*>& graph_nodes() const {
|
||||
@ -275,19 +440,26 @@ class TFMultiGraphNode {
|
||||
const std::vector<std::vector<int64>>& shapes() const { return shapes_; }
|
||||
|
||||
private:
|
||||
std::map<string, std::vector<const TFGraphNode*>> pick_nodes(
|
||||
std::vector<const TFGraphNode*> pick_nodes(
|
||||
const std::vector<string>& type_regexes) {
|
||||
if (type_regexes.empty()) {
|
||||
return {};
|
||||
}
|
||||
std::vector<const TFGraphNode*> ret;
|
||||
if (type_regexes.size() == 1 && type_regexes[0] == ".*") {
|
||||
return nodes_by_type_;
|
||||
for (const auto& n : nodes_) {
|
||||
ret.push_back(n.second);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
std::map<string, std::vector<const TFGraphNode*>> ret;
|
||||
|
||||
for (const string& regex : type_regexes) {
|
||||
for (const auto& n : nodes_by_type_) {
|
||||
if (RE2::FullMatch(n.first, regex)) {
|
||||
ret[n.first] = n.second;
|
||||
for (const auto& n : nodes_) {
|
||||
for (const string& type : n.second->op_types()) {
|
||||
if (RE2::FullMatch(type, regex)) {
|
||||
ret.push_back(n.second);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -295,7 +467,7 @@ class TFMultiGraphNode {
|
||||
}
|
||||
|
||||
const string name_;
|
||||
// Snapshot micros based on type_regexes
|
||||
// Snapshot based on type_regexes
|
||||
std::set<string> op_types_;
|
||||
int64 kernel_exec_micros_;
|
||||
int64 requested_bytes_;
|
||||
@ -306,7 +478,6 @@ class TFMultiGraphNode {
|
||||
|
||||
// Overall data held by the TFMultiGraphNode.
|
||||
std::map<string, const TFGraphNode*> nodes_;
|
||||
std::map<string, std::vector<const TFGraphNode*>> nodes_by_type_;
|
||||
std::map<string, std::unique_ptr<TFMultiGraphNode>> children_;
|
||||
};
|
||||
} // namespace tfprof
|
||||
|
@ -22,18 +22,20 @@ namespace tfprof {
|
||||
namespace {}
|
||||
|
||||
ShowNode::ShowNode(const TFGraphNode* node) : node(node), account(false) {
|
||||
ReInit();
|
||||
ReInit(-1);
|
||||
}
|
||||
|
||||
void ShowNode::ReInit() {
|
||||
void ShowNode::ReInit(int64 step) {
|
||||
mutable_proto()->set_name(name());
|
||||
mutable_proto()->clear_devices();
|
||||
if (!node->canonical_device().empty()) {
|
||||
mutable_proto()->add_devices(node->canonical_device());
|
||||
}
|
||||
mutable_proto()->set_exec_micros(node->kernel_exec_micros());
|
||||
mutable_proto()->set_requested_bytes(node->requested_bytes());
|
||||
mutable_proto()->set_exec_micros(node->kernel_exec_micros(step));
|
||||
mutable_proto()->set_requested_bytes(node->requested_bytes(step));
|
||||
mutable_proto()->set_float_ops(node->float_ops());
|
||||
|
||||
proto_.clear_parameters();
|
||||
if (!node->shape().empty()) {
|
||||
int64 params = 1;
|
||||
bool complete_shape = true;
|
||||
@ -90,17 +92,19 @@ void ShowNode::ResetTotalStats() {
|
||||
|
||||
ShowMultiNode::ShowMultiNode(TFMultiGraphNode* node)
|
||||
: node(node), account(false), show(false) {
|
||||
ReInit({".*"});
|
||||
ReInit(-1, {".*"});
|
||||
}
|
||||
|
||||
bool ShowMultiNode::ReInit(const std::vector<string>& type_regexes) {
|
||||
bool has_matched_type = node->SnapshotNodes(type_regexes);
|
||||
bool ShowMultiNode::ReInit(int64 step,
|
||||
const std::vector<string>& type_regexes) {
|
||||
bool has_matched_type = node->SnapshotNodes(step, type_regexes);
|
||||
|
||||
std::vector<ShowNode> snodes;
|
||||
mutable_proto()->mutable_graph_nodes()->Clear();
|
||||
for (auto it : node->graph_nodes()) {
|
||||
ShowNode snode(it.second);
|
||||
snodes.push_back(snode);
|
||||
snodes.back().ReInit(step);
|
||||
snodes.back().AddSelfToTotalStats();
|
||||
mutable_proto()->add_graph_nodes()->MergeFrom(snodes.back().proto());
|
||||
}
|
||||
@ -110,7 +114,7 @@ bool ShowMultiNode::ReInit(const std::vector<string>& type_regexes) {
|
||||
mutable_proto()->set_requested_bytes(node->requested_bytes());
|
||||
mutable_proto()->set_float_ops(node->float_ops());
|
||||
|
||||
mutable_proto()->set_parameters(0);
|
||||
mutable_proto()->clear_parameters();
|
||||
if (!node->shapes().empty()) {
|
||||
for (const std::vector<int64>& shape : node->shapes()) {
|
||||
int64 params = 1;
|
||||
|
@ -48,7 +48,7 @@ class ShowNode {
|
||||
TFGraphNodeProto* mutable_proto();
|
||||
const TFGraphNodeProto& proto() const;
|
||||
|
||||
void ReInit();
|
||||
void ReInit(int64 step);
|
||||
|
||||
void AggregateTotalStats(ShowNode* node);
|
||||
|
||||
@ -66,24 +66,10 @@ class ShowNode {
|
||||
|
||||
class GraphNode : public ShowNode {
|
||||
public:
|
||||
explicit GraphNode(TFGraphNode* node) : ShowNode(node) {
|
||||
trackable = Trackable();
|
||||
}
|
||||
explicit GraphNode(TFGraphNode* node) : ShowNode(node) {}
|
||||
|
||||
void ReInit() {
|
||||
ShowNode::ReInit();
|
||||
}
|
||||
bool Trackable(int64 step) { return node->trackable(step); }
|
||||
|
||||
bool Trackable() {
|
||||
if (!node->step_stats()) return false;
|
||||
if (node->all_start_micros() == 0) return false;
|
||||
if (node->canonical_device().empty() || node->host_device().empty()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool trackable;
|
||||
std::vector<GraphNode*> children;
|
||||
std::vector<GraphNode*> show_children;
|
||||
};
|
||||
@ -102,7 +88,7 @@ class ShowMultiNode {
|
||||
explicit ShowMultiNode(TFMultiGraphNode* node);
|
||||
virtual ~ShowMultiNode() {}
|
||||
|
||||
bool ReInit(const std::vector<string>& type_regexes);
|
||||
bool ReInit(int64 step, const std::vector<string>& type_regexes);
|
||||
|
||||
const string& name() const { return node->name(); }
|
||||
TFMultiGraphNodeProto* mutable_proto();
|
||||
|
@ -146,7 +146,7 @@ tensorflow::Status Options::FromProtoStr(const string& opts_proto_str,
|
||||
*opts = Options(
|
||||
opts_pb.max_depth(), opts_pb.min_bytes(), opts_pb.min_micros(),
|
||||
opts_pb.min_params(), opts_pb.min_float_ops(), opts_pb.min_occurrence(),
|
||||
opts_pb.order_by(),
|
||||
opts_pb.step(), opts_pb.order_by(),
|
||||
std::vector<string>(opts_pb.account_type_regexes().begin(),
|
||||
opts_pb.account_type_regexes().end()),
|
||||
std::vector<string>(opts_pb.start_name_regexes().begin(),
|
||||
@ -171,6 +171,7 @@ string Options::ToString() const {
|
||||
"%-28s%lld\n"
|
||||
"%-28s%lld\n"
|
||||
"%-28s%lld\n"
|
||||
"%-28s%lld\n"
|
||||
"%-28s%s\n"
|
||||
"%-28s%s\n"
|
||||
"%-28s%s\n"
|
||||
@ -182,15 +183,15 @@ string Options::ToString() const {
|
||||
"%-28s%s:%s\n",
|
||||
kOptions[0], max_depth, kOptions[1], min_bytes, kOptions[2], min_micros,
|
||||
kOptions[3], min_params, kOptions[4], min_float_ops, kOptions[5],
|
||||
min_occurrence, kOptions[6], order_by.c_str(), kOptions[7],
|
||||
str_util::Join(account_type_regexes, ",").c_str(), kOptions[8],
|
||||
str_util::Join(start_name_regexes, ",").c_str(), kOptions[9],
|
||||
str_util::Join(trim_name_regexes, ",").c_str(), kOptions[10],
|
||||
str_util::Join(show_name_regexes, ",").c_str(), kOptions[11],
|
||||
str_util::Join(hide_name_regexes, ",").c_str(), kOptions[12],
|
||||
(account_displayed_op_only ? "true" : "false"), kOptions[13],
|
||||
str_util::Join(select, ",").c_str(), kOptions[14], output_type.c_str(),
|
||||
KeyValueToStr(output_options).c_str());
|
||||
min_occurrence, kOptions[6], step, kOptions[7], order_by.c_str(),
|
||||
kOptions[8], str_util::Join(account_type_regexes, ",").c_str(),
|
||||
kOptions[9], str_util::Join(start_name_regexes, ",").c_str(),
|
||||
kOptions[10], str_util::Join(trim_name_regexes, ",").c_str(),
|
||||
kOptions[11], str_util::Join(show_name_regexes, ",").c_str(),
|
||||
kOptions[12], str_util::Join(hide_name_regexes, ",").c_str(),
|
||||
kOptions[13], (account_displayed_op_only ? "true" : "false"),
|
||||
kOptions[14], str_util::Join(select, ",").c_str(), kOptions[15],
|
||||
output_type.c_str(), KeyValueToStr(output_options).c_str());
|
||||
return s;
|
||||
}
|
||||
|
||||
|
@ -33,6 +33,7 @@ static const char* const kOptions[] = {
|
||||
"-min_params",
|
||||
"-min_float_ops",
|
||||
"-min_occurrence",
|
||||
"-step",
|
||||
"-order_by",
|
||||
"-account_type_regexes",
|
||||
"-start_name_regexes",
|
||||
@ -81,12 +82,13 @@ struct Options {
|
||||
|
||||
virtual ~Options() {}
|
||||
Options()
|
||||
: Options(0, 0, 0, 0, 0, 0, "", {}, {}, {}, {}, {}, false, {}, "", {}) {}
|
||||
: Options(0, 0, 0, 0, 0, 0, 0, "", {}, {}, {}, {}, {}, false, {}, "",
|
||||
{}) {}
|
||||
|
||||
Options(int max_depth, tensorflow::int64 min_bytes,
|
||||
tensorflow::int64 min_micros, tensorflow::int64 min_params,
|
||||
tensorflow::int64 min_float_ops, tensorflow::int64 min_occurrence,
|
||||
const string& order_by,
|
||||
tensorflow::int64 step, const string& order_by,
|
||||
const std::vector<string>& account_type_regexes,
|
||||
const std::vector<string>& start_name_regexes,
|
||||
const std::vector<string>& trim_name_regexes,
|
||||
@ -101,6 +103,7 @@ struct Options {
|
||||
min_params(min_params),
|
||||
min_float_ops(min_float_ops),
|
||||
min_occurrence(min_occurrence),
|
||||
step(step),
|
||||
order_by(order_by),
|
||||
account_type_regexes(account_type_regexes),
|
||||
start_name_regexes(start_name_regexes),
|
||||
@ -120,6 +123,7 @@ struct Options {
|
||||
tensorflow::int64 min_params;
|
||||
tensorflow::int64 min_float_ops;
|
||||
tensorflow::int64 min_occurrence;
|
||||
tensorflow::int64 step;
|
||||
string order_by;
|
||||
|
||||
std::vector<string> account_type_regexes;
|
||||
|
@ -196,7 +196,7 @@ std::vector<ScopeNode*> TFScope::Account(const std::vector<ScopeNode*>& roots,
|
||||
node->ResetTotalStats();
|
||||
std::vector<ScopeNode*> act_cnodes = Account(node->children, opts);
|
||||
|
||||
node->account = ShouldAccount(node, opts);
|
||||
node->account = ReAccount(node, opts);
|
||||
if (node->account || !act_cnodes.empty()) {
|
||||
node->show_children.clear();
|
||||
node->ResetTotalStats();
|
||||
|
@ -27,7 +27,7 @@ namespace tfprof {
|
||||
|
||||
const TFGraphNodeProto& TFShow::Show(const Options& opts) {
|
||||
if (opts.output_type == kOutput[0]) {
|
||||
Timeline timeline(opts.output_options.at(kTimelineOpts[0]));
|
||||
Timeline timeline(opts.step, opts.output_options.at(kTimelineOpts[0]));
|
||||
return ShowInternal(opts, &timeline)->proto();
|
||||
} else if (opts.output_type == kOutput[2]) {
|
||||
const ShowNode* root = ShowInternal(opts, nullptr);
|
||||
@ -105,7 +105,8 @@ bool TFShow::ShouldTrim(ShowNode* node, const std::vector<string>& regexes) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool TFShow::ShouldAccount(ShowNode* node, const Options& opts) {
|
||||
bool TFShow::ReAccount(ShowNode* node, const Options& opts) {
|
||||
node->ReInit(opts.step);
|
||||
if (opts.account_type_regexes.size() == 1 &&
|
||||
opts.account_type_regexes[0] == ".*") {
|
||||
return true;
|
||||
|
@ -63,7 +63,7 @@ class TFShow {
|
||||
|
||||
bool ShouldTrim(ShowNode* node, const std::vector<string>& regexes);
|
||||
|
||||
bool ShouldAccount(ShowNode* node, const Options& opts);
|
||||
bool ReAccount(ShowNode* node, const Options& opts);
|
||||
|
||||
string FormatNode(ShowNode* node, const Options& opts);
|
||||
|
||||
|
@ -29,7 +29,7 @@ namespace tfprof {
|
||||
|
||||
const TFMultiGraphNodeProto& TFMultiShow::Show(const Options& opts) {
|
||||
if (opts.output_type == kOutput[0]) {
|
||||
Timeline timeline(opts.output_options.at(kTimelineOpts[0]));
|
||||
Timeline timeline(opts.step, opts.output_options.at(kTimelineOpts[0]));
|
||||
return ShowInternal(opts, &timeline)->proto();
|
||||
} else if (opts.output_type == kOutput[2]) {
|
||||
const ShowMultiNode* root = ShowInternal(opts, nullptr);
|
||||
@ -99,7 +99,7 @@ bool TFMultiShow::ShouldTrim(ShowMultiNode* node,
|
||||
}
|
||||
|
||||
bool TFMultiShow::ReAccount(ShowMultiNode* node, const Options& opts) {
|
||||
return node->ReInit(opts.account_type_regexes);
|
||||
return node->ReInit(opts.step, opts.account_type_regexes);
|
||||
}
|
||||
|
||||
string TFMultiShow::FormatLegend(const Options& opts) {
|
||||
|
@ -71,7 +71,8 @@ class TFProfShowTest : public ::testing::Test {
|
||||
|
||||
TEST_F(TFProfShowTest, DumpScopeMode) {
|
||||
string dump_file = io::JoinPath(testing::TmpDir(), "dump");
|
||||
Options opts(5, 0, 0, 0, 0, 0, "name", {"VariableV2"}, // accout_type_regexes
|
||||
Options opts(5, 0, 0, 0, 0, 0, -1, "name",
|
||||
{"VariableV2"}, // accout_type_regexes
|
||||
{".*"}, {""}, {".*"}, {""}, false,
|
||||
{"params", "bytes", "micros", "float_ops"}, "file",
|
||||
{{"outfile", dump_file}});
|
||||
@ -93,7 +94,7 @@ TEST_F(TFProfShowTest, DumpScopeMode) {
|
||||
|
||||
TEST_F(TFProfShowTest, DumpOpMode) {
|
||||
string dump_file = io::JoinPath(testing::TmpDir(), "dump");
|
||||
Options opts(5, 0, 0, 0, 0, 4, "params", {".*"}, // accout_type_regexes
|
||||
Options opts(5, 0, 0, 0, 0, 4, -1, "params", {".*"}, // accout_type_regexes
|
||||
{".*"}, {""}, {".*"}, {""}, false,
|
||||
{"params", "bytes", "micros", "float_ops", "occurrence"}, "file",
|
||||
{{"outfile", dump_file}});
|
||||
|
@ -30,24 +30,17 @@ TFStats::TFStats(std::unique_ptr<GraphDef> graph,
|
||||
std::unique_ptr<OpLog> op_log,
|
||||
std::unique_ptr<checkpoint::CheckpointReader> ckpt_reader)
|
||||
: graph_(std::move(graph)),
|
||||
run_meta_(std::move(run_meta)),
|
||||
op_log_(std::move(op_log)),
|
||||
ckpt_reader_(std::move(ckpt_reader)) {
|
||||
CHECK(graph_) << "Must at least have GraphDef";
|
||||
|
||||
printf("Parsing GraphDef...\n");
|
||||
printf("Parsing Inputs...\n");
|
||||
ParseGraph();
|
||||
if (run_meta_) {
|
||||
printf("Parsing RunMetadata...\n");
|
||||
ParseRunMeta();
|
||||
}
|
||||
if (op_log_) {
|
||||
printf("Parsing OpLog...\n");
|
||||
ParseOpLog();
|
||||
if (run_meta && run_meta->has_step_stats()) {
|
||||
ParseRunMeta(0, std::move(run_meta));
|
||||
}
|
||||
ParseOpLog(std::move(op_log));
|
||||
|
||||
if (ckpt_reader_) {
|
||||
printf("Parsing Checkpoint...\n");
|
||||
for (const auto& v : ckpt_reader_->GetVariableToShapeMap()) {
|
||||
auto node = nodes_map_.find(v.first);
|
||||
if (node != nodes_map_.end()) {
|
||||
@ -76,6 +69,9 @@ TFStats::TFStats(std::unique_ptr<GraphDef> graph,
|
||||
|
||||
const TFGraphNodeProto& TFStats::ShowGraphNode(const string& cmd,
|
||||
const Options& opts) {
|
||||
if (!Validate(opts)) {
|
||||
return empty_graph_node_;
|
||||
}
|
||||
if (cmd == kCmds[0]) {
|
||||
return scope_view_->Show(opts);
|
||||
} else if (cmd == kCmds[1]) {
|
||||
@ -88,6 +84,9 @@ const TFGraphNodeProto& TFStats::ShowGraphNode(const string& cmd,
|
||||
|
||||
const TFMultiGraphNodeProto& TFStats::ShowMultiGraphNode(const string& cmd,
|
||||
const Options& opts) {
|
||||
if (!Validate(opts)) {
|
||||
return empty_multi_graph_node_;
|
||||
}
|
||||
if (cmd == kCmds[2]) {
|
||||
return code_view_->Show(opts);
|
||||
} else if (cmd == kCmds[3]) {
|
||||
@ -130,8 +129,11 @@ void TFStats::ParseGraph() {
|
||||
}
|
||||
}
|
||||
|
||||
void TFStats::ParseOpLog() {
|
||||
for (const OpLogEntry& entry : op_log_->log_entries()) {
|
||||
void TFStats::ParseOpLog(std::unique_ptr<OpLog> op_log) {
|
||||
if (!op_log) {
|
||||
return;
|
||||
}
|
||||
for (const OpLogEntry& entry : op_log->log_entries()) {
|
||||
auto node = nodes_map_.find(entry.name());
|
||||
if (node == nodes_map_.end()) continue;
|
||||
for (const string& type : entry.types()) {
|
||||
@ -141,16 +143,24 @@ void TFStats::ParseOpLog() {
|
||||
node->second->AddFloatOps(entry.float_ops());
|
||||
}
|
||||
if (entry.has_code_def()) {
|
||||
node->second->AddCode(&entry.code_def());
|
||||
node->second->AddCode(entry.code_def());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TFStats::ParseRunMeta() {
|
||||
if (!run_meta_->has_step_stats()) return;
|
||||
void TFStats::ParseRunMeta(int64 step, std::unique_ptr<RunMetadata> run_meta) {
|
||||
if (!run_meta || !run_meta->has_step_stats()) {
|
||||
fprintf(stderr, "Invalid RunMetadata for step %lld\n", step);
|
||||
return;
|
||||
}
|
||||
if (steps_.find(step) != steps_.end()) {
|
||||
fprintf(stderr, "The same step %lld has been added before.\n", step);
|
||||
return;
|
||||
}
|
||||
steps_.insert(step);
|
||||
|
||||
for (const auto& dev_stat : run_meta_->step_stats().dev_stats()) {
|
||||
for (const auto& node_stat : dev_stat.node_stats()) {
|
||||
for (const auto& dev_stat : run_meta->step_stats().dev_stats()) {
|
||||
for (const NodeExecStats& node_stat : dev_stat.node_stats()) {
|
||||
string name = node_stat.node_name();
|
||||
// Sometimes the node_name is suffixed with unnecessary information.
|
||||
auto split_pos = node_stat.node_name().find(":");
|
||||
@ -159,10 +169,18 @@ void TFStats::ParseRunMeta() {
|
||||
}
|
||||
auto node = nodes_map_.find(name);
|
||||
if (node != nodes_map_.end()) {
|
||||
node->second->AddStepStat(dev_stat.device(), &node_stat);
|
||||
node->second->AddStepStat(step, dev_stat.device(), node_stat);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool TFStats::Validate(const Options& opts) {
|
||||
if (opts.step >= 0 && steps_.find(opts.step) == steps_.end()) {
|
||||
fprintf(stderr, "Options -step=%lld not found\n", opts.step);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace tfprof
|
||||
} // namespace tensorflow
|
||||
|
@ -62,20 +62,20 @@ class TFStats {
|
||||
const TFMultiGraphNodeProto& ShowMultiGraphNode(const string& cmd,
|
||||
const Options& opts);
|
||||
|
||||
void ParseRunMeta(int64 step, std::unique_ptr<RunMetadata> run_meta);
|
||||
void ParseOpLog(std::unique_ptr<OpLog> op_log);
|
||||
|
||||
private:
|
||||
bool Validate(const Options& opts);
|
||||
|
||||
void ParseGraph();
|
||||
|
||||
void ParseOpLog();
|
||||
|
||||
void ParseRunMeta();
|
||||
|
||||
std::set<int64> steps_;
|
||||
std::unique_ptr<GraphDef> graph_;
|
||||
std::unique_ptr<TFScope> scope_view_;
|
||||
std::unique_ptr<TFGraph> graph_view_;
|
||||
std::unique_ptr<TFCode> code_view_;
|
||||
std::unique_ptr<TFOp> op_view_;
|
||||
std::unique_ptr<GraphDef> graph_;
|
||||
std::unique_ptr<RunMetadata> run_meta_;
|
||||
std::unique_ptr<OpLog> op_log_;
|
||||
std::unique_ptr<checkpoint::CheckpointReader> ckpt_reader_;
|
||||
// Store TFGraphNode instead of TFGraphNode* to avoid large number of
|
||||
// dynamic alloc.
|
||||
|
@ -71,7 +71,7 @@ class TFProfStatsTest : public ::testing::Test {
|
||||
};
|
||||
|
||||
TEST_F(TFProfStatsTest, CustomOpType) {
|
||||
Options opts(3, 0, 0, 0, 0, 0, "name",
|
||||
Options opts(3, 0, 0, 0, 0, 0, -1, "name",
|
||||
{kTrainableVarType}, // accout_type_regexes
|
||||
{".*"}, {""}, {".*"}, {""}, false,
|
||||
{"params", "bytes", "micros", "float_ops"}, "", {});
|
||||
@ -113,7 +113,8 @@ TEST_F(TFProfStatsTest, CustomOpType) {
|
||||
}
|
||||
|
||||
TEST_F(TFProfStatsTest, CheckPointOpType) {
|
||||
Options opts(3, 0, 0, 0, 0, 0, "name", {kCkptVarType}, // accout_type_regexes
|
||||
Options opts(3, 0, 0, 0, 0, 0, -1, "name",
|
||||
{kCkptVarType}, // accout_type_regexes
|
||||
{".*"}, {""}, {".*"}, {""}, false,
|
||||
{"params", "bytes", "micros", "float_ops"}, "", {});
|
||||
const TFGraphNodeProto& root = tf_stats_->ShowGraphNode("scope", opts);
|
||||
@ -154,7 +155,7 @@ TEST_F(TFProfStatsTest, CheckPointOpType) {
|
||||
}
|
||||
|
||||
TEST_F(TFProfStatsTest, TestGraph) {
|
||||
Options opts(100, 0, 10000, 0, 0, 0, "name", {".*"},
|
||||
Options opts(100, 0, 10000, 0, 0, 0, -1, "name", {".*"},
|
||||
{"cost.*"}, // start_name_regexes
|
||||
{""}, {".*"}, {""}, false,
|
||||
{"params", "bytes", "micros", "float_ops"}, "", {});
|
||||
@ -171,8 +172,8 @@ TEST_F(TFProfStatsTest, TestGraph) {
|
||||
}
|
||||
|
||||
TEST_F(TFProfStatsTest, TestFloatOps) {
|
||||
Options opts(10, 0, 0, 0, 1, 0, "name", {".*"}, {".*"}, {""}, {".*"}, {""},
|
||||
false, {"float_ops"}, "", {});
|
||||
Options opts(10, 0, 0, 0, 1, 0, -1, "name", {".*"}, {".*"}, {""}, {".*"},
|
||||
{""}, false, {"float_ops"}, "", {});
|
||||
const TFGraphNodeProto& root = tf_stats_->ShowGraphNode("scope", opts);
|
||||
|
||||
TFGraphNodeProto expected;
|
||||
@ -201,7 +202,7 @@ TEST_F(TFProfStatsTest, TestFloatOps) {
|
||||
}
|
||||
|
||||
TEST_F(TFProfStatsTest, TestAccountShownNameOnly) {
|
||||
Options opts(100, 0, 0, 0, 0, 0, "name", {".*"}, {".*"}, {""},
|
||||
Options opts(100, 0, 0, 0, 0, 0, -1, "name", {".*"}, {".*"}, {""},
|
||||
{"unit_2_1.*DW"}, // show_name_regexes.
|
||||
{""}, true, // account_displayed_op_only.
|
||||
{"params"}, "", {});
|
||||
@ -217,7 +218,7 @@ TEST_F(TFProfStatsTest, TestAccountShownNameOnly) {
|
||||
}
|
||||
|
||||
TEST_F(TFProfStatsTest, TestShowTensorValue) {
|
||||
Options opts(10, 0, 0, 0, 0, 0, "name", {".*"}, {".*"}, {""},
|
||||
Options opts(10, 0, 0, 0, 0, 0, -1, "name", {".*"}, {".*"}, {""},
|
||||
{"unit_1_0.*gamma"}, {""}, false,
|
||||
{"tensor_value"}, // Show tensor value from checkpoint.
|
||||
"", {});
|
||||
|
@ -76,7 +76,8 @@ class TFProfTensor {
|
||||
CHECK(strings::safe_strto64(sstream.str().c_str(), &int64_val));
|
||||
dim->add_value_int64(int64_val);
|
||||
formatted_str_ += strings::Printf(
|
||||
"%lld ", dim->value_int64(dim->value_int64_size() - 1));
|
||||
"%lld ", static_cast<int64>(
|
||||
dim->value_int64(dim->value_int64_size() - 1)));
|
||||
} else if (typeid(values[nstart]) == typeid(string)) {
|
||||
dim->add_value_str(sstream.str());
|
||||
formatted_str_ =
|
||||
|
@ -55,8 +55,8 @@ class TFProfTensorTest : public ::testing::Test {
|
||||
};
|
||||
|
||||
TEST_F(TFProfTensorTest, Basics) {
|
||||
Options opts(3, 0, 0, 0, 0, 0, "name", {"VariableV2"}, {".*"}, {""}, {".*"},
|
||||
{""}, false, {"tensor_value"}, // show the tensor value.
|
||||
Options opts(3, 0, 0, 0, 0, 0, -1, "name", {"VariableV2"}, {".*"}, {""},
|
||||
{".*"}, {""}, false, {"tensor_value"}, // show the tensor value.
|
||||
"", {});
|
||||
const TFGraphNodeProto& root = tf_stats_->ShowGraphNode("scope", opts);
|
||||
|
||||
|
@ -114,34 +114,42 @@ string ChromeTraceFormatter::Format() {
|
||||
return trace_str;
|
||||
}
|
||||
|
||||
void MemoryTracker::TrackNode(GraphNode* node) {
|
||||
void MemoryTracker::TrackNode(int64 step, GraphNode* node) {
|
||||
if (!node->Trackable(step)) {
|
||||
return;
|
||||
}
|
||||
Device& dev = devices_[node->node->canonical_device()];
|
||||
int64 end_micros =
|
||||
node->node->all_start_micros() + node->node->latest_end_rel_micros();
|
||||
if (node->node->accelerator_persistent_bytes() != 0) {
|
||||
int64 end_micros = node->node->all_start_micros(step) +
|
||||
node->node->latest_end_rel_micros(step);
|
||||
if (node->node->accelerator_persistent_bytes(step) != 0) {
|
||||
string tensor_name = strings::StrCat(node->name(), ":", -1);
|
||||
dev.earliest_ref[tensor_name] = node->node->all_start_micros();
|
||||
dev.tensor_size[tensor_name] = node->node->accelerator_persistent_bytes();
|
||||
dev.earliest_ref[tensor_name] = node->node->all_start_micros(step);
|
||||
dev.tensor_size[tensor_name] =
|
||||
node->node->accelerator_persistent_bytes(step);
|
||||
// TODO(xpan): Need latest_ref?
|
||||
}
|
||||
if (node->node->accelerator_temp_bytes()) {
|
||||
if (node->node->accelerator_temp_bytes(step)) {
|
||||
string tensor_name = strings::StrCat(node->name(), ":", -2);
|
||||
dev.earliest_ref[tensor_name] = node->node->all_start_micros();
|
||||
dev.earliest_ref[tensor_name] = node->node->all_start_micros(step);
|
||||
dev.latest_ref[tensor_name] = end_micros;
|
||||
dev.tensor_size[tensor_name] = node->node->accelerator_temp_bytes();
|
||||
dev.tensor_size[tensor_name] = node->node->accelerator_temp_bytes(step);
|
||||
}
|
||||
if (node->node->allocator_bytes_in_use() > 0) {
|
||||
dev.allocator_stats[end_micros] = node->node->allocator_bytes_in_use();
|
||||
if (node->node->allocator_bytes_in_use(step) > 0) {
|
||||
dev.allocator_stats[end_micros] = node->node->allocator_bytes_in_use(step);
|
||||
}
|
||||
}
|
||||
|
||||
void MemoryTracker::TrackNodeConnection(GraphNode* node, GraphNode* src) {
|
||||
void MemoryTracker::TrackNodeConnection(int64 step, GraphNode* node,
|
||||
GraphNode* src) {
|
||||
if (!node->Trackable(step) || !src->Trackable(step)) {
|
||||
return;
|
||||
}
|
||||
const auto& output_idx = node->node->output_idx().find(src->name());
|
||||
if (output_idx == node->node->output_idx().end()) {
|
||||
return;
|
||||
}
|
||||
const auto& output = src->node->output_bytes().find(output_idx->second);
|
||||
if (output == src->node->output_bytes().end()) {
|
||||
const auto& output = src->node->output_bytes(step).find(output_idx->second);
|
||||
if (output == src->node->output_bytes(step).end()) {
|
||||
return;
|
||||
}
|
||||
int64 output_bytes = output->second.first;
|
||||
@ -155,14 +163,14 @@ void MemoryTracker::TrackNodeConnection(GraphNode* node, GraphNode* src) {
|
||||
}
|
||||
|
||||
src_dev.tensor_size[tensor_name] = output_bytes;
|
||||
src_dev.earliest_ref[tensor_name] = src->node->all_start_micros();
|
||||
src_dev.earliest_ref[tensor_name] = src->node->all_start_micros(step);
|
||||
|
||||
int64 src_end_micros =
|
||||
src->node->all_start_micros() + src->node->latest_end_rel_micros();
|
||||
int64 src_end_micros = src->node->all_start_micros(step) +
|
||||
src->node->latest_end_rel_micros(step);
|
||||
|
||||
if (src->node->canonical_device() != node->node->canonical_device()) {
|
||||
int64 transfer_micros =
|
||||
(src_end_micros + node->node->all_start_micros()) / 2;
|
||||
(src_end_micros + node->node->all_start_micros(step)) / 2;
|
||||
src_dev.latest_ref[tensor_name] =
|
||||
std::max(src_dev.latest_ref[tensor_name], transfer_micros);
|
||||
|
||||
@ -171,18 +179,19 @@ void MemoryTracker::TrackNodeConnection(GraphNode* node, GraphNode* src) {
|
||||
strings::StrCat(tensor_name, node->node->canonical_device());
|
||||
dest_dev.tensor_size[dest_tensor_name] = output_bytes;
|
||||
dest_dev.earliest_ref[dest_tensor_name] = transfer_micros;
|
||||
dest_dev.latest_ref[dest_tensor_name] = std::max(
|
||||
dest_dev.latest_ref[dest_tensor_name],
|
||||
node->node->all_start_micros() + node->node->latest_end_rel_micros());
|
||||
dest_dev.latest_ref[dest_tensor_name] =
|
||||
std::max(dest_dev.latest_ref[dest_tensor_name],
|
||||
node->node->all_start_micros(step) +
|
||||
node->node->latest_end_rel_micros(step));
|
||||
} else {
|
||||
src_dev.latest_ref[tensor_name] = std::max(
|
||||
src_dev.latest_ref[tensor_name],
|
||||
node->node->all_start_micros() + node->node->latest_end_rel_micros());
|
||||
src_dev.latest_ref[tensor_name] =
|
||||
std::max(src_dev.latest_ref[tensor_name],
|
||||
node->node->all_start_micros(step) +
|
||||
node->node->latest_end_rel_micros(step));
|
||||
}
|
||||
}
|
||||
|
||||
void Timeline::GenerateGraphTimeline(const GraphNode* gnode,
|
||||
const MemoryTracker& memory_tracker) {
|
||||
void Timeline::GenerateGraphTimeline(const GraphNode* gnode) {
|
||||
AddGraphNode(gnode);
|
||||
AllocateLanes();
|
||||
fprintf(stdout, "generating trace file.\n");
|
||||
@ -215,7 +224,7 @@ void Timeline::GenerateGraphTimeline(const GraphNode* gnode,
|
||||
}
|
||||
}
|
||||
}
|
||||
for (const auto& dev : memory_tracker.devices()) {
|
||||
for (const auto& dev : mem_tracker_.devices()) {
|
||||
int64 pid = AllocatePID();
|
||||
chrome_formatter_.EmitPID(GetMemoryLaneName(dev.first), pid);
|
||||
const MemoryTracker::Device& device = dev.second;
|
||||
@ -268,12 +277,12 @@ std::vector<TimeNode*> Timeline::AddGraphNode(const GraphNode* gnode) {
|
||||
std::vector<TimeNode*> inputs = AddGraphNode(schild);
|
||||
shown_cinputs.insert(shown_cinputs.end(), inputs.begin(), inputs.end());
|
||||
}
|
||||
if (!gnode->node->step_stats()) {
|
||||
if (!gnode->node->trackable(step_)) {
|
||||
return shown_cinputs;
|
||||
}
|
||||
|
||||
const TFGraphNode* node = gnode->node;
|
||||
for (const auto& kernel_execs : node->op_execs()) {
|
||||
for (const auto& kernel_execs : node->op_execs(step_)) {
|
||||
const string& device = kernel_execs.first;
|
||||
const std::vector<std::pair<int64, int64>>& execs = kernel_execs.second;
|
||||
|
||||
|
@ -101,9 +101,9 @@ class MemoryTracker {
|
||||
std::map<int64, int64> allocator_stats;
|
||||
};
|
||||
|
||||
void TrackNode(GraphNode* node);
|
||||
void TrackNode(int64 step, GraphNode* node);
|
||||
|
||||
void TrackNodeConnection(GraphNode* node, GraphNode* src);
|
||||
void TrackNodeConnection(int64 step, GraphNode* node, GraphNode* src);
|
||||
|
||||
const std::map<string, Device>& devices() const { return devices_; }
|
||||
|
||||
@ -113,16 +113,25 @@ class MemoryTracker {
|
||||
|
||||
class Timeline {
|
||||
public:
|
||||
Timeline(const string& outfile) : outfile_(outfile) {}
|
||||
Timeline(int64 step, const string& outfile)
|
||||
: step_(step), outfile_(outfile) {}
|
||||
~Timeline() {}
|
||||
|
||||
void GenerateGraphTimeline(const GraphNode* gnode,
|
||||
const MemoryTracker& memory_tracker);
|
||||
int64 step() const { return step_; }
|
||||
void SetStep(int64 step) { step_ = step; }
|
||||
|
||||
void GenerateGraphTimeline(const GraphNode* gnode);
|
||||
|
||||
void GenerateScopeTimeline(const ScopeNode* node);
|
||||
|
||||
void GenerateCodeTimeline(const CodeNode* node);
|
||||
|
||||
void TrackNode(GraphNode* node) { mem_tracker_.TrackNode(step_, node); }
|
||||
|
||||
void TrackNodeConnection(GraphNode* node, GraphNode* src) {
|
||||
mem_tracker_.TrackNodeConnection(step_, node, src);
|
||||
}
|
||||
|
||||
private:
|
||||
void OutputTimeline();
|
||||
|
||||
@ -162,9 +171,11 @@ class Timeline {
|
||||
|
||||
int64 AllocatePID();
|
||||
|
||||
int64 step_;
|
||||
const string outfile_;
|
||||
int64 next_pid_ = 0;
|
||||
int64 allocator_pid_ = -1;
|
||||
MemoryTracker mem_tracker_;
|
||||
ChromeTraceFormatter chrome_formatter_;
|
||||
std::map<string, int64> device_pids_;
|
||||
|
||||
|
@ -60,7 +60,7 @@ class TFProfTimelineTest : public ::testing::Test {
|
||||
// manually check it's correct
|
||||
TEST_F(TFProfTimelineTest, GraphView) {
|
||||
string dump_file = io::JoinPath(testing::TmpDir(), "dump");
|
||||
Options opts(10000, 0, 0, 0, 0, 0, "name", {".*"}, // accout_type_regexes
|
||||
Options opts(10000, 0, 0, 0, 0, 0, 0, "name", {".*"}, // accout_type_regexes
|
||||
{".*"}, {""}, {".*"}, {""}, false,
|
||||
{"params", "bytes", "micros", "float_ops"}, "timeline",
|
||||
{{"outfile", dump_file}});
|
||||
@ -73,7 +73,7 @@ TEST_F(TFProfTimelineTest, GraphView) {
|
||||
|
||||
TEST_F(TFProfTimelineTest, ScopeView) {
|
||||
string dump_file = io::JoinPath(testing::TmpDir(), "dump");
|
||||
Options opts(5, 0, 0, 0, 0, 0, "name", {".*"}, // accout_type_regexes
|
||||
Options opts(5, 0, 0, 0, 0, 0, 0, "name", {".*"}, // accout_type_regexes
|
||||
{".*"}, {""}, {".*"}, {""}, false,
|
||||
{"params", "bytes", "micros", "float_ops"}, "timeline",
|
||||
{{"outfile", dump_file}});
|
||||
|
@ -176,6 +176,12 @@ tensorflow::Status ParseCmdLine(const string& line, string* cmd,
|
||||
}
|
||||
++i;
|
||||
} else if (pieces[i] == tensorflow::tfprof::kOptions[6]) {
|
||||
if (pieces.size() <= i + 1 ||
|
||||
!strings::safe_strto64(pieces[i + 1], &opts->step)) {
|
||||
return ReturnError(pieces, i);
|
||||
}
|
||||
++i;
|
||||
} else if (pieces[i] == tensorflow::tfprof::kOptions[7]) {
|
||||
if (pieces.size() <= i + 1) {
|
||||
return ReturnError(pieces, i);
|
||||
}
|
||||
@ -187,42 +193,42 @@ tensorflow::Status ParseCmdLine(const string& line, string* cmd,
|
||||
}
|
||||
opts->order_by = *order_by;
|
||||
++i;
|
||||
} else if (pieces[i] == tensorflow::tfprof::kOptions[7]) {
|
||||
} else if (pieces[i] == tensorflow::tfprof::kOptions[8]) {
|
||||
if (pieces.size() <= i + 1) {
|
||||
return ReturnError(pieces, i);
|
||||
}
|
||||
opts->account_type_regexes = str_util::Split(StripQuote(pieces[i + 1]),
|
||||
',', str_util::SkipEmpty());
|
||||
++i;
|
||||
} else if (pieces[i] == tensorflow::tfprof::kOptions[8]) {
|
||||
} else if (pieces[i] == tensorflow::tfprof::kOptions[9]) {
|
||||
if (pieces.size() <= i + 1) {
|
||||
return ReturnError(pieces, i);
|
||||
}
|
||||
opts->start_name_regexes = str_util::Split(StripQuote(pieces[i + 1]), ',',
|
||||
str_util::SkipEmpty());
|
||||
++i;
|
||||
} else if (pieces[i] == tensorflow::tfprof::kOptions[9]) {
|
||||
} else if (pieces[i] == tensorflow::tfprof::kOptions[10]) {
|
||||
if (pieces.size() <= i + 1) {
|
||||
return ReturnError(pieces, i);
|
||||
}
|
||||
opts->trim_name_regexes = str_util::Split(StripQuote(pieces[i + 1]), ',',
|
||||
str_util::SkipEmpty());
|
||||
++i;
|
||||
} else if (pieces[i] == tensorflow::tfprof::kOptions[10]) {
|
||||
} else if (pieces[i] == tensorflow::tfprof::kOptions[11]) {
|
||||
if (pieces.size() <= i + 1) {
|
||||
return ReturnError(pieces, i);
|
||||
}
|
||||
opts->show_name_regexes = str_util::Split(StripQuote(pieces[i + 1]), ',',
|
||||
str_util::SkipEmpty());
|
||||
++i;
|
||||
} else if (pieces[i] == tensorflow::tfprof::kOptions[11]) {
|
||||
} else if (pieces[i] == tensorflow::tfprof::kOptions[12]) {
|
||||
if (pieces.size() <= i + 1) {
|
||||
return ReturnError(pieces, i);
|
||||
}
|
||||
opts->hide_name_regexes = str_util::Split(StripQuote(pieces[i + 1]), ',',
|
||||
str_util::SkipEmpty());
|
||||
++i;
|
||||
} else if (pieces[i] == tensorflow::tfprof::kOptions[12]) {
|
||||
} else if (pieces[i] == tensorflow::tfprof::kOptions[13]) {
|
||||
if ((pieces.size() > i + 1 && pieces[i + 1].find("-") == 0) ||
|
||||
pieces.size() == i + 1) {
|
||||
opts->account_displayed_op_only = true;
|
||||
@ -232,7 +238,7 @@ tensorflow::Status ParseCmdLine(const string& line, string* cmd,
|
||||
} else {
|
||||
++i;
|
||||
}
|
||||
} else if (pieces[i] == tensorflow::tfprof::kOptions[13]) {
|
||||
} else if (pieces[i] == tensorflow::tfprof::kOptions[14]) {
|
||||
if (pieces.size() <= i + 1) {
|
||||
return ReturnError(pieces, i);
|
||||
}
|
||||
@ -249,7 +255,7 @@ tensorflow::Status ParseCmdLine(const string& line, string* cmd,
|
||||
}
|
||||
opts->select = requested_set;
|
||||
++i;
|
||||
} else if (pieces[i] == tensorflow::tfprof::kOptions[14]) {
|
||||
} else if (pieces[i] == tensorflow::tfprof::kOptions[15]) {
|
||||
if (pieces.size() <= i + 1) {
|
||||
return ReturnError(pieces, i);
|
||||
}
|
||||
@ -291,6 +297,10 @@ void PrintHelp() {
|
||||
"float operations. Only available if an op has "
|
||||
"op.RegisterStatistics() defined and OpLog is "
|
||||
"provided\n\n"
|
||||
" -min_occurrence: Show the op types that are at least used this number "
|
||||
"of times. Only available in op view.\n\n"
|
||||
" -step: Show the stats of a step when multiple steps of "
|
||||
"RunMetadata were added. By default (-1), show the average of all steps."
|
||||
" -order_by: Order the results by [name|depth|bytes|micros|params|"
|
||||
"float_ops]\n\n"
|
||||
" -account_type_regexes: Account and display the ops whose types match "
|
||||
|
@ -75,6 +75,7 @@ int main(int argc, char** argv) {
|
||||
tensorflow::int64 FLAGS_min_params = 0;
|
||||
tensorflow::int64 FLAGS_min_float_ops = 0;
|
||||
tensorflow::int64 FLAGS_min_occurrence = 0;
|
||||
tensorflow::int64 FLAGS_step = -1;
|
||||
tensorflow::string FLAGS_order_by = "name";
|
||||
tensorflow::string FLAGS_account_type_regexes = ".*";
|
||||
tensorflow::string FLAGS_start_name_regexes = ".*";
|
||||
@ -92,7 +93,8 @@ int main(int argc, char** argv) {
|
||||
tensorflow::Flag("graph_path", &FLAGS_graph_path,
|
||||
"GraphDef proto text file name"),
|
||||
tensorflow::Flag("run_meta_path", &FLAGS_run_meta_path,
|
||||
"RunMetadata proto binary file name"),
|
||||
"Comma-separated list of RunMetadata proto binary "
|
||||
"files. Each file is given step number 0,1,2,etc"),
|
||||
tensorflow::Flag("op_log_path", &FLAGS_op_log_path,
|
||||
"tensorflow::tfprof::OpLog proto binary file name"),
|
||||
tensorflow::Flag("checkpoint_path", &FLAGS_checkpoint_path,
|
||||
@ -104,6 +106,8 @@ int main(int argc, char** argv) {
|
||||
tensorflow::Flag("min_float_ops", &FLAGS_min_float_ops, "min float ops"),
|
||||
tensorflow::Flag("min_occurrence", &FLAGS_min_occurrence,
|
||||
"min occurrence"),
|
||||
tensorflow::Flag("step", &FLAGS_step,
|
||||
"The stats of which step to use. By default average"),
|
||||
tensorflow::Flag("order_by", &FLAGS_order_by, "order by"),
|
||||
tensorflow::Flag("account_type_regexes", &FLAGS_start_name_regexes,
|
||||
"start name regexes"),
|
||||
@ -181,18 +185,6 @@ int main(int argc, char** argv) {
|
||||
TF_CHECK_OK(tensorflow::tfprof::ReadGraphDef(tensorflow::Env::Default(),
|
||||
FLAGS_graph_path, graph.get()));
|
||||
|
||||
std::unique_ptr<tensorflow::RunMetadata> run_meta(
|
||||
new tensorflow::RunMetadata());
|
||||
if (!FLAGS_run_meta_path.empty()) {
|
||||
s = ReadBinaryProto(tensorflow::Env::Default(), FLAGS_run_meta_path,
|
||||
run_meta.get());
|
||||
if (!s.ok()) {
|
||||
fprintf(stderr, "Failed to read run_meta_path: %s\n",
|
||||
s.ToString().c_str());
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::tfprof::OpLog> op_log(
|
||||
new tensorflow::tfprof::OpLog());
|
||||
if (!FLAGS_op_log_path.empty()) {
|
||||
@ -222,12 +214,27 @@ int main(int argc, char** argv) {
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
tensorflow::tfprof::TFStats tf_stat(std::move(graph), std::move(run_meta),
|
||||
std::move(op_log),
|
||||
std::move(ckpt_reader));
|
||||
tensorflow::tfprof::TFStats tf_stat(
|
||||
std::move(graph), nullptr, std::move(op_log), std::move(ckpt_reader));
|
||||
|
||||
std::vector<string> run_meta_files =
|
||||
Split(FLAGS_run_meta_path, ',', tensorflow::str_util::SkipEmpty());
|
||||
for (int i = 0; i < run_meta_files.size(); ++i) {
|
||||
std::unique_ptr<tensorflow::RunMetadata> run_meta(
|
||||
new tensorflow::RunMetadata());
|
||||
s = ReadBinaryProto(tensorflow::Env::Default(), run_meta_files[i],
|
||||
run_meta.get());
|
||||
if (!s.ok()) {
|
||||
fprintf(stderr, "Failed to read run_meta_path %s. Status: %s\n",
|
||||
run_meta_files[i].c_str(), s.ToString().c_str());
|
||||
return 1;
|
||||
}
|
||||
tf_stat.ParseRunMeta(i, std::move(run_meta));
|
||||
}
|
||||
|
||||
tensorflow::tfprof::Options opts(
|
||||
FLAGS_max_depth, FLAGS_min_bytes, FLAGS_min_micros, FLAGS_min_params,
|
||||
FLAGS_min_float_ops, FLAGS_min_occurrence, FLAGS_order_by,
|
||||
FLAGS_min_float_ops, FLAGS_min_occurrence, FLAGS_step, FLAGS_order_by,
|
||||
account_type_regexes, start_name_regexes, trim_name_regexes,
|
||||
show_name_regexes, hide_name_regexes, FLAGS_account_displayed_op_only,
|
||||
select, output_type, output_options);
|
||||
|
@ -11,6 +11,7 @@ message OptionsProto {
|
||||
optional int64 min_params = 4;
|
||||
optional int64 min_float_ops = 5;
|
||||
optional int64 min_occurrence = 17;
|
||||
optional int64 step = 18 [default = -1];
|
||||
|
||||
optional string order_by = 7;
|
||||
repeated string account_type_regexes = 8;
|
||||
|
Loading…
Reference in New Issue
Block a user