Added a unit test to test benchmark_model tool to run w/ cmdline args.
PiperOrigin-RevId: 306385040 Change-Id: I7eb056476ca7c67b84c75ecfd540ac1fb02c5c9c
This commit is contained in:
parent
0e357fd45e
commit
005ad3dda3
@ -101,6 +101,7 @@ cc_test(
|
||||
":delegate_provider_hdr",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite:string_util",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/testing:util",
|
||||
"//tensorflow/lite/tools:command_line_flags",
|
||||
"//tensorflow/lite/tools:logging",
|
||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@ -22,6 +23,7 @@ limitations under the License.
|
||||
#include "absl/algorithm/algorithm.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
#include "tensorflow/lite/testing/util.h"
|
||||
@ -344,6 +346,56 @@ TEST(BenchmarkTest, DoesntCrashWithExplicitInputValueFilesStringModel) {
|
||||
CheckInputTensorValue(input_tensor, 2, string_value_2);
|
||||
}
|
||||
|
||||
class ScopedCommandlineArgs {
|
||||
public:
|
||||
explicit ScopedCommandlineArgs(const std::vector<std::string>& actual_args) {
|
||||
argc_ = actual_args.size() + 1;
|
||||
argv_ = new char*[argc_];
|
||||
const std::string program_name = "benchmark_model";
|
||||
int buffer_size = program_name.length() + 1;
|
||||
for (const auto& arg : actual_args) buffer_size += arg.length() + 1;
|
||||
buffer_ = new char[buffer_size];
|
||||
auto next_start = program_name.copy(buffer_, program_name.length());
|
||||
buffer_[next_start++] = '\0';
|
||||
argv_[0] = buffer_;
|
||||
for (int i = 0; i < actual_args.size(); ++i) {
|
||||
const auto& arg = actual_args[i];
|
||||
argv_[i + 1] = buffer_ + next_start;
|
||||
next_start += arg.copy(argv_[i + 1], arg.length());
|
||||
buffer_[next_start++] = '\0';
|
||||
}
|
||||
}
|
||||
~ScopedCommandlineArgs() {
|
||||
delete[] argv_;
|
||||
delete[] buffer_;
|
||||
}
|
||||
|
||||
int argc() const { return argc_; }
|
||||
|
||||
char** argv() const { return argv_; }
|
||||
|
||||
private:
|
||||
char* buffer_; // the buffer for all arguments.
|
||||
int argc_;
|
||||
char** argv_; // Each char* element points to each argument.
|
||||
};
|
||||
|
||||
TEST(BenchmarkTest, RunWithCorrectFlags) {
|
||||
ASSERT_THAT(g_fp32_model_path, testing::NotNull());
|
||||
TestBenchmark benchmark(CreateFp32Params());
|
||||
ScopedCommandlineArgs scoped_argv({"--num_threads=4"});
|
||||
auto status = benchmark.Run(scoped_argv.argc(), scoped_argv.argv());
|
||||
EXPECT_EQ(kTfLiteOk, status);
|
||||
}
|
||||
|
||||
TEST(BenchmarkTest, RunWithWrongFlags) {
|
||||
ASSERT_THAT(g_fp32_model_path, testing::NotNull());
|
||||
TestBenchmark benchmark(CreateFp32Params());
|
||||
ScopedCommandlineArgs scoped_argv({"--num_threads=str"});
|
||||
auto status = benchmark.Run(scoped_argv.argc(), scoped_argv.argv());
|
||||
EXPECT_EQ(kTfLiteError, status);
|
||||
}
|
||||
|
||||
class MaxDurationWorksTestListener : public BenchmarkListener {
|
||||
void OnBenchmarkEnd(const BenchmarkResults& results) override {
|
||||
const int64_t num_actual_runs = results.inference_time_us().count();
|
||||
|
Loading…
Reference in New Issue
Block a user