Moved result limiting to ModelState instead of CTC decoder
This commit is contained in:
parent
969b2ac4ba
commit
e0c42f01a4
@ -157,7 +157,7 @@ DecoderState::next(const double *probs,
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Output>
|
std::vector<Output>
|
||||||
DecoderState::decode(size_t top_paths) const
|
DecoderState::decode() const
|
||||||
{
|
{
|
||||||
std::vector<PathTrie*> prefixes_copy = prefixes_;
|
std::vector<PathTrie*> prefixes_copy = prefixes_;
|
||||||
std::unordered_map<const PathTrie*, float> scores;
|
std::unordered_map<const PathTrie*, float> scores;
|
||||||
@ -167,7 +167,7 @@ DecoderState::decode(size_t top_paths) const
|
|||||||
|
|
||||||
// score the last word of each prefix that doesn't end with space
|
// score the last word of each prefix that doesn't end with space
|
||||||
if (ext_scorer_) {
|
if (ext_scorer_) {
|
||||||
for (size_t i = 0; i < top_paths && i < prefixes_copy.size(); ++i) {
|
for (size_t i = 0; i < beam_size_ && i < prefixes_copy.size(); ++i) {
|
||||||
auto prefix = prefixes_copy[i];
|
auto prefix = prefixes_copy[i];
|
||||||
if (!ext_scorer_->is_scoring_boundary(prefix->parent, prefix->character)) {
|
if (!ext_scorer_->is_scoring_boundary(prefix->parent, prefix->character)) {
|
||||||
float score = 0.0;
|
float score = 0.0;
|
||||||
@ -181,13 +181,13 @@ DecoderState::decode(size_t top_paths) const
|
|||||||
}
|
}
|
||||||
|
|
||||||
using namespace std::placeholders;
|
using namespace std::placeholders;
|
||||||
size_t num_prefixes = std::min(prefixes_copy.size(), top_paths);
|
size_t num_prefixes = std::min(prefixes_copy.size(), beam_size_);
|
||||||
std::partial_sort(prefixes_copy.begin(),
|
std::partial_sort(prefixes_copy.begin(),
|
||||||
prefixes_copy.begin() + num_prefixes,
|
prefixes_copy.begin() + num_prefixes,
|
||||||
prefixes_copy.end(),
|
prefixes_copy.end(),
|
||||||
std::bind(prefix_compare_external, _1, _2, scores));
|
std::bind(prefix_compare_external, _1, _2, scores));
|
||||||
|
|
||||||
size_t num_returned = std::min(num_prefixes, top_paths);
|
size_t num_returned = std::min(num_prefixes, beam_size_);
|
||||||
|
|
||||||
std::vector<Output> outputs;
|
std::vector<Output> outputs;
|
||||||
outputs.reserve(num_returned);
|
outputs.reserve(num_returned);
|
||||||
@ -218,7 +218,6 @@ std::vector<Output> ctc_beam_search_decoder(
|
|||||||
int class_dim,
|
int class_dim,
|
||||||
const Alphabet &alphabet,
|
const Alphabet &alphabet,
|
||||||
size_t beam_size,
|
size_t beam_size,
|
||||||
size_t top_paths,
|
|
||||||
double cutoff_prob,
|
double cutoff_prob,
|
||||||
size_t cutoff_top_n,
|
size_t cutoff_top_n,
|
||||||
std::shared_ptr<Scorer> ext_scorer)
|
std::shared_ptr<Scorer> ext_scorer)
|
||||||
@ -226,7 +225,7 @@ std::vector<Output> ctc_beam_search_decoder(
|
|||||||
DecoderState state;
|
DecoderState state;
|
||||||
state.init(alphabet, beam_size, cutoff_prob, cutoff_top_n, ext_scorer);
|
state.init(alphabet, beam_size, cutoff_prob, cutoff_top_n, ext_scorer);
|
||||||
state.next(probs, time_dim, class_dim);
|
state.next(probs, time_dim, class_dim);
|
||||||
return state.decode(top_paths);
|
return state.decode();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::vector<Output>>
|
std::vector<std::vector<Output>>
|
||||||
@ -239,7 +238,6 @@ ctc_beam_search_decoder_batch(
|
|||||||
int seq_lengths_size,
|
int seq_lengths_size,
|
||||||
const Alphabet &alphabet,
|
const Alphabet &alphabet,
|
||||||
size_t beam_size,
|
size_t beam_size,
|
||||||
size_t top_paths,
|
|
||||||
size_t num_processes,
|
size_t num_processes,
|
||||||
double cutoff_prob,
|
double cutoff_prob,
|
||||||
size_t cutoff_top_n,
|
size_t cutoff_top_n,
|
||||||
@ -259,7 +257,6 @@ ctc_beam_search_decoder_batch(
|
|||||||
class_dim,
|
class_dim,
|
||||||
alphabet,
|
alphabet,
|
||||||
beam_size,
|
beam_size,
|
||||||
top_paths,
|
|
||||||
cutoff_prob,
|
cutoff_prob,
|
||||||
cutoff_top_n,
|
cutoff_top_n,
|
||||||
ext_scorer));
|
ext_scorer));
|
||||||
|
@ -148,21 +148,7 @@ StreamingState::finishStreamWithMetadata(unsigned int num_results)
|
|||||||
{
|
{
|
||||||
finalizeStream();
|
finalizeStream();
|
||||||
|
|
||||||
vector<Metadata*> metadata = model_->decode_metadata(decoder_state_, numResults);
|
return model_->decode_metadata(decoder_state_, num_results);
|
||||||
|
|
||||||
std::unique_ptr<Result> result(new Result());
|
|
||||||
result->num_transcriptions = metadata.size();
|
|
||||||
|
|
||||||
std::unique_ptr<Metadata[]> items(new Metadata[result->num_transcriptions]);
|
|
||||||
|
|
||||||
for (int i = 0; i < result->num_transcriptions; ++i) {
|
|
||||||
std::unique_ptr<Metadata> pointer(new Metadata(*metadata[i]));
|
|
||||||
items[i] = *pointer.release();
|
|
||||||
}
|
|
||||||
|
|
||||||
result->transcriptions = items.release();
|
|
||||||
|
|
||||||
return result.release();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
|
@ -32,22 +32,25 @@ ModelState::init(const char* model_path)
|
|||||||
char*
|
char*
|
||||||
ModelState::decode(const DecoderState& state) const
|
ModelState::decode(const DecoderState& state) const
|
||||||
{
|
{
|
||||||
vector<Output> out = state.decode(1);
|
vector<Output> out = state.decode();
|
||||||
return strdup(alphabet_.LabelsToString(out[0].tokens).c_str());
|
return strdup(alphabet_.LabelsToString(out[0].tokens).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
vector<Metadata*>
|
Result*
|
||||||
ModelState::decode_metadata(const DecoderState& state,
|
ModelState::decode_metadata(const DecoderState& state,
|
||||||
size_t top_paths)
|
size_t num_results)
|
||||||
{
|
{
|
||||||
vector<Output> out = state.decode(top_paths);
|
vector<Output> out = state.decode();
|
||||||
|
|
||||||
vector<Metadata*> meta_out;
|
size_t max_results = std::min(num_results, out.size());
|
||||||
|
|
||||||
size_t max_results = std::min(top_paths, out.size());
|
std::unique_ptr<Result> result(new Result());
|
||||||
|
result->num_transcriptions = max_results;
|
||||||
|
|
||||||
|
std::unique_ptr<Metadata[]> transcripts(new Metadata[max_results]());
|
||||||
|
|
||||||
for (int j = 0; j < max_results; ++j) {
|
for (int j = 0; j < max_results; ++j) {
|
||||||
std::unique_ptr<Metadata> metadata(new Metadata());
|
Metadata* metadata = &transcripts[j];
|
||||||
metadata->num_items = out[j].tokens.size();
|
metadata->num_items = out[j].tokens.size();
|
||||||
metadata->confidence = out[j].confidence;
|
metadata->confidence = out[j].confidence;
|
||||||
|
|
||||||
@ -65,8 +68,9 @@ ModelState::decode_metadata(const DecoderState& state,
|
|||||||
}
|
}
|
||||||
|
|
||||||
metadata->items = items.release();
|
metadata->items = items.release();
|
||||||
meta_out.push_back(metadata.release());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return meta_out;
|
result->transcriptions = transcripts.release();
|
||||||
|
|
||||||
|
return result.release();
|
||||||
}
|
}
|
||||||
|
@ -66,14 +66,14 @@ struct ModelState {
|
|||||||
* @brief Return character-level metadata including letter timings.
|
* @brief Return character-level metadata including letter timings.
|
||||||
*
|
*
|
||||||
* @param state Decoder state to use when decoding.
|
* @param state Decoder state to use when decoding.
|
||||||
* @param top_paths Number of alternate results to return.
|
* @param num_results Number of alternate results to return.
|
||||||
*
|
*
|
||||||
* @return Vector of Metadata structs containing MetadataItem structs for each character.
|
* @return A Result struct containing Metadata structs.
|
||||||
* Each represents an alternate transcription, with the first ranked most probable.
|
* Each represents an alternate transcription, with the first ranked most probable.
|
||||||
* The user is responsible for freeing Metadata by calling DS_FreeMetadata() on each item.
|
* The user is responsible for freeing Result by calling DS_FreeResult().
|
||||||
*/
|
*/
|
||||||
virtual std::vector<Metadata*> decode_metadata(const DecoderState& state,
|
virtual Result* decode_metadata(const DecoderState& state,
|
||||||
size_t top_paths);
|
size_t num_results);
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // MODELSTATE_H
|
#endif // MODELSTATE_H
|
||||||
|
Loading…
Reference in New Issue
Block a user