The CTC decoder timesteps now corresponds to the timesteps of the most

probable CTC path, instead of the earliest timesteps of all possible paths.
This commit is contained in:
godeffroy 2020-08-25 12:08:14 +02:00
parent a54b198d1e
commit 04a36fbf68
4 changed files with 64 additions and 33 deletions

View File

@ -96,24 +96,52 @@ DecoderState::next(const double *probs,
if (full_beam && log_prob_c + prefix->score < min_cutoff) { if (full_beam && log_prob_c + prefix->score < min_cutoff) {
break; break;
} }
if (prefix->score == -NUM_FLT_INF) {
continue;
}
if (!prefix->is_empty() && prefix->timesteps.empty()) {
// This should never happen. But we report it if it does.
std::cerr<<"error: non-empty prefix has empty timestep sequence"<<std::endl;
continue;
}
// blank // blank
if (c == blank_id_) { if (c == blank_id_) {
// compute probability of current path
float log_p = log_prob_c + prefix->score;
// combine current path with previous ones with the same prefix
// the blank label comes last, so we can compare log_prob_nb_cur with log_p
if (prefix->log_prob_nb_cur < log_p) {
prefix->timesteps_cur = prefix->timesteps;
}
prefix->log_prob_b_cur = prefix->log_prob_b_cur =
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score); log_sum_exp(prefix->log_prob_b_cur, log_p);
continue; continue;
} }
// repeated character // repeated character
if (c == prefix->character) { if (c == prefix->character) {
// compute probability of current path
float log_p = log_prob_c + prefix->log_prob_nb_prev;
// combine current path with previous ones with the same prefix
if (prefix->log_prob_nb_cur < log_p) {
prefix->timesteps_cur = prefix->timesteps;
}
prefix->log_prob_nb_cur = log_sum_exp( prefix->log_prob_nb_cur = log_sum_exp(
prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev); prefix->log_prob_nb_cur, log_p);
} }
// get new prefix // get new prefix
auto prefix_new = prefix->get_path_trie(c, abs_time_step_, log_prob_c); auto prefix_new = prefix->get_path_trie(c, log_prob_c);
if (prefix_new != nullptr) { if (prefix_new != nullptr) {
// compute timesteps of current path
std::vector<unsigned int> timesteps_new=prefix->timesteps;
timesteps_new.push_back(abs_time_step_);
// compute probability of current path
float log_p = -NUM_FLT_INF; float log_p = -NUM_FLT_INF;
if (c == prefix->character && if (c == prefix->character &&
@ -144,6 +172,10 @@ DecoderState::next(const double *probs,
} }
} }
// combine current path with previous ones with the same prefix
if (prefix_new->log_prob_nb_cur < log_p) {
prefix_new->timesteps_cur = timesteps_new;
}
prefix_new->log_prob_nb_cur = prefix_new->log_prob_nb_cur =
log_sum_exp(prefix_new->log_prob_nb_cur, log_p); log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
} }
@ -205,11 +237,13 @@ DecoderState::decode(size_t num_results) const
std::vector<Output> outputs; std::vector<Output> outputs;
outputs.reserve(num_returned); outputs.reserve(num_returned);
for (size_t i = 0; i < num_returned; ++i) { for (PathTrie* prefix : prefixes_copy) {
Output output; Output output;
prefixes_copy[i]->get_path_vec(output.tokens, output.timesteps); output.tokens = prefix->get_path_vec();
output.confidence = scores[prefixes_copy[i]]; output.timesteps = prefix->timesteps;
output.confidence = scores[prefix];
outputs.push_back(output); outputs.push_back(output);
if(outputs.size()>=num_returned) break;
} }
return outputs; return outputs;

View File

