Add passed and description fields to BenchmarkEntry in test_log.proto
PiperOrigin-RevId: 242054418
This commit is contained in:
parent
0c4d4090c9
commit
82735fd7a6
@ -2,6 +2,7 @@
|
|||||||
syntax = "proto3";
|
syntax = "proto3";
|
||||||
|
|
||||||
import "google/protobuf/any.proto";
|
import "google/protobuf/any.proto";
|
||||||
|
import "google/protobuf/wrappers.proto";
|
||||||
|
|
||||||
option cc_enable_arenas = true;
|
option cc_enable_arenas = true;
|
||||||
option java_outer_classname = "TestLogProtos";
|
option java_outer_classname = "TestLogProtos";
|
||||||
@ -17,6 +18,20 @@ message EntryValue {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
message MetricEntry {
|
||||||
|
// Metric name
|
||||||
|
string name = 1;
|
||||||
|
|
||||||
|
// Metric value
|
||||||
|
double value = 2;
|
||||||
|
|
||||||
|
// The minimum acceptable value for the metric if specified
|
||||||
|
google.protobuf.DoubleValue min_value = 3;
|
||||||
|
|
||||||
|
// The maximum acceptable value for the metric if specified
|
||||||
|
google.protobuf.DoubleValue max_value = 4;
|
||||||
|
}
|
||||||
|
|
||||||
// Each unit test or benchmark in a test or benchmark run provides
|
// Each unit test or benchmark in a test or benchmark run provides
|
||||||
// some set of information. Here we provide some reasonable keys
|
// some set of information. Here we provide some reasonable keys
|
||||||
// one would expect to see, with optional key/value pairs for things
|
// one would expect to see, with optional key/value pairs for things
|
||||||
@ -43,6 +58,10 @@ message BenchmarkEntry {
|
|||||||
|
|
||||||
// Generic map from result key to value.
|
// Generic map from result key to value.
|
||||||
map<string, EntryValue> extras = 6;
|
map<string, EntryValue> extras = 6;
|
||||||
|
|
||||||
|
// Metric name, value and expected range. This can include accuracy metrics
|
||||||
|
// typically used to determine whether the accuracy test has passed
|
||||||
|
repeated MetricEntry metrics = 7;
|
||||||
};
|
};
|
||||||
|
|
||||||
message BenchmarkEntries {
|
message BenchmarkEntries {
|
||||||
|
@ -1367,6 +1367,25 @@ tf_py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_py_test(
|
||||||
|
name = "benchmark_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = [
|
||||||
|
"platform/benchmark.py",
|
||||||
|
"platform/benchmark_test.py",
|
||||||
|
],
|
||||||
|
additional_deps = [
|
||||||
|
":client_testlib",
|
||||||
|
":platform",
|
||||||
|
],
|
||||||
|
main = "platform/benchmark_test.py",
|
||||||
|
tags = [
|
||||||
|
"manual",
|
||||||
|
"no_pip",
|
||||||
|
"notap",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
name = "proto_test",
|
name = "proto_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
@ -53,7 +53,7 @@ OVERRIDE_GLOBAL_THREADPOOL = "TF_OVERRIDE_GLOBAL_THREADPOOL"
|
|||||||
|
|
||||||
def _global_report_benchmark(
|
def _global_report_benchmark(
|
||||||
name, iters=None, cpu_time=None, wall_time=None,
|
name, iters=None, cpu_time=None, wall_time=None,
|
||||||
throughput=None, extras=None):
|
throughput=None, extras=None, metrics=None):
|
||||||
"""Method for recording a benchmark directly.
|
"""Method for recording a benchmark directly.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -63,20 +63,22 @@ def _global_report_benchmark(
|
|||||||
wall_time: (optional) Total wall time in seconds
|
wall_time: (optional) Total wall time in seconds
|
||||||
throughput: (optional) Throughput (in MB/s)
|
throughput: (optional) Throughput (in MB/s)
|
||||||
extras: (optional) Dict mapping string keys to additional benchmark info.
|
extras: (optional) Dict mapping string keys to additional benchmark info.
|
||||||
|
metrics: (optional) A list of dict representing metrics generated by the
|
||||||
|
benchmark. Each dict should contain keys 'name' and'value'. A dict
|
||||||
|
can optionally contain keys 'min_value' and 'max_value'.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: if extras is not a dict.
|
TypeError: if extras is not a dict.
|
||||||
IOError: if the benchmark output file already exists.
|
IOError: if the benchmark output file already exists.
|
||||||
"""
|
"""
|
||||||
if extras is not None:
|
|
||||||
if not isinstance(extras, dict):
|
|
||||||
raise TypeError("extras must be a dict")
|
|
||||||
|
|
||||||
logging.info("Benchmark [%s] iters: %d, wall_time: %g, cpu_time: %g,"
|
logging.info("Benchmark [%s] iters: %d, wall_time: %g, cpu_time: %g,"
|
||||||
"throughput: %g %s", name, iters if iters is not None else -1,
|
"throughput: %g, extras: %s, metrics: %s", name,
|
||||||
wall_time if wall_time is not None else -1, cpu_time if
|
iters if iters is not None else -1,
|
||||||
cpu_time is not None else -1, throughput if
|
wall_time if wall_time is not None else -1,
|
||||||
throughput is not None else -1, str(extras) if extras else "")
|
cpu_time if cpu_time is not None else -1,
|
||||||
|
throughput if throughput is not None else -1,
|
||||||
|
str(extras) if extras else "None",
|
||||||
|
str(metrics) if metrics else "None")
|
||||||
|
|
||||||
entries = test_log_pb2.BenchmarkEntries()
|
entries = test_log_pb2.BenchmarkEntries()
|
||||||
entry = entries.entry.add()
|
entry = entries.entry.add()
|
||||||
@ -90,11 +92,29 @@ def _global_report_benchmark(
|
|||||||
if throughput is not None:
|
if throughput is not None:
|
||||||
entry.throughput = throughput
|
entry.throughput = throughput
|
||||||
if extras is not None:
|
if extras is not None:
|
||||||
|
if not isinstance(extras, dict):
|
||||||
|
raise TypeError("extras must be a dict")
|
||||||
for (k, v) in extras.items():
|
for (k, v) in extras.items():
|
||||||
if isinstance(v, numbers.Number):
|
if isinstance(v, numbers.Number):
|
||||||
entry.extras[k].double_value = v
|
entry.extras[k].double_value = v
|
||||||
else:
|
else:
|
||||||
entry.extras[k].string_value = str(v)
|
entry.extras[k].string_value = str(v)
|
||||||
|
if metrics is not None:
|
||||||
|
if not isinstance(metrics, list):
|
||||||
|
raise TypeError("metrics must be a list")
|
||||||
|
for metric in metrics:
|
||||||
|
if "name" not in metric:
|
||||||
|
raise TypeError("metric must has a 'name' field")
|
||||||
|
if "value" not in metric:
|
||||||
|
raise TypeError("metric must has a 'value' field")
|
||||||
|
|
||||||
|
metric_entry = entry.metrics.add()
|
||||||
|
metric_entry.name = metric["name"]
|
||||||
|
metric_entry.value = metric["value"]
|
||||||
|
if "min_value" in metric:
|
||||||
|
metric_entry.min_value.value = metric["min_value"]
|
||||||
|
if "max_value" in metric:
|
||||||
|
metric_entry.max_value.value = metric["max_value"]
|
||||||
|
|
||||||
test_env = os.environ.get(TEST_REPORTER_TEST_ENV, None)
|
test_env = os.environ.get(TEST_REPORTER_TEST_ENV, None)
|
||||||
if test_env is None:
|
if test_env is None:
|
||||||
@ -169,23 +189,29 @@ class Benchmark(six.with_metaclass(_BenchmarkRegistrar, object)):
|
|||||||
wall_time=None,
|
wall_time=None,
|
||||||
throughput=None,
|
throughput=None,
|
||||||
extras=None,
|
extras=None,
|
||||||
name=None):
|
name=None,
|
||||||
|
metrics=None):
|
||||||
"""Report a benchmark.
|
"""Report a benchmark.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
iters: (optional) How many iterations were run
|
iters: (optional) How many iterations were run
|
||||||
cpu_time: (optional) median or mean cpu time in seconds.
|
cpu_time: (optional) Median or mean cpu time in seconds.
|
||||||
wall_time: (optional) median or mean wall time in seconds.
|
wall_time: (optional) Median or mean wall time in seconds.
|
||||||
throughput: (optional) Throughput (in MB/s)
|
throughput: (optional) Throughput (in MB/s)
|
||||||
extras: (optional) Dict mapping string keys to additional benchmark info.
|
extras: (optional) Dict mapping string keys to additional benchmark info.
|
||||||
Values may be either floats or values that are convertible to strings.
|
Values may be either floats or values that are convertible to strings.
|
||||||
name: (optional) Override the BenchmarkEntry name with `name`.
|
name: (optional) Override the BenchmarkEntry name with `name`.
|
||||||
Otherwise it is inferred from the top-level method name.
|
Otherwise it is inferred from the top-level method name.
|
||||||
|
metrics: (optional) A list of dict, where each dict has the keys below
|
||||||
|
name (required), string, metric name
|
||||||
|
value (required), double, metric value
|
||||||
|
min_value (optional), double, minimum acceptable metric value
|
||||||
|
max_value (optional), double, maximum acceptable metric value
|
||||||
"""
|
"""
|
||||||
name = self._get_name(overwrite_name=name)
|
name = self._get_name(overwrite_name=name)
|
||||||
_global_report_benchmark(
|
_global_report_benchmark(
|
||||||
name=name, iters=iters, cpu_time=cpu_time, wall_time=wall_time,
|
name=name, iters=iters, cpu_time=cpu_time, wall_time=wall_time,
|
||||||
throughput=throughput, extras=extras)
|
throughput=throughput, extras=extras, metrics=metrics)
|
||||||
|
|
||||||
|
|
||||||
@tf_export("test.benchmark_config")
|
@tf_export("test.benchmark_config")
|
||||||
|
70
tensorflow/python/platform/benchmark_test.py
Normal file
70
tensorflow/python/platform/benchmark_test.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Test for the tf.test.benchmark."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
from google.protobuf import json_format
|
||||||
|
from tensorflow.core.util import test_log_pb2
|
||||||
|
from tensorflow.python.platform import benchmark
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarkTest(test.TestCase, benchmark.TensorFlowBenchmark):
|
||||||
|
|
||||||
|
def testReportBenchmark(self):
|
||||||
|
output_dir = '/tmp/'
|
||||||
|
os.environ['TEST_REPORT_FILE_PREFIX'] = output_dir
|
||||||
|
proto_file_path = os.path.join(output_dir,
|
||||||
|
'BenchmarkTest.testReportBenchmark')
|
||||||
|
if os.path.exists(proto_file_path):
|
||||||
|
os.remove(proto_file_path)
|
||||||
|
|
||||||
|
self.report_benchmark(
|
||||||
|
iters=2000,
|
||||||
|
wall_time=1000,
|
||||||
|
name='testReportBenchmark',
|
||||||
|
metrics=[{'name': 'metric_name', 'value': 99, 'min_value': 1}])
|
||||||
|
|
||||||
|
with open(proto_file_path, 'rb') as f:
|
||||||
|
benchmark_entries = test_log_pb2.BenchmarkEntries()
|
||||||
|
benchmark_entries.ParseFromString(f.read())
|
||||||
|
|
||||||
|
actual_result = json_format.MessageToDict(
|
||||||
|
benchmark_entries, preserving_proto_field_name=True)['entry'][0]
|
||||||
|
os.remove(proto_file_path)
|
||||||
|
|
||||||
|
expected_result = {
|
||||||
|
'name': 'BenchmarkTest.testReportBenchmark',
|
||||||
|
# google.protobuf.json_format.MessageToDict() will convert
|
||||||
|
# int64 field to string.
|
||||||
|
'iters': '2000',
|
||||||
|
'wall_time': 1000,
|
||||||
|
'metrics': [{
|
||||||
|
'name': 'metric_name',
|
||||||
|
'value': 99,
|
||||||
|
'min_value': 1
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
|
||||||
|
self.assertEqual(2000, benchmark_entries.entry[0].iters)
|
||||||
|
self.assertDictEqual(expected_result, actual_result)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test.main()
|
||||||
|
|
@ -17,7 +17,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "report_benchmark"
|
name: "report_benchmark"
|
||||||
argspec: "args=[\'self\', \'iters\', \'cpu_time\', \'wall_time\', \'throughput\', \'extras\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
argspec: "args=[\'self\', \'iters\', \'cpu_time\', \'wall_time\', \'throughput\', \'extras\', \'name\', \'metrics\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "run_op_benchmark"
|
name: "run_op_benchmark"
|
||||||
|
@ -17,7 +17,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "report_benchmark"
|
name: "report_benchmark"
|
||||||
argspec: "args=[\'self\', \'iters\', \'cpu_time\', \'wall_time\', \'throughput\', \'extras\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
argspec: "args=[\'self\', \'iters\', \'cpu_time\', \'wall_time\', \'throughput\', \'extras\', \'name\', \'metrics\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "run_op_benchmark"
|
name: "run_op_benchmark"
|
||||||
|
@ -38,8 +38,8 @@ import tensorflow as tf
|
|||||||
from google.protobuf import message
|
from google.protobuf import message
|
||||||
from google.protobuf import text_format
|
from google.protobuf import text_format
|
||||||
|
|
||||||
from tensorflow.python.lib.io import file_io
|
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.lib.io import file_io
|
||||||
from tensorflow.python.platform import resource_loader
|
from tensorflow.python.platform import resource_loader
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
@ -89,6 +89,13 @@ def _KeyToFilePath(key, api_version):
|
|||||||
"""From a given key, construct a filepath.
|
"""From a given key, construct a filepath.
|
||||||
|
|
||||||
Filepath will be inside golden folder for api_version.
|
Filepath will be inside golden folder for api_version.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: a string used to determine the file path
|
||||||
|
api_version: a number indicating the tensorflow API version, e.g. 1 or 2.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A string of file path to the pbtxt file which describes the public API
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _ReplaceCapsWithDash(matchobj):
|
def _ReplaceCapsWithDash(matchobj):
|
||||||
|
Loading…
Reference in New Issue
Block a user