Merge pull request #30414 from Dayananda-V:tflite_tools_accuracy_num_ranks
PiperOrigin-RevId: 257208730
This commit is contained in:
commit
67a48c79a3
@ -56,6 +56,10 @@ and the following optional parameters:
|
|||||||
Optionally, the computed accuracies can be output to a file as a
|
Optionally, the computed accuracies can be output to a file as a
|
||||||
string-serialized instance of tflite::evaluation::TopkAccuracyEvalMetrics.
|
string-serialized instance of tflite::evaluation::TopkAccuracyEvalMetrics.
|
||||||
|
|
||||||
|
* `num_ranks`: `int` (default=10) \
|
||||||
|
The number of top-K accuracies to return. For example, if num_ranks=5, top-1
|
||||||
|
to top-5 accuracy fractions are returned.
|
||||||
|
|
||||||
The following optional parameters can be used to modify the inference runtime:
|
The following optional parameters can be used to modify the inference runtime:
|
||||||
|
|
||||||
* `num_interpreter_threads`: `int` (default=1) \
|
* `num_interpreter_threads`: `int` (default=1) \
|
||||||
|
@ -49,6 +49,7 @@ constexpr char kInterpreterThreadsFlag[] = "num_interpreter_threads";
|
|||||||
constexpr char kDelegateFlag[] = "delegate";
|
constexpr char kDelegateFlag[] = "delegate";
|
||||||
constexpr char kNnapiDelegate[] = "nnapi";
|
constexpr char kNnapiDelegate[] = "nnapi";
|
||||||
constexpr char kGpuDelegate[] = "gpu";
|
constexpr char kGpuDelegate[] = "gpu";
|
||||||
|
constexpr char kNumRanksFlag[] = "num_ranks";
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::vector<T> GetFirstN(const std::vector<T>& v, int n) {
|
std::vector<T> GetFirstN(const std::vector<T>& v, int n) {
|
||||||
@ -144,6 +145,9 @@ class CompositeObserver : public ImagenetModelEvaluator::Observer {
|
|||||||
tflite::Flag::CreateFlag(kDelegateFlag, ¶ms.delegate,
|
tflite::Flag::CreateFlag(kDelegateFlag, ¶ms.delegate,
|
||||||
"Delegate to use for inference, if available. "
|
"Delegate to use for inference, if available. "
|
||||||
"Must be one of {'nnapi', 'gpu'}"),
|
"Must be one of {'nnapi', 'gpu'}"),
|
||||||
|
tflite::Flag::CreateFlag(kNumRanksFlag, ¶ms.num_ranks,
|
||||||
|
"Generates the top-1 to top-k accuracy values"
|
||||||
|
"where k = num_ranks. Default: 10"),
|
||||||
};
|
};
|
||||||
tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
|
tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user