The CTC decoder timesteps now corresponds to the timesteps of the most
probable CTC path, instead of the earliest timesteps of all possible paths.
This commit is contained in:
parent
a54b198d1e
commit
04a36fbf68
@ -96,24 +96,52 @@ DecoderState::next(const double *probs,
|
||||
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
|
||||
break;
|
||||
}
|
||||
if (prefix->score == -NUM_FLT_INF) {
|
||||
continue;
|
||||
}
|
||||
if (!prefix->is_empty() && prefix->timesteps.empty()) {
|
||||
// This should never happen. But we report it if it does.
|
||||
std::cerr<<"error: non-empty prefix has empty timestep sequence"<<std::endl;
|
||||
continue;
|
||||
}
|
||||
|
||||
// blank
|
||||
if (c == blank_id_) {
|
||||
// compute probability of current path
|
||||
float log_p = log_prob_c + prefix->score;
|
||||
|
||||
// combine current path with previous ones with the same prefix
|
||||
// the blank label comes last, so we can compare log_prob_nb_cur with log_p
|
||||
if (prefix->log_prob_nb_cur < log_p) {
|
||||
prefix->timesteps_cur = prefix->timesteps;
|
||||
}
|
||||
prefix->log_prob_b_cur =
|
||||
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
|
||||
log_sum_exp(prefix->log_prob_b_cur, log_p);
|
||||
continue;
|
||||
}
|
||||
|
||||
// repeated character
|
||||
if (c == prefix->character) {
|
||||
// compute probability of current path
|
||||
float log_p = log_prob_c + prefix->log_prob_nb_prev;
|
||||
|
||||
// combine current path with previous ones with the same prefix
|
||||
if (prefix->log_prob_nb_cur < log_p) {
|
||||
prefix->timesteps_cur = prefix->timesteps;
|
||||
}
|
||||
prefix->log_prob_nb_cur = log_sum_exp(
|
||||
prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev);
|
||||
prefix->log_prob_nb_cur, log_p);
|
||||
}
|
||||
|
||||
// get new prefix
|
||||
auto prefix_new = prefix->get_path_trie(c, abs_time_step_, log_prob_c);
|
||||
auto prefix_new = prefix->get_path_trie(c, log_prob_c);
|
||||
|
||||
if (prefix_new != nullptr) {
|
||||
// compute timesteps of current path
|
||||
std::vector<unsigned int> timesteps_new=prefix->timesteps;
|
||||
timesteps_new.push_back(abs_time_step_);
|
||||
|
||||
// compute probability of current path
|
||||
float log_p = -NUM_FLT_INF;
|
||||
|
||||
if (c == prefix->character &&
|
||||
@ -144,6 +172,10 @@ DecoderState::next(const double *probs,
|
||||
}
|
||||
}
|
||||
|
||||
// combine current path with previous ones with the same prefix
|
||||
if (prefix_new->log_prob_nb_cur < log_p) {
|
||||
prefix_new->timesteps_cur = timesteps_new;
|
||||
}
|
||||
prefix_new->log_prob_nb_cur =
|
||||
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
|
||||
}
|
||||
@ -205,11 +237,13 @@ DecoderState::decode(size_t num_results) const
|
||||
std::vector<Output> outputs;
|
||||
outputs.reserve(num_returned);
|
||||
|
||||
for (size_t i = 0; i < num_returned; ++i) {
|
||||
for (PathTrie* prefix : prefixes_copy) {
|
||||
Output output;
|
||||
prefixes_copy[i]->get_path_vec(output.tokens, output.timesteps);
|
||||
output.confidence = scores[prefixes_copy[i]];
|
||||
output.tokens = prefix->get_path_vec();
|
||||
output.timesteps = prefix->timesteps;
|
||||
output.confidence = scores[prefix];
|
||||
outputs.push_back(output);
|
||||
if(outputs.size()>=num_returned) break;
|
||||
}
|
||||
|
||||
return outputs;
|
||||
|
@ -18,7 +18,6 @@ PathTrie::PathTrie() {
|
||||
|
||||
ROOT_ = -1;
|
||||
character = ROOT_;
|
||||
timestep = 0;
|
||||
exists_ = true;
|
||||
parent = nullptr;
|
||||
|
||||
@ -35,7 +34,7 @@ PathTrie::~PathTrie() {
|
||||
}
|
||||
}
|
||||
|
||||
PathTrie* PathTrie::get_path_trie(unsigned int new_char, unsigned int new_timestep, float cur_log_prob_c, bool reset) {
|
||||
PathTrie* PathTrie::get_path_trie(unsigned int new_char, float cur_log_prob_c, bool reset) {
|
||||
auto child = children_.begin();
|
||||
for (; child != children_.end(); ++child) {
|
||||
if (child->first == new_char) {
|
||||
@ -67,7 +66,6 @@ PathTrie* PathTrie::get_path_trie(unsigned int new_char, unsigned int new_timest
|
||||
} else {
|
||||
PathTrie* new_path = new PathTrie;
|
||||
new_path->character = new_char;
|
||||
new_path->timestep = new_timestep;
|
||||
new_path->parent = this;
|
||||
new_path->dictionary_ = dictionary_;
|
||||
new_path->has_dictionary_ = true;
|
||||
@ -93,7 +91,6 @@ PathTrie* PathTrie::get_path_trie(unsigned int new_char, unsigned int new_timest
|
||||
} else {
|
||||
PathTrie* new_path = new PathTrie;
|
||||
new_path->character = new_char;
|
||||
new_path->timestep = new_timestep;
|
||||
new_path->parent = this;
|
||||
new_path->log_prob_c = cur_log_prob_c;
|
||||
children_.push_back(std::make_pair(new_char, new_path));
|
||||
@ -102,20 +99,18 @@ PathTrie* PathTrie::get_path_trie(unsigned int new_char, unsigned int new_timest
|
||||
}
|
||||
}
|
||||
|
||||
void PathTrie::get_path_vec(std::vector<unsigned int>& output, std::vector<unsigned int>& timesteps) {
|
||||
// Recursive call: recurse back until stop condition, then append data in
|
||||
// correct order as we walk back down the stack in the lines below.
|
||||
if (parent != nullptr) {
|
||||
parent->get_path_vec(output, timesteps);
|
||||
std::vector<unsigned int> PathTrie::get_path_vec() {
|
||||
if (parent == nullptr) {
|
||||
return std::vector<unsigned int>{};
|
||||
}
|
||||
std::vector<unsigned int> output_tokens=parent->get_path_vec();
|
||||
if (character != ROOT_) {
|
||||
output.push_back(character);
|
||||
timesteps.push_back(timestep);
|
||||
output_tokens.push_back(character);
|
||||
}
|
||||
return output_tokens;
|
||||
}
|
||||
|
||||
PathTrie* PathTrie::get_prev_grapheme(std::vector<unsigned int>& output,
|
||||
std::vector<unsigned int>& timesteps,
|
||||
const Alphabet& alphabet)
|
||||
{
|
||||
PathTrie* stop = this;
|
||||
@ -125,10 +120,9 @@ PathTrie* PathTrie::get_prev_grapheme(std::vector<unsigned int>& output,
|
||||
// Recursive call: recurse back until stop condition, then append data in
|
||||
// correct order as we walk back down the stack in the lines below.
|
||||
if (!byte_is_codepoint_boundary(alphabet.DecodeSingle(character)[0])) {
|
||||
stop = parent->get_prev_grapheme(output, timesteps, alphabet);
|
||||
stop = parent->get_prev_grapheme(output, alphabet);
|
||||
}
|
||||
output.push_back(character);
|
||||
timesteps.push_back(timestep);
|
||||
return stop;
|
||||
}
|
||||
|
||||
@ -147,7 +141,6 @@ int PathTrie::distance_to_codepoint_boundary(unsigned char *first_byte,
|
||||
}
|
||||
|
||||
PathTrie* PathTrie::get_prev_word(std::vector<unsigned int>& output,
|
||||
std::vector<unsigned int>& timesteps,
|
||||
const Alphabet& alphabet)
|
||||
{
|
||||
PathTrie* stop = this;
|
||||
@ -157,10 +150,9 @@ PathTrie* PathTrie::get_prev_word(std::vector<unsigned int>& output,
|
||||
// Recursive call: recurse back until stop condition, then append data in
|
||||
// correct order as we walk back down the stack in the lines below.
|
||||
if (parent != nullptr) {
|
||||
stop = parent->get_prev_word(output, timesteps, alphabet);
|
||||
stop = parent->get_prev_word(output, alphabet);
|
||||
}
|
||||
output.push_back(character);
|
||||
timesteps.push_back(timestep);
|
||||
return stop;
|
||||
}
|
||||
|
||||
@ -173,6 +165,10 @@ void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
|
||||
log_prob_nb_cur = -NUM_FLT_INF;
|
||||
|
||||
score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev);
|
||||
|
||||
timesteps = std::move(timesteps_cur);
|
||||
timesteps_cur.clear();
|
||||
|
||||
output.push_back(this);
|
||||
}
|
||||
for (auto child : children_) {
|
||||
@ -229,8 +225,8 @@ void PathTrie::print(const Alphabet& a) {
|
||||
}
|
||||
}
|
||||
printf("\ntimesteps:\t ");
|
||||
for (PathTrie* el : chain) {
|
||||
printf("%d ", el->timestep);
|
||||
for (unsigned int timestep : timesteps) {
|
||||
printf("%d ", timestep);
|
||||
}
|
||||
printf("\n");
|
||||
printf("transcript:\t %s\n", tr.c_str());
|
||||
|
@ -21,14 +21,13 @@ public:
|
||||
~PathTrie();
|
||||
|
||||
// get new prefix after appending new char
|
||||
PathTrie* get_path_trie(unsigned int new_char, unsigned int new_timestep, float log_prob_c, bool reset = true);
|
||||
PathTrie* get_path_trie(unsigned int new_char, float log_prob_c, bool reset = true);
|
||||
|
||||
// get the prefix data in correct time order from root to current node
|
||||
void get_path_vec(std::vector<unsigned int>& output, std::vector<unsigned int>& timesteps);
|
||||
std::vector<unsigned int> get_path_vec();
|
||||
|
||||
// get the prefix data in correct time order from beginning of last grapheme to current node
|
||||
PathTrie* get_prev_grapheme(std::vector<unsigned int>& output,
|
||||
std::vector<unsigned int>& timesteps,
|
||||
const Alphabet& alphabet);
|
||||
|
||||
// get the distance from current node to the first codepoint boundary, and the byte value at the boundary
|
||||
@ -36,7 +35,6 @@ public:
|
||||
|
||||
// get the prefix data in correct time order from beginning of last word to current node
|
||||
PathTrie* get_prev_word(std::vector<unsigned int>& output,
|
||||
std::vector<unsigned int>& timesteps,
|
||||
const Alphabet& alphabet);
|
||||
|
||||
// update log probs
|
||||
@ -65,7 +63,10 @@ public:
|
||||
float score;
|
||||
float approx_ctc;
|
||||
unsigned int character;
|
||||
unsigned int timestep;
|
||||
std::vector<unsigned int> timesteps;
|
||||
// `timesteps_cur` is a temporary storage for each decoding step.
|
||||
// At the end of a decoding step, it is moved to `timesteps`.
|
||||
std::vector<unsigned int> timesteps_cur;
|
||||
PathTrie* parent;
|
||||
|
||||
private:
|
||||
|
@ -293,12 +293,12 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix)
|
||||
}
|
||||
|
||||
std::vector<unsigned int> prefix_vec;
|
||||
std::vector<unsigned int> prefix_steps;
|
||||
std::vector<unsigned int> prefix_steps = current_node->timesteps;
|
||||
|
||||
if (is_utf8_mode_) {
|
||||
new_node = current_node->get_prev_grapheme(prefix_vec, prefix_steps, alphabet_);
|
||||
new_node = current_node->get_prev_grapheme(prefix_vec, alphabet_);
|
||||
} else {
|
||||
new_node = current_node->get_prev_word(prefix_vec, prefix_steps, alphabet_);
|
||||
new_node = current_node->get_prev_word(prefix_vec, alphabet_);
|
||||
}
|
||||
current_node = new_node->parent;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user