From 58afbd980364e158e463e4b1f51a2031b6131449 Mon Sep 17 00:00:00 2001 From: Chao Mei Date: Fri, 10 Apr 2020 02:42:19 -0700 Subject: [PATCH] Support to reuse the same set of commandline flags that initialize tflite delegates in the benchmark tool for evaluation-related tools via the delegate-provider registerar, and apply this support to image_classification, objection_detection and inference_diff eval tools. PiperOrigin-RevId: 305849016 Change-Id: Ie5f35b1bf4fa38e10f3ac7bf9825efacd41988d4 --- tensorflow/lite/tools/evaluation/BUILD | 4 + .../evaluation_delegate_provider.cc | 84 ++++++++++++++++++- .../evaluation/evaluation_delegate_provider.h | 48 +++++++++++ .../evaluation_delegate_provider_test.cc | 15 ++++ tensorflow/lite/tools/evaluation/stages/BUILD | 3 + .../stages/image_classification_stage.cc | 6 +- .../stages/image_classification_stage.h | 4 +- .../stages/inference_profiler_stage.cc | 5 +- .../stages/inference_profiler_stage.h | 4 +- .../stages/object_detection_stage.cc | 5 +- .../stages/object_detection_stage.h | 4 +- .../stages/tflite_inference_stage.cc | 23 +++-- .../stages/tflite_inference_stage.h | 6 +- .../tasks/coco_object_detection/run_eval.cc | 9 +- .../imagenet_image_classification/run_eval.cc | 10 ++- .../tasks/inference_diff/run_eval.cc | 9 +- 16 files changed, 209 insertions(+), 30 deletions(-) diff --git a/tensorflow/lite/tools/evaluation/BUILD b/tensorflow/lite/tools/evaluation/BUILD index caa4a637766..bf21e553b1f 100644 --- a/tensorflow/lite/tools/evaluation/BUILD +++ b/tensorflow/lite/tools/evaluation/BUILD @@ -70,6 +70,10 @@ cc_library( copts = tflite_copts(), deps = [ ":utils", + "//tensorflow/lite/tools:command_line_flags", + "//tensorflow/lite/tools:tool_params", + "//tensorflow/lite/tools/benchmark:delegate_provider_hdr", + "//tensorflow/lite/tools/benchmark:tflite_execution_providers", "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", ], ) diff --git a/tensorflow/lite/tools/evaluation/evaluation_delegate_provider.cc b/tensorflow/lite/tools/evaluation/evaluation_delegate_provider.cc index 925cae8d140..91a0aea4711 100644 --- a/tensorflow/lite/tools/evaluation/evaluation_delegate_provider.cc +++ b/tensorflow/lite/tools/evaluation/evaluation_delegate_provider.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h" +#include "tensorflow/lite/tools/command_line_flags.h" + namespace tflite { namespace evaluation { namespace { @@ -36,7 +38,6 @@ TfliteInferenceParams::Delegate ParseStringToDelegateType( TfLiteDelegatePtr CreateTfLiteDelegate(const TfliteInferenceParams& params, std::string* error_msg) { const auto type = params.delegate(); - switch (type) { case TfliteInferenceParams::NNAPI: { auto p = CreateNNAPIDelegate(); @@ -76,5 +77,86 @@ TfLiteDelegatePtr CreateTfLiteDelegate(const TfliteInferenceParams& params, } } +DelegateProviders::DelegateProviders() + : delegates_list_(benchmark::GetRegisteredDelegateProviders()), + delegates_map_([=]() -> std::unordered_map { + std::unordered_map delegates_map; + for (int i = 0; i < delegates_list_.size(); ++i) { + delegates_map[delegates_list_[i]->GetName()] = i; + } + return delegates_map; + }()) { + for (const auto& one : delegates_list_) { + params_.Merge(one->DefaultParams()); + } +} + +bool DelegateProviders::InitFromCmdlineArgs(int* argc, const char** argv) { + std::vector flags; + for (const auto& one : delegates_list_) { + auto one_flags = one->CreateFlags(¶ms_); + flags.insert(flags.end(), one_flags.begin(), one_flags.end()); + } + return Flags::Parse(argc, argv, flags); +} + +TfLiteDelegatePtr DelegateProviders::CreateDelegate( + const std::string& name) const { + const auto it = delegates_map_.find(name); + if (it == delegates_map_.end()) { + return TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {}); + } + return delegates_list_[it->second]->CreateTfLiteDelegate(params_); +} + +std::vector DelegateProviders::CreateAllDelegates( + const tools::ToolParams& params) const { + std::vector delegates; + for (const auto& one : delegates_list_) { + auto ptr = one->CreateTfLiteDelegate(params); + // It's possible that a delegate of certain type won't be created as + // user-specified benchmark params tells not to. + if (ptr == nullptr) continue; + delegates.emplace_back(std::move(ptr)); + } + return delegates; +} + +std::vector DelegateProviders::CreateAllDelegates( + const TfliteInferenceParams& params) const { + tools::ToolParams merged_params; + merged_params.Merge(params_); + + const auto type = params.delegate(); + switch (type) { + case TfliteInferenceParams::NNAPI: + if (merged_params.HasParam("use_nnapi")) { + merged_params.Set("use_nnapi", true); + } + break; + case TfliteInferenceParams::GPU: + if (merged_params.HasParam("use_gpu")) { + merged_params.Set("use_gpu", true); + } + break; + case TfliteInferenceParams::HEXAGON: + if (merged_params.HasParam("use_hexagon")) { + merged_params.Set("use_hexagon", true); + } + break; + case TfliteInferenceParams::XNNPACK: + if (merged_params.HasParam("use_xnnpack")) { + merged_params.Set("use_xnnpack", true); + } + if (params.has_num_threads()) { + merged_params.Set("num_threads", params.num_threads()); + } + break; + default: + break; + } + return CreateAllDelegates(merged_params); +} + } // namespace evaluation } // namespace tflite diff --git a/tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h b/tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h index 7f093295be2..5c5c4bb1021 100644 --- a/tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h +++ b/tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h @@ -16,12 +16,60 @@ limitations under the License. #ifndef TENSORFLOW_LITE_TOOLS_EVALUATION_EVALUATION_DELEGATE_PROVIDER_H_ #define TENSORFLOW_LITE_TOOLS_EVALUATION_EVALUATION_DELEGATE_PROVIDER_H_ +#include +#include +#include + +#include "tensorflow/lite/tools/benchmark/delegate_provider.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h" #include "tensorflow/lite/tools/evaluation/utils.h" +#include "tensorflow/lite/tools/tool_params.h" namespace tflite { namespace evaluation { +class DelegateProviders { + public: + DelegateProviders(); + + // Initialize delegate-related parameters from commandline arguments and + // returns true if sucessful. + bool InitFromCmdlineArgs(int* argc, const char** argv); + + // Get all parameters from all registered delegate providers. + const tools::ToolParams& GetAllParams() const { return params_; } + + // Create the a TfLite delegate instance based on the provided delegate + // 'name'. If the specified one isn't found, an empty TfLiteDelegatePtr is + // returned. + TfLiteDelegatePtr CreateDelegate(const std::string& name) const; + + // Create a list of TfLite delegates based on what have been initialized (i.e. + // 'params_'). + std::vector CreateAllDelegates() const { + return CreateAllDelegates(params_); + } + + // Create a list of TfLite delegates based on the given TfliteInferenceParams + // 'params' but considering what have been initialized (i.e. 'params_'). + std::vector CreateAllDelegates( + const TfliteInferenceParams& params) const; + + private: + // Create a list of TfLite delegates based on the provided 'params'. + std::vector CreateAllDelegates( + const tools::ToolParams& params) const; + + // Contain delegate-related parameters that are initialized from command-line + // flags. + tools::ToolParams params_; + + const benchmark::DelegateProviderList& delegates_list_; + // Key is the delegate name, and the value is the index to the + // 'delegates_list_'. + const std::unordered_map delegates_map_; +}; + // Parse a string 'val' to the corresponding delegate type defined by // TfliteInferenceParams::Delegate. TfliteInferenceParams::Delegate ParseStringToDelegateType( diff --git a/tensorflow/lite/tools/evaluation/evaluation_delegate_provider_test.cc b/tensorflow/lite/tools/evaluation/evaluation_delegate_provider_test.cc index 1b984206eb6..1d7870eaed0 100644 --- a/tensorflow/lite/tools/evaluation/evaluation_delegate_provider_test.cc +++ b/tensorflow/lite/tools/evaluation/evaluation_delegate_provider_test.cc @@ -39,6 +39,21 @@ TEST(EvaluationDelegateProviderTest, CreateTfLiteDelegate) { EXPECT_TRUE(!CreateTfLiteDelegate(params)); } +TEST(EvaluationDelegateProviderTest, DelegateProvidersParams) { + DelegateProviders providers; + const auto& params = providers.GetAllParams(); + EXPECT_TRUE(params.HasParam("use_nnapi")); + EXPECT_TRUE(params.HasParam("use_gpu")); + + int argc = 3; + const char* argv[] = {"program_name", "--use_gpu=true", + "--other_undefined_flag=1"}; + EXPECT_TRUE(providers.InitFromCmdlineArgs(&argc, argv)); + EXPECT_TRUE(params.Get("use_gpu")); + EXPECT_EQ(2, argc); + EXPECT_EQ("--other_undefined_flag=1", argv[1]); +} + } // namespace } // namespace evaluation } // namespace tflite diff --git a/tensorflow/lite/tools/evaluation/stages/BUILD b/tensorflow/lite/tools/evaluation/stages/BUILD index ea3341f4e75..1650151bfa7 100644 --- a/tensorflow/lite/tools/evaluation/stages/BUILD +++ b/tensorflow/lite/tools/evaluation/stages/BUILD @@ -147,6 +147,7 @@ cc_library( ":tflite_inference_stage", ":topk_accuracy_eval_stage", "//tensorflow/core:tflite_portable_logging", + "//tensorflow/lite/tools/evaluation:evaluation_delegate_provider", "//tensorflow/lite/tools/evaluation:evaluation_stage", "//tensorflow/lite/tools/evaluation:utils", "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", @@ -163,6 +164,7 @@ cc_library( ":tflite_inference_stage", "//tensorflow/core:tflite_portable_logging", "//tensorflow/core/util:stats_calculator_portable", + "//tensorflow/lite/tools/evaluation:evaluation_delegate_provider", "//tensorflow/lite/tools/evaluation:evaluation_stage", "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", @@ -224,6 +226,7 @@ cc_library( ":tflite_inference_stage", "//tensorflow/core:tflite_portable_logging", "//tensorflow/lite/c:common", + "//tensorflow/lite/tools/evaluation:evaluation_delegate_provider", "//tensorflow/lite/tools/evaluation:evaluation_stage", "//tensorflow/lite/tools/evaluation:utils", "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", diff --git a/tensorflow/lite/tools/evaluation/stages/image_classification_stage.cc b/tensorflow/lite/tools/evaluation/stages/image_classification_stage.cc index c9f8f832441..212e148cbc7 100644 --- a/tensorflow/lite/tools/evaluation/stages/image_classification_stage.cc +++ b/tensorflow/lite/tools/evaluation/stages/image_classification_stage.cc @@ -29,7 +29,8 @@ namespace { const float kCroppingFraction = 0.875; } // namespace -TfLiteStatus ImageClassificationStage::Init() { +TfLiteStatus ImageClassificationStage::Init( + const DelegateProviders* delegate_providers) { // Ensure inference params are provided. if (!config_.specification().has_image_classification_params()) { LOG(ERROR) << "ImageClassificationParams not provided"; @@ -47,7 +48,8 @@ TfLiteStatus ImageClassificationStage::Init() { *tflite_inference_config.mutable_specification() ->mutable_tflite_inference_params() = params.inference_params(); inference_stage_.reset(new TfliteInferenceStage(tflite_inference_config)); - if (inference_stage_->Init() != kTfLiteOk) return kTfLiteError; + if (inference_stage_->Init(delegate_providers) != kTfLiteOk) + return kTfLiteError; // Validate model inputs. const TfLiteModelInfo* model_info = inference_stage_->GetModelInfo(); diff --git a/tensorflow/lite/tools/evaluation/stages/image_classification_stage.h b/tensorflow/lite/tools/evaluation/stages/image_classification_stage.h index a74a5979f35..c3f8eb8f900 100644 --- a/tensorflow/lite/tools/evaluation/stages/image_classification_stage.h +++ b/tensorflow/lite/tools/evaluation/stages/image_classification_stage.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h" #include "tensorflow/lite/tools/evaluation/evaluation_stage.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h" #include "tensorflow/lite/tools/evaluation/stages/image_preprocessing_stage.h" @@ -36,7 +37,8 @@ class ImageClassificationStage : public EvaluationStage { explicit ImageClassificationStage(const EvaluationStageConfig& config) : EvaluationStage(config) {} - TfLiteStatus Init() override; + TfLiteStatus Init() override { return Init(nullptr); } + TfLiteStatus Init(const DelegateProviders* delegate_providers); TfLiteStatus Run() override; diff --git a/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.cc b/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.cc index cfafc1e9214..8a3759a17c2 100644 --- a/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.cc +++ b/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.cc @@ -68,11 +68,12 @@ float CalculateAverageError(T* reference, T* test, int64_t num_elements) { } // namespace -TfLiteStatus InferenceProfilerStage::Init() { +TfLiteStatus InferenceProfilerStage::Init( + const DelegateProviders* delegate_providers) { // Initialize TfliteInferenceStage with the user-provided // TfliteInferenceParams. test_stage_.reset(new TfliteInferenceStage(config_)); - if (test_stage_->Init() != kTfLiteOk) return kTfLiteError; + if (test_stage_->Init(delegate_providers) != kTfLiteOk) return kTfLiteError; LOG(INFO) << "Test interpreter has been initialized."; // Initialize a reference TfliteInferenceStage that uses the given model & diff --git a/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.h b/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.h index e5fd37943e9..d10c7beb088 100644 --- a/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.h +++ b/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.h @@ -22,6 +22,7 @@ limitations under the License. #include #include "tensorflow/core/util/stats_calculator.h" +#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h" #include "tensorflow/lite/tools/evaluation/evaluation_stage.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h" #include "tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.h" @@ -39,7 +40,8 @@ class InferenceProfilerStage : public EvaluationStage { explicit InferenceProfilerStage(const EvaluationStageConfig& config) : EvaluationStage(config) {} - TfLiteStatus Init() override; + TfLiteStatus Init() override { return Init(nullptr); } + TfLiteStatus Init(const DelegateProviders* delegate_providers); // New Gaussian random data is used as input for each Run. TfLiteStatus Run() override; diff --git a/tensorflow/lite/tools/evaluation/stages/object_detection_stage.cc b/tensorflow/lite/tools/evaluation/stages/object_detection_stage.cc index f7821d81894..1ed8db2076c 100644 --- a/tensorflow/lite/tools/evaluation/stages/object_detection_stage.cc +++ b/tensorflow/lite/tools/evaluation/stages/object_detection_stage.cc @@ -26,7 +26,8 @@ limitations under the License. namespace tflite { namespace evaluation { -TfLiteStatus ObjectDetectionStage::Init() { +TfLiteStatus ObjectDetectionStage::Init( + const DelegateProviders* delegate_providers) { // Ensure inference params are provided. if (!config_.specification().has_object_detection_params()) { LOG(ERROR) << "ObjectDetectionParams not provided"; @@ -48,7 +49,7 @@ TfLiteStatus ObjectDetectionStage::Init() { *tflite_inference_config.mutable_specification() ->mutable_tflite_inference_params() = params.inference_params(); inference_stage_.reset(new TfliteInferenceStage(tflite_inference_config)); - TF_LITE_ENSURE_STATUS(inference_stage_->Init()); + TF_LITE_ENSURE_STATUS(inference_stage_->Init(delegate_providers)); // Validate model inputs. const TfLiteModelInfo* model_info = inference_stage_->GetModelInfo(); diff --git a/tensorflow/lite/tools/evaluation/stages/object_detection_stage.h b/tensorflow/lite/tools/evaluation/stages/object_detection_stage.h index cc0c935bba9..1489d853c34 100644 --- a/tensorflow/lite/tools/evaluation/stages/object_detection_stage.h +++ b/tensorflow/lite/tools/evaluation/stages/object_detection_stage.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h" #include "tensorflow/lite/tools/evaluation/evaluation_stage.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h" @@ -43,7 +44,8 @@ class ObjectDetectionStage : public EvaluationStage { explicit ObjectDetectionStage(const EvaluationStageConfig& config) : EvaluationStage(config) {} - TfLiteStatus Init() override; + TfLiteStatus Init() override { return Init(nullptr); } + TfLiteStatus Init(const DelegateProviders* delegate_providers); TfLiteStatus Run() override; diff --git a/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.cc b/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.cc index cbf41de0e03..bb01ca2cc4d 100644 --- a/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.cc +++ b/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/profiling/time.h" -#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h" #include "tensorflow/lite/tools/evaluation/utils.h" @@ -71,7 +70,8 @@ TfLiteStatus TfliteInferenceStage::ApplyCustomDelegate( return kTfLiteOk; } -TfLiteStatus TfliteInferenceStage::Init() { +TfLiteStatus TfliteInferenceStage::Init( + const DelegateProviders* delegate_providers) { if (!config_.specification().has_tflite_inference_params()) { LOG(ERROR) << "TfliteInferenceParams not provided"; return kTfLiteError; @@ -96,14 +96,19 @@ TfLiteStatus TfliteInferenceStage::Init() { } interpreter_->SetNumThreads(params.num_threads()); - std::string error_message; - auto delegate = CreateTfLiteDelegate(params, &error_message); - if (delegate) { - delegates_.push_back(std::move(delegate)); - LOG(INFO) << "Successfully created " - << params.Delegate_Name(params.delegate()) << " delegate."; + if (!delegate_providers) { + std::string error_message; + auto delegate = CreateTfLiteDelegate(params, &error_message); + if (delegate) { + delegates_.push_back(std::move(delegate)); + LOG(INFO) << "Successfully created " + << params.Delegate_Name(params.delegate()) << " delegate."; + } else { + LOG(WARNING) << error_message; + } } else { - LOG(WARNING) << error_message; + auto delegates = delegate_providers->CreateAllDelegates(params); + for (auto& one : delegates) delegates_.push_back(std::move(one)); } for (int i = 0; i < delegates_.size(); ++i) { diff --git a/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.h b/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.h index c27462fdcd6..a8a319fcd16 100644 --- a/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.h +++ b/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" +#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h" #include "tensorflow/lite/tools/evaluation/evaluation_stage.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h" @@ -41,14 +42,15 @@ class TfliteInferenceStage : public EvaluationStage { explicit TfliteInferenceStage(const EvaluationStageConfig& config) : EvaluationStage(config) {} - TfLiteStatus Init() override; + TfLiteStatus Init() override { return Init(nullptr); } + TfLiteStatus Init(const DelegateProviders* delegate_providers); TfLiteStatus Run() override; // EvaluationStageMetrics.num_runs denotes the number of inferences run. EvaluationStageMetrics LatestMetrics() override; - ~TfliteInferenceStage() {} + ~TfliteInferenceStage() override {} // Call before Run(). // This class does not take ownership of raw_input_ptrs. diff --git a/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/run_eval.cc b/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/run_eval.cc index 39b5082accb..3b5fc08ab84 100644 --- a/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/run_eval.cc +++ b/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/run_eval.cc @@ -50,7 +50,8 @@ bool EvaluateModel(const std::string& model_file_path, const std::vector& image_paths, const std::string& ground_truth_proto_file, std::string delegate, std::string output_file_path, - int num_interpreter_threads, bool debug_mode) { + int num_interpreter_threads, bool debug_mode, + const DelegateProviders& delegate_providers) { EvaluationStageConfig eval_config; eval_config.set_name("object_detection"); auto* detection_params = @@ -74,7 +75,7 @@ bool EvaluateModel(const std::string& model_file_path, ObjectDetectionStage eval(eval_config); eval.SetAllLabels(model_labels); - if (eval.Init() != kTfLiteOk) return false; + if (eval.Init(&delegate_providers) != kTfLiteOk) return false; // Open output file for writing. std::ofstream ofile; @@ -156,6 +157,8 @@ int Main(int argc, char* argv[]) { "Must be one of {'nnapi', 'gpu'}"), }; tflite::Flags::Parse(&argc, const_cast(argv), flag_list); + DelegateProviders delegate_providers; + delegate_providers.InitFromCmdlineArgs(&argc, const_cast(argv)); // Process images in filename-sorted order. std::vector image_paths; @@ -170,7 +173,7 @@ int Main(int argc, char* argv[]) { if (!EvaluateModel(model_file_path, model_labels, image_paths, ground_truth_proto_file, delegate, output_file_path, - num_interpreter_threads, debug_mode)) { + num_interpreter_threads, debug_mode, delegate_providers)) { LOG(ERROR) << "Could not evaluate model"; return EXIT_FAILURE; } diff --git a/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/run_eval.cc b/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/run_eval.cc index 5268039c500..cd6c6cfb3c4 100644 --- a/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/run_eval.cc +++ b/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/run_eval.cc @@ -50,7 +50,8 @@ bool EvaluateModel(const std::string& model_file_path, const std::vector& image_labels, const std::vector& model_labels, std::string delegate, std::string output_file_path, - int num_interpreter_threads) { + int num_interpreter_threads, + const DelegateProviders& delegate_providers) { EvaluationStageConfig eval_config; eval_config.set_name("image_classification"); auto* classification_params = eval_config.mutable_specification() @@ -69,7 +70,7 @@ bool EvaluateModel(const std::string& model_file_path, ImageClassificationStage eval(eval_config); eval.SetAllLabels(model_labels); - if (eval.Init() != kTfLiteOk) return false; + if (eval.Init(&delegate_providers) != kTfLiteOk) return false; const int step = image_labels.size() / 100; for (int i = 0; i < image_labels.size(); ++i) { @@ -135,6 +136,8 @@ int Main(int argc, char* argv[]) { "Must be one of {'nnapi', 'gpu'}"), }; tflite::Flags::Parse(&argc, const_cast(argv), flag_list); + DelegateProviders delegate_providers; + delegate_providers.InitFromCmdlineArgs(&argc, const_cast(argv)); // Process images in filename-sorted order. std::vector image_files, ground_truth_image_labels; @@ -168,7 +171,8 @@ int Main(int argc, char* argv[]) { } if (!EvaluateModel(model_file_path, image_labels, model_labels, delegate, - output_file_path, num_interpreter_threads)) { + output_file_path, num_interpreter_threads, + delegate_providers)) { LOG(ERROR) << "Could not evaluate model"; return EXIT_FAILURE; } diff --git a/tensorflow/lite/tools/evaluation/tasks/inference_diff/run_eval.cc b/tensorflow/lite/tools/evaluation/tasks/inference_diff/run_eval.cc index cdd83d52d6f..6a7c6e8fc42 100644 --- a/tensorflow/lite/tools/evaluation/tasks/inference_diff/run_eval.cc +++ b/tensorflow/lite/tools/evaluation/tasks/inference_diff/run_eval.cc @@ -36,7 +36,8 @@ constexpr char kDelegateFlag[] = "delegate"; bool EvaluateModel(const std::string& model_file_path, const std::string& delegate, int num_runs, const std::string& output_file_path, - int num_interpreter_threads) { + int num_interpreter_threads, + const DelegateProviders& delegate_providers) { // Initialize evaluation stage. EvaluationStageConfig eval_config; eval_config.set_name("inference_profiling"); @@ -54,7 +55,7 @@ bool EvaluateModel(const std::string& model_file_path, return false; } InferenceProfilerStage eval(eval_config); - if (eval.Init() != kTfLiteOk) return false; + if (eval.Init(&delegate_providers) != kTfLiteOk) return false; // Run inference & check diff for specified number of runs. for (int i = 0; i < num_runs; ++i) { @@ -94,8 +95,10 @@ int Main(int argc, char* argv[]) { }; tflite::Flags::Parse(&argc, const_cast(argv), flag_list); + DelegateProviders delegate_providers; + delegate_providers.InitFromCmdlineArgs(&argc, const_cast(argv)); if (!EvaluateModel(model_file_path, delegate, num_runs, output_file_path, - num_interpreter_threads)) { + num_interpreter_threads, delegate_providers)) { LOG(ERROR) << "Could not evaluate model!"; return EXIT_FAILURE; }