Utilize globally-registered delegate providers to initialize tflite delegates in the imagenet-based image classification accuracy evaluation tool.

PiperOrigin-RevId: 306357221
Change-Id: I34a9fbf22e3d928c1c2b376b93c6d792c74fa94c
This commit is contained in:
Chao Mei 2020-04-13 19:03:51 -07:00 committed by TensorFlower Gardener
parent 258d16bc66
commit 5b5876b58d
4 changed files with 28 additions and 20 deletions

View File

@ -58,5 +58,6 @@ cc_binary(
deps = [
":imagenet_accuracy_eval_lib",
"//tensorflow/lite/tools:command_line_flags",
"//tensorflow/lite/tools/evaluation:evaluation_delegate_provider",
],
)

View File

@ -15,19 +15,20 @@ limitations under the License.
#include "tensorflow/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.h"
#include "tensorflow/lite/tools/command_line_flags.h"
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
namespace {
constexpr char kNumThreadsFlag[] = "num_threads";
constexpr char kNumEvalThreadsFlag[] = "num_eval_threads";
constexpr char kOutputFilePathFlag[] = "output_file_path";
constexpr char kProtoOutputFilePathFlag[] = "proto_output_file_path";
} // namespace
int main(int argc, char* argv[]) {
std::string output_file_path, proto_output_file_path;
int num_threads = 4;
int num_eval_threads = 4;
std::vector<tflite::Flag> flag_list = {
tflite::Flag::CreateFlag(kNumThreadsFlag, &num_threads,
"Number of threads."),
tflite::Flag::CreateFlag(kNumEvalThreadsFlag, &num_eval_threads,
"Number of threads used for evaluation."),
tflite::Flag::CreateFlag(kOutputFilePathFlag, &output_file_path,
"Path to output file."),
tflite::Flag::CreateFlag(kProtoOutputFilePathFlag,
@ -36,14 +37,17 @@ int main(int argc, char* argv[]) {
};
tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
if (num_threads <= 0) {
if (num_eval_threads <= 0) {
LOG(ERROR) << "Invalid number of threads.";
return EXIT_FAILURE;
}
tflite::evaluation::DelegateProviders delegate_providers;
delegate_providers.InitFromCmdlineArgs(&argc, const_cast<const char**>(argv));
std::unique_ptr<tensorflow::metrics::ImagenetModelEvaluator> evaluator =
tensorflow::metrics::CreateImagenetModelEvaluator(&argc, argv,
num_threads);
num_eval_threads);
if (!evaluator) {
LOG(ERROR) << "Fail to create the ImagenetModelEvaluator.";
@ -59,8 +63,8 @@ int main(int argc, char* argv[]) {
}
evaluator->AddObserver(writer.get());
LOG(ERROR) << "Starting evaluation with: " << num_threads << " threads.";
if (evaluator->EvaluateModel() != kTfLiteOk) {
LOG(ERROR) << "Starting evaluation with: " << num_eval_threads << " threads.";
if (evaluator->EvaluateModel(&delegate_providers) != kTfLiteOk) {
LOG(ERROR) << "Failed to evaluate the model!";
return EXIT_FAILURE;
}

View File

@ -26,7 +26,6 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/tools/command_line_flags.h"
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
#include "tensorflow/lite/tools/evaluation/stages/image_classification_stage.h"
@ -155,12 +154,12 @@ class CompositeObserver : public ImagenetModelEvaluator::Observer {
return kTfLiteOk;
}
TfLiteStatus EvaluateModelForShard(const uint64_t shard_id,
const std::vector<ImageLabel>& image_labels,
const std::vector<std::string>& model_labels,
const ImagenetModelEvaluator::Params& params,
ImagenetModelEvaluator::Observer* observer,
int num_ranks) {
TfLiteStatus EvaluateModelForShard(
const uint64_t shard_id, const std::vector<ImageLabel>& image_labels,
const std::vector<std::string>& model_labels,
const ImagenetModelEvaluator::Params& params,
ImagenetModelEvaluator::Observer* observer, int num_ranks,
const tflite::evaluation::DelegateProviders* delegate_providers) {
tflite::evaluation::EvaluationStageConfig eval_config;
eval_config.set_name("image_classification");
auto* classification_params = eval_config.mutable_specification()
@ -174,7 +173,7 @@ TfLiteStatus EvaluateModelForShard(const uint64_t shard_id,
tflite::evaluation::ImageClassificationStage eval(eval_config);
eval.SetAllLabels(model_labels);
TF_LITE_ENSURE_STATUS(eval.Init());
TF_LITE_ENSURE_STATUS(eval.Init(delegate_providers));
for (const auto& image_label : image_labels) {
eval.SetInputs(image_label.image, image_label.label);
@ -191,7 +190,8 @@ TfLiteStatus EvaluateModelForShard(const uint64_t shard_id,
return kTfLiteOk;
}
TfLiteStatus ImagenetModelEvaluator::EvaluateModel() const {
TfLiteStatus ImagenetModelEvaluator::EvaluateModel(
const tflite::evaluation::DelegateProviders* delegate_providers) const {
const std::string data_path = tflite::evaluation::StripTrailingSlashes(
params_.ground_truth_images_path) +
"/";
@ -252,9 +252,10 @@ TfLiteStatus ImagenetModelEvaluator::EvaluateModel() const {
const uint64_t shard_id = i + 1;
shard_id_image_count_map[shard_id] = image_label.size();
auto func = [shard_id, &image_label, &model_labels, this, &observer,
&all_okay]() {
&all_okay, delegate_providers]() {
if (EvaluateModelForShard(shard_id, image_label, model_labels, params_,
&observer, params_.num_ranks) != kTfLiteOk) {
&observer, params_.num_ranks,
delegate_providers) != kTfLiteOk) {
all_okay = false;
}
};

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
namespace tensorflow {
@ -120,7 +121,8 @@ class ImagenetModelEvaluator {
const Params& params() const { return params_; }
// Evaluates the provided model over the dataset.
TfLiteStatus EvaluateModel() const;
TfLiteStatus EvaluateModel(const tflite::evaluation::DelegateProviders*
delegate_providers = nullptr) const;
private:
const Params params_;