Added a unit test to test benchmark_model tool to run w/ cmdline args.

PiperOrigin-RevId: 306385040
Change-Id: I7eb056476ca7c67b84c75ecfd540ac1fb02c5c9c
This commit is contained in:
Chao Mei 2020-04-14 00:01:37 -07:00 committed by TensorFlower Gardener
parent 0e357fd45e
commit 005ad3dda3
2 changed files with 53 additions and 0 deletions

View File

@ -101,6 +101,7 @@ cc_test(
":delegate_provider_hdr",
"//tensorflow/lite:framework",
"//tensorflow/lite:string_util",
"//tensorflow/lite/c:common",
"//tensorflow/lite/testing:util",
"//tensorflow/lite/tools:command_line_flags",
"//tensorflow/lite/tools:logging",

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include <fstream>
#include <iostream>
#include <memory>
#include <string>
#include <vector>
@ -22,6 +23,7 @@ limitations under the License.
#include "absl/algorithm/algorithm.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_format.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/string_util.h"
#include "tensorflow/lite/testing/util.h"
@ -344,6 +346,56 @@ TEST(BenchmarkTest, DoesntCrashWithExplicitInputValueFilesStringModel) {
CheckInputTensorValue(input_tensor, 2, string_value_2);
}
class ScopedCommandlineArgs {
public:
explicit ScopedCommandlineArgs(const std::vector<std::string>& actual_args) {
argc_ = actual_args.size() + 1;
argv_ = new char*[argc_];
const std::string program_name = "benchmark_model";
int buffer_size = program_name.length() + 1;
for (const auto& arg : actual_args) buffer_size += arg.length() + 1;
buffer_ = new char[buffer_size];
auto next_start = program_name.copy(buffer_, program_name.length());
buffer_[next_start++] = '\0';
argv_[0] = buffer_;
for (int i = 0; i < actual_args.size(); ++i) {
const auto& arg = actual_args[i];
argv_[i + 1] = buffer_ + next_start;
next_start += arg.copy(argv_[i + 1], arg.length());
buffer_[next_start++] = '\0';
}
}
~ScopedCommandlineArgs() {
delete[] argv_;
delete[] buffer_;
}
int argc() const { return argc_; }
char** argv() const { return argv_; }
private:
char* buffer_; // the buffer for all arguments.
int argc_;
char** argv_; // Each char* element points to each argument.
};
TEST(BenchmarkTest, RunWithCorrectFlags) {
ASSERT_THAT(g_fp32_model_path, testing::NotNull());
TestBenchmark benchmark(CreateFp32Params());
ScopedCommandlineArgs scoped_argv({"--num_threads=4"});
auto status = benchmark.Run(scoped_argv.argc(), scoped_argv.argv());
EXPECT_EQ(kTfLiteOk, status);
}
TEST(BenchmarkTest, RunWithWrongFlags) {
ASSERT_THAT(g_fp32_model_path, testing::NotNull());
TestBenchmark benchmark(CreateFp32Params());
ScopedCommandlineArgs scoped_argv({"--num_threads=str"});
auto status = benchmark.Run(scoped_argv.argc(), scoped_argv.argv());
EXPECT_EQ(kTfLiteError, status);
}
class MaxDurationWorksTestListener : public BenchmarkListener {
void OnBenchmarkEnd(const BenchmarkResults& results) override {
const int64_t num_actual_runs = results.inference_time_us().count();