STT-tensorflow/tensorflow/python/platform/benchmark_test.py
Gunhan Gulsoy 1dfd8dc2c4 Avoid using hardcoded path in benchmark_test
/tmp does not exist on windows.

PiperOrigin-RevId: 295988695
Change-Id: Ie45e6311df617462f4ba10354fabbeaf2eb05127
2020-02-19 09:42:08 -08:00

83 lines
2.8 KiB
Python

# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test for the tf.test.benchmark."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from google.protobuf import json_format
from tensorflow.core.util import test_log_pb2
from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
class BenchmarkTest(test.TestCase, benchmark.TensorFlowBenchmark):
def testReportBenchmark(self):
output_dir = self.get_temp_dir() + os.path.sep
os.environ['TEST_REPORT_FILE_PREFIX'] = output_dir
proto_file_path = os.path.join(output_dir,
'BenchmarkTest.testReportBenchmark')
if os.path.exists(proto_file_path):
os.remove(proto_file_path)
self.report_benchmark(
iters=2000,
wall_time=1000,
name='testReportBenchmark',
metrics=[{'name': 'metric_name_1', 'value': 0, 'min_value': 1},
{'name': 'metric_name_2', 'value': 90, 'min_value': 0,
'max_value': 95}])
with open(proto_file_path, 'rb') as f:
benchmark_entries = test_log_pb2.BenchmarkEntries()
benchmark_entries.ParseFromString(f.read())
actual_result = json_format.MessageToDict(
benchmark_entries, preserving_proto_field_name=True,
including_default_value_fields=True)['entry'][0]
os.remove(proto_file_path)
expected_result = {
'name': 'BenchmarkTest.testReportBenchmark',
# google.protobuf.json_format.MessageToDict() will convert
# int64 field to string.
'iters': '2000',
'wall_time': 1000,
'cpu_time': 0,
'throughput': 0,
'extras': {},
'metrics': [
{
'name': 'metric_name_1',
'value': 0,
'min_value': 1
},
{
'name': 'metric_name_2',
'value': 90,
'min_value': 0,
'max_value': 95
}
]
}
self.assertEqual(2000, benchmark_entries.entry[0].iters)
self.assertDictEqual(expected_result, actual_result)
if __name__ == '__main__':
test.main()