parent
3bf4628d75
commit
bc1700a6f0
@ -2,7 +2,6 @@
|
|||||||
syntax = "proto3";
|
syntax = "proto3";
|
||||||
|
|
||||||
import "google/protobuf/any.proto";
|
import "google/protobuf/any.proto";
|
||||||
import "google/protobuf/wrappers.proto";
|
|
||||||
|
|
||||||
option cc_enable_arenas = true;
|
option cc_enable_arenas = true;
|
||||||
option java_outer_classname = "TestLogProtos";
|
option java_outer_classname = "TestLogProtos";
|
||||||
@ -18,20 +17,6 @@ message EntryValue {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
message MetricEntry {
|
|
||||||
// Metric name
|
|
||||||
string name = 1;
|
|
||||||
|
|
||||||
// Metric value
|
|
||||||
double value = 2;
|
|
||||||
|
|
||||||
// The minimum acceptable value for the metric if specified
|
|
||||||
google.protobuf.DoubleValue min_value = 3;
|
|
||||||
|
|
||||||
// The maximum acceptable value for the metric if specified
|
|
||||||
google.protobuf.DoubleValue max_value = 4;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Each unit test or benchmark in a test or benchmark run provides
|
// Each unit test or benchmark in a test or benchmark run provides
|
||||||
// some set of information. Here we provide some reasonable keys
|
// some set of information. Here we provide some reasonable keys
|
||||||
// one would expect to see, with optional key/value pairs for things
|
// one would expect to see, with optional key/value pairs for things
|
||||||
@ -58,10 +43,6 @@ message BenchmarkEntry {
|
|||||||
|
|
||||||
// Generic map from result key to value.
|
// Generic map from result key to value.
|
||||||
map<string, EntryValue> extras = 6;
|
map<string, EntryValue> extras = 6;
|
||||||
|
|
||||||
// Metric name, value and expected range. This can include accuracy metrics
|
|
||||||
// typically used to determine whether the accuracy test has passed
|
|
||||||
repeated MetricEntry metrics = 7;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
message BenchmarkEntries {
|
message BenchmarkEntries {
|
||||||
|
@ -1367,25 +1367,6 @@ tf_py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_py_test(
|
|
||||||
name = "benchmark_test",
|
|
||||||
size = "small",
|
|
||||||
srcs = [
|
|
||||||
"platform/benchmark.py",
|
|
||||||
"platform/benchmark_test.py",
|
|
||||||
],
|
|
||||||
additional_deps = [
|
|
||||||
":client_testlib",
|
|
||||||
":platform",
|
|
||||||
],
|
|
||||||
main = "platform/benchmark_test.py",
|
|
||||||
tags = [
|
|
||||||
"manual",
|
|
||||||
"no_pip",
|
|
||||||
"notap",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
name = "proto_test",
|
name = "proto_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
@ -53,7 +53,7 @@ OVERRIDE_GLOBAL_THREADPOOL = "TF_OVERRIDE_GLOBAL_THREADPOOL"
|
|||||||
|
|
||||||
def _global_report_benchmark(
|
def _global_report_benchmark(
|
||||||
name, iters=None, cpu_time=None, wall_time=None,
|
name, iters=None, cpu_time=None, wall_time=None,
|
||||||
throughput=None, extras=None, metrics=None):
|
throughput=None, extras=None):
|
||||||
"""Method for recording a benchmark directly.
|
"""Method for recording a benchmark directly.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -63,22 +63,20 @@ def _global_report_benchmark(
|
|||||||
wall_time: (optional) Total wall time in seconds
|
wall_time: (optional) Total wall time in seconds
|
||||||
throughput: (optional) Throughput (in MB/s)
|
throughput: (optional) Throughput (in MB/s)
|
||||||
extras: (optional) Dict mapping string keys to additional benchmark info.
|
extras: (optional) Dict mapping string keys to additional benchmark info.
|
||||||
metrics: (optional) A list of dict representing metrics generated by the
|
|
||||||
benchmark. Each dict should contain keys 'name' and'value'. A dict
|
|
||||||
can optionally contain keys 'min_value' and 'max_value'.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: if extras is not a dict.
|
TypeError: if extras is not a dict.
|
||||||
IOError: if the benchmark output file already exists.
|
IOError: if the benchmark output file already exists.
|
||||||
"""
|
"""
|
||||||
|
if extras is not None:
|
||||||
|
if not isinstance(extras, dict):
|
||||||
|
raise TypeError("extras must be a dict")
|
||||||
|
|
||||||
logging.info("Benchmark [%s] iters: %d, wall_time: %g, cpu_time: %g,"
|
logging.info("Benchmark [%s] iters: %d, wall_time: %g, cpu_time: %g,"
|
||||||
"throughput: %g, extras: %s, metrics: %s", name,
|
"throughput: %g %s", name, iters if iters is not None else -1,
|
||||||
iters if iters is not None else -1,
|
wall_time if wall_time is not None else -1, cpu_time if
|
||||||
wall_time if wall_time is not None else -1,
|
cpu_time is not None else -1, throughput if
|
||||||
cpu_time if cpu_time is not None else -1,
|
throughput is not None else -1, str(extras) if extras else "")
|
||||||
throughput if throughput is not None else -1,
|
|
||||||
str(extras) if extras else "None",
|
|
||||||
str(metrics) if metrics else "None")
|
|
||||||
|
|
||||||
entries = test_log_pb2.BenchmarkEntries()
|
entries = test_log_pb2.BenchmarkEntries()
|
||||||
entry = entries.entry.add()
|
entry = entries.entry.add()
|
||||||
@ -92,29 +90,11 @@ def _global_report_benchmark(
|
|||||||
if throughput is not None:
|
if throughput is not None:
|
||||||
entry.throughput = throughput
|
entry.throughput = throughput
|
||||||
if extras is not None:
|
if extras is not None:
|
||||||
if not isinstance(extras, dict):
|
|
||||||
raise TypeError("extras must be a dict")
|
|
||||||
for (k, v) in extras.items():
|
for (k, v) in extras.items():
|
||||||
if isinstance(v, numbers.Number):
|
if isinstance(v, numbers.Number):
|
||||||
entry.extras[k].double_value = v
|
entry.extras[k].double_value = v
|
||||||
else:
|
else:
|
||||||
entry.extras[k].string_value = str(v)
|
entry.extras[k].string_value = str(v)
|
||||||
if metrics is not None:
|
|
||||||
if not isinstance(metrics, list):
|
|
||||||
raise TypeError("metrics must be a list")
|
|
||||||
for metric in metrics:
|
|
||||||
if "name" not in metric:
|
|
||||||
raise TypeError("metric must has a 'name' field")
|
|
||||||
if "value" not in metric:
|
|
||||||
raise TypeError("metric must has a 'value' field")
|
|
||||||
|
|
||||||
metric_entry = entry.metrics.add()
|
|
||||||
metric_entry.name = metric["name"]
|
|
||||||
metric_entry.value = metric["value"]
|
|
||||||
if "min_value" in metric:
|
|
||||||
metric_entry.min_value.value = metric["min_value"]
|
|
||||||
if "max_value" in metric:
|
|
||||||
metric_entry.max_value.value = metric["max_value"]
|
|
||||||
|
|
||||||
test_env = os.environ.get(TEST_REPORTER_TEST_ENV, None)
|
test_env = os.environ.get(TEST_REPORTER_TEST_ENV, None)
|
||||||
if test_env is None:
|
if test_env is None:
|
||||||
@ -189,29 +169,23 @@ class Benchmark(six.with_metaclass(_BenchmarkRegistrar, object)):
|
|||||||
wall_time=None,
|
wall_time=None,
|
||||||
throughput=None,
|
throughput=None,
|
||||||
extras=None,
|
extras=None,
|
||||||
name=None,
|
name=None):
|
||||||
metrics=None):
|
|
||||||
"""Report a benchmark.
|
"""Report a benchmark.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
iters: (optional) How many iterations were run
|
iters: (optional) How many iterations were run
|
||||||
cpu_time: (optional) Median or mean cpu time in seconds.
|
cpu_time: (optional) median or mean cpu time in seconds.
|
||||||
wall_time: (optional) Median or mean wall time in seconds.
|
wall_time: (optional) median or mean wall time in seconds.
|
||||||
throughput: (optional) Throughput (in MB/s)
|
throughput: (optional) Throughput (in MB/s)
|
||||||
extras: (optional) Dict mapping string keys to additional benchmark info.
|
extras: (optional) Dict mapping string keys to additional benchmark info.
|
||||||
Values may be either floats or values that are convertible to strings.
|
Values may be either floats or values that are convertible to strings.
|
||||||
name: (optional) Override the BenchmarkEntry name with `name`.
|
name: (optional) Override the BenchmarkEntry name with `name`.
|
||||||
Otherwise it is inferred from the top-level method name.
|
Otherwise it is inferred from the top-level method name.
|
||||||
metrics: (optional) A list of dict, where each dict has the keys below
|
|
||||||
name (required), string, metric name
|
|
||||||
value (required), double, metric value
|
|
||||||
min_value (optional), double, minimum acceptable metric value
|
|
||||||
max_value (optional), double, maximum acceptable metric value
|
|
||||||
"""
|
"""
|
||||||
name = self._get_name(overwrite_name=name)
|
name = self._get_name(overwrite_name=name)
|
||||||
_global_report_benchmark(
|
_global_report_benchmark(
|
||||||
name=name, iters=iters, cpu_time=cpu_time, wall_time=wall_time,
|
name=name, iters=iters, cpu_time=cpu_time, wall_time=wall_time,
|
||||||
throughput=throughput, extras=extras, metrics=metrics)
|
throughput=throughput, extras=extras)
|
||||||
|
|
||||||
|
|
||||||
@tf_export("test.benchmark_config")
|
@tf_export("test.benchmark_config")
|
||||||
|
@ -1,70 +0,0 @@
|
|||||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Test for the tf.test.benchmark."""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import os
|
|
||||||
from google.protobuf import json_format
|
|
||||||
from tensorflow.core.util import test_log_pb2
|
|
||||||
from tensorflow.python.platform import benchmark
|
|
||||||
from tensorflow.python.platform import test
|
|
||||||
|
|
||||||
|
|
||||||
class BenchmarkTest(test.TestCase, benchmark.TensorFlowBenchmark):
|
|
||||||
|
|
||||||
def testReportBenchmark(self):
|
|
||||||
output_dir = '/tmp/'
|
|
||||||
os.environ['TEST_REPORT_FILE_PREFIX'] = output_dir
|
|
||||||
proto_file_path = os.path.join(output_dir,
|
|
||||||
'BenchmarkTest.testReportBenchmark')
|
|
||||||
if os.path.exists(proto_file_path):
|
|
||||||
os.remove(proto_file_path)
|
|
||||||
|
|
||||||
self.report_benchmark(
|
|
||||||
iters=2000,
|
|
||||||
wall_time=1000,
|
|
||||||
name='testReportBenchmark',
|
|
||||||
metrics=[{'name': 'metric_name', 'value': 99, 'min_value': 1}])
|
|
||||||
|
|
||||||
with open(proto_file_path, 'rb') as f:
|
|
||||||
benchmark_entries = test_log_pb2.BenchmarkEntries()
|
|
||||||
benchmark_entries.ParseFromString(f.read())
|
|
||||||
|
|
||||||
actual_result = json_format.MessageToDict(
|
|
||||||
benchmark_entries, preserving_proto_field_name=True)['entry'][0]
|
|
||||||
os.remove(proto_file_path)
|
|
||||||
|
|
||||||
expected_result = {
|
|
||||||
'name': 'BenchmarkTest.testReportBenchmark',
|
|
||||||
# google.protobuf.json_format.MessageToDict() will convert
|
|
||||||
# int64 field to string.
|
|
||||||
'iters': '2000',
|
|
||||||
'wall_time': 1000,
|
|
||||||
'metrics': [{
|
|
||||||
'name': 'metric_name',
|
|
||||||
'value': 99,
|
|
||||||
'min_value': 1
|
|
||||||
}]
|
|
||||||
}
|
|
||||||
|
|
||||||
self.assertEqual(2000, benchmark_entries.entry[0].iters)
|
|
||||||
self.assertDictEqual(expected_result, actual_result)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test.main()
|
|
||||||
|
|
@ -17,7 +17,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "report_benchmark"
|
name: "report_benchmark"
|
||||||
argspec: "args=[\'self\', \'iters\', \'cpu_time\', \'wall_time\', \'throughput\', \'extras\', \'name\', \'metrics\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
argspec: "args=[\'self\', \'iters\', \'cpu_time\', \'wall_time\', \'throughput\', \'extras\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "run_op_benchmark"
|
name: "run_op_benchmark"
|
||||||
|
@ -17,7 +17,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "report_benchmark"
|
name: "report_benchmark"
|
||||||
argspec: "args=[\'self\', \'iters\', \'cpu_time\', \'wall_time\', \'throughput\', \'extras\', \'name\', \'metrics\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
argspec: "args=[\'self\', \'iters\', \'cpu_time\', \'wall_time\', \'throughput\', \'extras\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "run_op_benchmark"
|
name: "run_op_benchmark"
|
||||||
|
@ -38,8 +38,8 @@ import tensorflow as tf
|
|||||||
from google.protobuf import message
|
from google.protobuf import message
|
||||||
from google.protobuf import text_format
|
from google.protobuf import text_format
|
||||||
|
|
||||||
from tensorflow.python.framework import test_util
|
|
||||||
from tensorflow.python.lib.io import file_io
|
from tensorflow.python.lib.io import file_io
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.platform import resource_loader
|
from tensorflow.python.platform import resource_loader
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
@ -89,13 +89,6 @@ def _KeyToFilePath(key, api_version):
|
|||||||
"""From a given key, construct a filepath.
|
"""From a given key, construct a filepath.
|
||||||
|
|
||||||
Filepath will be inside golden folder for api_version.
|
Filepath will be inside golden folder for api_version.
|
||||||
|
|
||||||
Args:
|
|
||||||
key: a string used to determine the file path
|
|
||||||
api_version: a number indicating the tensorflow API version, e.g. 1 or 2.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A string of file path to the pbtxt file which describes the public API
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _ReplaceCapsWithDash(matchobj):
|
def _ReplaceCapsWithDash(matchobj):
|
||||||
|
Loading…
Reference in New Issue
Block a user