Merge pull request #45934 from Intel-tensorflow:sriniva2/cnmspq

PiperOrigin-RevId: 351395114
Change-Id: Iac5cda7f1ca745318f350c59abe45bc32069fbad
This commit is contained in:
TensorFlower Gardener 2021-01-12 10:17:00 -08:00
commit 0c68f6b042

View File

@ -330,12 +330,13 @@ void DoNMSPerClass(int batch_idx, int class_idx, const float* boxes_data,
float score;
};
auto cmp = [](const Candidate bs_i, const Candidate bs_j) {
return bs_i.score > bs_j.score;
return bs_i.score < bs_j.score;
};
std::vector<Candidate> candidate_vector;
for (int i = 0; i < class_scores_data.size(); ++i) {
std::priority_queue<Candidate, std::vector<Candidate>, decltype(cmp)>
candidate_priority_queue(cmp);
for (int i = 0; i < num_boxes; ++i) {
if (class_scores_data[i] > score_threshold) {
candidate_vector.emplace_back(Candidate({i, class_scores_data[i]}));
candidate_priority_queue.emplace(Candidate({i, class_scores_data[i]}));
}
}
@ -343,17 +344,15 @@ void DoNMSPerClass(int batch_idx, int class_idx, const float* boxes_data,
std::vector<float> selected_boxes;
Candidate next_candidate;
std::sort(candidate_vector.begin(), candidate_vector.end(), cmp);
// Move class_boxes_data to a tensor
Eigen::array<Eigen::DenseIndex, 2> boxesShape = {num_boxes, 4};
typename TTypes<float, 2>::ConstTensor boxes_data_t(class_boxes_data.data(),
boxesShape);
int candidate_idx = 0;
float iou;
while (selected.size() < size_per_class &&
candidate_idx < candidate_vector.size()) {
next_candidate = candidate_vector[candidate_idx++];
!candidate_priority_queue.empty()) {
next_candidate = candidate_priority_queue.top();
candidate_priority_queue.pop();
// Overlapping boxes are likely to have similar scores,
// therefore we iterate through the previously selected boxes backwards
// in order to see if `next_candidate` should be suppressed.