Support to reuse the same set of commandline flags that initialize tflite delegates in the benchmark tool for evaluation-related tools via the delegate-provider registerar, and apply this support to image_classification, objection_detection and inference_diff eval tools.
PiperOrigin-RevId: 305849016 Change-Id: Ie5f35b1bf4fa38e10f3ac7bf9825efacd41988d4
This commit is contained in:
parent
77c0bb4322
commit
58afbd9803
@ -70,6 +70,10 @@ cc_library(
|
||||
copts = tflite_copts(),
|
||||
deps = [
|
||||
":utils",
|
||||
"//tensorflow/lite/tools:command_line_flags",
|
||||
"//tensorflow/lite/tools:tool_params",
|
||||
"//tensorflow/lite/tools/benchmark:delegate_provider_hdr",
|
||||
"//tensorflow/lite/tools/benchmark:tflite_execution_providers",
|
||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto",
|
||||
],
|
||||
)
|
||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
|
||||
|
||||
#include "tensorflow/lite/tools/command_line_flags.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace evaluation {
|
||||
namespace {
|
||||
@ -36,7 +38,6 @@ TfliteInferenceParams::Delegate ParseStringToDelegateType(
|
||||
TfLiteDelegatePtr CreateTfLiteDelegate(const TfliteInferenceParams& params,
|
||||
std::string* error_msg) {
|
||||
const auto type = params.delegate();
|
||||
|
||||
switch (type) {
|
||||
case TfliteInferenceParams::NNAPI: {
|
||||
auto p = CreateNNAPIDelegate();
|
||||
@ -76,5 +77,86 @@ TfLiteDelegatePtr CreateTfLiteDelegate(const TfliteInferenceParams& params,
|
||||
}
|
||||
}
|
||||
|
||||
DelegateProviders::DelegateProviders()
|
||||
: delegates_list_(benchmark::GetRegisteredDelegateProviders()),
|
||||
delegates_map_([=]() -> std::unordered_map<std::string, int> {
|
||||
std::unordered_map<std::string, int> delegates_map;
|
||||
for (int i = 0; i < delegates_list_.size(); ++i) {
|
||||
delegates_map[delegates_list_[i]->GetName()] = i;
|
||||
}
|
||||
return delegates_map;
|
||||
}()) {
|
||||
for (const auto& one : delegates_list_) {
|
||||
params_.Merge(one->DefaultParams());
|
||||
}
|
||||
}
|
||||
|
||||
bool DelegateProviders::InitFromCmdlineArgs(int* argc, const char** argv) {
|
||||
std::vector<Flag> flags;
|
||||
for (const auto& one : delegates_list_) {
|
||||
auto one_flags = one->CreateFlags(¶ms_);
|
||||
flags.insert(flags.end(), one_flags.begin(), one_flags.end());
|
||||
}
|
||||
return Flags::Parse(argc, argv, flags);
|
||||
}
|
||||
|
||||
TfLiteDelegatePtr DelegateProviders::CreateDelegate(
|
||||
const std::string& name) const {
|
||||
const auto it = delegates_map_.find(name);
|
||||
if (it == delegates_map_.end()) {
|
||||
return TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {});
|
||||
}
|
||||
return delegates_list_[it->second]->CreateTfLiteDelegate(params_);
|
||||
}
|
||||
|
||||
std::vector<TfLiteDelegatePtr> DelegateProviders::CreateAllDelegates(
|
||||
const tools::ToolParams& params) const {
|
||||
std::vector<TfLiteDelegatePtr> delegates;
|
||||
for (const auto& one : delegates_list_) {
|
||||
auto ptr = one->CreateTfLiteDelegate(params);
|
||||
// It's possible that a delegate of certain type won't be created as
|
||||
// user-specified benchmark params tells not to.
|
||||
if (ptr == nullptr) continue;
|
||||
delegates.emplace_back(std::move(ptr));
|
||||
}
|
||||
return delegates;
|
||||
}
|
||||
|
||||
std::vector<TfLiteDelegatePtr> DelegateProviders::CreateAllDelegates(
|
||||
const TfliteInferenceParams& params) const {
|
||||
tools::ToolParams merged_params;
|
||||
merged_params.Merge(params_);
|
||||
|
||||
const auto type = params.delegate();
|
||||
switch (type) {
|
||||
case TfliteInferenceParams::NNAPI:
|
||||
if (merged_params.HasParam("use_nnapi")) {
|
||||
merged_params.Set<bool>("use_nnapi", true);
|
||||
}
|
||||
break;
|
||||
case TfliteInferenceParams::GPU:
|
||||
if (merged_params.HasParam("use_gpu")) {
|
||||
merged_params.Set<bool>("use_gpu", true);
|
||||
}
|
||||
break;
|
||||
case TfliteInferenceParams::HEXAGON:
|
||||
if (merged_params.HasParam("use_hexagon")) {
|
||||
merged_params.Set<bool>("use_hexagon", true);
|
||||
}
|
||||
break;
|
||||
case TfliteInferenceParams::XNNPACK:
|
||||
if (merged_params.HasParam("use_xnnpack")) {
|
||||
merged_params.Set<bool>("use_xnnpack", true);
|
||||
}
|
||||
if (params.has_num_threads()) {
|
||||
merged_params.Set<int32_t>("num_threads", params.num_threads());
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return CreateAllDelegates(merged_params);
|
||||
}
|
||||
|
||||
} // namespace evaluation
|
||||
} // namespace tflite
|
||||
|
@ -16,12 +16,60 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_LITE_TOOLS_EVALUATION_EVALUATION_DELEGATE_PROVIDER_H_
|
||||
#define TENSORFLOW_LITE_TOOLS_EVALUATION_EVALUATION_DELEGATE_PROVIDER_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/tools/benchmark/delegate_provider.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
|
||||
#include "tensorflow/lite/tools/evaluation/utils.h"
|
||||
#include "tensorflow/lite/tools/tool_params.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace evaluation {
|
||||
|
||||
class DelegateProviders {
|
||||
public:
|
||||
DelegateProviders();
|
||||
|
||||
// Initialize delegate-related parameters from commandline arguments and
|
||||
// returns true if sucessful.
|
||||
bool InitFromCmdlineArgs(int* argc, const char** argv);
|
||||
|
||||
// Get all parameters from all registered delegate providers.
|
||||
const tools::ToolParams& GetAllParams() const { return params_; }
|
||||
|
||||
// Create the a TfLite delegate instance based on the provided delegate
|
||||
// 'name'. If the specified one isn't found, an empty TfLiteDelegatePtr is
|
||||
// returned.
|
||||
TfLiteDelegatePtr CreateDelegate(const std::string& name) const;
|
||||
|
||||
// Create a list of TfLite delegates based on what have been initialized (i.e.
|
||||
// 'params_').
|
||||
std::vector<TfLiteDelegatePtr> CreateAllDelegates() const {
|
||||
return CreateAllDelegates(params_);
|
||||
}
|
||||
|
||||
// Create a list of TfLite delegates based on the given TfliteInferenceParams
|
||||
// 'params' but considering what have been initialized (i.e. 'params_').
|
||||
std::vector<TfLiteDelegatePtr> CreateAllDelegates(
|
||||
const TfliteInferenceParams& params) const;
|
||||
|
||||
private:
|
||||
// Create a list of TfLite delegates based on the provided 'params'.
|
||||
std::vector<TfLiteDelegatePtr> CreateAllDelegates(
|
||||
const tools::ToolParams& params) const;
|
||||
|
||||
// Contain delegate-related parameters that are initialized from command-line
|
||||
// flags.
|
||||
tools::ToolParams params_;
|
||||
|
||||
const benchmark::DelegateProviderList& delegates_list_;
|
||||
// Key is the delegate name, and the value is the index to the
|
||||
// 'delegates_list_'.
|
||||
const std::unordered_map<std::string, int> delegates_map_;
|
||||
};
|
||||
|
||||
// Parse a string 'val' to the corresponding delegate type defined by
|
||||
// TfliteInferenceParams::Delegate.
|
||||
TfliteInferenceParams::Delegate ParseStringToDelegateType(
|
||||
|
@ -39,6 +39,21 @@ TEST(EvaluationDelegateProviderTest, CreateTfLiteDelegate) {
|
||||
EXPECT_TRUE(!CreateTfLiteDelegate(params));
|
||||
}
|
||||
|
||||
TEST(EvaluationDelegateProviderTest, DelegateProvidersParams) {
|
||||
DelegateProviders providers;
|
||||
const auto& params = providers.GetAllParams();
|
||||
EXPECT_TRUE(params.HasParam("use_nnapi"));
|
||||
EXPECT_TRUE(params.HasParam("use_gpu"));
|
||||
|
||||
int argc = 3;
|
||||
const char* argv[] = {"program_name", "--use_gpu=true",
|
||||
"--other_undefined_flag=1"};
|
||||
EXPECT_TRUE(providers.InitFromCmdlineArgs(&argc, argv));
|
||||
EXPECT_TRUE(params.Get<bool>("use_gpu"));
|
||||
EXPECT_EQ(2, argc);
|
||||
EXPECT_EQ("--other_undefined_flag=1", argv[1]);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace evaluation
|
||||
} // namespace tflite
|
||||
|
@ -147,6 +147,7 @@ cc_library(
|
||||
":tflite_inference_stage",
|
||||
":topk_accuracy_eval_stage",
|
||||
"//tensorflow/core:tflite_portable_logging",
|
||||
"//tensorflow/lite/tools/evaluation:evaluation_delegate_provider",
|
||||
"//tensorflow/lite/tools/evaluation:evaluation_stage",
|
||||
"//tensorflow/lite/tools/evaluation:utils",
|
||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
|
||||
@ -163,6 +164,7 @@ cc_library(
|
||||
":tflite_inference_stage",
|
||||
"//tensorflow/core:tflite_portable_logging",
|
||||
"//tensorflow/core/util:stats_calculator_portable",
|
||||
"//tensorflow/lite/tools/evaluation:evaluation_delegate_provider",
|
||||
"//tensorflow/lite/tools/evaluation:evaluation_stage",
|
||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
|
||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto",
|
||||
@ -224,6 +226,7 @@ cc_library(
|
||||
":tflite_inference_stage",
|
||||
"//tensorflow/core:tflite_portable_logging",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/tools/evaluation:evaluation_delegate_provider",
|
||||
"//tensorflow/lite/tools/evaluation:evaluation_stage",
|
||||
"//tensorflow/lite/tools/evaluation:utils",
|
||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
|
||||
|
@ -29,7 +29,8 @@ namespace {
|
||||
const float kCroppingFraction = 0.875;
|
||||
} // namespace
|
||||
|
||||
TfLiteStatus ImageClassificationStage::Init() {
|
||||
TfLiteStatus ImageClassificationStage::Init(
|
||||
const DelegateProviders* delegate_providers) {
|
||||
// Ensure inference params are provided.
|
||||
if (!config_.specification().has_image_classification_params()) {
|
||||
LOG(ERROR) << "ImageClassificationParams not provided";
|
||||
@ -47,7 +48,8 @@ TfLiteStatus ImageClassificationStage::Init() {
|
||||
*tflite_inference_config.mutable_specification()
|
||||
->mutable_tflite_inference_params() = params.inference_params();
|
||||
inference_stage_.reset(new TfliteInferenceStage(tflite_inference_config));
|
||||
if (inference_stage_->Init() != kTfLiteOk) return kTfLiteError;
|
||||
if (inference_stage_->Init(delegate_providers) != kTfLiteOk)
|
||||
return kTfLiteError;
|
||||
|
||||
// Validate model inputs.
|
||||
const TfLiteModelInfo* model_info = inference_stage_->GetModelInfo();
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
|
||||
#include "tensorflow/lite/tools/evaluation/evaluation_stage.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
|
||||
#include "tensorflow/lite/tools/evaluation/stages/image_preprocessing_stage.h"
|
||||
@ -36,7 +37,8 @@ class ImageClassificationStage : public EvaluationStage {
|
||||
explicit ImageClassificationStage(const EvaluationStageConfig& config)
|
||||
: EvaluationStage(config) {}
|
||||
|
||||
TfLiteStatus Init() override;
|
||||
TfLiteStatus Init() override { return Init(nullptr); }
|
||||
TfLiteStatus Init(const DelegateProviders* delegate_providers);
|
||||
|
||||
TfLiteStatus Run() override;
|
||||
|
||||
|
@ -68,11 +68,12 @@ float CalculateAverageError(T* reference, T* test, int64_t num_elements) {
|
||||
|
||||
} // namespace
|
||||
|
||||
TfLiteStatus InferenceProfilerStage::Init() {
|
||||
TfLiteStatus InferenceProfilerStage::Init(
|
||||
const DelegateProviders* delegate_providers) {
|
||||
// Initialize TfliteInferenceStage with the user-provided
|
||||
// TfliteInferenceParams.
|
||||
test_stage_.reset(new TfliteInferenceStage(config_));
|
||||
if (test_stage_->Init() != kTfLiteOk) return kTfLiteError;
|
||||
if (test_stage_->Init(delegate_providers) != kTfLiteOk) return kTfLiteError;
|
||||
LOG(INFO) << "Test interpreter has been initialized.";
|
||||
|
||||
// Initialize a reference TfliteInferenceStage that uses the given model &
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/util/stats_calculator.h"
|
||||
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
|
||||
#include "tensorflow/lite/tools/evaluation/evaluation_stage.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
|
||||
#include "tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.h"
|
||||
@ -39,7 +40,8 @@ class InferenceProfilerStage : public EvaluationStage {
|
||||
explicit InferenceProfilerStage(const EvaluationStageConfig& config)
|
||||
: EvaluationStage(config) {}
|
||||
|
||||
TfLiteStatus Init() override;
|
||||
TfLiteStatus Init() override { return Init(nullptr); }
|
||||
TfLiteStatus Init(const DelegateProviders* delegate_providers);
|
||||
|
||||
// New Gaussian random data is used as input for each Run.
|
||||
TfLiteStatus Run() override;
|
||||
|
@ -26,7 +26,8 @@ limitations under the License.
|
||||
namespace tflite {
|
||||
namespace evaluation {
|
||||
|
||||
TfLiteStatus ObjectDetectionStage::Init() {
|
||||
TfLiteStatus ObjectDetectionStage::Init(
|
||||
const DelegateProviders* delegate_providers) {
|
||||
// Ensure inference params are provided.
|
||||
if (!config_.specification().has_object_detection_params()) {
|
||||
LOG(ERROR) << "ObjectDetectionParams not provided";
|
||||
@ -48,7 +49,7 @@ TfLiteStatus ObjectDetectionStage::Init() {
|
||||
*tflite_inference_config.mutable_specification()
|
||||
->mutable_tflite_inference_params() = params.inference_params();
|
||||
inference_stage_.reset(new TfliteInferenceStage(tflite_inference_config));
|
||||
TF_LITE_ENSURE_STATUS(inference_stage_->Init());
|
||||
TF_LITE_ENSURE_STATUS(inference_stage_->Init(delegate_providers));
|
||||
|
||||
// Validate model inputs.
|
||||
const TfLiteModelInfo* model_info = inference_stage_->GetModelInfo();
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
|
||||
#include "tensorflow/lite/tools/evaluation/evaluation_stage.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
|
||||
@ -43,7 +44,8 @@ class ObjectDetectionStage : public EvaluationStage {
|
||||
explicit ObjectDetectionStage(const EvaluationStageConfig& config)
|
||||
: EvaluationStage(config) {}
|
||||
|
||||
TfLiteStatus Init() override;
|
||||
TfLiteStatus Init() override { return Init(nullptr); }
|
||||
TfLiteStatus Init(const DelegateProviders* delegate_providers);
|
||||
|
||||
TfLiteStatus Run() override;
|
||||
|
||||
|
@ -20,7 +20,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/profiling/time.h"
|
||||
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
|
||||
#include "tensorflow/lite/tools/evaluation/utils.h"
|
||||
|
||||
@ -71,7 +70,8 @@ TfLiteStatus TfliteInferenceStage::ApplyCustomDelegate(
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus TfliteInferenceStage::Init() {
|
||||
TfLiteStatus TfliteInferenceStage::Init(
|
||||
const DelegateProviders* delegate_providers) {
|
||||
if (!config_.specification().has_tflite_inference_params()) {
|
||||
LOG(ERROR) << "TfliteInferenceParams not provided";
|
||||
return kTfLiteError;
|
||||
@ -96,14 +96,19 @@ TfLiteStatus TfliteInferenceStage::Init() {
|
||||
}
|
||||
interpreter_->SetNumThreads(params.num_threads());
|
||||
|
||||
std::string error_message;
|
||||
auto delegate = CreateTfLiteDelegate(params, &error_message);
|
||||
if (delegate) {
|
||||
delegates_.push_back(std::move(delegate));
|
||||
LOG(INFO) << "Successfully created "
|
||||
<< params.Delegate_Name(params.delegate()) << " delegate.";
|
||||
if (!delegate_providers) {
|
||||
std::string error_message;
|
||||
auto delegate = CreateTfLiteDelegate(params, &error_message);
|
||||
if (delegate) {
|
||||
delegates_.push_back(std::move(delegate));
|
||||
LOG(INFO) << "Successfully created "
|
||||
<< params.Delegate_Name(params.delegate()) << " delegate.";
|
||||
} else {
|
||||
LOG(WARNING) << error_message;
|
||||
}
|
||||
} else {
|
||||
LOG(WARNING) << error_message;
|
||||
auto delegates = delegate_providers->CreateAllDelegates(params);
|
||||
for (auto& one : delegates) delegates_.push_back(std::move(one));
|
||||
}
|
||||
|
||||
for (int i = 0; i < delegates_.size(); ++i) {
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
|
||||
#include "tensorflow/lite/tools/evaluation/evaluation_stage.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
|
||||
|
||||
@ -41,14 +42,15 @@ class TfliteInferenceStage : public EvaluationStage {
|
||||
explicit TfliteInferenceStage(const EvaluationStageConfig& config)
|
||||
: EvaluationStage(config) {}
|
||||
|
||||
TfLiteStatus Init() override;
|
||||
TfLiteStatus Init() override { return Init(nullptr); }
|
||||
TfLiteStatus Init(const DelegateProviders* delegate_providers);
|
||||
|
||||
TfLiteStatus Run() override;
|
||||
|
||||
// EvaluationStageMetrics.num_runs denotes the number of inferences run.
|
||||
EvaluationStageMetrics LatestMetrics() override;
|
||||
|
||||
~TfliteInferenceStage() {}
|
||||
~TfliteInferenceStage() override {}
|
||||
|
||||
// Call before Run().
|
||||
// This class does not take ownership of raw_input_ptrs.
|
||||
|
@ -50,7 +50,8 @@ bool EvaluateModel(const std::string& model_file_path,
|
||||
const std::vector<std::string>& image_paths,
|
||||
const std::string& ground_truth_proto_file,
|
||||
std::string delegate, std::string output_file_path,
|
||||
int num_interpreter_threads, bool debug_mode) {
|
||||
int num_interpreter_threads, bool debug_mode,
|
||||
const DelegateProviders& delegate_providers) {
|
||||
EvaluationStageConfig eval_config;
|
||||
eval_config.set_name("object_detection");
|
||||
auto* detection_params =
|
||||
@ -74,7 +75,7 @@ bool EvaluateModel(const std::string& model_file_path,
|
||||
ObjectDetectionStage eval(eval_config);
|
||||
|
||||
eval.SetAllLabels(model_labels);
|
||||
if (eval.Init() != kTfLiteOk) return false;
|
||||
if (eval.Init(&delegate_providers) != kTfLiteOk) return false;
|
||||
|
||||
// Open output file for writing.
|
||||
std::ofstream ofile;
|
||||
@ -156,6 +157,8 @@ int Main(int argc, char* argv[]) {
|
||||
"Must be one of {'nnapi', 'gpu'}"),
|
||||
};
|
||||
tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
|
||||
DelegateProviders delegate_providers;
|
||||
delegate_providers.InitFromCmdlineArgs(&argc, const_cast<const char**>(argv));
|
||||
|
||||
// Process images in filename-sorted order.
|
||||
std::vector<std::string> image_paths;
|
||||
@ -170,7 +173,7 @@ int Main(int argc, char* argv[]) {
|
||||
|
||||
if (!EvaluateModel(model_file_path, model_labels, image_paths,
|
||||
ground_truth_proto_file, delegate, output_file_path,
|
||||
num_interpreter_threads, debug_mode)) {
|
||||
num_interpreter_threads, debug_mode, delegate_providers)) {
|
||||
LOG(ERROR) << "Could not evaluate model";
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
@ -50,7 +50,8 @@ bool EvaluateModel(const std::string& model_file_path,
|
||||
const std::vector<ImageLabel>& image_labels,
|
||||
const std::vector<std::string>& model_labels,
|
||||
std::string delegate, std::string output_file_path,
|
||||
int num_interpreter_threads) {
|
||||
int num_interpreter_threads,
|
||||
const DelegateProviders& delegate_providers) {
|
||||
EvaluationStageConfig eval_config;
|
||||
eval_config.set_name("image_classification");
|
||||
auto* classification_params = eval_config.mutable_specification()
|
||||
@ -69,7 +70,7 @@ bool EvaluateModel(const std::string& model_file_path,
|
||||
ImageClassificationStage eval(eval_config);
|
||||
|
||||
eval.SetAllLabels(model_labels);
|
||||
if (eval.Init() != kTfLiteOk) return false;
|
||||
if (eval.Init(&delegate_providers) != kTfLiteOk) return false;
|
||||
|
||||
const int step = image_labels.size() / 100;
|
||||
for (int i = 0; i < image_labels.size(); ++i) {
|
||||
@ -135,6 +136,8 @@ int Main(int argc, char* argv[]) {
|
||||
"Must be one of {'nnapi', 'gpu'}"),
|
||||
};
|
||||
tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
|
||||
DelegateProviders delegate_providers;
|
||||
delegate_providers.InitFromCmdlineArgs(&argc, const_cast<const char**>(argv));
|
||||
|
||||
// Process images in filename-sorted order.
|
||||
std::vector<std::string> image_files, ground_truth_image_labels;
|
||||
@ -168,7 +171,8 @@ int Main(int argc, char* argv[]) {
|
||||
}
|
||||
|
||||
if (!EvaluateModel(model_file_path, image_labels, model_labels, delegate,
|
||||
output_file_path, num_interpreter_threads)) {
|
||||
output_file_path, num_interpreter_threads,
|
||||
delegate_providers)) {
|
||||
LOG(ERROR) << "Could not evaluate model";
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
@ -36,7 +36,8 @@ constexpr char kDelegateFlag[] = "delegate";
|
||||
bool EvaluateModel(const std::string& model_file_path,
|
||||
const std::string& delegate, int num_runs,
|
||||
const std::string& output_file_path,
|
||||
int num_interpreter_threads) {
|
||||
int num_interpreter_threads,
|
||||
const DelegateProviders& delegate_providers) {
|
||||
// Initialize evaluation stage.
|
||||
EvaluationStageConfig eval_config;
|
||||
eval_config.set_name("inference_profiling");
|
||||
@ -54,7 +55,7 @@ bool EvaluateModel(const std::string& model_file_path,
|
||||
return false;
|
||||
}
|
||||
InferenceProfilerStage eval(eval_config);
|
||||
if (eval.Init() != kTfLiteOk) return false;
|
||||
if (eval.Init(&delegate_providers) != kTfLiteOk) return false;
|
||||
|
||||
// Run inference & check diff for specified number of runs.
|
||||
for (int i = 0; i < num_runs; ++i) {
|
||||
@ -94,8 +95,10 @@ int Main(int argc, char* argv[]) {
|
||||
};
|
||||
tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
|
||||
|
||||
DelegateProviders delegate_providers;
|
||||
delegate_providers.InitFromCmdlineArgs(&argc, const_cast<const char**>(argv));
|
||||
if (!EvaluateModel(model_file_path, delegate, num_runs, output_file_path,
|
||||
num_interpreter_threads)) {
|
||||
num_interpreter_threads, delegate_providers)) {
|
||||
LOG(ERROR) << "Could not evaluate model!";
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user