From 668bd2a76fb45172b32a956c93682aaefd9323dc Mon Sep 17 00:00:00 2001 From: Dayananda-V Date: Thu, 4 Jul 2019 19:32:00 +0530 Subject: [PATCH] [Lite] num_ranks option brings to user through cmd line parameter num_ranks optional parameters provides to user to calculate accuracy ranks on out file. --- tensorflow/lite/tools/accuracy/ilsvrc/README.md | 4 ++++ .../lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/tensorflow/lite/tools/accuracy/ilsvrc/README.md b/tensorflow/lite/tools/accuracy/ilsvrc/README.md index 6e27c8570f3..4d61aa1d854 100644 --- a/tensorflow/lite/tools/accuracy/ilsvrc/README.md +++ b/tensorflow/lite/tools/accuracy/ilsvrc/README.md @@ -56,6 +56,10 @@ and the following optional parameters: Optionally, the computed accuracies can be output to a file as a string-serialized instance of tflite::evaluation::TopkAccuracyEvalMetrics. +* `num_ranks`: `int` (default=10) \ + The number of top-K accuracies to return. For example, if num_ranks=5, + top-1 to top-5 accuracy fractions are returned. + The following optional parameters can be used to modify the inference runtime: * `num_interpreter_threads`: `int` (default=1) \ diff --git a/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc index d7230976961..f296b89b583 100644 --- a/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc +++ b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc @@ -49,6 +49,7 @@ constexpr char kInterpreterThreadsFlag[] = "num_interpreter_threads"; constexpr char kDelegateFlag[] = "delegate"; constexpr char kNnapiDelegate[] = "nnapi"; constexpr char kGpuDelegate[] = "gpu"; +constexpr char kNumRanksFlag[] = "num_ranks"; template std::vector GetFirstN(const std::vector& v, int n) { @@ -144,6 +145,9 @@ class CompositeObserver : public ImagenetModelEvaluator::Observer { tflite::Flag::CreateFlag(kDelegateFlag, ¶ms.delegate, "Delegate to use for inference, if available. " "Must be one of {'nnapi', 'gpu'}"), + tflite::Flag::CreateFlag(kNumRanksFlag, ¶ms.num_ranks, + "Generates the top-1 to top-k accuracy values" + "where k = num_ranks. Default: 10"), }; tflite::Flags::Parse(&argc, const_cast(argv), flag_list);