Merge pull request #30414 from Dayananda-V:tflite_tools_accuracy_num_ranks

PiperOrigin-RevId: 257208730
This commit is contained in:
TensorFlower Gardener 2019-07-09 09:34:56 -07:00
commit 67a48c79a3
2 changed files with 8 additions and 0 deletions

View File

@ -56,6 +56,10 @@ and the following optional parameters:
Optionally, the computed accuracies can be output to a file as a Optionally, the computed accuracies can be output to a file as a
string-serialized instance of tflite::evaluation::TopkAccuracyEvalMetrics. string-serialized instance of tflite::evaluation::TopkAccuracyEvalMetrics.
* `num_ranks`: `int` (default=10) \
The number of top-K accuracies to return. For example, if num_ranks=5, top-1
to top-5 accuracy fractions are returned.
The following optional parameters can be used to modify the inference runtime: The following optional parameters can be used to modify the inference runtime:
* `num_interpreter_threads`: `int` (default=1) \ * `num_interpreter_threads`: `int` (default=1) \

View File

@ -49,6 +49,7 @@ constexpr char kInterpreterThreadsFlag[] = "num_interpreter_threads";
constexpr char kDelegateFlag[] = "delegate"; constexpr char kDelegateFlag[] = "delegate";
constexpr char kNnapiDelegate[] = "nnapi"; constexpr char kNnapiDelegate[] = "nnapi";
constexpr char kGpuDelegate[] = "gpu"; constexpr char kGpuDelegate[] = "gpu";
constexpr char kNumRanksFlag[] = "num_ranks";
template <typename T> template <typename T>
std::vector<T> GetFirstN(const std::vector<T>& v, int n) { std::vector<T> GetFirstN(const std::vector<T>& v, int n) {
@ -144,6 +145,9 @@ class CompositeObserver : public ImagenetModelEvaluator::Observer {
tflite::Flag::CreateFlag(kDelegateFlag, &params.delegate, tflite::Flag::CreateFlag(kDelegateFlag, &params.delegate,
"Delegate to use for inference, if available. " "Delegate to use for inference, if available. "
"Must be one of {'nnapi', 'gpu'}"), "Must be one of {'nnapi', 'gpu'}"),
tflite::Flag::CreateFlag(kNumRanksFlag, &params.num_ranks,
"Generates the top-1 to top-k accuracy values"
"where k = num_ranks. Default: 10"),
}; };
tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list); tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);