@ -18,7 +18,6 @@ PathTrie::PathTrie() {
ROOT_ = -1; ROOT_ = -1;
character = ROOT_; character = ROOT_;
timestep = 0;
exists_ = true; exists_ = true;
parent = nullptr; parent = nullptr;
@ -35,7 +34,7 @@ PathTrie::~PathTrie() {
} }
} }
PathTrie* PathTrie::get_path_trie(unsigned int new_char, unsigned int new_timestep, float cur_log_prob_c, bool reset) { PathTrie* PathTrie::get_path_trie(unsigned int new_char, float cur_log_prob_c, bool reset) {
auto child = children_.begin(); auto child = children_.begin();
for (; child != children_.end(); ++child) { for (; child != children_.end(); ++child) {
if (child->first == new_char) { if (child->first == new_char) {
@ -67,7 +66,6 @@ PathTrie* PathTrie::get_path_trie(unsigned int new_char, unsigned int new_timest
} else { } else {
PathTrie* new_path = new PathTrie; PathTrie* new_path = new PathTrie;
new_path->character = new_char; new_path->character = new_char;
new_path->timestep = new_timestep;
new_path->parent = this; new_path->parent = this;
new_path->dictionary_ = dictionary_; new_path->dictionary_ = dictionary_;
new_path->has_dictionary_ = true; new_path->has_dictionary_ = true;
@ -93,7 +91,6 @@ PathTrie* PathTrie::get_path_trie(unsigned int new_char, unsigned int new_timest
} else { } else {
PathTrie* new_path = new PathTrie; PathTrie* new_path = new PathTrie;
new_path->character = new_char; new_path->character = new_char;
new_path->timestep = new_timestep;
new_path->parent = this; new_path->parent = this;
new_path->log_prob_c = cur_log_prob_c; new_path->log_prob_c = cur_log_prob_c;
children_.push_back(std::make_pair(new_char, new_path)); children_.push_back(std::make_pair(new_char, new_path));
@ -102,20 +99,18 @@ PathTrie* PathTrie::get_path_trie(unsigned int new_char, unsigned int new_timest
} }
} }
void PathTrie::get_path_vec(std::vector<unsigned int>& output, std::vector<unsigned int>& timesteps) { std::vector<unsigned int> PathTrie::get_path_vec() {
// Recursive call: recurse back until stop condition, then append data in if (parent == nullptr) {
// correct order as we walk back down the stack in the lines below. return std::vector<unsigned int>{};
if (parent != nullptr) {
parent->get_path_vec(output, timesteps);
} }
std::vector<unsigned int> output_tokens=parent->get_path_vec();
if (character != ROOT_) { if (character != ROOT_) {
output.push_back(character); output_tokens.push_back(character);
timesteps.push_back(timestep);
} }
return output_tokens;
} }
PathTrie* PathTrie::get_prev_grapheme(std::vector<unsigned int>& output, PathTrie* PathTrie::get_prev_grapheme(std::vector<unsigned int>& output,
std::vector<unsigned int>& timesteps,
const Alphabet& alphabet) const Alphabet& alphabet)
{ {
PathTrie* stop = this; PathTrie* stop = this;
@ -125,10 +120,9 @@ PathTrie* PathTrie::get_prev_grapheme(std::vector<unsigned int>& output,
// Recursive call: recurse back until stop condition, then append data in // Recursive call: recurse back until stop condition, then append data in
// correct order as we walk back down the stack in the lines below. // correct order as we walk back down the stack in the lines below.
if (!byte_is_codepoint_boundary(alphabet.DecodeSingle(character)[0])) { if (!byte_is_codepoint_boundary(alphabet.DecodeSingle(character)[0])) {
stop = parent->get_prev_grapheme(output, timesteps, alphabet); stop = parent->get_prev_grapheme(output, alphabet);
} }
output.push_back(character); output.push_back(character);
timesteps.push_back(timestep);
return stop; return stop;
} }
@ -147,7 +141,6 @@ int PathTrie::distance_to_codepoint_boundary(unsigned char *first_byte,
} }
PathTrie* PathTrie::get_prev_word(std::vector<unsigned int>& output, PathTrie* PathTrie::get_prev_word(std::vector<unsigned int>& output,
std::vector<unsigned int>& timesteps,
const Alphabet& alphabet) const Alphabet& alphabet)
{ {
PathTrie* stop = this; PathTrie* stop = this;
@ -157,10 +150,9 @@ PathTrie* PathTrie::get_prev_word(std::vector<unsigned int>& output,
// Recursive call: recurse back until stop condition, then append data in // Recursive call: recurse back until stop condition, then append data in
// correct order as we walk back down the stack in the lines below. // correct order as we walk back down the stack in the lines below.
if (parent != nullptr) { if (parent != nullptr) {
stop = parent->get_prev_word(output, timesteps, alphabet); stop = parent->get_prev_word(output, alphabet);
} }
output.push_back(character); output.push_back(character);
timesteps.push_back(timestep);
return stop; return stop;
} }
@ -173,6 +165,10 @@ void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
log_prob_nb_cur = -NUM_FLT_INF; log_prob_nb_cur = -NUM_FLT_INF;
score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev); score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev);
timesteps = std::move(timesteps_cur);
timesteps_cur.clear();
output.push_back(this); output.push_back(this);
} }
for (auto child : children_) { for (auto child : children_) {
@ -229,8 +225,8 @@ void PathTrie::print(const Alphabet& a) {
} }
} }
printf("\ntimesteps:\t "); printf("\ntimesteps:\t ");
for (PathTrie* el : chain) { for (unsigned int timestep : timesteps) {
printf("%d ", el->timestep); printf("%d ", timestep);
} }
printf("\n"); printf("\n");
printf("transcript:\t %s\n", tr.c_str()); printf("transcript:\t %s\n", tr.c_str());

