Moved result limiting to ModelState instead of CTC decoder

This commit is contained in:
dabinat 2020-02-14 19:17:52 -08:00 committed by Reuben Morais
parent 969b2ac4ba
commit e0c42f01a4
4 changed files with 24 additions and 37 deletions

View File

@ -157,7 +157,7 @@ DecoderState::next(const double *probs,
} }
std::vector<Output> std::vector<Output>
DecoderState::decode(size_t top_paths) const DecoderState::decode() const
{ {
std::vector<PathTrie*> prefixes_copy = prefixes_; std::vector<PathTrie*> prefixes_copy = prefixes_;
std::unordered_map<const PathTrie*, float> scores; std::unordered_map<const PathTrie*, float> scores;
@ -167,7 +167,7 @@ DecoderState::decode(size_t top_paths) const
// score the last word of each prefix that doesn't end with space // score the last word of each prefix that doesn't end with space
if (ext_scorer_) { if (ext_scorer_) {
for (size_t i = 0; i < top_paths && i < prefixes_copy.size(); ++i) { for (size_t i = 0; i < beam_size_ && i < prefixes_copy.size(); ++i) {
auto prefix = prefixes_copy[i]; auto prefix = prefixes_copy[i];
if (!ext_scorer_->is_scoring_boundary(prefix->parent, prefix->character)) { if (!ext_scorer_->is_scoring_boundary(prefix->parent, prefix->character)) {
float score = 0.0; float score = 0.0;
@ -181,13 +181,13 @@ DecoderState::decode(size_t top_paths) const
} }
using namespace std::placeholders; using namespace std::placeholders;
size_t num_prefixes = std::min(prefixes_copy.size(), top_paths); size_t num_prefixes = std::min(prefixes_copy.size(), beam_size_);
std::partial_sort(prefixes_copy.begin(), std::partial_sort(prefixes_copy.begin(),
prefixes_copy.begin() + num_prefixes, prefixes_copy.begin() + num_prefixes,
prefixes_copy.end(), prefixes_copy.end(),
std::bind(prefix_compare_external, _1, _2, scores)); std::bind(prefix_compare_external, _1, _2, scores));
size_t num_returned = std::min(num_prefixes, top_paths); size_t num_returned = std::min(num_prefixes, beam_size_);
std::vector<Output> outputs; std::vector<Output> outputs;
outputs.reserve(num_returned); outputs.reserve(num_returned);
@ -218,7 +218,6 @@ std::vector<Output> ctc_beam_search_decoder(
int class_dim, int class_dim,
const Alphabet &alphabet, const Alphabet &alphabet,
size_t beam_size, size_t beam_size,
size_t top_paths,
double cutoff_prob, double cutoff_prob,
size_t cutoff_top_n, size_t cutoff_top_n,
std::shared_ptr<Scorer> ext_scorer) std::shared_ptr<Scorer> ext_scorer)
@ -226,7 +225,7 @@ std::vector<Output> ctc_beam_search_decoder(
DecoderState state; DecoderState state;
state.init(alphabet, beam_size, cutoff_prob, cutoff_top_n, ext_scorer); state.init(alphabet, beam_size, cutoff_prob, cutoff_top_n, ext_scorer);
state.next(probs, time_dim, class_dim); state.next(probs, time_dim, class_dim);
return state.decode(top_paths); return state.decode();
} }
std::vector<std::vector<Output>> std::vector<std::vector<Output>>
@ -239,7 +238,6 @@ ctc_beam_search_decoder_batch(
int seq_lengths_size, int seq_lengths_size,
const Alphabet &alphabet, const Alphabet &alphabet,
size_t beam_size, size_t beam_size,
size_t top_paths,
size_t num_processes, size_t num_processes,
double cutoff_prob, double cutoff_prob,
size_t cutoff_top_n, size_t cutoff_top_n,
@ -259,7 +257,6 @@ ctc_beam_search_decoder_batch(
class_dim, class_dim,
alphabet, alphabet,
beam_size, beam_size,
top_paths,
cutoff_prob, cutoff_prob,
cutoff_top_n, cutoff_top_n,
ext_scorer)); ext_scorer));

View File

@ -148,21 +148,7 @@ StreamingState::finishStreamWithMetadata(unsigned int num_results)
{ {
finalizeStream(); finalizeStream();
vector<Metadata*> metadata = model_->decode_metadata(decoder_state_, numResults); return model_->decode_metadata(decoder_state_, num_results);
std::unique_ptr<Result> result(new Result());
result->num_transcriptions = metadata.size();
std::unique_ptr<Metadata[]> items(new Metadata[result->num_transcriptions]);
for (int i = 0; i < result->num_transcriptions; ++i) {
std::unique_ptr<Metadata> pointer(new Metadata(*metadata[i]));
items[i] = *pointer.release();
}
result->transcriptions = items.release();
return result.release();
} }
void void

View File

@ -32,22 +32,25 @@ ModelState::init(const char* model_path)
char* char*
ModelState::decode(const DecoderState& state) const ModelState::decode(const DecoderState& state) const
{ {
vector<Output> out = state.decode(1); vector<Output> out = state.decode();
return strdup(alphabet_.LabelsToString(out[0].tokens).c_str()); return strdup(alphabet_.LabelsToString(out[0].tokens).c_str());
} }
vector<Metadata*> Result*
ModelState::decode_metadata(const DecoderState& state, ModelState::decode_metadata(const DecoderState& state,
size_t top_paths) size_t num_results)
{ {
vector<Output> out = state.decode(top_paths); vector<Output> out = state.decode();
vector<Metadata*> meta_out; size_t max_results = std::min(num_results, out.size());
size_t max_results = std::min(top_paths, out.size()); std::unique_ptr<Result> result(new Result());
result->num_transcriptions = max_results;
std::unique_ptr<Metadata[]> transcripts(new Metadata[max_results]());
for (int j = 0; j < max_results; ++j) { for (int j = 0; j < max_results; ++j) {
std::unique_ptr<Metadata> metadata(new Metadata()); Metadata* metadata = &transcripts[j];
metadata->num_items = out[j].tokens.size(); metadata->num_items = out[j].tokens.size();
metadata->confidence = out[j].confidence; metadata->confidence = out[j].confidence;
@ -65,8 +68,9 @@ ModelState::decode_metadata(const DecoderState& state,
} }
metadata->items = items.release(); metadata->items = items.release();
meta_out.push_back(metadata.release());
} }
return meta_out; result->transcriptions = transcripts.release();
return result.release();
} }

View File

@ -66,14 +66,14 @@ struct ModelState {
* @brief Return character-level metadata including letter timings. * @brief Return character-level metadata including letter timings.
* *
* @param state Decoder state to use when decoding. * @param state Decoder state to use when decoding.
* @param top_paths Number of alternate results to return. * @param num_results Number of alternate results to return.
* *
* @return Vector of Metadata structs containing MetadataItem structs for each character. * @return A Result struct containing Metadata structs.
* Each represents an alternate transcription, with the first ranked most probable. * Each represents an alternate transcription, with the first ranked most probable.
* The user is responsible for freeing Metadata by calling DS_FreeMetadata() on each item. * The user is responsible for freeing Result by calling DS_FreeResult().
*/ */
virtual std::vector<Metadata*> decode_metadata(const DecoderState& state, virtual Result* decode_metadata(const DecoderState& state,
size_t top_paths); size_t num_results);
}; };
#endif // MODELSTATE_H #endif // MODELSTATE_H