Support to reuse the same set of commandline flags that initialize tflite delegates in the benchmark tool for evaluation-related tools via the delegate-provider registerar, and apply this support to image_classification, objection_detection and inference_diff eval tools.
PiperOrigin-RevId: 305849016 Change-Id: Ie5f35b1bf4fa38e10f3ac7bf9825efacd41988d4
This commit is contained in:
parent
77c0bb4322
commit
58afbd9803
@ -70,6 +70,10 @@ cc_library(
|
|||||||
copts = tflite_copts(),
|
copts = tflite_copts(),
|
||||||
deps = [
|
deps = [
|
||||||
":utils",
|
":utils",
|
||||||
|
"//tensorflow/lite/tools:command_line_flags",
|
||||||
|
"//tensorflow/lite/tools:tool_params",
|
||||||
|
"//tensorflow/lite/tools/benchmark:delegate_provider_hdr",
|
||||||
|
"//tensorflow/lite/tools/benchmark:tflite_execution_providers",
|
||||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto",
|
"//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
|
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
|
||||||
|
|
||||||
|
#include "tensorflow/lite/tools/command_line_flags.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace evaluation {
|
namespace evaluation {
|
||||||
namespace {
|
namespace {
|
||||||
@ -36,7 +38,6 @@ TfliteInferenceParams::Delegate ParseStringToDelegateType(
|
|||||||
TfLiteDelegatePtr CreateTfLiteDelegate(const TfliteInferenceParams& params,
|
TfLiteDelegatePtr CreateTfLiteDelegate(const TfliteInferenceParams& params,
|
||||||
std::string* error_msg) {
|
std::string* error_msg) {
|
||||||
const auto type = params.delegate();
|
const auto type = params.delegate();
|
||||||
|
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case TfliteInferenceParams::NNAPI: {
|
case TfliteInferenceParams::NNAPI: {
|
||||||
auto p = CreateNNAPIDelegate();
|
auto p = CreateNNAPIDelegate();
|
||||||
@ -76,5 +77,86 @@ TfLiteDelegatePtr CreateTfLiteDelegate(const TfliteInferenceParams& params,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DelegateProviders::DelegateProviders()
|
||||||
|
: delegates_list_(benchmark::GetRegisteredDelegateProviders()),
|
||||||
|
delegates_map_([=]() -> std::unordered_map<std::string, int> {
|
||||||
|
std::unordered_map<std::string, int> delegates_map;
|
||||||
|
for (int i = 0; i < delegates_list_.size(); ++i) {
|
||||||
|
delegates_map[delegates_list_[i]->GetName()] = i;
|
||||||
|
}
|
||||||
|
return delegates_map;
|
||||||
|
}()) {
|
||||||
|
for (const auto& one : delegates_list_) {
|
||||||
|
params_.Merge(one->DefaultParams());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool DelegateProviders::InitFromCmdlineArgs(int* argc, const char** argv) {
|
||||||
|
std::vector<Flag> flags;
|
||||||
|
for (const auto& one : delegates_list_) {
|
||||||
|
auto one_flags = one->CreateFlags(¶ms_);
|
||||||
|
flags.insert(flags.end(), one_flags.begin(), one_flags.end());
|
||||||
|
}
|
||||||
|
return Flags::Parse(argc, argv, flags);
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteDelegatePtr DelegateProviders::CreateDelegate(
|
||||||
|
const std::string& name) const {
|
||||||
|
const auto it = delegates_map_.find(name);
|
||||||
|
if (it == delegates_map_.end()) {
|
||||||
|
return TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {});
|
||||||
|
}
|
||||||
|
return delegates_list_[it->second]->CreateTfLiteDelegate(params_);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<TfLiteDelegatePtr> DelegateProviders::CreateAllDelegates(
|
||||||
|
const tools::ToolParams& params) const {
|
||||||
|
std::vector<TfLiteDelegatePtr> delegates;
|
||||||
|
for (const auto& one : delegates_list_) {
|
||||||
|
auto ptr = one->CreateTfLiteDelegate(params);
|
||||||
|
// It's possible that a delegate of certain type won't be created as
|
||||||
|
// user-specified benchmark params tells not to.
|
||||||
|
if (ptr == nullptr) continue;
|
||||||
|
delegates.emplace_back(std::move(ptr));
|
||||||
|
}
|
||||||
|
return delegates;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<TfLiteDelegatePtr> DelegateProviders::CreateAllDelegates(
|
||||||
|
const TfliteInferenceParams& params) const {
|
||||||
|
tools::ToolParams merged_params;
|
||||||
|
merged_params.Merge(params_);
|
||||||
|
|
||||||
|
const auto type = params.delegate();
|
||||||
|
switch (type) {
|
||||||
|
case TfliteInferenceParams::NNAPI:
|
||||||
|
if (merged_params.HasParam("use_nnapi")) {
|
||||||
|
merged_params.Set<bool>("use_nnapi", true);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case TfliteInferenceParams::GPU:
|
||||||
|
if (merged_params.HasParam("use_gpu")) {
|
||||||
|
merged_params.Set<bool>("use_gpu", true);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case TfliteInferenceParams::HEXAGON:
|
||||||
|
if (merged_params.HasParam("use_hexagon")) {
|
||||||
|
merged_params.Set<bool>("use_hexagon", true);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case TfliteInferenceParams::XNNPACK:
|
||||||
|
if (merged_params.HasParam("use_xnnpack")) {
|
||||||
|
merged_params.Set<bool>("use_xnnpack", true);
|
||||||
|
}
|
||||||
|
if (params.has_num_threads()) {
|
||||||
|
merged_params.Set<int32_t>("num_threads", params.num_threads());
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return CreateAllDelegates(merged_params);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace evaluation
|
} // namespace evaluation
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -16,12 +16,60 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_LITE_TOOLS_EVALUATION_EVALUATION_DELEGATE_PROVIDER_H_
|
#ifndef TENSORFLOW_LITE_TOOLS_EVALUATION_EVALUATION_DELEGATE_PROVIDER_H_
|
||||||
#define TENSORFLOW_LITE_TOOLS_EVALUATION_EVALUATION_DELEGATE_PROVIDER_H_
|
#define TENSORFLOW_LITE_TOOLS_EVALUATION_EVALUATION_DELEGATE_PROVIDER_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/tools/benchmark/delegate_provider.h"
|
||||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
|
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
|
||||||
#include "tensorflow/lite/tools/evaluation/utils.h"
|
#include "tensorflow/lite/tools/evaluation/utils.h"
|
||||||
|
#include "tensorflow/lite/tools/tool_params.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace evaluation {
|
namespace evaluation {
|
||||||
|
|
||||||
|
class DelegateProviders {
|
||||||
|
public:
|
||||||
|
DelegateProviders();
|
||||||
|
|
||||||
|
// Initialize delegate-related parameters from commandline arguments and
|
||||||
|
// returns true if sucessful.
|
||||||
|
bool InitFromCmdlineArgs(int* argc, const char** argv);
|
||||||
|
|
||||||
|
// Get all parameters from all registered delegate providers.
|
||||||
|
const tools::ToolParams& GetAllParams() const { return params_; }
|
||||||
|
|
||||||
|
// Create the a TfLite delegate instance based on the provided delegate
|
||||||
|
// 'name'. If the specified one isn't found, an empty TfLiteDelegatePtr is
|
||||||
|
// returned.
|
||||||
|
TfLiteDelegatePtr CreateDelegate(const std::string& name) const;
|
||||||
|
|
||||||
|
// Create a list of TfLite delegates based on what have been initialized (i.e.
|
||||||
|
// 'params_').
|
||||||
|
std::vector<TfLiteDelegatePtr> CreateAllDelegates() const {
|
||||||
|
return CreateAllDelegates(params_);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a list of TfLite delegates based on the given TfliteInferenceParams
|
||||||
|
// 'params' but considering what have been initialized (i.e. 'params_').
|
||||||
|
std::vector<TfLiteDelegatePtr> CreateAllDelegates(
|
||||||
|
const TfliteInferenceParams& params) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Create a list of TfLite delegates based on the provided 'params'.
|
||||||
|
std::vector<TfLiteDelegatePtr> CreateAllDelegates(
|
||||||
|
const tools::ToolParams& params) const;
|
||||||
|
|
||||||
|
// Contain delegate-related parameters that are initialized from command-line
|
||||||
|
// flags.
|
||||||
|
tools::ToolParams params_;
|
||||||
|
|
||||||
|
const benchmark::DelegateProviderList& delegates_list_;
|
||||||
|
// Key is the delegate name, and the value is the index to the
|
||||||
|
// 'delegates_list_'.
|
||||||
|
const std::unordered_map<std::string, int> delegates_map_;
|
||||||
|
};
|
||||||
|
|
||||||
// Parse a string 'val' to the corresponding delegate type defined by
|
// Parse a string 'val' to the corresponding delegate type defined by
|
||||||
// TfliteInferenceParams::Delegate.
|
// TfliteInferenceParams::Delegate.
|
||||||
TfliteInferenceParams::Delegate ParseStringToDelegateType(
|
TfliteInferenceParams::Delegate ParseStringToDelegateType(
|
||||||
|
@ -39,6 +39,21 @@ TEST(EvaluationDelegateProviderTest, CreateTfLiteDelegate) {
|
|||||||
EXPECT_TRUE(!CreateTfLiteDelegate(params));
|
EXPECT_TRUE(!CreateTfLiteDelegate(params));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(EvaluationDelegateProviderTest, DelegateProvidersParams) {
|
||||||
|
DelegateProviders providers;
|
||||||
|
const auto& params = providers.GetAllParams();
|
||||||
|
EXPECT_TRUE(params.HasParam("use_nnapi"));
|
||||||
|
EXPECT_TRUE(params.HasParam("use_gpu"));
|
||||||
|
|
||||||
|
int argc = 3;
|
||||||
|
const char* argv[] = {"program_name", "--use_gpu=true",
|
||||||
|
"--other_undefined_flag=1"};
|
||||||
|
EXPECT_TRUE(providers.InitFromCmdlineArgs(&argc, argv));
|
||||||
|
EXPECT_TRUE(params.Get<bool>("use_gpu"));
|
||||||
|
EXPECT_EQ(2, argc);
|
||||||
|
EXPECT_EQ("--other_undefined_flag=1", argv[1]);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace evaluation
|
} // namespace evaluation
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -147,6 +147,7 @@ cc_library(
|
|||||||
":tflite_inference_stage",
|
":tflite_inference_stage",
|
||||||
":topk_accuracy_eval_stage",
|
":topk_accuracy_eval_stage",
|
||||||
"//tensorflow/core:tflite_portable_logging",
|
"//tensorflow/core:tflite_portable_logging",
|
||||||
|
"//tensorflow/lite/tools/evaluation:evaluation_delegate_provider",
|
||||||
"//tensorflow/lite/tools/evaluation:evaluation_stage",
|
"//tensorflow/lite/tools/evaluation:evaluation_stage",
|
||||||
"//tensorflow/lite/tools/evaluation:utils",
|
"//tensorflow/lite/tools/evaluation:utils",
|
||||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
|
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
|
||||||
@ -163,6 +164,7 @@ cc_library(
|
|||||||
":tflite_inference_stage",
|
":tflite_inference_stage",
|
||||||
"//tensorflow/core:tflite_portable_logging",
|
"//tensorflow/core:tflite_portable_logging",
|
||||||
"//tensorflow/core/util:stats_calculator_portable",
|
"//tensorflow/core/util:stats_calculator_portable",
|
||||||
|
"//tensorflow/lite/tools/evaluation:evaluation_delegate_provider",
|
||||||
"//tensorflow/lite/tools/evaluation:evaluation_stage",
|
"//tensorflow/lite/tools/evaluation:evaluation_stage",
|
||||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
|
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
|
||||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto",
|
"//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto",
|
||||||
@ -224,6 +226,7 @@ cc_library(
|
|||||||
":tflite_inference_stage",
|
":tflite_inference_stage",
|
||||||
"//tensorflow/core:tflite_portable_logging",
|
"//tensorflow/core:tflite_portable_logging",
|
||||||
"//tensorflow/lite/c:common",
|
"//tensorflow/lite/c:common",
|
||||||
|
"//tensorflow/lite/tools/evaluation:evaluation_delegate_provider",
|
||||||
"//tensorflow/lite/tools/evaluation:evaluation_stage",
|
"//tensorflow/lite/tools/evaluation:evaluation_stage",
|
||||||
"//tensorflow/lite/tools/evaluation:utils",
|
"//tensorflow/lite/tools/evaluation:utils",
|
||||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
|
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
|
||||||
|
@ -29,7 +29,8 @@ namespace {
|
|||||||
const float kCroppingFraction = 0.875;
|
const float kCroppingFraction = 0.875;
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TfLiteStatus ImageClassificationStage::Init() {
|
TfLiteStatus ImageClassificationStage::Init(
|
||||||
|
const DelegateProviders* delegate_providers) {
|
||||||
// Ensure inference params are provided.
|
// Ensure inference params are provided.
|
||||||
if (!config_.specification().has_image_classification_params()) {
|
if (!config_.specification().has_image_classification_params()) {
|
||||||
LOG(ERROR) << "ImageClassificationParams not provided";
|
LOG(ERROR) << "ImageClassificationParams not provided";
|
||||||
@ -47,7 +48,8 @@ TfLiteStatus ImageClassificationStage::Init() {
|
|||||||
*tflite_inference_config.mutable_specification()
|
*tflite_inference_config.mutable_specification()
|
||||||
->mutable_tflite_inference_params() = params.inference_params();
|
->mutable_tflite_inference_params() = params.inference_params();
|
||||||
inference_stage_.reset(new TfliteInferenceStage(tflite_inference_config));
|
inference_stage_.reset(new TfliteInferenceStage(tflite_inference_config));
|
||||||
if (inference_stage_->Init() != kTfLiteOk) return kTfLiteError;
|
if (inference_stage_->Init(delegate_providers) != kTfLiteOk)
|
||||||
|
return kTfLiteError;
|
||||||
|
|
||||||
// Validate model inputs.
|
// Validate model inputs.
|
||||||
const TfLiteModelInfo* model_info = inference_stage_->GetModelInfo();
|
const TfLiteModelInfo* model_info = inference_stage_->GetModelInfo();
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
|
||||||
#include "tensorflow/lite/tools/evaluation/evaluation_stage.h"
|
#include "tensorflow/lite/tools/evaluation/evaluation_stage.h"
|
||||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
|
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
|
||||||
#include "tensorflow/lite/tools/evaluation/stages/image_preprocessing_stage.h"
|
#include "tensorflow/lite/tools/evaluation/stages/image_preprocessing_stage.h"
|
||||||
@ -36,7 +37,8 @@ class ImageClassificationStage : public EvaluationStage {
|
|||||||
explicit ImageClassificationStage(const EvaluationStageConfig& config)
|
explicit ImageClassificationStage(const EvaluationStageConfig& config)
|
||||||
: EvaluationStage(config) {}
|
: EvaluationStage(config) {}
|
||||||
|
|
||||||
TfLiteStatus Init() override;
|
TfLiteStatus Init() override { return Init(nullptr); }
|
||||||
|
TfLiteStatus Init(const DelegateProviders* delegate_providers);
|
||||||
|
|
||||||
TfLiteStatus Run() override;
|
TfLiteStatus Run() override;
|
||||||
|
|
||||||
|
@ -68,11 +68,12 @@ float CalculateAverageError(T* reference, T* test, int64_t num_elements) {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TfLiteStatus InferenceProfilerStage::Init() {
|
TfLiteStatus InferenceProfilerStage::Init(
|
||||||
|
const DelegateProviders* delegate_providers) {
|
||||||
// Initialize TfliteInferenceStage with the user-provided
|
// Initialize TfliteInferenceStage with the user-provided
|
||||||
// TfliteInferenceParams.
|
// TfliteInferenceParams.
|
||||||
test_stage_.reset(new TfliteInferenceStage(config_));
|
test_stage_.reset(new TfliteInferenceStage(config_));
|
||||||
if (test_stage_->Init() != kTfLiteOk) return kTfLiteError;
|
if (test_stage_->Init(delegate_providers) != kTfLiteOk) return kTfLiteError;
|
||||||
LOG(INFO) << "Test interpreter has been initialized.";
|
LOG(INFO) << "Test interpreter has been initialized.";
|
||||||
|
|
||||||
// Initialize a reference TfliteInferenceStage that uses the given model &
|
// Initialize a reference TfliteInferenceStage that uses the given model &
|
||||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/core/util/stats_calculator.h"
|
#include "tensorflow/core/util/stats_calculator.h"
|
||||||
|
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
|
||||||
#include "tensorflow/lite/tools/evaluation/evaluation_stage.h"
|
#include "tensorflow/lite/tools/evaluation/evaluation_stage.h"
|
||||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
|
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
|
||||||
#include "tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.h"
|
#include "tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.h"
|
||||||
@ -39,7 +40,8 @@ class InferenceProfilerStage : public EvaluationStage {
|
|||||||
explicit InferenceProfilerStage(const EvaluationStageConfig& config)
|
explicit InferenceProfilerStage(const EvaluationStageConfig& config)
|
||||||
: EvaluationStage(config) {}
|
: EvaluationStage(config) {}
|
||||||
|
|
||||||
TfLiteStatus Init() override;
|
TfLiteStatus Init() override { return Init(nullptr); }
|
||||||
|
TfLiteStatus Init(const DelegateProviders* delegate_providers);
|
||||||
|
|
||||||
// New Gaussian random data is used as input for each Run.
|
// New Gaussian random data is used as input for each Run.
|
||||||
TfLiteStatus Run() override;
|
TfLiteStatus Run() override;
|
||||||
|
@ -26,7 +26,8 @@ limitations under the License.
|
|||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace evaluation {
|
namespace evaluation {
|
||||||
|
|
||||||
TfLiteStatus ObjectDetectionStage::Init() {
|
TfLiteStatus ObjectDetectionStage::Init(
|
||||||
|
const DelegateProviders* delegate_providers) {
|
||||||
// Ensure inference params are provided.
|
// Ensure inference params are provided.
|
||||||
if (!config_.specification().has_object_detection_params()) {
|
if (!config_.specification().has_object_detection_params()) {
|
||||||
LOG(ERROR) << "ObjectDetectionParams not provided";
|
LOG(ERROR) << "ObjectDetectionParams not provided";
|
||||||
@ -48,7 +49,7 @@ TfLiteStatus ObjectDetectionStage::Init() {
|
|||||||
*tflite_inference_config.mutable_specification()
|
*tflite_inference_config.mutable_specification()
|
||||||
->mutable_tflite_inference_params() = params.inference_params();
|
->mutable_tflite_inference_params() = params.inference_params();
|
||||||
inference_stage_.reset(new TfliteInferenceStage(tflite_inference_config));
|
inference_stage_.reset(new TfliteInferenceStage(tflite_inference_config));
|
||||||
TF_LITE_ENSURE_STATUS(inference_stage_->Init());
|
TF_LITE_ENSURE_STATUS(inference_stage_->Init(delegate_providers));
|
||||||
|
|
||||||
// Validate model inputs.
|
// Validate model inputs.
|
||||||
const TfLiteModelInfo* model_info = inference_stage_->GetModelInfo();
|
const TfLiteModelInfo* model_info = inference_stage_->GetModelInfo();
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/container/flat_hash_map.h"
|
#include "absl/container/flat_hash_map.h"
|
||||||
|
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
|
||||||
#include "tensorflow/lite/tools/evaluation/evaluation_stage.h"
|
#include "tensorflow/lite/tools/evaluation/evaluation_stage.h"
|
||||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
|
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
|
||||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
|
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
|
||||||
@ -43,7 +44,8 @@ class ObjectDetectionStage : public EvaluationStage {
|
|||||||
explicit ObjectDetectionStage(const EvaluationStageConfig& config)
|
explicit ObjectDetectionStage(const EvaluationStageConfig& config)
|
||||||
: EvaluationStage(config) {}
|
: EvaluationStage(config) {}
|
||||||
|
|
||||||
TfLiteStatus Init() override;
|
TfLiteStatus Init() override { return Init(nullptr); }
|
||||||
|
TfLiteStatus Init(const DelegateProviders* delegate_providers);
|
||||||
|
|
||||||
TfLiteStatus Run() override;
|
TfLiteStatus Run() override;
|
||||||
|
|
||||||
|
@ -20,7 +20,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/profiling/time.h"
|
#include "tensorflow/lite/profiling/time.h"
|
||||||
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
|
|
||||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
|
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
|
||||||
#include "tensorflow/lite/tools/evaluation/utils.h"
|
#include "tensorflow/lite/tools/evaluation/utils.h"
|
||||||
|
|
||||||
@ -71,7 +70,8 @@ TfLiteStatus TfliteInferenceStage::ApplyCustomDelegate(
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus TfliteInferenceStage::Init() {
|
TfLiteStatus TfliteInferenceStage::Init(
|
||||||
|
const DelegateProviders* delegate_providers) {
|
||||||
if (!config_.specification().has_tflite_inference_params()) {
|
if (!config_.specification().has_tflite_inference_params()) {
|
||||||
LOG(ERROR) << "TfliteInferenceParams not provided";
|
LOG(ERROR) << "TfliteInferenceParams not provided";
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
@ -96,14 +96,19 @@ TfLiteStatus TfliteInferenceStage::Init() {
|
|||||||
}
|
}
|
||||||
interpreter_->SetNumThreads(params.num_threads());
|
interpreter_->SetNumThreads(params.num_threads());
|
||||||
|
|
||||||
std::string error_message;
|
if (!delegate_providers) {
|
||||||
auto delegate = CreateTfLiteDelegate(params, &error_message);
|
std::string error_message;
|
||||||
if (delegate) {
|
auto delegate = CreateTfLiteDelegate(params, &error_message);
|
||||||
delegates_.push_back(std::move(delegate));
|
if (delegate) {
|
||||||
LOG(INFO) << "Successfully created "
|
delegates_.push_back(std::move(delegate));
|
||||||
<< params.Delegate_Name(params.delegate()) << " delegate.";
|
LOG(INFO) << "Successfully created "
|
||||||
|
<< params.Delegate_Name(params.delegate()) << " delegate.";
|
||||||
|
} else {
|
||||||
|
LOG(WARNING) << error_message;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
LOG(WARNING) << error_message;
|
auto delegates = delegate_providers->CreateAllDelegates(params);
|
||||||
|
for (auto& one : delegates) delegates_.push_back(std::move(one));
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < delegates_.size(); ++i) {
|
for (int i = 0; i < delegates_.size(); ++i) {
|
||||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/interpreter.h"
|
#include "tensorflow/lite/interpreter.h"
|
||||||
#include "tensorflow/lite/kernels/register.h"
|
#include "tensorflow/lite/kernels/register.h"
|
||||||
#include "tensorflow/lite/model.h"
|
#include "tensorflow/lite/model.h"
|
||||||
|
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
|
||||||
#include "tensorflow/lite/tools/evaluation/evaluation_stage.h"
|
#include "tensorflow/lite/tools/evaluation/evaluation_stage.h"
|
||||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
|
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
|
||||||
|
|
||||||
@ -41,14 +42,15 @@ class TfliteInferenceStage : public EvaluationStage {
|
|||||||
explicit TfliteInferenceStage(const EvaluationStageConfig& config)
|
explicit TfliteInferenceStage(const EvaluationStageConfig& config)
|
||||||
: EvaluationStage(config) {}
|
: EvaluationStage(config) {}
|
||||||
|
|
||||||
TfLiteStatus Init() override;
|
TfLiteStatus Init() override { return Init(nullptr); }
|
||||||
|
TfLiteStatus Init(const DelegateProviders* delegate_providers);
|
||||||
|
|
||||||
TfLiteStatus Run() override;
|
TfLiteStatus Run() override;
|
||||||
|
|
||||||
// EvaluationStageMetrics.num_runs denotes the number of inferences run.
|
// EvaluationStageMetrics.num_runs denotes the number of inferences run.
|
||||||
EvaluationStageMetrics LatestMetrics() override;
|
EvaluationStageMetrics LatestMetrics() override;
|
||||||
|
|
||||||
~TfliteInferenceStage() {}
|
~TfliteInferenceStage() override {}
|
||||||
|
|
||||||
// Call before Run().
|
// Call before Run().
|
||||||
// This class does not take ownership of raw_input_ptrs.
|
// This class does not take ownership of raw_input_ptrs.
|
||||||
|
@ -50,7 +50,8 @@ bool EvaluateModel(const std::string& model_file_path,
|
|||||||
const std::vector<std::string>& image_paths,
|
const std::vector<std::string>& image_paths,
|
||||||
const std::string& ground_truth_proto_file,
|
const std::string& ground_truth_proto_file,
|
||||||
std::string delegate, std::string output_file_path,
|
std::string delegate, std::string output_file_path,
|
||||||
int num_interpreter_threads, bool debug_mode) {
|
int num_interpreter_threads, bool debug_mode,
|
||||||
|
const DelegateProviders& delegate_providers) {
|
||||||
EvaluationStageConfig eval_config;
|
EvaluationStageConfig eval_config;
|
||||||
eval_config.set_name("object_detection");
|
eval_config.set_name("object_detection");
|
||||||
auto* detection_params =
|
auto* detection_params =
|
||||||
@ -74,7 +75,7 @@ bool EvaluateModel(const std::string& model_file_path,
|
|||||||
ObjectDetectionStage eval(eval_config);
|
ObjectDetectionStage eval(eval_config);
|
||||||
|
|
||||||
eval.SetAllLabels(model_labels);
|
eval.SetAllLabels(model_labels);
|
||||||
if (eval.Init() != kTfLiteOk) return false;
|
if (eval.Init(&delegate_providers) != kTfLiteOk) return false;
|
||||||
|
|
||||||
// Open output file for writing.
|
// Open output file for writing.
|
||||||
std::ofstream ofile;
|
std::ofstream ofile;
|
||||||
@ -156,6 +157,8 @@ int Main(int argc, char* argv[]) {
|
|||||||
"Must be one of {'nnapi', 'gpu'}"),
|
"Must be one of {'nnapi', 'gpu'}"),
|
||||||
};
|
};
|
||||||
tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
|
tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
|
||||||
|
DelegateProviders delegate_providers;
|
||||||
|
delegate_providers.InitFromCmdlineArgs(&argc, const_cast<const char**>(argv));
|
||||||
|
|
||||||
// Process images in filename-sorted order.
|
// Process images in filename-sorted order.
|
||||||
std::vector<std::string> image_paths;
|
std::vector<std::string> image_paths;
|
||||||
@ -170,7 +173,7 @@ int Main(int argc, char* argv[]) {
|
|||||||
|
|
||||||
if (!EvaluateModel(model_file_path, model_labels, image_paths,
|
if (!EvaluateModel(model_file_path, model_labels, image_paths,
|
||||||
ground_truth_proto_file, delegate, output_file_path,
|
ground_truth_proto_file, delegate, output_file_path,
|
||||||
num_interpreter_threads, debug_mode)) {
|
num_interpreter_threads, debug_mode, delegate_providers)) {
|
||||||
LOG(ERROR) << "Could not evaluate model";
|
LOG(ERROR) << "Could not evaluate model";
|
||||||
return EXIT_FAILURE;
|
return EXIT_FAILURE;
|
||||||
}
|
}
|
||||||
|
@ -50,7 +50,8 @@ bool EvaluateModel(const std::string& model_file_path,
|
|||||||
const std::vector<ImageLabel>& image_labels,
|
const std::vector<ImageLabel>& image_labels,
|
||||||
const std::vector<std::string>& model_labels,
|
const std::vector<std::string>& model_labels,
|
||||||
std::string delegate, std::string output_file_path,
|
std::string delegate, std::string output_file_path,
|
||||||
int num_interpreter_threads) {
|
int num_interpreter_threads,
|
||||||
|
const DelegateProviders& delegate_providers) {
|
||||||
EvaluationStageConfig eval_config;
|
EvaluationStageConfig eval_config;
|
||||||
eval_config.set_name("image_classification");
|
eval_config.set_name("image_classification");
|
||||||
auto* classification_params = eval_config.mutable_specification()
|
auto* classification_params = eval_config.mutable_specification()
|
||||||
@ -69,7 +70,7 @@ bool EvaluateModel(const std::string& model_file_path,
|
|||||||
ImageClassificationStage eval(eval_config);
|
ImageClassificationStage eval(eval_config);
|
||||||
|
|
||||||
eval.SetAllLabels(model_labels);
|
eval.SetAllLabels(model_labels);
|
||||||
if (eval.Init() != kTfLiteOk) return false;
|
if (eval.Init(&delegate_providers) != kTfLiteOk) return false;
|
||||||
|
|
||||||
const int step = image_labels.size() / 100;
|
const int step = image_labels.size() / 100;
|
||||||
for (int i = 0; i < image_labels.size(); ++i) {
|
for (int i = 0; i < image_labels.size(); ++i) {
|
||||||
@ -135,6 +136,8 @@ int Main(int argc, char* argv[]) {
|
|||||||
"Must be one of {'nnapi', 'gpu'}"),
|
"Must be one of {'nnapi', 'gpu'}"),
|
||||||
};
|
};
|
||||||
tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
|
tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
|
||||||
|
DelegateProviders delegate_providers;
|
||||||
|
delegate_providers.InitFromCmdlineArgs(&argc, const_cast<const char**>(argv));
|
||||||
|
|
||||||
// Process images in filename-sorted order.
|
// Process images in filename-sorted order.
|
||||||
std::vector<std::string> image_files, ground_truth_image_labels;
|
std::vector<std::string> image_files, ground_truth_image_labels;
|
||||||
@ -168,7 +171,8 @@ int Main(int argc, char* argv[]) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (!EvaluateModel(model_file_path, image_labels, model_labels, delegate,
|
if (!EvaluateModel(model_file_path, image_labels, model_labels, delegate,
|
||||||
output_file_path, num_interpreter_threads)) {
|
output_file_path, num_interpreter_threads,
|
||||||
|
delegate_providers)) {
|
||||||
LOG(ERROR) << "Could not evaluate model";
|
LOG(ERROR) << "Could not evaluate model";
|
||||||
return EXIT_FAILURE;
|
return EXIT_FAILURE;
|
||||||
}
|
}
|
||||||
|
@ -36,7 +36,8 @@ constexpr char kDelegateFlag[] = "delegate";
|
|||||||
bool EvaluateModel(const std::string& model_file_path,
|
bool EvaluateModel(const std::string& model_file_path,
|
||||||
const std::string& delegate, int num_runs,
|
const std::string& delegate, int num_runs,
|
||||||
const std::string& output_file_path,
|
const std::string& output_file_path,
|
||||||
int num_interpreter_threads) {
|
int num_interpreter_threads,
|
||||||
|
const DelegateProviders& delegate_providers) {
|
||||||
// Initialize evaluation stage.
|
// Initialize evaluation stage.
|
||||||
EvaluationStageConfig eval_config;
|
EvaluationStageConfig eval_config;
|
||||||
eval_config.set_name("inference_profiling");
|
eval_config.set_name("inference_profiling");
|
||||||
@ -54,7 +55,7 @@ bool EvaluateModel(const std::string& model_file_path,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
InferenceProfilerStage eval(eval_config);
|
InferenceProfilerStage eval(eval_config);
|
||||||
if (eval.Init() != kTfLiteOk) return false;
|
if (eval.Init(&delegate_providers) != kTfLiteOk) return false;
|
||||||
|
|
||||||
// Run inference & check diff for specified number of runs.
|
// Run inference & check diff for specified number of runs.
|
||||||
for (int i = 0; i < num_runs; ++i) {
|
for (int i = 0; i < num_runs; ++i) {
|
||||||
@ -94,8 +95,10 @@ int Main(int argc, char* argv[]) {
|
|||||||
};
|
};
|
||||||
tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
|
tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
|
||||||
|
|
||||||
|
DelegateProviders delegate_providers;
|
||||||
|
delegate_providers.InitFromCmdlineArgs(&argc, const_cast<const char**>(argv));
|
||||||
if (!EvaluateModel(model_file_path, delegate, num_runs, output_file_path,
|
if (!EvaluateModel(model_file_path, delegate, num_runs, output_file_path,
|
||||||
num_interpreter_threads)) {
|
num_interpreter_threads, delegate_providers)) {
|
||||||
LOG(ERROR) << "Could not evaluate model!";
|
LOG(ERROR) << "Could not evaluate model!";
|
||||||
return EXIT_FAILURE;
|
return EXIT_FAILURE;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user