parent
3bf4628d75
commit
bc1700a6f0
@ -2,7 +2,6 @@
|
||||
syntax = "proto3";
|
||||
|
||||
import "google/protobuf/any.proto";
|
||||
import "google/protobuf/wrappers.proto";
|
||||
|
||||
option cc_enable_arenas = true;
|
||||
option java_outer_classname = "TestLogProtos";
|
||||
@ -18,20 +17,6 @@ message EntryValue {
|
||||
}
|
||||
};
|
||||
|
||||
message MetricEntry {
|
||||
// Metric name
|
||||
string name = 1;
|
||||
|
||||
// Metric value
|
||||
double value = 2;
|
||||
|
||||
// The minimum acceptable value for the metric if specified
|
||||
google.protobuf.DoubleValue min_value = 3;
|
||||
|
||||
// The maximum acceptable value for the metric if specified
|
||||
google.protobuf.DoubleValue max_value = 4;
|
||||
}
|
||||
|
||||
// Each unit test or benchmark in a test or benchmark run provides
|
||||
// some set of information. Here we provide some reasonable keys
|
||||
// one would expect to see, with optional key/value pairs for things
|
||||
@ -58,10 +43,6 @@ message BenchmarkEntry {
|
||||
|
||||
// Generic map from result key to value.
|
||||
map<string, EntryValue> extras = 6;
|
||||
|
||||
// Metric name, value and expected range. This can include accuracy metrics
|
||||
// typically used to determine whether the accuracy test has passed
|
||||
repeated MetricEntry metrics = 7;
|
||||
};
|
||||
|
||||
message BenchmarkEntries {
|
||||
|
@ -1367,25 +1367,6 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "benchmark_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"platform/benchmark.py",
|
||||
"platform/benchmark_test.py",
|
||||
],
|
||||
additional_deps = [
|
||||
":client_testlib",
|
||||
":platform",
|
||||
],
|
||||
main = "platform/benchmark_test.py",
|
||||
tags = [
|
||||
"manual",
|
||||
"no_pip",
|
||||
"notap",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "proto_test",
|
||||
size = "small",
|
||||
|
@ -53,7 +53,7 @@ OVERRIDE_GLOBAL_THREADPOOL = "TF_OVERRIDE_GLOBAL_THREADPOOL"
|
||||
|
||||
def _global_report_benchmark(
|
||||
name, iters=None, cpu_time=None, wall_time=None,
|
||||
throughput=None, extras=None, metrics=None):
|
||||
throughput=None, extras=None):
|
||||
"""Method for recording a benchmark directly.
|
||||
|
||||
Args:
|
||||
@ -63,22 +63,20 @@ def _global_report_benchmark(
|
||||
wall_time: (optional) Total wall time in seconds
|
||||
throughput: (optional) Throughput (in MB/s)
|
||||
extras: (optional) Dict mapping string keys to additional benchmark info.
|
||||
metrics: (optional) A list of dict representing metrics generated by the
|
||||
benchmark. Each dict should contain keys 'name' and'value'. A dict
|
||||
can optionally contain keys 'min_value' and 'max_value'.
|
||||
|
||||
Raises:
|
||||
TypeError: if extras is not a dict.
|
||||
IOError: if the benchmark output file already exists.
|
||||
"""
|
||||
if extras is not None:
|
||||
if not isinstance(extras, dict):
|
||||
raise TypeError("extras must be a dict")
|
||||
|
||||
logging.info("Benchmark [%s] iters: %d, wall_time: %g, cpu_time: %g,"
|
||||
"throughput: %g, extras: %s, metrics: %s", name,
|
||||
iters if iters is not None else -1,
|
||||
wall_time if wall_time is not None else -1,
|
||||
cpu_time if cpu_time is not None else -1,
|
||||
throughput if throughput is not None else -1,
|
||||
str(extras) if extras else "None",
|
||||
str(metrics) if metrics else "None")
|
||||
"throughput: %g %s", name, iters if iters is not None else -1,
|
||||
wall_time if wall_time is not None else -1, cpu_time if
|
||||
cpu_time is not None else -1, throughput if
|
||||
throughput is not None else -1, str(extras) if extras else "")
|
||||
|
||||
entries = test_log_pb2.BenchmarkEntries()
|
||||
entry = entries.entry.add()
|
||||
@ -92,29 +90,11 @@ def _global_report_benchmark(
|
||||
if throughput is not None:
|
||||
entry.throughput = throughput
|
||||
if extras is not None:
|
||||
if not isinstance(extras, dict):
|
||||
raise TypeError("extras must be a dict")
|
||||
for (k, v) in extras.items():
|
||||
if isinstance(v, numbers.Number):
|
||||
entry.extras[k].double_value = v
|
||||
else:
|
||||
entry.extras[k].string_value = str(v)
|
||||
if metrics is not None:
|
||||
if not isinstance(metrics, list):
|
||||
raise TypeError("metrics must be a list")
|
||||
for metric in metrics:
|
||||
if "name" not in metric:
|
||||
raise TypeError("metric must has a 'name' field")
|
||||
if "value" not in metric:
|
||||
raise TypeError("metric must has a 'value' field")
|
||||
|
||||
metric_entry = entry.metrics.add()
|
||||
metric_entry.name = metric["name"]
|
||||
metric_entry.value = metric["value"]
|
||||
if "min_value" in metric:
|
||||
metric_entry.min_value.value = metric["min_value"]
|
||||
if "max_value" in metric:
|
||||
metric_entry.max_value.value = metric["max_value"]
|
||||
|
||||
test_env = os.environ.get(TEST_REPORTER_TEST_ENV, None)
|
||||
if test_env is None:
|
||||
@ -189,29 +169,23 @@ class Benchmark(six.with_metaclass(_BenchmarkRegistrar, object)):
|
||||
wall_time=None,
|
||||
throughput=None,
|
||||
extras=None,
|
||||
name=None,
|
||||
metrics=None):
|
||||
name=None):
|
||||
"""Report a benchmark.
|
||||
|
||||
Args:
|
||||
iters: (optional) How many iterations were run
|
||||
cpu_time: (optional) Median or mean cpu time in seconds.
|
||||
wall_time: (optional) Median or mean wall time in seconds.
|
||||
cpu_time: (optional) median or mean cpu time in seconds.
|
||||
wall_time: (optional) median or mean wall time in seconds.
|
||||
throughput: (optional) Throughput (in MB/s)
|
||||
extras: (optional) Dict mapping string keys to additional benchmark info.
|
||||
Values may be either floats or values that are convertible to strings.
|
||||
name: (optional) Override the BenchmarkEntry name with `name`.
|
||||
Otherwise it is inferred from the top-level method name.
|
||||
metrics: (optional) A list of dict, where each dict has the keys below
|
||||
name (required), string, metric name
|
||||
value (required), double, metric value
|
||||
min_value (optional), double, minimum acceptable metric value
|
||||
max_value (optional), double, maximum acceptable metric value
|
||||
"""
|
||||
name = self._get_name(overwrite_name=name)
|
||||
_global_report_benchmark(
|
||||
name=name, iters=iters, cpu_time=cpu_time, wall_time=wall_time,
|
||||
throughput=throughput, extras=extras, metrics=metrics)
|
||||
throughput=throughput, extras=extras)
|
||||
|
||||
|
||||
@tf_export("test.benchmark_config")
|
||||
|
@ -1,70 +0,0 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Test for the tf.test.benchmark."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
from google.protobuf import json_format
|
||||
from tensorflow.core.util import test_log_pb2
|
||||
from tensorflow.python.platform import benchmark
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class BenchmarkTest(test.TestCase, benchmark.TensorFlowBenchmark):
|
||||
|
||||
def testReportBenchmark(self):
|
||||
output_dir = '/tmp/'
|
||||
os.environ['TEST_REPORT_FILE_PREFIX'] = output_dir
|
||||
proto_file_path = os.path.join(output_dir,
|
||||
'BenchmarkTest.testReportBenchmark')
|
||||
if os.path.exists(proto_file_path):
|
||||
os.remove(proto_file_path)
|
||||
|
||||
self.report_benchmark(
|
||||
iters=2000,
|
||||
wall_time=1000,
|
||||
name='testReportBenchmark',
|
||||
metrics=[{'name': 'metric_name', 'value': 99, 'min_value': 1}])
|
||||
|
||||
with open(proto_file_path, 'rb') as f:
|
||||
benchmark_entries = test_log_pb2.BenchmarkEntries()
|
||||
benchmark_entries.ParseFromString(f.read())
|
||||
|
||||
actual_result = json_format.MessageToDict(
|
||||
benchmark_entries, preserving_proto_field_name=True)['entry'][0]
|
||||
os.remove(proto_file_path)
|
||||
|
||||
expected_result = {
|
||||
'name': 'BenchmarkTest.testReportBenchmark',
|
||||
# google.protobuf.json_format.MessageToDict() will convert
|
||||
# int64 field to string.
|
||||
'iters': '2000',
|
||||
'wall_time': 1000,
|
||||
'metrics': [{
|
||||
'name': 'metric_name',
|
||||
'value': 99,
|
||||
'min_value': 1
|
||||
}]
|
||||
}
|
||||
|
||||
self.assertEqual(2000, benchmark_entries.entry[0].iters)
|
||||
self.assertDictEqual(expected_result, actual_result)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -17,7 +17,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "report_benchmark"
|
||||
argspec: "args=[\'self\', \'iters\', \'cpu_time\', \'wall_time\', \'throughput\', \'extras\', \'name\', \'metrics\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'iters\', \'cpu_time\', \'wall_time\', \'throughput\', \'extras\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "run_op_benchmark"
|
||||
|
@ -17,7 +17,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "report_benchmark"
|
||||
argspec: "args=[\'self\', \'iters\', \'cpu_time\', \'wall_time\', \'throughput\', \'extras\', \'name\', \'metrics\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'iters\', \'cpu_time\', \'wall_time\', \'throughput\', \'extras\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "run_op_benchmark"
|
||||
|
@ -38,8 +38,8 @@ import tensorflow as tf
|
||||
from google.protobuf import message
|
||||
from google.protobuf import text_format
|
||||
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import resource_loader
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
@ -89,13 +89,6 @@ def _KeyToFilePath(key, api_version):
|
||||
"""From a given key, construct a filepath.
|
||||
|
||||
Filepath will be inside golden folder for api_version.
|
||||
|
||||
Args:
|
||||
key: a string used to determine the file path
|
||||
api_version: a number indicating the tensorflow API version, e.g. 1 or 2.
|
||||
|
||||
Returns:
|
||||
A string of file path to the pbtxt file which describes the public API
|
||||
"""
|
||||
|
||||
def _ReplaceCapsWithDash(matchobj):
|
||||
|
Loading…
Reference in New Issue
Block a user