View File

@ -21,14 +21,13 @@ public:
~PathTrie(); ~PathTrie();
// get new prefix after appending new char // get new prefix after appending new char
PathTrie* get_path_trie(unsigned int new_char, unsigned int new_timestep, float log_prob_c, bool reset = true); PathTrie* get_path_trie(unsigned int new_char, float log_prob_c, bool reset = true);
// get the prefix data in correct time order from root to current node // get the prefix data in correct time order from root to current node
void get_path_vec(std::vector<unsigned int>& output, std::vector<unsigned int>& timesteps); std::vector<unsigned int> get_path_vec();
// get the prefix data in correct time order from beginning of last grapheme to current node // get the prefix data in correct time order from beginning of last grapheme to current node
PathTrie* get_prev_grapheme(std::vector<unsigned int>& output, PathTrie* get_prev_grapheme(std::vector<unsigned int>& output,
std::vector<unsigned int>& timesteps,
const Alphabet& alphabet); const Alphabet& alphabet);
// get the distance from current node to the first codepoint boundary, and the byte value at the boundary // get the distance from current node to the first codepoint boundary, and the byte value at the boundary
@ -36,7 +35,6 @@ public:
// get the prefix data in correct time order from beginning of last word to current node // get the prefix data in correct time order from beginning of last word to current node
PathTrie* get_prev_word(std::vector<unsigned int>& output, PathTrie* get_prev_word(std::vector<unsigned int>& output,
std::vector<unsigned int>& timesteps,
const Alphabet& alphabet); const Alphabet& alphabet);
// update log probs // update log probs
@ -65,7 +63,10 @@ public:
float score; float score;
float approx_ctc; float approx_ctc;
unsigned int character; unsigned int character;
unsigned int timestep; std::vector<unsigned int> timesteps;
// `timesteps_cur` is a temporary storage for each decoding step.
// At the end of a decoding step, it is moved to `timesteps`.
std::vector<unsigned int> timesteps_cur;
PathTrie* parent; PathTrie* parent;
private: private:

View File

@ -293,12 +293,12 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix)
} }
std::vector<unsigned int> prefix_vec; std::vector<unsigned int> prefix_vec;
std::vector<unsigned int> prefix_steps; std::vector<unsigned int> prefix_steps = current_node->timesteps;
if (is_utf8_mode_) { if (is_utf8_mode_) {
new_node = current_node->get_prev_grapheme(prefix_vec, prefix_steps, alphabet_); new_node = current_node->get_prev_grapheme(prefix_vec, alphabet_);
} else { } else {
new_node = current_node->get_prev_word(prefix_vec, prefix_steps, alphabet_); new_node = current_node->get_prev_word(prefix_vec, alphabet_);
} }
current_node = new_node->parent; current_node = new_node->parent;