Move benchmarking code to a new directory and add some documentation.
PiperOrigin-RevId: 199200246
This commit is contained in:
parent
142ccf3666
commit
e2d300823f
@ -45,9 +45,6 @@ class ProfileSummarizer {
|
||||
return stats_calculator_->GetShortSummary();
|
||||
}
|
||||
|
||||
// Prints the string returned by GetOutputString().
|
||||
void PrintStepStats() const { stats_calculator_->PrintStepStats(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<tensorflow::StatsCalculator> stats_calculator_;
|
||||
};
|
||||
|
@ -30,87 +30,6 @@ tf_cc_binary(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_binary(
|
||||
name = "benchmark_model",
|
||||
srcs = [
|
||||
"benchmark_main.cc",
|
||||
"logging.h",
|
||||
],
|
||||
copts = common_copts,
|
||||
linkopts = select({
|
||||
"//tensorflow:android": [
|
||||
"-pie",
|
||||
"-landroid",
|
||||
"-lm",
|
||||
"-z defs",
|
||||
"-Wl,--exclude-libs,ALL", # Exclude syms in all libs from auto export
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
deps = [
|
||||
":benchmark_tflite_model_lib",
|
||||
"//tensorflow/core:stats_calculator_portable",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "command_line_flags",
|
||||
srcs = ["command_line_flags.cc"],
|
||||
hdrs = ["command_line_flags.h"],
|
||||
copts = common_copts,
|
||||
visibility = ["//visibility:private"],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "command_line_flags_test",
|
||||
srcs = ["command_line_flags_test.cc"],
|
||||
copts = common_copts,
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
":command_line_flags",
|
||||
"//tensorflow/contrib/lite/testing:util",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "benchmark_tflite_model_lib",
|
||||
srcs = [
|
||||
"benchmark_tflite_model.cc",
|
||||
"logging.h",
|
||||
],
|
||||
hdrs = ["benchmark_tflite_model.h"],
|
||||
copts = common_copts,
|
||||
deps = [
|
||||
":benchmark_model_lib",
|
||||
"//tensorflow/contrib/lite:framework",
|
||||
"//tensorflow/contrib/lite:string_util",
|
||||
"//tensorflow/contrib/lite/kernels:builtin_ops",
|
||||
"//tensorflow/contrib/lite/profiling:profile_summarizer",
|
||||
"//tensorflow/contrib/lite/profiling:profiler",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "benchmark_model_lib",
|
||||
srcs = [
|
||||
"benchmark_model.cc",
|
||||
"logging.h",
|
||||
],
|
||||
hdrs = ["benchmark_model.h"],
|
||||
copts = common_copts,
|
||||
deps = [
|
||||
":command_line_flags",
|
||||
"//tensorflow/contrib/lite:framework",
|
||||
"//tensorflow/contrib/lite:string_util",
|
||||
"//tensorflow/contrib/lite/kernels:builtin_ops",
|
||||
"//tensorflow/contrib/lite/profiling:profile_summarizer",
|
||||
"//tensorflow/contrib/lite/profiling:profiler",
|
||||
"//tensorflow/contrib/lite/profiling:time",
|
||||
"//tensorflow/core:stats_calculator_portable",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gen_op_registration",
|
||||
srcs = ["gen_op_registration.cc"],
|
||||
|
91
tensorflow/contrib/lite/tools/benchmark/BUILD
Normal file
91
tensorflow/contrib/lite/tools/benchmark/BUILD
Normal file
@ -0,0 +1,91 @@
|
||||
package(default_visibility = [
|
||||
"//visibility:public",
|
||||
])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite")
|
||||
|
||||
common_copts = ["-Wall"]
|
||||
|
||||
cc_binary(
|
||||
name = "benchmark_model",
|
||||
srcs = [
|
||||
"benchmark_main.cc",
|
||||
"logging.h",
|
||||
],
|
||||
copts = common_copts,
|
||||
linkopts = select({
|
||||
"//tensorflow:android": [
|
||||
"-pie",
|
||||
"-landroid",
|
||||
"-lm",
|
||||
"-z defs",
|
||||
"-Wl,--exclude-libs,ALL", # Exclude syms in all libs from auto export
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
deps = [
|
||||
":benchmark_tflite_model_lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "command_line_flags",
|
||||
srcs = ["command_line_flags.cc"],
|
||||
hdrs = ["command_line_flags.h"],
|
||||
copts = common_copts,
|
||||
visibility = ["//visibility:private"],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "command_line_flags_test",
|
||||
srcs = ["command_line_flags_test.cc"],
|
||||
copts = common_copts,
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
":command_line_flags",
|
||||
"//tensorflow/contrib/lite/testing:util",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "benchmark_tflite_model_lib",
|
||||
srcs = [
|
||||
"benchmark_tflite_model.cc",
|
||||
"logging.h",
|
||||
],
|
||||
hdrs = ["benchmark_tflite_model.h"],
|
||||
copts = common_copts,
|
||||
deps = [
|
||||
":benchmark_model_lib",
|
||||
"//tensorflow/contrib/lite:framework",
|
||||
"//tensorflow/contrib/lite:string_util",
|
||||
"//tensorflow/contrib/lite/kernels:builtin_ops",
|
||||
"//tensorflow/contrib/lite/profiling:profile_summarizer",
|
||||
"//tensorflow/contrib/lite/profiling:profiler",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "benchmark_model_lib",
|
||||
srcs = [
|
||||
"benchmark_model.cc",
|
||||
"logging.h",
|
||||
],
|
||||
hdrs = ["benchmark_model.h"],
|
||||
copts = common_copts,
|
||||
deps = [
|
||||
":command_line_flags",
|
||||
"//tensorflow/contrib/lite:framework",
|
||||
"//tensorflow/contrib/lite:string_util",
|
||||
"//tensorflow/contrib/lite/kernels:builtin_ops",
|
||||
"//tensorflow/contrib/lite/profiling:profile_summarizer",
|
||||
"//tensorflow/contrib/lite/profiling:profiler",
|
||||
"//tensorflow/contrib/lite/profiling:time",
|
||||
"//tensorflow/core:stats_calculator_portable",
|
||||
],
|
||||
)
|
||||
|
||||
tflite_portable_test_suite()
|
172
tensorflow/contrib/lite/tools/benchmark/README.md
Normal file
172
tensorflow/contrib/lite/tools/benchmark/README.md
Normal file
@ -0,0 +1,172 @@
|
||||
# TFLite Model Benchmark Tool
|
||||
|
||||
## Description
|
||||
|
||||
A simple C++ binary to benchmark a TFLite model and its individual operators,
|
||||
both on desktop machines and on Android.
|
||||
|
||||
## To build/install/run
|
||||
|
||||
### On Android:
|
||||
|
||||
(0) Refer to https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android to edit the `WORKSPACE` to configure the android NDK/SDK.
|
||||
|
||||
(1) Build for your specific platform, e.g.:
|
||||
|
||||
```
|
||||
bazel build -c opt \
|
||||
--config=android_arm \
|
||||
--cxxopt='--std=c++11' \
|
||||
tensorflow/contrib/lite/tools/benchmark:benchmark_model
|
||||
```
|
||||
|
||||
(2) Connect your phone. Push the binary to your phone with adb push
|
||||
(make the directory if required):
|
||||
|
||||
```
|
||||
adb push bazel-bin/tensorflow/contrib/lite/tools/benchmark/benchmark_model /data/local/tmp
|
||||
```
|
||||
|
||||
(3) Make the binary executable.
|
||||
|
||||
```
|
||||
adb shell chmod +x /data/local/tmp/benchmark_model
|
||||
```
|
||||
|
||||
(4) Push the compute graph that you need to test. For example:
|
||||
|
||||
```
|
||||
adb push mobilenet_quant_v1_224.tflite /data/local/tmp
|
||||
```
|
||||
|
||||
(5) Run the benchmark. For example:
|
||||
|
||||
```
|
||||
adb shell /data/local/tmp/benchmark_model \
|
||||
--graph=/data/local/tmp/mobilenet_quant_v1_224.tflite \
|
||||
--input_layer="Placeholder" \
|
||||
--input_layer_shape="1,224,224,3" \
|
||||
--input_layer_type="uint8" \
|
||||
--output_layer="MobilenetV1/Predictions/Reshape_1" \
|
||||
--num_threads=4
|
||||
```
|
||||
|
||||
### On desktop:
|
||||
(1) build the binary
|
||||
|
||||
```
|
||||
bazel build -c opt tensorflow/contrib/lite/tools/benchmark:benchmark_model
|
||||
```
|
||||
|
||||
(2) Run on your compute graph, similar to the Android case but without the need of adb shell.
|
||||
For example:
|
||||
|
||||
```
|
||||
bazel-bin/tensorflow/contrib/lite/tools/benchmark/benchmark_model \
|
||||
--graph=mobilenet_quant_v1_224.tflite \
|
||||
--input_layer="Placeholder" \
|
||||
--input_layer_shape="1,224,224,3" \
|
||||
--input_layer_type="uint8" \
|
||||
--output_layer="MobilenetV1/Predictions/Reshape_1" \
|
||||
--num_threads=4
|
||||
```
|
||||
|
||||
The MobileNet graph used as an example here may be downloaded from
|
||||
https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip
|
||||
|
||||
## Profiling model operators
|
||||
The benchmark model binary also allows you to profile operators and give execution times of each operator. To do this,
|
||||
compile the binary with a compiler flag that enables profiling to be compiled in. Pass **--copt=-DTFLITE_PROFILING_ENABLED**
|
||||
to compile benchmark with profiling support.
|
||||
For example, to compile with profiling support on Android, add this flag to the previous command:
|
||||
|
||||
```
|
||||
bazel build -c opt \
|
||||
--config=android_arm \
|
||||
--cxxopt='--std=c++11' \
|
||||
--copt=-DTFLITE_PROFILING_ENABLED \
|
||||
tensorflow/contrib/lite/tools/benchmark:benchmark_model
|
||||
```
|
||||
This compiles TFLite with profiling enabled, now you can run the benchmark binary like before. The binary will produce detailed statistics for each operation similar to those shown below:
|
||||
|
||||
```
|
||||
|
||||
============================== Run Order ==============================
|
||||
[node type] [start] [first] [avg ms] [%] [cdf%] [mem KB] [times called] [Name]
|
||||
CONV_2D 0.000 9.132 9.132 0.121% 0.121% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_0/Relu6]
|
||||
DEPTHWISE_CONV_2D 9.135 3.280 3.280 0.043% 0.165% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_1_depthwise/Relu6]
|
||||
CONV_2D 12.419 6.877 6.877 0.091% 0.256% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_1_pointwise/Relu6]
|
||||
DEPTHWISE_CONV_2D 19.299 1.708 1.708 0.023% 0.278% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_2_depthwise/Relu6]
|
||||
CONV_2D 21.012 4.162 4.162 0.055% 0.334% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_2_pointwise/Relu6]
|
||||
DEPTHWISE_CONV_2D 25.177 3.520 3.520 0.047% 0.380% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_3_depthwise/Relu6]
|
||||
CONV_2D 28.701 10.218 10.218 0.136% 0.516% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_3_pointwise/Relu6]
|
||||
DEPTHWISE_CONV_2D 38.922 0.827 0.827 0.011% 0.527% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_4_depthwise/Relu6]
|
||||
CONV_2D 39.752 1.401 1.401 0.019% 0.545% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_4_pointwise/Relu6]
|
||||
DEPTHWISE_CONV_2D 41.156 1.290 1.290 0.017% 0.563% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_5_depthwise/Relu6]
|
||||
CONV_2D 42.448 5.995 5.995 0.080% 0.642% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_5_pointwise/Relu6]
|
||||
DEPTHWISE_CONV_2D 48.445 0.409 0.409 0.005% 0.647% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_depthwise/Relu6]
|
||||
CONV_2D 48.856 6.167 6.167 0.082% 0.729% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_pointwise/Relu6]
|
||||
DEPTHWISE_CONV_2D 55.026 0.629 0.629 0.008% 0.738% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_depthwise/Relu6]
|
||||
CONV_2D 55.656 6.464 6.464 0.086% 0.823% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_pointwise/Relu6]
|
||||
DEPTHWISE_CONV_2D 62.124 0.647 0.647 0.009% 0.832% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_depthwise/Relu6]
|
||||
CONV_2D 62.774 14.666 14.666 0.195% 1.026% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_pointwise/Relu6]
|
||||
DEPTHWISE_CONV_2D 77.444 0.635 0.635 0.008% 1.035% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_depthwise/Relu6]
|
||||
CONV_2D 78.081 7.186 7.186 0.095% 1.130% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_pointwise/Relu6]
|
||||
DEPTHWISE_CONV_2D 85.270 0.646 0.646 0.009% 1.139% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_10_depthwise/Relu6]
|
||||
CONV_2D 85.918 9.529 9.529 0.126% 1.265% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_10_pointwise/Relu6]
|
||||
DEPTHWISE_CONV_2D 95.451 0.628 0.628 0.008% 1.273% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_11_depthwise/Relu6]
|
||||
CONV_2D 96.081 2.077 2.077 0.028% 1.301% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_11_pointwise/Relu6]
|
||||
DEPTHWISE_CONV_2D 98.162 0.168 0.168 0.002% 1.303% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_12_depthwise/Relu6]
|
||||
CONV_2D 98.332 1.007 1.007 0.013% 1.317% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_12_pointwise/Relu6]
|
||||
DEPTHWISE_CONV_2D 99.342 0.288 0.288 0.004% 1.320% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_13_depthwise/Relu6]
|
||||
CONV_2D 99.632 8.197 8.197 0.109% 1.429% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_13_pointwise/Relu6]
|
||||
AVERAGE_POOL_2D 107.832 0.045 0.045 0.001% 1.430% 0.000 0 [MobilenetV1/Logits/AvgPool_1a/AvgPool]
|
||||
CONV_2D 107.878 0.325 0.325 0.004% 1.434% 0.000 0 [MobilenetV1/Logits/Conv2d_1c_1x1/BiasAdd]
|
||||
RESHAPE 108.206 0.003 0.003 0.000% 1.434% 0.000 0 [MobilenetV1/Predictions/Reshape]
|
||||
SOFTMAX 108.211 0.038 0.038 0.001% 1.434% 0.000 0 [MobilenetV1/Predictions/Softmax]
|
||||
|
||||
============================== Top by Computation Time ==============================
|
||||
[node type] [start] [first] [avg ms] [%] [cdf%] [mem KB] [times called] [Name]
|
||||
CONV_2D 62.774 14.666 14.666 0.195% 0.195% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_pointwise/Relu6]
|
||||
CONV_2D 28.701 10.218 10.218 0.136% 0.330% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_3_pointwise/Relu6]
|
||||
CONV_2D 85.918 9.529 9.529 0.126% 0.456% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_10_pointwise/Relu6]
|
||||
CONV_2D 0.000 9.132 9.132 0.121% 0.578% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_0/Relu6]
|
||||
CONV_2D 99.632 8.197 8.197 0.109% 0.686% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_13_pointwise/Relu6]
|
||||
CONV_2D 78.081 7.186 7.186 0.095% 0.782% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_pointwise/Relu6]
|
||||
CONV_2D 12.419 6.877 6.877 0.091% 0.873% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_1_pointwise/Relu6]
|
||||
CONV_2D 55.656 6.464 6.464 0.086% 0.958% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_pointwise/Relu6]
|
||||
CONV_2D 48.856 6.167 6.167 0.082% 1.040% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_pointwise/Relu6]
|
||||
CONV_2D 42.448 5.995 5.995 0.080% 1.120% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_5_pointwise/Relu6]
|
||||
|
||||
============================== Top by Memory Use ==============================
|
||||
[node type] [start] [first] [avg ms] [%] [cdf%] [mem KB] [times called] [Name]
|
||||
SOFTMAX 108.211 0.038 0.038 0.001% 0.001% 0.000 0 [MobilenetV1/Predictions/Softmax]
|
||||
RESHAPE 108.206 0.003 0.003 0.000% 0.001% 0.000 0 [MobilenetV1/Predictions/Reshape]
|
||||
CONV_2D 78.081 7.186 7.186 0.095% 0.096% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_pointwise/Relu6]
|
||||
DEPTHWISE_CONV_2D 77.444 0.635 0.635 0.008% 0.104% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_depthwise/Relu6]
|
||||
CONV_2D 62.774 14.666 14.666 0.195% 0.299% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_pointwise/Relu6]
|
||||
DEPTHWISE_CONV_2D 62.124 0.647 0.647 0.009% 0.307% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_depthwise/Relu6]
|
||||
CONV_2D 55.656 6.464 6.464 0.086% 0.393% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_pointwise/Relu6]
|
||||
DEPTHWISE_CONV_2D 55.026 0.629 0.629 0.008% 0.401% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_depthwise/Relu6]
|
||||
CONV_2D 48.856 6.167 6.167 0.082% 0.483% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_pointwise/Relu6]
|
||||
DEPTHWISE_CONV_2D 48.445 0.409 0.409 0.005% 0.489% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_depthwise/Relu6]
|
||||
|
||||
Number of nodes executed: 31
|
||||
============================== Summary by node type ==============================
|
||||
[Node type] [count] [avg ms] [avg %] [cdf %] [mem KB] [times called]
|
||||
CONV_2D 15 1.861 86.679% 86.679% 0.000 0
|
||||
DEPTHWISE_CONV_2D 13 0.286 13.321% 100.000% 0.000 0
|
||||
SOFTMAX 1 0.000 0.000% 100.000% 0.000 0
|
||||
RESHAPE 1 0.000 0.000% 100.000% 0.000 0
|
||||
AVERAGE_POOL_2D 1 0.000 0.000% 100.000% 0.000 0
|
||||
|
||||
Timings (microseconds): count=50 first=108164 curr=128308 min=102850 max=197072 avg=150805 std=24368
|
||||
Memory (bytes): count=0
|
||||
31 nodes observed
|
||||
|
||||
|
||||
Average inference timings in us: Warmup: 135310, Init: 12123, no stats: 150988
|
||||
|
||||
```
|
||||
|
||||
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/lite/tools/benchmark_tflite_model.h"
|
||||
#include "tensorflow/contrib/lite/tools/logging.h"
|
||||
#include "tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h"
|
||||
#include "tensorflow/contrib/lite/tools/benchmark/logging.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace benchmark {
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/lite/tools/benchmark_model.h"
|
||||
#include "tensorflow/contrib/lite/tools/benchmark/benchmark_model.h"
|
||||
|
||||
#include <time.h>
|
||||
|
||||
@ -21,7 +21,7 @@ limitations under the License.
|
||||
#include <sstream>
|
||||
|
||||
#include "tensorflow/contrib/lite/profiling/time.h"
|
||||
#include "tensorflow/contrib/lite/tools/logging.h"
|
||||
#include "tensorflow/contrib/lite/tools/benchmark/logging.h"
|
||||
|
||||
namespace {
|
||||
void SleepForSeconds(double sleep_seconds) {
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/contrib/lite/tools//command_line_flags.h"
|
||||
#include "tensorflow/contrib/lite/tools/benchmark/command_line_flags.h"
|
||||
#include "tensorflow/core/util/stats_calculator.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -158,4 +158,4 @@ class BenchmarkModel {
|
||||
} // namespace benchmark
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_MODEL_H_
|
||||
#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_MODEL_H_
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/lite/tools/benchmark_tflite_model.h"
|
||||
#include "tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h"
|
||||
|
||||
#include <cstdarg>
|
||||
#include <cstdlib>
|
||||
@ -27,7 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/contrib/lite/model.h"
|
||||
#include "tensorflow/contrib/lite/op_resolver.h"
|
||||
#include "tensorflow/contrib/lite/string_util.h"
|
||||
#include "tensorflow/contrib/lite/tools/logging.h"
|
||||
#include "tensorflow/contrib/lite/tools/benchmark/logging.h"
|
||||
|
||||
#ifdef TFLITE_CUSTOM_OPS_HEADER
|
||||
void RegisterSelectedOps(::tflite::MutableOpResolver* resolver);
|
@ -22,7 +22,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/contrib/lite/model.h"
|
||||
#include "tensorflow/contrib/lite/profiling/profile_summarizer.h"
|
||||
#include "tensorflow/contrib/lite/tools/benchmark_model.h"
|
||||
#include "tensorflow/contrib/lite/tools/benchmark/benchmark_model.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace benchmark {
|
||||
@ -87,4 +87,4 @@ class BenchmarkTfLiteModel : public BenchmarkModel {
|
||||
} // namespace benchmark
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_TFLITE_MODEL_H_
|
||||
#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_TFLITE_MODEL_H_
|
@ -10,8 +10,9 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/lite/tools/command_line_flags.h"
|
||||
#include "tensorflow/contrib/lite/tools/benchmark/command_line_flags.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
@ -19,6 +20,13 @@ limitations under the License.
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
std::string ToString(T val) {
|
||||
std::ostringstream stream;
|
||||
stream << val;
|
||||
return stream.str();
|
||||
}
|
||||
|
||||
bool ParseFlag(const std::string& arg, const std::string& flag,
|
||||
const std::function<bool(const std::string&)>& parse_func,
|
||||
bool* value_parsing_ok) {
|
||||
@ -35,14 +43,16 @@ bool ParseFlag(const std::string& arg, const std::string& flag,
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ParseInt32Flag(const std::string& flag_value, int32_t* value) {
|
||||
char extra;
|
||||
return sscanf(flag_value.data(), "%d%c", value, &extra) == 1;
|
||||
}
|
||||
|
||||
bool ParseInt64Flag(const std::string& flag_value, int64_t* value) {
|
||||
char extra;
|
||||
return sscanf(flag_value.data(), "%ld%c", value, &extra) == 1;
|
||||
template <typename T>
|
||||
bool ParseFlag(const std::string& flag_value, T* value) {
|
||||
std::istringstream stream(flag_value);
|
||||
T read_value;
|
||||
stream >> read_value;
|
||||
if (!stream.eof() && !stream.good()) {
|
||||
return false;
|
||||
}
|
||||
*value = read_value;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ParseBoolFlag(const std::string& flag_value, bool* value) {
|
||||
@ -54,11 +64,6 @@ bool ParseBoolFlag(const std::string& flag_value, bool* value) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ParseFloatFlag(const std::string& flag_value, float* value) {
|
||||
char extra;
|
||||
return sscanf(flag_value.data(), "%f%c", value, &extra) == 1;
|
||||
}
|
||||
|
||||
bool ParseStringFlag(const std::string& flag_value, std::string* value) {
|
||||
*value = flag_value;
|
||||
return true;
|
||||
@ -70,27 +75,27 @@ Flag::Flag(const char* name, int32_t* dst, const std::string& usage_text)
|
||||
: name_(name),
|
||||
type_(TYPE_INT32),
|
||||
value_hook_([dst](const std::string& flag_value) {
|
||||
return ParseInt32Flag(flag_value, dst);
|
||||
return ParseFlag<int32_t>(flag_value, dst);
|
||||
}),
|
||||
default_for_display_(std::to_string(*dst)),
|
||||
default_for_display_(ToString(*dst)),
|
||||
usage_text_(usage_text) {}
|
||||
|
||||
Flag::Flag(const char* name, int64_t* dst, const std::string& usage_text)
|
||||
: name_(name),
|
||||
type_(TYPE_INT64),
|
||||
value_hook_([dst](const std::string& flag_value) {
|
||||
return ParseInt64Flag(flag_value, dst);
|
||||
return ParseFlag<int64_t>(flag_value, dst);
|
||||
}),
|
||||
default_for_display_(std::to_string(*dst)),
|
||||
default_for_display_(ToString(*dst)),
|
||||
usage_text_(usage_text) {}
|
||||
|
||||
Flag::Flag(const char* name, float* dst, const std::string& usage_text)
|
||||
: name_(name),
|
||||
type_(TYPE_FLOAT),
|
||||
value_hook_([dst](const std::string& flag_value) {
|
||||
return ParseFloatFlag(flag_value, dst);
|
||||
return ParseFlag<float>(flag_value, dst);
|
||||
}),
|
||||
default_for_display_(std::to_string(*dst)),
|
||||
default_for_display_(ToString(*dst)),
|
||||
usage_text_(usage_text) {}
|
||||
|
||||
Flag::Flag(const char* name, bool* dst, const std::string& usage_text)
|
||||
@ -166,7 +171,7 @@ std::string Flag::GetTypeName() const {
|
||||
}
|
||||
argv[dst++] = nullptr;
|
||||
*argc = unknown_flags.size() + 1;
|
||||
return result && (*argc < 2 || strcmp(argv[1], "--help") != 0);
|
||||
return result && (*argc < 2 || std::strcmp(argv[1], "--help") != 0);
|
||||
}
|
||||
|
||||
/*static*/ std::string Flags::Usage(const std::string& cmdline,
|
@ -109,4 +109,4 @@ class Flags {
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_COMMAND_LINE_FLAGS_H_
|
||||
#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_COMMAND_LINE_FLAGS_H_
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/lite/tools/command_line_flags.h"
|
||||
#include "tensorflow/contrib/lite/tools/benchmark/command_line_flags.h"
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/contrib/lite/testing/util.h"
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
// LOG and CHECK macros for benchmarks.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
@ -72,4 +73,4 @@ class LoggingWrapper {
|
||||
|
||||
#define TFLITE_BENCHMARK_CHECK_EQ(a, b) TFLITE_BENCHMARK_CHECK(a == b)
|
||||
|
||||
#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_LOGGING_H_
|
||||
#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_LOGGING_H_
|
@ -876,7 +876,6 @@ cc_library(
|
||||
hdrs = [
|
||||
"util/stats_calculator.h",
|
||||
],
|
||||
deps = [":platform_base"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
@ -78,6 +78,14 @@ void StatSummarizer::Validate(const std::vector<TensorDescription>* outputs,
|
||||
}
|
||||
}
|
||||
|
||||
void StatSummarizer::PrintStepStats() const {
|
||||
string output = GetOutputString();
|
||||
std::istringstream iss(output);
|
||||
for (std::string line; std::getline(iss, line);) {
|
||||
LOG(INFO) << line;
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
std::string OpType(const DeviceStepStats& ds, const NodeExecStats& ns) {
|
||||
// There is no published specification of how DeviceStats and NodeStats
|
||||
|
@ -68,7 +68,7 @@ class StatSummarizer {
|
||||
}
|
||||
|
||||
// Prints the string returned by GetOutputString().
|
||||
void PrintStepStats() const { stats_calculator_->PrintStepStats(); }
|
||||
void PrintStepStats() const;
|
||||
|
||||
// Prints the output tensor sizes and types for each node.
|
||||
void PrintOutputs() const;
|
||||
|
@ -21,8 +21,6 @@ limitations under the License.
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
StatsCalculator::StatsCalculator(const StatSummarizerOptions& options)
|
||||
@ -93,7 +91,7 @@ std::string StatsCalculator::ColumnString(const Detail& detail,
|
||||
|
||||
void StatsCalculator::OrderNodesByMetric(
|
||||
SortingMetric metric, std::vector<const Detail*>* details) const {
|
||||
std::priority_queue<std::pair<string, const Detail*>> sorted_list;
|
||||
std::priority_queue<std::pair<std::string, const Detail*>> sorted_list;
|
||||
const int num_nodes = details_.size();
|
||||
|
||||
for (const auto& det : details_) {
|
||||
@ -142,7 +140,7 @@ void StatsCalculator::ComputeStatsByType(
|
||||
int64_t run_count = run_total_us_.count();
|
||||
|
||||
for (const auto& det : details_) {
|
||||
const string node_name = det.first;
|
||||
const std::string node_name = det.first;
|
||||
const Detail& detail = det.second;
|
||||
|
||||
int64_t curr_time_val =
|
||||
@ -151,7 +149,7 @@ void StatsCalculator::ComputeStatsByType(
|
||||
|
||||
int64_t curr_memory_val = detail.mem_used.newest();
|
||||
|
||||
const string& node_type = detail.type;
|
||||
const std::string& node_type = detail.type;
|
||||
|
||||
(*node_type_map_count)[node_type] += 1;
|
||||
(*node_type_map_time)[node_type] += curr_time_val;
|
||||
@ -163,12 +161,12 @@ void StatsCalculator::ComputeStatsByType(
|
||||
std::string StatsCalculator::GetStatsByNodeType() const {
|
||||
std::stringstream stream;
|
||||
|
||||
stream << "Number of nodes executed: " << details_.size() << std::endl;
|
||||
|
||||
stream << "============================== Summary by node type "
|
||||
"=============================="
|
||||
<< std::endl;
|
||||
|
||||
LOG(INFO) << "Number of nodes executed: " << details_.size();
|
||||
|
||||
std::map<std::string, int64_t> node_type_map_count;
|
||||
std::map<std::string, int64_t> node_type_map_time;
|
||||
std::map<std::string, int64_t> node_type_map_memory;
|
||||
@ -180,11 +178,12 @@ std::string StatsCalculator::GetStatsByNodeType() const {
|
||||
&accumulated_us);
|
||||
|
||||
// Sort them.
|
||||
std::priority_queue<std::pair<int64_t, std::pair<string, int64_t>>> timings;
|
||||
std::priority_queue<std::pair<int64_t, std::pair<std::string, int64_t>>>
|
||||
timings;
|
||||
for (const auto& node_type : node_type_map_time) {
|
||||
const int64_t mem_used = node_type_map_memory[node_type.first];
|
||||
timings.emplace(node_type.second,
|
||||
std::pair<string, int64_t>(node_type.first, mem_used));
|
||||
std::pair<std::string, int64_t>(node_type.first, mem_used));
|
||||
}
|
||||
|
||||
InitField(stream, 24) << "[Node type]";
|
||||
@ -201,7 +200,7 @@ std::string StatsCalculator::GetStatsByNodeType() const {
|
||||
auto entry = timings.top();
|
||||
timings.pop();
|
||||
|
||||
const string node_type = entry.second.first;
|
||||
const std::string node_type = entry.second.first;
|
||||
const float memory = entry.second.second / 1000.0f;
|
||||
|
||||
const int64_t node_type_total_us = entry.first;
|
||||
@ -273,14 +272,6 @@ std::string StatsCalculator::GetOutputString() const {
|
||||
return stream.str();
|
||||
}
|
||||
|
||||
void StatsCalculator::PrintStepStats() const {
|
||||
string output = GetOutputString();
|
||||
std::istringstream iss(output);
|
||||
for (std::string line; std::getline(iss, line);) {
|
||||
LOG(INFO) << line;
|
||||
}
|
||||
}
|
||||
|
||||
void StatsCalculator::UpdateDetails(
|
||||
const std::map<std::string, Detail>& details) {
|
||||
details_.insert(details.begin(), details.end());
|
||||
|
@ -127,9 +127,6 @@ class StatsCalculator {
|
||||
|
||||
std::string GetShortSummary() const;
|
||||
|
||||
// Prints the string returned by GetOutputString().
|
||||
void PrintStepStats() const;
|
||||
|
||||
void ComputeStatsByType(
|
||||
std::map<std::string, int64_t>* node_type_map_count,
|
||||
std::map<std::string, int64_t>* node_type_map_time,
|
||||
|
Loading…
Reference in New Issue
Block a user