Add a converter from overview_page.proto to GViz DataTable format.
PiperOrigin-RevId: 290288974 Change-Id: I1390c52fddf20f0778ae69d4fb54af92dc947262
This commit is contained in:
parent
bd6b68397f
commit
ee24d4b059
142
tensorflow/python/profiler/overview_page_proto_to_gviz.py
Normal file
142
tensorflow/python/profiler/overview_page_proto_to_gviz.py
Normal file
@ -0,0 +1,142 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""For conversion of TF Overview Page protos to GViz DataTables.
|
||||
|
||||
Usage:
|
||||
gviz_data_table = generate_chart_table(overview_page)
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import google_type_annotations
|
||||
from __future__ import print_function
|
||||
|
||||
import datetime
|
||||
import gviz_api
|
||||
|
||||
|
||||
def get_run_environment_table_args(run_environment):
|
||||
"""Creates a gviz DataTable object from a RunEnvironment proto.
|
||||
|
||||
Args:
|
||||
run_environment: An op_stats_pb2.RunEnvironment.
|
||||
|
||||
Returns:
|
||||
Returns a gviz_api.DataTable
|
||||
"""
|
||||
|
||||
table_description = [
|
||||
("host_id", "string", "host_id"),
|
||||
("command_line", "string", "command_line"),
|
||||
("start_time", "string", "start_time"),
|
||||
("bns_address", "string", "bns_address"),
|
||||
]
|
||||
|
||||
data = []
|
||||
for job in run_environment.host_dependent_job_info:
|
||||
row = [
|
||||
str(job.host_id),
|
||||
str(job.command_line),
|
||||
str(datetime.datetime.utcfromtimestamp(job.start_time)),
|
||||
str(job.bns_address),
|
||||
]
|
||||
data.append(row)
|
||||
|
||||
return (table_description, data, [])
|
||||
|
||||
|
||||
def generate_run_environment_table(run_environment):
|
||||
(table_description, data,
|
||||
custom_properties) = get_run_environment_table_args(run_environment)
|
||||
return gviz_api.DataTable(table_description, data, custom_properties)
|
||||
|
||||
|
||||
def get_overview_page_analysis_table_args(overview_page_analysis):
|
||||
"""Creates a gviz DataTable object from an OverviewPageAnalysis proto.
|
||||
|
||||
Args:
|
||||
overview_page_analysis: An overview_page_pb2.OverviewPageAnalysis.
|
||||
|
||||
Returns:
|
||||
Returns a gviz_api.DataTable
|
||||
"""
|
||||
|
||||
table_description = [
|
||||
("selfTimePercent", "number", "Time (%)"),
|
||||
("cumulativeTimePercent", "number", "Cumulative time (%)"),
|
||||
("category", "string", "Category"),
|
||||
("operation", "string", "Operation"),
|
||||
("flopRate", "number", "GFLOPs/Sec"),
|
||||
]
|
||||
|
||||
data = []
|
||||
for op in overview_page_analysis.top_device_ops:
|
||||
row = [
|
||||
op.self_time_fraction,
|
||||
op.cumulative_time_fraction,
|
||||
str(op.category),
|
||||
str(op.name),
|
||||
op.flop_rate,
|
||||
]
|
||||
data.append(row)
|
||||
|
||||
return (table_description, data, [])
|
||||
|
||||
|
||||
def generate_overview_page_analysis_table(overview_page_analysis):
|
||||
(table_description, data, custom_properties) = \
|
||||
get_overview_page_analysis_table_args(overview_page_analysis)
|
||||
return gviz_api.DataTable(table_description, data, custom_properties)
|
||||
|
||||
|
||||
def get_recommendation_table_args(overview_page_recommendation):
|
||||
"""Creates a gviz DataTable object from an OverviewPageRecommendation proto.
|
||||
|
||||
Args:
|
||||
overview_page_recommendation: An
|
||||
overview_page_pb2.OverviewPageRecommendation.
|
||||
|
||||
Returns:
|
||||
Returns a gviz_api.DataTable
|
||||
"""
|
||||
|
||||
table_description = [
|
||||
("tip_type", "string", "tip_type"),
|
||||
("link", "string", "link"),
|
||||
]
|
||||
|
||||
data = []
|
||||
for faq_tip in overview_page_recommendation.faq_tips:
|
||||
data.append(["faq", faq_tip.link])
|
||||
|
||||
for host_tip in overview_page_recommendation.host_tips:
|
||||
data.append(["host", host_tip.link])
|
||||
|
||||
for device_tip in overview_page_recommendation.device_tips:
|
||||
data.append(["device", device_tip.link])
|
||||
|
||||
for doc_tip in overview_page_recommendation.documentation_tips:
|
||||
data.append(["doc", doc_tip.link])
|
||||
|
||||
for inference_tip in overview_page_recommendation.inference_tips:
|
||||
data.append(["inference", inference_tip.link])
|
||||
|
||||
return (table_description, data, [])
|
||||
|
||||
|
||||
def generate_recommendation_table(overview_page_recommendation):
|
||||
(table_description, data, custom_properties) = \
|
||||
get_recommendation_table_args(overview_page_recommendation)
|
||||
return gviz_api.DataTable(table_description, data, custom_properties)
|
285
tensorflow/python/profiler/overview_page_proto_to_gviz_test.py
Normal file
285
tensorflow/python/profiler/overview_page_proto_to_gviz_test.py
Normal file
@ -0,0 +1,285 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
# Lint as: python3
|
||||
"""Tests for overview_page_proto_to_gviz."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import csv
|
||||
import io
|
||||
|
||||
import gviz_api
|
||||
|
||||
# pylint: disable=g-direct-tensorflow-import
|
||||
from tensorflow.core.profiler.protobuf import op_stats_pb2
|
||||
from tensorflow.core.profiler.protobuf import overview_page_pb2
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.profiler import overview_page_proto_to_gviz
|
||||
# pylint: enable=g-direct-tensorflow-import
|
||||
|
||||
|
||||
class ProtoToGvizTest(test.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(ProtoToGvizTest, cls).setUpClass()
|
||||
MockRunEnvironment = collections.namedtuple( # pylint: disable=invalid-name
|
||||
"MockRunEnvironment",
|
||||
["host_id", "command_line", "start_time", "bns_address"])
|
||||
|
||||
ProtoToGvizTest.mock_run_env = MockRunEnvironment(
|
||||
host_id="1",
|
||||
command_line="2",
|
||||
start_time=1582202096,
|
||||
bns_address="4",
|
||||
)
|
||||
|
||||
MockOverviewTfOp = collections.namedtuple( # pylint: disable=invalid-name
|
||||
"MockOverviewTfOp", [
|
||||
"self_time_fraction",
|
||||
"cumulative_time_fraction",
|
||||
"category",
|
||||
"name",
|
||||
"flop_rate",
|
||||
])
|
||||
|
||||
ProtoToGvizTest.mock_tf_op = MockOverviewTfOp(
|
||||
self_time_fraction=3.0,
|
||||
cumulative_time_fraction=4.0,
|
||||
category="2",
|
||||
name="1",
|
||||
flop_rate=5.0,
|
||||
)
|
||||
|
||||
MockTip = collections.namedtuple( # pylint: disable=invalid-name
|
||||
"MockTip", [
|
||||
"tip_type",
|
||||
"link",
|
||||
])
|
||||
|
||||
ProtoToGvizTest.mock_tips = []
|
||||
for tip in ["faq", "host", "device", "doc"]:
|
||||
for idx in range(0, 3):
|
||||
ProtoToGvizTest.mock_tips.append(MockTip(tip, tip + "_link" + str(idx)))
|
||||
|
||||
# Checks that DataTable columns match schema defined in table_description.
|
||||
def check_header_row(self, data, table_description, row_values):
|
||||
for (cc, column_header) in enumerate(row_values):
|
||||
self.assertEqual(table_description[cc][2], column_header)
|
||||
|
||||
# Checks that DataTable row value representation matches number or string.
|
||||
def check_row_types(self, data, table_description, row_values, row_idx):
|
||||
for (cc, cell_str) in enumerate(row_values):
|
||||
raw_value = data[row_idx - 1][cc]
|
||||
value_type = table_description[cc][1]
|
||||
|
||||
# Only number and strings are used in our DataTable schema.
|
||||
self.assertIn(value_type, ["number", "string"])
|
||||
|
||||
# Encode in similar fashion as DataTable.ToCsv().
|
||||
expected_value = gviz_api.DataTable.CoerceValue(raw_value, value_type)
|
||||
self.assertNotIsInstance(expected_value, tuple)
|
||||
self.assertEqual(expected_value, raw_value)
|
||||
self.assertEqual(str(expected_value), cell_str)
|
||||
|
||||
def create_empty_run_environment(self):
|
||||
return op_stats_pb2.RunEnvironment()
|
||||
|
||||
def create_empty_overview_page_analysis(self):
|
||||
return overview_page_pb2.OverviewPageAnalysis()
|
||||
|
||||
def create_empty_recommendation(self):
|
||||
return overview_page_pb2.OverviewPageRecommendation()
|
||||
|
||||
def create_mock_run_environment(self):
|
||||
run_env = op_stats_pb2.RunEnvironment()
|
||||
|
||||
# Add 3 rows
|
||||
for _ in range(0, 3):
|
||||
job = op_stats_pb2.HostDependentJobInfoResult()
|
||||
job.host_id = self.mock_run_env.host_id
|
||||
job.command_line = self.mock_run_env.command_line
|
||||
job.start_time = self.mock_run_env.start_time
|
||||
job.bns_address = self.mock_run_env.bns_address
|
||||
run_env.host_dependent_job_info.append(job)
|
||||
|
||||
return run_env
|
||||
|
||||
def test_run_environment_empty(self):
|
||||
run_env = self.create_empty_run_environment()
|
||||
data_table = overview_page_proto_to_gviz.generate_run_environment_table(
|
||||
run_env)
|
||||
|
||||
self.assertEqual(0, data_table.NumberOfRows(),
|
||||
"Empty table should have 0 rows.")
|
||||
# Check the number of columns in Run environment data table.
|
||||
self.assertLen(data_table.columns, len(list(self.mock_run_env)))
|
||||
|
||||
def test_run_environment_simple(self):
|
||||
run_env = self.create_mock_run_environment()
|
||||
(table_description, data, custom_properties) = \
|
||||
overview_page_proto_to_gviz.get_run_environment_table_args(run_env)
|
||||
data_table = gviz_api.DataTable(table_description, data, custom_properties)
|
||||
|
||||
# Data is a list of 3 rows.
|
||||
self.assertLen(data, 3)
|
||||
self.assertEqual(3, data_table.NumberOfRows(), "Simple table has 3 rows.")
|
||||
# Check the number of columns in table descriptor and data table.
|
||||
self.assertLen(table_description, len(list(self.mock_run_env)))
|
||||
self.assertLen(data_table.columns, len(list(self.mock_run_env)))
|
||||
|
||||
# Prepare expectation to check against.
|
||||
# get_run_environment_table_args() formats ns to RFC3339_full format.
|
||||
mock_data_run_env = self.mock_run_env._replace(
|
||||
start_time="2020-02-20 12:34:56")
|
||||
# Check data against mock values.
|
||||
for row in data:
|
||||
self.assertEqual(list(mock_data_run_env), row)
|
||||
|
||||
# Check DataTable against mock values.
|
||||
# Only way to access DataTable contents is by CSV
|
||||
csv_file = io.StringIO(data_table.ToCsv())
|
||||
reader = csv.reader(csv_file)
|
||||
|
||||
for (rr, row_values) in enumerate(reader):
|
||||
if rr == 0:
|
||||
self.check_header_row(data, table_description, row_values)
|
||||
else:
|
||||
self.check_row_types(data, table_description, row_values, rr)
|
||||
|
||||
self.assertEqual(list(mock_data_run_env), row_values)
|
||||
|
||||
def create_mock_overview_page_analysis(self):
|
||||
analysis = overview_page_pb2.OverviewPageAnalysis()
|
||||
|
||||
# Add 3 rows
|
||||
for _ in range(0, 3):
|
||||
op = overview_page_pb2.OverviewTfOp()
|
||||
op.self_time_fraction = self.mock_tf_op.self_time_fraction
|
||||
op.cumulative_time_fraction = self.mock_tf_op.cumulative_time_fraction
|
||||
op.category = self.mock_tf_op.category
|
||||
op.name = self.mock_tf_op.name
|
||||
op.flop_rate = self.mock_tf_op.flop_rate
|
||||
analysis.top_device_ops.append(op)
|
||||
|
||||
return analysis
|
||||
|
||||
def test_overview_page_analysis_empty(self):
|
||||
analysis = self.create_empty_overview_page_analysis()
|
||||
data_table = \
|
||||
overview_page_proto_to_gviz.generate_overview_page_analysis_table(
|
||||
analysis)
|
||||
|
||||
self.assertEqual(0, data_table.NumberOfRows(),
|
||||
"Empty table should have 0 rows.")
|
||||
# Check the number of Overview Page Analysis data table columns.
|
||||
self.assertLen(data_table.columns, len(list(self.mock_tf_op)))
|
||||
|
||||
def test_overview_page_analysis_simple(self):
|
||||
analysis = self.create_mock_overview_page_analysis()
|
||||
(table_description, data, custom_properties) = \
|
||||
overview_page_proto_to_gviz.get_overview_page_analysis_table_args(
|
||||
analysis)
|
||||
data_table = gviz_api.DataTable(table_description, data, custom_properties)
|
||||
|
||||
# Data is a list of 3 rows.
|
||||
self.assertLen(data, 3)
|
||||
self.assertEqual(3, data_table.NumberOfRows(), "Simple table has 3 rows.")
|
||||
# Check the number of columns in table descriptor and data table.
|
||||
self.assertLen(table_description, len(list(self.mock_tf_op)))
|
||||
self.assertLen(data_table.columns, len(list(self.mock_tf_op)))
|
||||
|
||||
# Prepare expectation to check against.
|
||||
mock_csv_tf_op = [str(x) for x in list(self.mock_tf_op)]
|
||||
|
||||
# Check data against mock values.
|
||||
for row in data:
|
||||
self.assertEqual(list(self.mock_tf_op), row)
|
||||
|
||||
# Check DataTable against mock values.
|
||||
# Only way to access DataTable contents is by CSV
|
||||
csv_file = io.StringIO(data_table.ToCsv())
|
||||
reader = csv.reader(csv_file)
|
||||
|
||||
for (rr, row_values) in enumerate(reader):
|
||||
if rr == 0:
|
||||
self.check_header_row(data, table_description, row_values)
|
||||
else:
|
||||
self.check_row_types(data, table_description, row_values, rr)
|
||||
|
||||
self.assertEqual(mock_csv_tf_op, row_values)
|
||||
|
||||
def create_mock_recommendation(self):
|
||||
recommendation = overview_page_pb2.OverviewPageRecommendation()
|
||||
|
||||
for idx in range(0, 3):
|
||||
recommendation.faq_tips.add().link = "faq_link" + str(idx)
|
||||
recommendation.host_tips.add().link = "host_link" + str(idx)
|
||||
recommendation.device_tips.add().link = "device_link" + str(idx)
|
||||
recommendation.documentation_tips.add().link = "doc_link" + str(idx)
|
||||
|
||||
return recommendation
|
||||
|
||||
def test_recommendation_empty(self):
|
||||
recommendation = self.create_empty_recommendation()
|
||||
data_table = overview_page_proto_to_gviz.generate_recommendation_table(
|
||||
recommendation)
|
||||
|
||||
self.assertEqual(0, data_table.NumberOfRows(),
|
||||
"Empty table should have 0 rows.")
|
||||
# Check the number of Overview Page Recommendation data table columns.
|
||||
# One for tip_type, and one for link
|
||||
self.assertLen(data_table.columns, 2)
|
||||
|
||||
def test_recommendation_simple(self):
|
||||
recommendation = self.create_mock_recommendation()
|
||||
(table_description, data, custom_properties) = \
|
||||
overview_page_proto_to_gviz.get_recommendation_table_args(
|
||||
recommendation)
|
||||
data_table = gviz_api.DataTable(table_description, data, custom_properties)
|
||||
|
||||
# Data is a list of 12 rows: 3 rows for each tip type.
|
||||
self.assertLen(data, len(list(self.mock_tips)))
|
||||
self.assertLen(
|
||||
list(self.mock_tips), data_table.NumberOfRows(),
|
||||
"Simple table has 12 rows.")
|
||||
# Check the number of columns in table descriptor and data table.
|
||||
self.assertLen(table_description, 2)
|
||||
self.assertLen(data_table.columns, 2)
|
||||
|
||||
# Check data against mock values.
|
||||
for idx, row in enumerate(data):
|
||||
self.assertEqual(list(self.mock_tips[idx]), row)
|
||||
|
||||
# Check DataTable against mock values.
|
||||
# Only way to access DataTable contents is by CSV
|
||||
csv_file = io.StringIO(data_table.ToCsv())
|
||||
reader = csv.reader(csv_file)
|
||||
|
||||
for (rr, row_values) in enumerate(reader):
|
||||
if rr == 0:
|
||||
self.check_header_row(data, table_description, row_values)
|
||||
else:
|
||||
self.check_row_types(data, table_description, row_values, rr)
|
||||
|
||||
self.assertEqual(list(self.mock_tips[rr - 1]), row_values)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
Loading…
x
Reference in New Issue
Block a user