Replace incomplete sorts with partial sorts

This commit is contained in:
Reuben Morais 2019-10-14 11:46:39 +02:00
parent 739841d731
commit 3015237e8d

View File

@ -59,8 +59,10 @@ DecoderState::next(const double *probs,
bool full_beam = false;
if (ext_scorer_ != nullptr) {
size_t num_prefixes = std::min(prefixes_.size(), beam_size_);
std::sort(
prefixes_.begin(), prefixes_.begin() + num_prefixes, prefix_compare);
std::partial_sort(prefixes_.begin(),
prefixes_.begin() + num_prefixes,
prefixes_.end(),
prefix_compare);
min_cutoff = prefixes_[num_prefixes - 1]->score +
std::log(prob[blank_id_]) - std::max(0.0, ext_scorer_->beta);
@ -177,7 +179,10 @@ DecoderState::decode() const
using namespace std::placeholders;
size_t num_prefixes = std::min(prefixes_copy.size(), beam_size_);
std::sort(prefixes_copy.begin(), prefixes_copy.begin() + num_prefixes, std::bind(prefix_compare_external, _1, _2, scores));
std::partial_sort(prefixes_copy.begin(),
prefixes_copy.begin() + num_prefixes,
prefixes_copy.end(),
std::bind(prefix_compare_external, _1, _2, scores));
//TODO: expose this as an API parameter
const int top_paths = 1;