Add a converter from overview_page.proto to GViz DataTable format.

PiperOrigin-RevId: 290288974
Change-Id: I1390c52fddf20f0778ae69d4fb54af92dc947262
This commit is contained in:
A. Unique TensorFlower 2020-01-17 09:53:13 -08:00 committed by TensorFlower Gardener
parent bd6b68397f
commit ee24d4b059
2 changed files with 427 additions and 0 deletions

View File

@ -0,0 +1,142 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""For conversion of TF Overview Page protos to GViz DataTables.
Usage:
gviz_data_table = generate_chart_table(overview_page)
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import google_type_annotations
from __future__ import print_function
import datetime
import gviz_api
def get_run_environment_table_args(run_environment):
"""Creates a gviz DataTable object from a RunEnvironment proto.
Args:
run_environment: An op_stats_pb2.RunEnvironment.
Returns:
Returns a gviz_api.DataTable
"""
table_description = [
("host_id", "string", "host_id"),
("command_line", "string", "command_line"),
("start_time", "string", "start_time"),
("bns_address", "string", "bns_address"),
]
data = []
for job in run_environment.host_dependent_job_info:
row = [
str(job.host_id),
str(job.command_line),
str(datetime.datetime.utcfromtimestamp(job.start_time)),
str(job.bns_address),
]
data.append(row)
return (table_description, data, [])
def generate_run_environment_table(run_environment):
(table_description, data,
custom_properties) = get_run_environment_table_args(run_environment)
return gviz_api.DataTable(table_description, data, custom_properties)
def get_overview_page_analysis_table_args(overview_page_analysis):
"""Creates a gviz DataTable object from an OverviewPageAnalysis proto.
Args:
overview_page_analysis: An overview_page_pb2.OverviewPageAnalysis.
Returns:
Returns a gviz_api.DataTable
"""
table_description = [
("selfTimePercent", "number", "Time (%)"),
("cumulativeTimePercent", "number", "Cumulative time (%)"),
("category", "string", "Category"),
("operation", "string", "Operation"),
("flopRate", "number", "GFLOPs/Sec"),
]
data = []
for op in overview_page_analysis.top_device_ops:
row = [
op.self_time_fraction,
op.cumulative_time_fraction,
str(op.category),
str(op.name),
op.flop_rate,
]
data.append(row)
return (table_description, data, [])
def generate_overview_page_analysis_table(overview_page_analysis):
(table_description, data, custom_properties) = \
get_overview_page_analysis_table_args(overview_page_analysis)
return gviz_api.DataTable(table_description, data, custom_properties)
def get_recommendation_table_args(overview_page_recommendation):
"""Creates a gviz DataTable object from an OverviewPageRecommendation proto.
Args:
overview_page_recommendation: An
overview_page_pb2.OverviewPageRecommendation.
Returns:
Returns a gviz_api.DataTable
"""
table_description = [
("tip_type", "string", "tip_type"),
("link", "string", "link"),
]
data = []
for faq_tip in overview_page_recommendation.faq_tips:
data.append(["faq", faq_tip.link])
for host_tip in overview_page_recommendation.host_tips:
data.append(["host", host_tip.link])
for device_tip in overview_page_recommendation.device_tips:
data.append(["device", device_tip.link])
for doc_tip in overview_page_recommendation.documentation_tips:
data.append(["doc", doc_tip.link])
for inference_tip in overview_page_recommendation.inference_tips:
data.append(["inference", inference_tip.link])
return (table_description, data, [])
def generate_recommendation_table(overview_page_recommendation):
(table_description, data, custom_properties) = \
get_recommendation_table_args(overview_page_recommendation)
return gviz_api.DataTable(table_description, data, custom_properties)

View File

