Add a flag (i.e. --max_secs) to limit the benchmark duration, and set the default value as 150 secs.
PiperOrigin-RevId: 253420291
This commit is contained in:
parent
8211365f9e
commit
6a96e865fa
@ -43,6 +43,7 @@ BenchmarkParams BenchmarkModel::DefaultParams() {
|
||||
BenchmarkParams params;
|
||||
params.AddParam("num_runs", BenchmarkParam::Create<int32_t>(50));
|
||||
params.AddParam("min_secs", BenchmarkParam::Create<float>(1.0f));
|
||||
params.AddParam("max_secs", BenchmarkParam::Create<float>(150.0f));
|
||||
params.AddParam("run_delay", BenchmarkParam::Create<float>(-1.0f));
|
||||
params.AddParam("num_threads", BenchmarkParam::Create<int32_t>(1));
|
||||
params.AddParam("benchmark_name", BenchmarkParam::Create<std::string>(""));
|
||||
@ -66,12 +67,19 @@ void BenchmarkLoggingListener::OnBenchmarkEnd(const BenchmarkResults &results) {
|
||||
|
||||
std::vector<Flag> BenchmarkModel::GetFlags() {
|
||||
return {
|
||||
CreateFlag<int32_t>("num_runs", ¶ms_,
|
||||
"minimum number of runs, see also min_secs"),
|
||||
CreateFlag<int32_t>(
|
||||
"num_runs", ¶ms_,
|
||||
"expected number of runs, see also min_secs, max_secs"),
|
||||
CreateFlag<float>(
|
||||
"min_secs", ¶ms_,
|
||||
"minimum number of seconds to rerun for, potentially making the "
|
||||
"actual number of runs to be greater than num_runs"),
|
||||
CreateFlag<float>(
|
||||
"max_secs", ¶ms_,
|
||||
"maximum number of seconds to rerun for, potentially making the "
|
||||
"actual number of runs to be less than num_runs. Note if --max-secs "
|
||||
"is exceeded in the middle of a run, the benchmark will continue to "
|
||||
"the end of the run but will not start the next run."),
|
||||
CreateFlag<float>("run_delay", ¶ms_, "delay between runs in seconds"),
|
||||
CreateFlag<int32_t>("num_threads", ¶ms_, "number of threads"),
|
||||
CreateFlag<std::string>("benchmark_name", ¶ms_, "benchmark name"),
|
||||
@ -94,6 +102,8 @@ void BenchmarkModel::LogParams() {
|
||||
<< "]";
|
||||
TFLITE_LOG(INFO) << "Min runs duration (seconds): ["
|
||||
<< params_.Get<float>("min_secs") << "]";
|
||||
TFLITE_LOG(INFO) << "Max runs duration (seconds): ["
|
||||
<< params_.Get<float>("max_secs") << "]";
|
||||
TFLITE_LOG(INFO) << "Inter-run delay (seconds): ["
|
||||
<< params_.Get<float>("run_delay") << "]";
|
||||
TFLITE_LOG(INFO) << "Num threads: [" << params_.Get<int32_t>("num_threads")
|
||||
@ -113,14 +123,17 @@ void BenchmarkModel::PrepareInputData() {}
|
||||
void BenchmarkModel::ResetInputsAndOutputs() {}
|
||||
|
||||
Stat<int64_t> BenchmarkModel::Run(int min_num_times, float min_secs,
|
||||
RunType run_type) {
|
||||
float max_secs, RunType run_type) {
|
||||
Stat<int64_t> run_stats;
|
||||
TFLITE_LOG(INFO) << "Running benchmark for at least " << min_num_times
|
||||
<< " iterations and at least " << min_secs << " seconds";
|
||||
int64_t min_finish_us =
|
||||
profiling::time::NowMicros() + static_cast<int64_t>(min_secs * 1.e6f);
|
||||
for (int run = 0;
|
||||
run < min_num_times || profiling::time::NowMicros() < min_finish_us;
|
||||
<< " iterations and at least " << min_secs << " seconds but"
|
||||
<< " terminate if exceeding " << max_secs << " seconds.";
|
||||
int64_t now_us = profiling::time::NowMicros();
|
||||
int64_t min_finish_us = now_us + static_cast<int64_t>(min_secs * 1.e6f);
|
||||
int64_t max_finish_us = now_us + static_cast<int64_t>(max_secs * 1.e6f);
|
||||
|
||||
for (int run = 0; (run < min_num_times || now_us < min_finish_us) &&
|
||||
now_us <= max_finish_us;
|
||||
run++) {
|
||||
ResetInputsAndOutputs();
|
||||
listeners_.OnSingleRunStart(run_type);
|
||||
@ -131,6 +144,7 @@ Stat<int64_t> BenchmarkModel::Run(int min_num_times, float min_secs,
|
||||
|
||||
run_stats.UpdateStat(end_us - start_us);
|
||||
SleepForSeconds(params_.Get<float>("run_delay"));
|
||||
now_us = profiling::time::NowMicros();
|
||||
}
|
||||
|
||||
std::stringstream stream;
|
||||
@ -163,12 +177,12 @@ void BenchmarkModel::Run() {
|
||||
PrepareInputData();
|
||||
uint64_t input_bytes = ComputeInputBytes();
|
||||
listeners_.OnBenchmarkStart(params_);
|
||||
Stat<int64_t> warmup_time_us =
|
||||
Run(params_.Get<int32_t>("warmup_runs"),
|
||||
params_.Get<float>("warmup_min_secs"), WARMUP);
|
||||
Stat<int64_t> warmup_time_us = Run(params_.Get<int32_t>("warmup_runs"),
|
||||
params_.Get<float>("warmup_min_secs"),
|
||||
params_.Get<float>("max_secs"), WARMUP);
|
||||
Stat<int64_t> inference_time_us =
|
||||
Run(params_.Get<int32_t>("num_runs"), params_.Get<float>("min_secs"),
|
||||
REGULAR);
|
||||
params_.Get<float>("max_secs"), REGULAR);
|
||||
listeners_.OnBenchmarkEnd(
|
||||
{startup_latency_us, input_bytes, warmup_time_us, inference_time_us});
|
||||
}
|
||||
|
@ -157,7 +157,7 @@ class BenchmarkModel {
|
||||
virtual std::vector<Flag> GetFlags();
|
||||
virtual uint64_t ComputeInputBytes() = 0;
|
||||
virtual tensorflow::Stat<int64_t> Run(int min_num_times, float min_secs,
|
||||
RunType run_type);
|
||||
float max_secs, RunType run_type);
|
||||
// Prepares input data for benchmark. This can be used to initialize input
|
||||
// data that has non-trivial cost.
|
||||
virtual void PrepareInputData();
|
||||
|
@ -31,10 +31,11 @@ namespace tflite {
|
||||
namespace benchmark {
|
||||
namespace {
|
||||
|
||||
BenchmarkParams CreateParams() {
|
||||
BenchmarkParams CreateParams(int32_t num_runs, float min_secs, float max_secs) {
|
||||
BenchmarkParams params;
|
||||
params.AddParam("num_runs", BenchmarkParam::Create<int32_t>(2));
|
||||
params.AddParam("min_secs", BenchmarkParam::Create<float>(1.0f));
|
||||
params.AddParam("num_runs", BenchmarkParam::Create<int32_t>(num_runs));
|
||||
params.AddParam("min_secs", BenchmarkParam::Create<float>(min_secs));
|
||||
params.AddParam("max_secs", BenchmarkParam::Create<float>(max_secs));
|
||||
params.AddParam("run_delay", BenchmarkParam::Create<float>(-1.0f));
|
||||
params.AddParam("num_threads", BenchmarkParam::Create<int32_t>(1));
|
||||
params.AddParam("benchmark_name", BenchmarkParam::Create<std::string>(""));
|
||||
@ -52,6 +53,8 @@ BenchmarkParams CreateParams() {
|
||||
return params;
|
||||
}
|
||||
|
||||
BenchmarkParams CreateParams() { return CreateParams(2, 1.0f, 150.0f); }
|
||||
|
||||
class TestBenchmark : public BenchmarkTfLiteModel {
|
||||
public:
|
||||
explicit TestBenchmark(BenchmarkParams params)
|
||||
@ -71,6 +74,25 @@ TEST(BenchmarkTest, DoesntCrash) {
|
||||
benchmark.Run();
|
||||
}
|
||||
|
||||
class MaxDurationWorksTestListener : public BenchmarkListener {
|
||||
void OnBenchmarkEnd(const BenchmarkResults& results) override {
|
||||
const int64_t num_actul_runs = results.inference_time_us().count();
|
||||
TFLITE_LOG(INFO) << "number of actual runs: " << num_actul_runs;
|
||||
EXPECT_GE(num_actul_runs, 1);
|
||||
EXPECT_LT(num_actul_runs, 100000000);
|
||||
}
|
||||
};
|
||||
|
||||
TEST(BenchmarkTest, MaxDurationWorks) {
|
||||
ASSERT_THAT(g_model_path, testing::NotNull());
|
||||
BenchmarkTfLiteModel benchmark(CreateParams(100000000 /* num_runs */,
|
||||
1000000.0f /* min_secs */,
|
||||
0.001f /* max_secs */));
|
||||
MaxDurationWorksTestListener listener;
|
||||
benchmark.AddListener(&listener);
|
||||
benchmark.Run();
|
||||
}
|
||||
|
||||
TEST(BenchmarkTest, ParametersArePopulatedWhenInputShapeIsNotSpecified) {
|
||||
ASSERT_THAT(g_model_path, testing::NotNull());
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user