Utilize globally-registered delegate providers to initialize tflite delegates in the imagenet-based image classification accuracy evaluation tool.
PiperOrigin-RevId: 306357221 Change-Id: I34a9fbf22e3d928c1c2b376b93c6d792c74fa94c
This commit is contained in:
parent
258d16bc66
commit
5b5876b58d
@ -58,5 +58,6 @@ cc_binary(
|
||||
deps = [
|
||||
":imagenet_accuracy_eval_lib",
|
||||
"//tensorflow/lite/tools:command_line_flags",
|
||||
"//tensorflow/lite/tools/evaluation:evaluation_delegate_provider",
|
||||
],
|
||||
)
|
||||
|
@ -15,19 +15,20 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.h"
|
||||
#include "tensorflow/lite/tools/command_line_flags.h"
|
||||
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
|
||||
|
||||
namespace {
|
||||
constexpr char kNumThreadsFlag[] = "num_threads";
|
||||
constexpr char kNumEvalThreadsFlag[] = "num_eval_threads";
|
||||
constexpr char kOutputFilePathFlag[] = "output_file_path";
|
||||
constexpr char kProtoOutputFilePathFlag[] = "proto_output_file_path";
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
std::string output_file_path, proto_output_file_path;
|
||||
int num_threads = 4;
|
||||
int num_eval_threads = 4;
|
||||
std::vector<tflite::Flag> flag_list = {
|
||||
tflite::Flag::CreateFlag(kNumThreadsFlag, &num_threads,
|
||||
"Number of threads."),
|
||||
tflite::Flag::CreateFlag(kNumEvalThreadsFlag, &num_eval_threads,
|
||||
"Number of threads used for evaluation."),
|
||||
tflite::Flag::CreateFlag(kOutputFilePathFlag, &output_file_path,
|
||||
"Path to output file."),
|
||||
tflite::Flag::CreateFlag(kProtoOutputFilePathFlag,
|
||||
@ -36,14 +37,17 @@ int main(int argc, char* argv[]) {
|
||||
};
|
||||
tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
|
||||
|
||||
if (num_threads <= 0) {
|
||||
if (num_eval_threads <= 0) {
|
||||
LOG(ERROR) << "Invalid number of threads.";
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
tflite::evaluation::DelegateProviders delegate_providers;
|
||||
delegate_providers.InitFromCmdlineArgs(&argc, const_cast<const char**>(argv));
|
||||
|
||||
std::unique_ptr<tensorflow::metrics::ImagenetModelEvaluator> evaluator =
|
||||
tensorflow::metrics::CreateImagenetModelEvaluator(&argc, argv,
|
||||
num_threads);
|
||||
num_eval_threads);
|
||||
|
||||
if (!evaluator) {
|
||||
LOG(ERROR) << "Fail to create the ImagenetModelEvaluator.";
|
||||
@ -59,8 +63,8 @@ int main(int argc, char* argv[]) {
|
||||
}
|
||||
|
||||
evaluator->AddObserver(writer.get());
|
||||
LOG(ERROR) << "Starting evaluation with: " << num_threads << " threads.";
|
||||
if (evaluator->EvaluateModel() != kTfLiteOk) {
|
||||
LOG(ERROR) << "Starting evaluation with: " << num_eval_threads << " threads.";
|
||||
if (evaluator->EvaluateModel(&delegate_providers) != kTfLiteOk) {
|
||||
LOG(ERROR) << "Failed to evaluate the model!";
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
@ -26,7 +26,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/tools/command_line_flags.h"
|
||||
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
|
||||
#include "tensorflow/lite/tools/evaluation/stages/image_classification_stage.h"
|
||||
@ -155,12 +154,12 @@ class CompositeObserver : public ImagenetModelEvaluator::Observer {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus EvaluateModelForShard(const uint64_t shard_id,
|
||||
const std::vector<ImageLabel>& image_labels,
|
||||
const std::vector<std::string>& model_labels,
|
||||
const ImagenetModelEvaluator::Params& params,
|
||||
ImagenetModelEvaluator::Observer* observer,
|
||||
int num_ranks) {
|
||||
TfLiteStatus EvaluateModelForShard(
|
||||
const uint64_t shard_id, const std::vector<ImageLabel>& image_labels,
|
||||
const std::vector<std::string>& model_labels,
|
||||
const ImagenetModelEvaluator::Params& params,
|
||||
ImagenetModelEvaluator::Observer* observer, int num_ranks,
|
||||
const tflite::evaluation::DelegateProviders* delegate_providers) {
|
||||
tflite::evaluation::EvaluationStageConfig eval_config;
|
||||
eval_config.set_name("image_classification");
|
||||
auto* classification_params = eval_config.mutable_specification()
|
||||
@ -174,7 +173,7 @@ TfLiteStatus EvaluateModelForShard(const uint64_t shard_id,
|
||||
|
||||
tflite::evaluation::ImageClassificationStage eval(eval_config);
|
||||
eval.SetAllLabels(model_labels);
|
||||
TF_LITE_ENSURE_STATUS(eval.Init());
|
||||
TF_LITE_ENSURE_STATUS(eval.Init(delegate_providers));
|
||||
|
||||
for (const auto& image_label : image_labels) {
|
||||
eval.SetInputs(image_label.image, image_label.label);
|
||||
@ -191,7 +190,8 @@ TfLiteStatus EvaluateModelForShard(const uint64_t shard_id,
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus ImagenetModelEvaluator::EvaluateModel() const {
|
||||
TfLiteStatus ImagenetModelEvaluator::EvaluateModel(
|
||||
const tflite::evaluation::DelegateProviders* delegate_providers) const {
|
||||
const std::string data_path = tflite::evaluation::StripTrailingSlashes(
|
||||
params_.ground_truth_images_path) +
|
||||
"/";
|
||||
@ -252,9 +252,10 @@ TfLiteStatus ImagenetModelEvaluator::EvaluateModel() const {
|
||||
const uint64_t shard_id = i + 1;
|
||||
shard_id_image_count_map[shard_id] = image_label.size();
|
||||
auto func = [shard_id, &image_label, &model_labels, this, &observer,
|
||||
&all_okay]() {
|
||||
&all_okay, delegate_providers]() {
|
||||
if (EvaluateModelForShard(shard_id, image_label, model_labels, params_,
|
||||
&observer, params_.num_ranks) != kTfLiteOk) {
|
||||
&observer, params_.num_ranks,
|
||||
delegate_providers) != kTfLiteOk) {
|
||||
all_okay = false;
|
||||
}
|
||||
};
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -120,7 +121,8 @@ class ImagenetModelEvaluator {
|
||||
const Params& params() const { return params_; }
|
||||
|
||||
// Evaluates the provided model over the dataset.
|
||||
TfLiteStatus EvaluateModel() const;
|
||||
TfLiteStatus EvaluateModel(const tflite::evaluation::DelegateProviders*
|
||||
delegate_providers = nullptr) const;
|
||||
|
||||
private:
|
||||
const Params params_;
|
||||
|
Loading…
x
Reference in New Issue
Block a user