@ -0,0 +1,285 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Tests for overview_page_proto_to_gviz."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import csv
import io
import gviz_api
# pylint: disable=g-direct-tensorflow-import
from tensorflow.core.profiler.protobuf import op_stats_pb2
from tensorflow.core.profiler.protobuf import overview_page_pb2
from tensorflow.python.platform import test
from tensorflow.python.profiler import overview_page_proto_to_gviz
# pylint: enable=g-direct-tensorflow-import
class ProtoToGvizTest(test.TestCase):
@classmethod
def setUpClass(cls):
super(ProtoToGvizTest, cls).setUpClass()
MockRunEnvironment = collections.namedtuple( # pylint: disable=invalid-name
"MockRunEnvironment",
["host_id", "command_line", "start_time", "bns_address"])
ProtoToGvizTest.mock_run_env = MockRunEnvironment(
host_id="1",
command_line="2",
start_time=1582202096,
bns_address="4",
)
MockOverviewTfOp = collections.namedtuple( # pylint: disable=invalid-name
"MockOverviewTfOp", [
"self_time_fraction",
"cumulative_time_fraction",
"category",
"name",
"flop_rate",
])
ProtoToGvizTest.mock_tf_op = MockOverviewTfOp(
self_time_fraction=3.0,
cumulative_time_fraction=4.0,
category="2",
name="1",
flop_rate=5.0,
)
MockTip = collections.namedtuple( # pylint: disable=invalid-name
"MockTip", [
"tip_type",
"link",
])
ProtoToGvizTest.mock_tips = []
for tip in ["faq", "host", "device", "doc"]:
for idx in range(0, 3):
ProtoToGvizTest.mock_tips.append(MockTip(tip, tip + "_link" + str(idx)))
# Checks that DataTable columns match schema defined in table_description.
def check_header_row(self, data, table_description, row_values):
for (cc, column_header) in enumerate(row_values):
self.assertEqual(table_description[cc][2], column_header)
# Checks that DataTable row value representation matches number or string.
def check_row_types(self, data, table_description, row_values, row_idx):
for (cc, cell_str) in enumerate(row_values):
raw_value = data[row_idx - 1][cc]
value_type = table_description[cc][1]
# Only number and strings are used in our DataTable schema.
self.assertIn(value_type, ["number", "string"])
# Encode in similar fashion as DataTable.ToCsv().
expected_value = gviz_api.DataTable.CoerceValue(raw_value, value_type)
self.assertNotIsInstance(expected_value, tuple)
self.assertEqual(expected_value, raw_value)
self.assertEqual(str(expected_value), cell_str)
def create_empty_run_environment(self):
return op_stats_pb2.RunEnvironment()
def create_empty_overview_page_analysis(self):
return overview_page_pb2.OverviewPageAnalysis()
def create_empty_recommendation(self):
return overview_page_pb2.OverviewPageRecommendation()
def create_mock_run_environment(self):
run_env = op_stats_pb2.RunEnvironment()
# Add 3 rows
for _ in range(0, 3):
job = op_stats_pb2.HostDependentJobInfoResult()
job.host_id = self.mock_run_env.host_id
job.command_line = self.mock_run_env.command_line
job.start_time = self.mock_run_env.start_time
job.bns_address = self.mock_run_env.bns_address
run_env.host_dependent_job_info.append(job)
return run_env
def test_run_environment_empty(self):
run_env = self.create_empty_run_environment()
data_table = overview_page_proto_to_gviz.generate_run_environment_table(
run_env)
self.assertEqual(0, data_table.NumberOfRows(),
"Empty table should have 0 rows.")
# Check the number of columns in Run environment data table.
self.assertLen(data_table.columns, len(list(self.mock_run_env)))
def test_run_environment_simple(self):
run_env = self.create_mock_run_environment()
(table_description, data, custom_properties) = \
overview_page_proto_to_gviz.get_run_environment_table_args(run_env)
data_table = gviz_api.DataTable(table_description, data, custom_properties)
# Data is a list of 3 rows.
self.assertLen(data, 3)
self.assertEqual(3, data_table.NumberOfRows(), "Simple table has 3 rows.")
# Check the number of columns in table descriptor and data table.
self.assertLen(table_description, len(list(self.mock_run_env)))
self.assertLen(data_table.columns, len(list(self.mock_run_env)))
# Prepare expectation to check against.
# get_run_environment_table_args() formats ns to RFC3339_full format.
mock_data_run_env = self.mock_run_env._replace(
start_time="2020-02-20 12:34:56")
# Check data against mock values.
for row in data:
self.assertEqual(list(mock_data_run_env), row)
# Check DataTable against mock values.
# Only way to access DataTable contents is by CSV
csv_file = io.StringIO(data_table.ToCsv())
reader = csv.reader(csv_file)
for (rr, row_values) in enumerate(reader):
if rr == 0:
self.check_header_row(data, table_description, row_values)
else:
self.check_row_types(data, table_description, row_values, rr)
self.assertEqual(list(mock_data_run_env), row_values)
def create_mock_overview_page_analysis(self):
analysis = overview_page_pb2.OverviewPageAnalysis()
# Add 3 rows
for _ in range(0, 3):
op = overview_page_pb2.OverviewTfOp()
op.self_time_fraction = self.mock_tf_op.self_time_fraction
op.cumulative_time_fraction = self.mock_tf_op.cumulative_time_fraction
op.category = self.mock_tf_op.category
op.name = self.mock_tf_op.name
op.flop_rate = self.mock_tf_op.flop_rate
analysis.top_device_ops.append(op)
return analysis
def test_overview_page_analysis_empty(self):
analysis = self.create_empty_overview_page_analysis()
data_table = \
overview_page_proto_to_gviz.generate_overview_page_analysis_table(
analysis)
self.assertEqual(0, data_table.NumberOfRows(),
"Empty table should have 0 rows.")
# Check the number of Overview Page Analysis data table columns.
self.assertLen(data_table.columns, len(list(self.mock_tf_op)))
def test_overview_page_analysis_simple(self):
analysis = self.create_mock_overview_page_analysis()
(table_description, data, custom_properties) = \
overview_page_proto_to_gviz.get_overview_page_analysis_table_args(
analysis)
data_table = gviz_api.DataTable(table_description, data, custom_properties)
# Data is a list of 3 rows.
self.assertLen(data, 3)
self.assertEqual(3, data_table.NumberOfRows(), "Simple table has 3 rows.")
# Check the number of columns in table descriptor and data table.
self.assertLen(table_description, len(list(self.mock_tf_op)))
self.assertLen(data_table.columns, len(list(self.mock_tf_op)))
# Prepare expectation to check against.
mock_csv_tf_op = [str(x) for x in list(self.mock_tf_op)]
# Check data against mock values.
for row in data:
self.assertEqual(list(self.mock_tf_op), row)
# Check DataTable against mock values.
# Only way to access DataTable contents is by CSV
csv_file = io.StringIO(data_table.ToCsv())
reader = csv.reader(csv_file)
for (rr, row_values) in enumerate(reader):
if rr == 0:
self.check_header_row(data, table_description, row_values)
else:
self.check_row_types(data, table_description, row_values, rr)
self.assertEqual(mock_csv_tf_op, row_values)
def create_mock_recommendation(self):
recommendation = overview_page_pb2.OverviewPageRecommendation()
for idx in range(0, 3):
recommendation.faq_tips.add().link = "faq_link" + str(idx)
recommendation.host_tips.add().link = "host_link" + str(idx)
recommendation.device_tips.add().link = "device_link" + str(idx)
recommendation.documentation_tips.add().link = "doc_link" + str(idx)
return recommendation
def test_recommendation_empty(self):
recommendation = self.create_empty_recommendation()
data_table = overview_page_proto_to_gviz.generate_recommendation_table(
recommendation)
self.assertEqual(0, data_table.NumberOfRows(),
"Empty table should have 0 rows.")
# Check the number of Overview Page Recommendation data table columns.
# One for tip_type, and one for link
self.assertLen(data_table.columns, 2)
def test_recommendation_simple(self):
recommendation = self.create_mock_recommendation()
(table_description, data, custom_properties) = \
overview_page_proto_to_gviz.get_recommendation_table_args(
recommendation)
data_table = gviz_api.DataTable(table_description, data, custom_properties)
# Data is a list of 12 rows: 3 rows for each tip type.
self.assertLen(data, len(list(self.mock_tips)))
self.assertLen(
list(self.mock_tips), data_table.NumberOfRows(),
"Simple table has 12 rows.")
# Check the number of columns in table descriptor and data table.
self.assertLen(table_description, 2)
self.assertLen(data_table.columns, 2)
# Check data against mock values.
for idx, row in enumerate(data):
self.assertEqual(list(self.mock_tips[idx]), row)
# Check DataTable against mock values.
# Only way to access DataTable contents is by CSV
csv_file = io.StringIO(data_table.ToCsv())
reader = csv.reader(csv_file)
for (rr, row_values) in enumerate(reader):
if rr == 0:
self.check_header_row(data, table_description, row_values)
else:
self.check_row_types(data, table_description, row_values, rr)
self.assertEqual(list(self.mock_tips[rr - 1]), row_values)
if __name__ == "__main__":
test.main()