Added a unit test to test benchmark_model tool to run w/ cmdline args.
PiperOrigin-RevId: 306385040 Change-Id: I7eb056476ca7c67b84c75ecfd540ac1fb02c5c9c
This commit is contained in:
parent
0e357fd45e
commit
005ad3dda3
@ -101,6 +101,7 @@ cc_test(
|
|||||||
":delegate_provider_hdr",
|
":delegate_provider_hdr",
|
||||||
"//tensorflow/lite:framework",
|
"//tensorflow/lite:framework",
|
||||||
"//tensorflow/lite:string_util",
|
"//tensorflow/lite:string_util",
|
||||||
|
"//tensorflow/lite/c:common",
|
||||||
"//tensorflow/lite/testing:util",
|
"//tensorflow/lite/testing:util",
|
||||||
"//tensorflow/lite/tools:command_line_flags",
|
"//tensorflow/lite/tools:command_line_flags",
|
||||||
"//tensorflow/lite/tools:logging",
|
"//tensorflow/lite/tools:logging",
|
||||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@ -22,6 +23,7 @@ limitations under the License.
|
|||||||
#include "absl/algorithm/algorithm.h"
|
#include "absl/algorithm/algorithm.h"
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/interpreter.h"
|
#include "tensorflow/lite/interpreter.h"
|
||||||
#include "tensorflow/lite/string_util.h"
|
#include "tensorflow/lite/string_util.h"
|
||||||
#include "tensorflow/lite/testing/util.h"
|
#include "tensorflow/lite/testing/util.h"
|
||||||
@ -344,6 +346,56 @@ TEST(BenchmarkTest, DoesntCrashWithExplicitInputValueFilesStringModel) {
|
|||||||
CheckInputTensorValue(input_tensor, 2, string_value_2);
|
CheckInputTensorValue(input_tensor, 2, string_value_2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class ScopedCommandlineArgs {
|
||||||
|
public:
|
||||||
|
explicit ScopedCommandlineArgs(const std::vector<std::string>& actual_args) {
|
||||||
|
argc_ = actual_args.size() + 1;
|
||||||
|
argv_ = new char*[argc_];
|
||||||
|
const std::string program_name = "benchmark_model";
|
||||||
|
int buffer_size = program_name.length() + 1;
|
||||||
|
for (const auto& arg : actual_args) buffer_size += arg.length() + 1;
|
||||||
|
buffer_ = new char[buffer_size];
|
||||||
|
auto next_start = program_name.copy(buffer_, program_name.length());
|
||||||
|
buffer_[next_start++] = '\0';
|
||||||
|
argv_[0] = buffer_;
|
||||||
|
for (int i = 0; i < actual_args.size(); ++i) {
|
||||||
|
const auto& arg = actual_args[i];
|
||||||
|
argv_[i + 1] = buffer_ + next_start;
|
||||||
|
next_start += arg.copy(argv_[i + 1], arg.length());
|
||||||
|
buffer_[next_start++] = '\0';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
~ScopedCommandlineArgs() {
|
||||||
|
delete[] argv_;
|
||||||
|
delete[] buffer_;
|
||||||
|
}
|
||||||
|
|
||||||
|
int argc() const { return argc_; }
|
||||||
|
|
||||||
|
char** argv() const { return argv_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
char* buffer_; // the buffer for all arguments.
|
||||||
|
int argc_;
|
||||||
|
char** argv_; // Each char* element points to each argument.
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST(BenchmarkTest, RunWithCorrectFlags) {
|
||||||
|
ASSERT_THAT(g_fp32_model_path, testing::NotNull());
|
||||||
|
TestBenchmark benchmark(CreateFp32Params());
|
||||||
|
ScopedCommandlineArgs scoped_argv({"--num_threads=4"});
|
||||||
|
auto status = benchmark.Run(scoped_argv.argc(), scoped_argv.argv());
|
||||||
|
EXPECT_EQ(kTfLiteOk, status);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(BenchmarkTest, RunWithWrongFlags) {
|
||||||
|
ASSERT_THAT(g_fp32_model_path, testing::NotNull());
|
||||||
|
TestBenchmark benchmark(CreateFp32Params());
|
||||||
|
ScopedCommandlineArgs scoped_argv({"--num_threads=str"});
|
||||||
|
auto status = benchmark.Run(scoped_argv.argc(), scoped_argv.argv());
|
||||||
|
EXPECT_EQ(kTfLiteError, status);
|
||||||
|
}
|
||||||
|
|
||||||
class MaxDurationWorksTestListener : public BenchmarkListener {
|
class MaxDurationWorksTestListener : public BenchmarkListener {
|
||||||
void OnBenchmarkEnd(const BenchmarkResults& results) override {
|
void OnBenchmarkEnd(const BenchmarkResults& results) override {
|
||||||
const int64_t num_actual_runs = results.inference_time_us().count();
|
const int64_t num_actual_runs = results.inference_time_us().count();
|
||||||
|
Loading…
Reference in New Issue
Block a user