Support to reuse the same set of commandline flags that initialize tflite delegates in the benchmark tool for evaluation-related tools via the delegate-provider registerar, and apply this support to image_classification, objection_detection and inference_diff eval tools.

PiperOrigin-RevId: 305849016
Change-Id: Ie5f35b1bf4fa38e10f3ac7bf9825efacd41988d4
This commit is contained in:
Chao Mei 2020-04-10 02:42:19 -07:00 committed by TensorFlower Gardener
parent 77c0bb4322
commit 58afbd9803
16 changed files with 209 additions and 30 deletions

View File

@ -70,6 +70,10 @@ cc_library(
copts = tflite_copts(),
deps = [
":utils",
"//tensorflow/lite/tools:command_line_flags",
"//tensorflow/lite/tools:tool_params",
"//tensorflow/lite/tools/benchmark:delegate_provider_hdr",
"//tensorflow/lite/tools/benchmark:tflite_execution_providers",
"//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto",
],
)

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
#include "tensorflow/lite/tools/command_line_flags.h"
namespace tflite {
namespace evaluation {
namespace {
@ -36,7 +38,6 @@ TfliteInferenceParams::Delegate ParseStringToDelegateType(
TfLiteDelegatePtr CreateTfLiteDelegate(const TfliteInferenceParams& params,
std::string* error_msg) {
const auto type = params.delegate();
switch (type) {
case TfliteInferenceParams::NNAPI: {
auto p = CreateNNAPIDelegate();
@ -76,5 +77,86 @@ TfLiteDelegatePtr CreateTfLiteDelegate(const TfliteInferenceParams& params,
}
}
DelegateProviders::DelegateProviders()
: delegates_list_(benchmark::GetRegisteredDelegateProviders()),
delegates_map_([=]() -> std::unordered_map<std::string, int> {
std::unordered_map<std::string, int> delegates_map;
for (int i = 0; i < delegates_list_.size(); ++i) {
delegates_map[delegates_list_[i]->GetName()] = i;
}
return delegates_map;
}()) {
for (const auto& one : delegates_list_) {
params_.Merge(one->DefaultParams());
}
}
bool DelegateProviders::InitFromCmdlineArgs(int* argc, const char** argv) {
std::vector<Flag> flags;
for (const auto& one : delegates_list_) {
auto one_flags = one->CreateFlags(&params_);
flags.insert(flags.end(), one_flags.begin(), one_flags.end());
}
return Flags::Parse(argc, argv, flags);
}
TfLiteDelegatePtr DelegateProviders::CreateDelegate(
const std::string& name) const {
const auto it = delegates_map_.find(name);
if (it == delegates_map_.end()) {
return TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {});
}
return delegates_list_[it->second]->CreateTfLiteDelegate(params_);
}
std::vector<TfLiteDelegatePtr> DelegateProviders::CreateAllDelegates(
const tools::ToolParams& params) const {
std::vector<TfLiteDelegatePtr> delegates;
for (const auto& one : delegates_list_) {
auto ptr = one->CreateTfLiteDelegate(params);
// It's possible that a delegate of certain type won't be created as
// user-specified benchmark params tells not to.
if (ptr == nullptr) continue;
delegates.emplace_back(std::move(ptr));
}
return delegates;
}
std::vector<TfLiteDelegatePtr> DelegateProviders::CreateAllDelegates(
const TfliteInferenceParams& params) const {
tools::ToolParams merged_params;
merged_params.Merge(params_);
const auto type = params.delegate();
switch (type) {
case TfliteInferenceParams::NNAPI:
if (merged_params.HasParam("use_nnapi")) {
merged_params.Set<bool>("use_nnapi", true);
}
break;
case TfliteInferenceParams::GPU:
if (merged_params.HasParam("use_gpu")) {
merged_params.Set<bool>("use_gpu", true);
}
break;
case TfliteInferenceParams::HEXAGON:
if (merged_params.HasParam("use_hexagon")) {
merged_params.Set<bool>("use_hexagon", true);
}
break;
case TfliteInferenceParams::XNNPACK:
if (merged_params.HasParam("use_xnnpack")) {
merged_params.Set<bool>("use_xnnpack", true);
}
if (params.has_num_threads()) {
merged_params.Set<int32_t>("num_threads", params.num_threads());
}
break;
default:
break;
}
return CreateAllDelegates(merged_params);
}
} // namespace evaluation
} // namespace tflite

View File

@ -16,12 +16,60 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_TOOLS_EVALUATION_EVALUATION_DELEGATE_PROVIDER_H_
#define TENSORFLOW_LITE_TOOLS_EVALUATION_EVALUATION_DELEGATE_PROVIDER_H_
#include <string>
#include <unordered_map>
#include <vector>
#include "tensorflow/lite/tools/benchmark/delegate_provider.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
#include "tensorflow/lite/tools/evaluation/utils.h"
#include "tensorflow/lite/tools/tool_params.h"
namespace tflite {
namespace evaluation {
class DelegateProviders {
public:
DelegateProviders();
// Initialize delegate-related parameters from commandline arguments and
// returns true if sucessful.
bool InitFromCmdlineArgs(int* argc, const char** argv);
// Get all parameters from all registered delegate providers.
const tools::ToolParams& GetAllParams() const { return params_; }
// Create the a TfLite delegate instance based on the provided delegate
// 'name'. If the specified one isn't found, an empty TfLiteDelegatePtr is
// returned.
TfLiteDelegatePtr CreateDelegate(const std::string& name) const;
// Create a list of TfLite delegates based on what have been initialized (i.e.
// 'params_').
std::vector<TfLiteDelegatePtr> CreateAllDelegates() const {
return CreateAllDelegates(params_);
}
// Create a list of TfLite delegates based on the given TfliteInferenceParams
// 'params' but considering what have been initialized (i.e. 'params_').
std::vector<TfLiteDelegatePtr> CreateAllDelegates(
const TfliteInferenceParams& params) const;
private:
// Create a list of TfLite delegates based on the provided 'params'.
std::vector<TfLiteDelegatePtr> CreateAllDelegates(
const tools::ToolParams& params) const;
// Contain delegate-related parameters that are initialized from command-line
// flags.
tools::ToolParams params_;
const benchmark::DelegateProviderList& delegates_list_;
// Key is the delegate name, and the value is the index to the
// 'delegates_list_'.
const std::unordered_map<std::string, int> delegates_map_;
};
// Parse a string 'val' to the corresponding delegate type defined by
// TfliteInferenceParams::Delegate.
TfliteInferenceParams::Delegate ParseStringToDelegateType(

View File

@ -39,6 +39,21 @@ TEST(EvaluationDelegateProviderTest, CreateTfLiteDelegate) {
EXPECT_TRUE(!CreateTfLiteDelegate(params));
}
TEST(EvaluationDelegateProviderTest, DelegateProvidersParams) {
DelegateProviders providers;
const auto& params = providers.GetAllParams();
EXPECT_TRUE(params.HasParam("use_nnapi"));
EXPECT_TRUE(params.HasParam("use_gpu"));
int argc = 3;
const char* argv[] = {"program_name", "--use_gpu=true",
"--other_undefined_flag=1"};
EXPECT_TRUE(providers.InitFromCmdlineArgs(&argc, argv));
EXPECT_TRUE(params.Get<bool>("use_gpu"));
EXPECT_EQ(2, argc);
EXPECT_EQ("--other_undefined_flag=1", argv[1]);
}
} // namespace
} // namespace evaluation
} // namespace tflite

View File

@ -147,6 +147,7 @@ cc_library(
":tflite_inference_stage",
":topk_accuracy_eval_stage",
"//tensorflow/core:tflite_portable_logging",
"//tensorflow/lite/tools/evaluation:evaluation_delegate_provider",
"//tensorflow/lite/tools/evaluation:evaluation_stage",
"//tensorflow/lite/tools/evaluation:utils",
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
@ -163,6 +164,7 @@ cc_library(
":tflite_inference_stage",
"//tensorflow/core:tflite_portable_logging",
"//tensorflow/core/util:stats_calculator_portable",
"//tensorflow/lite/tools/evaluation:evaluation_delegate_provider",
"//tensorflow/lite/tools/evaluation:evaluation_stage",
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
"//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto",
@ -224,6 +226,7 @@ cc_library(
":tflite_inference_stage",
"//tensorflow/core:tflite_portable_logging",
"//tensorflow/lite/c:common",
"//tensorflow/lite/tools/evaluation:evaluation_delegate_provider",
"//tensorflow/lite/tools/evaluation:evaluation_stage",
"//tensorflow/lite/tools/evaluation:utils",
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",

View File

@ -29,7 +29,8 @@ namespace {
const float kCroppingFraction = 0.875;
} // namespace
TfLiteStatus ImageClassificationStage::Init() {
TfLiteStatus ImageClassificationStage::Init(
const DelegateProviders* delegate_providers) {
// Ensure inference params are provided.
if (!config_.specification().has_image_classification_params()) {
LOG(ERROR) << "ImageClassificationParams not provided";
@ -47,7 +48,8 @@ TfLiteStatus ImageClassificationStage::Init() {
*tflite_inference_config.mutable_specification()
->mutable_tflite_inference_params() = params.inference_params();
inference_stage_.reset(new TfliteInferenceStage(tflite_inference_config));
if (inference_stage_->Init() != kTfLiteOk) return kTfLiteError;
if (inference_stage_->Init(delegate_providers) != kTfLiteOk)
return kTfLiteError;
// Validate model inputs.
const TfLiteModelInfo* model_info = inference_stage_->GetModelInfo();

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <string>
#include <vector>
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
#include "tensorflow/lite/tools/evaluation/evaluation_stage.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
#include "tensorflow/lite/tools/evaluation/stages/image_preprocessing_stage.h"
@ -36,7 +37,8 @@ class ImageClassificationStage : public EvaluationStage {
explicit ImageClassificationStage(const EvaluationStageConfig& config)
: EvaluationStage(config) {}
TfLiteStatus Init() override;
TfLiteStatus Init() override { return Init(nullptr); }
TfLiteStatus Init(const DelegateProviders* delegate_providers);
TfLiteStatus Run() override;

View File

@ -68,11 +68,12 @@ float CalculateAverageError(T* reference, T* test, int64_t num_elements) {
} // namespace
TfLiteStatus InferenceProfilerStage::Init() {
TfLiteStatus InferenceProfilerStage::Init(
const DelegateProviders* delegate_providers) {
// Initialize TfliteInferenceStage with the user-provided
// TfliteInferenceParams.
test_stage_.reset(new TfliteInferenceStage(config_));
if (test_stage_->Init() != kTfLiteOk) return kTfLiteError;
if (test_stage_->Init(delegate_providers) != kTfLiteOk) return kTfLiteError;
LOG(INFO) << "Test interpreter has been initialized.";
// Initialize a reference TfliteInferenceStage that uses the given model &

View File

@ -22,6 +22,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/util/stats_calculator.h"
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
#include "tensorflow/lite/tools/evaluation/evaluation_stage.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
#include "tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.h"
@ -39,7 +40,8 @@ class InferenceProfilerStage : public EvaluationStage {
explicit InferenceProfilerStage(const EvaluationStageConfig& config)
: EvaluationStage(config) {}
TfLiteStatus Init() override;
TfLiteStatus Init() override { return Init(nullptr); }
TfLiteStatus Init(const DelegateProviders* delegate_providers);
// New Gaussian random data is used as input for each Run.
TfLiteStatus Run() override;

View File

@ -26,7 +26,8 @@ limitations under the License.
namespace tflite {
namespace evaluation {
TfLiteStatus ObjectDetectionStage::Init() {
TfLiteStatus ObjectDetectionStage::Init(
const DelegateProviders* delegate_providers) {
// Ensure inference params are provided.
if (!config_.specification().has_object_detection_params()) {
LOG(ERROR) << "ObjectDetectionParams not provided";
@ -48,7 +49,7 @@ TfLiteStatus ObjectDetectionStage::Init() {
*tflite_inference_config.mutable_specification()
->mutable_tflite_inference_params() = params.inference_params();
inference_stage_.reset(new TfliteInferenceStage(tflite_inference_config));
TF_LITE_ENSURE_STATUS(inference_stage_->Init());
TF_LITE_ENSURE_STATUS(inference_stage_->Init(delegate_providers));
// Validate model inputs.
const TfLiteModelInfo* model_info = inference_stage_->GetModelInfo();

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
#include "tensorflow/lite/tools/evaluation/evaluation_stage.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
@ -43,7 +44,8 @@ class ObjectDetectionStage : public EvaluationStage {
explicit ObjectDetectionStage(const EvaluationStageConfig& config)
: EvaluationStage(config) {}
TfLiteStatus Init() override;
TfLiteStatus Init() override { return Init(nullptr); }
TfLiteStatus Init(const DelegateProviders* delegate_providers);
TfLiteStatus Run() override;

View File

@ -20,7 +20,6 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/profiling/time.h"
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
#include "tensorflow/lite/tools/evaluation/utils.h"
@ -71,7 +70,8 @@ TfLiteStatus TfliteInferenceStage::ApplyCustomDelegate(
return kTfLiteOk;
}
TfLiteStatus TfliteInferenceStage::Init() {
TfLiteStatus TfliteInferenceStage::Init(
const DelegateProviders* delegate_providers) {
if (!config_.specification().has_tflite_inference_params()) {
LOG(ERROR) << "TfliteInferenceParams not provided";
return kTfLiteError;
@ -96,14 +96,19 @@ TfLiteStatus TfliteInferenceStage::Init() {
}
interpreter_->SetNumThreads(params.num_threads());
std::string error_message;
auto delegate = CreateTfLiteDelegate(params, &error_message);
if (delegate) {
delegates_.push_back(std::move(delegate));
LOG(INFO) << "Successfully created "
<< params.Delegate_Name(params.delegate()) << " delegate.";
if (!delegate_providers) {
std::string error_message;
auto delegate = CreateTfLiteDelegate(params, &error_message);
if (delegate) {
delegates_.push_back(std::move(delegate));
LOG(INFO) << "Successfully created "
<< params.Delegate_Name(params.delegate()) << " delegate.";
} else {
LOG(WARNING) << error_message;
}
} else {
LOG(WARNING) << error_message;
auto delegates = delegate_providers->CreateAllDelegates(params);
for (auto& one : delegates) delegates_.push_back(std::move(one));
}
for (int i = 0; i < delegates_.size(); ++i) {

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
#include "tensorflow/lite/tools/evaluation/evaluation_stage.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
@ -41,14 +42,15 @@ class TfliteInferenceStage : public EvaluationStage {
explicit TfliteInferenceStage(const EvaluationStageConfig& config)
: EvaluationStage(config) {}
TfLiteStatus Init() override;
TfLiteStatus Init() override { return Init(nullptr); }
TfLiteStatus Init(const DelegateProviders* delegate_providers);
TfLiteStatus Run() override;
// EvaluationStageMetrics.num_runs denotes the number of inferences run.
EvaluationStageMetrics LatestMetrics() override;
~TfliteInferenceStage() {}
~TfliteInferenceStage() override {}
// Call before Run().
// This class does not take ownership of raw_input_ptrs.

View File

@ -50,7 +50,8 @@ bool EvaluateModel(const std::string& model_file_path,
const std::vector<std::string>& image_paths,
const std::string& ground_truth_proto_file,
std::string delegate, std::string output_file_path,
int num_interpreter_threads, bool debug_mode) {
int num_interpreter_threads, bool debug_mode,
const DelegateProviders& delegate_providers) {
EvaluationStageConfig eval_config;
eval_config.set_name("object_detection");
auto* detection_params =
@ -74,7 +75,7 @@ bool EvaluateModel(const std::string& model_file_path,
ObjectDetectionStage eval(eval_config);
eval.SetAllLabels(model_labels);
if (eval.Init() != kTfLiteOk) return false;
if (eval.Init(&delegate_providers) != kTfLiteOk) return false;
// Open output file for writing.
std::ofstream ofile;
@ -156,6 +157,8 @@ int Main(int argc, char* argv[]) {
"Must be one of {'nnapi', 'gpu'}"),
};
tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
DelegateProviders delegate_providers;
delegate_providers.InitFromCmdlineArgs(&argc, const_cast<const char**>(argv));
// Process images in filename-sorted order.
std::vector<std::string> image_paths;
@ -170,7 +173,7 @@ int Main(int argc, char* argv[]) {
if (!EvaluateModel(model_file_path, model_labels, image_paths,
ground_truth_proto_file, delegate, output_file_path,
num_interpreter_threads, debug_mode)) {
num_interpreter_threads, debug_mode, delegate_providers)) {
LOG(ERROR) << "Could not evaluate model";
return EXIT_FAILURE;
}

View File

@ -50,7 +50,8 @@ bool EvaluateModel(const std::string& model_file_path,
const std::vector<ImageLabel>& image_labels,
const std::vector<std::string>& model_labels,
std::string delegate, std::string output_file_path,
int num_interpreter_threads) {
int num_interpreter_threads,
const DelegateProviders& delegate_providers) {
EvaluationStageConfig eval_config;
eval_config.set_name("image_classification");
auto* classification_params = eval_config.mutable_specification()
@ -69,7 +70,7 @@ bool EvaluateModel(const std::string& model_file_path,
ImageClassificationStage eval(eval_config);
eval.SetAllLabels(model_labels);
if (eval.Init() != kTfLiteOk) return false;
if (eval.Init(&delegate_providers) != kTfLiteOk) return false;
const int step = image_labels.size() / 100;
for (int i = 0; i < image_labels.size(); ++i) {
@ -135,6 +136,8 @@ int Main(int argc, char* argv[]) {
"Must be one of {'nnapi', 'gpu'}"),
};
tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
DelegateProviders delegate_providers;
delegate_providers.InitFromCmdlineArgs(&argc, const_cast<const char**>(argv));
// Process images in filename-sorted order.
std::vector<std::string> image_files, ground_truth_image_labels;
@ -168,7 +171,8 @@ int Main(int argc, char* argv[]) {
}
if (!EvaluateModel(model_file_path, image_labels, model_labels, delegate,
output_file_path, num_interpreter_threads)) {
output_file_path, num_interpreter_threads,
delegate_providers)) {
LOG(ERROR) << "Could not evaluate model";
return EXIT_FAILURE;
}

View File

@ -36,7 +36,8 @@ constexpr char kDelegateFlag[] = "delegate";
bool EvaluateModel(const std::string& model_file_path,
const std::string& delegate, int num_runs,
const std::string& output_file_path,
int num_interpreter_threads) {
int num_interpreter_threads,
const DelegateProviders& delegate_providers) {
// Initialize evaluation stage.
EvaluationStageConfig eval_config;
eval_config.set_name("inference_profiling");
@ -54,7 +55,7 @@ bool EvaluateModel(const std::string& model_file_path,
return false;
}
InferenceProfilerStage eval(eval_config);
if (eval.Init() != kTfLiteOk) return false;
if (eval.Init(&delegate_providers) != kTfLiteOk) return false;
// Run inference & check diff for specified number of runs.
for (int i = 0; i < num_runs; ++i) {
@ -94,8 +95,10 @@ int Main(int argc, char* argv[]) {
};
tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
DelegateProviders delegate_providers;
delegate_providers.InitFromCmdlineArgs(&argc, const_cast<const char**>(argv));
if (!EvaluateModel(model_file_path, delegate, num_runs, output_file_path,
num_interpreter_threads)) {
num_interpreter_threads, delegate_providers)) {
LOG(ERROR) << "Could not evaluate model!";
return EXIT_FAILURE;
}