The CTC decoder timesteps now corresponds to the timesteps of the most
probable CTC path, instead of the earliest timesteps of all possible paths.
This commit is contained in:
parent
a54b198d1e
commit
04a36fbf68
@ -96,24 +96,52 @@ DecoderState::next(const double *probs,
|
|||||||
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
|
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
if (prefix->score == -NUM_FLT_INF) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!prefix->is_empty() && prefix->timesteps.empty()) {
|
||||||
|
// This should never happen. But we report it if it does.
|
||||||
|
std::cerr<<"error: non-empty prefix has empty timestep sequence"<<std::endl;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
// blank
|
// blank
|
||||||
if (c == blank_id_) {
|
if (c == blank_id_) {
|
||||||
|
// compute probability of current path
|
||||||
|
float log_p = log_prob_c + prefix->score;
|
||||||
|
|
||||||
|
// combine current path with previous ones with the same prefix
|
||||||
|
// the blank label comes last, so we can compare log_prob_nb_cur with log_p
|
||||||
|
if (prefix->log_prob_nb_cur < log_p) {
|
||||||
|
prefix->timesteps_cur = prefix->timesteps;
|
||||||
|
}
|
||||||
prefix->log_prob_b_cur =
|
prefix->log_prob_b_cur =
|
||||||
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
|
log_sum_exp(prefix->log_prob_b_cur, log_p);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// repeated character
|
// repeated character
|
||||||
if (c == prefix->character) {
|
if (c == prefix->character) {
|
||||||
|
// compute probability of current path
|
||||||
|
float log_p = log_prob_c + prefix->log_prob_nb_prev;
|
||||||
|
|
||||||
|
// combine current path with previous ones with the same prefix
|
||||||
|
if (prefix->log_prob_nb_cur < log_p) {
|
||||||
|
prefix->timesteps_cur = prefix->timesteps;
|
||||||
|
}
|
||||||
prefix->log_prob_nb_cur = log_sum_exp(
|
prefix->log_prob_nb_cur = log_sum_exp(
|
||||||
prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev);
|
prefix->log_prob_nb_cur, log_p);
|
||||||
}
|
}
|
||||||
|
|
||||||
// get new prefix
|
// get new prefix
|
||||||
auto prefix_new = prefix->get_path_trie(c, abs_time_step_, log_prob_c);
|
auto prefix_new = prefix->get_path_trie(c, log_prob_c);
|
||||||
|
|
||||||
if (prefix_new != nullptr) {
|
if (prefix_new != nullptr) {
|
||||||
|
// compute timesteps of current path
|
||||||
|
std::vector<unsigned int> timesteps_new=prefix->timesteps;
|
||||||
|
timesteps_new.push_back(abs_time_step_);
|
||||||
|
|
||||||
|
// compute probability of current path
|
||||||
float log_p = -NUM_FLT_INF;
|
float log_p = -NUM_FLT_INF;
|
||||||
|
|
||||||
if (c == prefix->character &&
|
if (c == prefix->character &&
|
||||||
@ -144,6 +172,10 @@ DecoderState::next(const double *probs,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// combine current path with previous ones with the same prefix
|
||||||
|
if (prefix_new->log_prob_nb_cur < log_p) {
|
||||||
|
prefix_new->timesteps_cur = timesteps_new;
|
||||||
|
}
|
||||||
prefix_new->log_prob_nb_cur =
|
prefix_new->log_prob_nb_cur =
|
||||||
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
|
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
|
||||||
}
|
}
|
||||||
@ -205,11 +237,13 @@ DecoderState::decode(size_t num_results) const
|
|||||||
std::vector<Output> outputs;
|
std::vector<Output> outputs;
|
||||||
outputs.reserve(num_returned);
|
outputs.reserve(num_returned);
|
||||||
|
|
||||||
for (size_t i = 0; i < num_returned; ++i) {
|
for (PathTrie* prefix : prefixes_copy) {
|
||||||
Output output;
|
Output output;
|
||||||
prefixes_copy[i]->get_path_vec(output.tokens, output.timesteps);
|
output.tokens = prefix->get_path_vec();
|
||||||
output.confidence = scores[prefixes_copy[i]];
|
output.timesteps = prefix->timesteps;
|
||||||
|
output.confidence = scores[prefix];
|
||||||
outputs.push_back(output);
|
outputs.push_back(output);
|
||||||
|
if(outputs.size()>=num_returned) break;
|
||||||
}
|
}
|
||||||
|
|
||||||
return outputs;
|
return outputs;
|
||||||
|
@ -18,7 +18,6 @@ PathTrie::PathTrie() {
|
|||||||
|
|
||||||
ROOT_ = -1;
|
ROOT_ = -1;
|
||||||
character = ROOT_;
|
character = ROOT_;
|
||||||
timestep = 0;
|
|
||||||
exists_ = true;
|
exists_ = true;
|
||||||
parent = nullptr;
|
parent = nullptr;
|
||||||
|
|
||||||
@ -35,7 +34,7 @@ PathTrie::~PathTrie() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
PathTrie* PathTrie::get_path_trie(unsigned int new_char, unsigned int new_timestep, float cur_log_prob_c, bool reset) {
|
PathTrie* PathTrie::get_path_trie(unsigned int new_char, float cur_log_prob_c, bool reset) {
|
||||||
auto child = children_.begin();
|
auto child = children_.begin();
|
||||||
for (; child != children_.end(); ++child) {
|
for (; child != children_.end(); ++child) {
|
||||||
if (child->first == new_char) {
|
if (child->first == new_char) {
|
||||||
@ -67,7 +66,6 @@ PathTrie* PathTrie::get_path_trie(unsigned int new_char, unsigned int new_timest
|
|||||||
} else {
|
} else {
|
||||||
PathTrie* new_path = new PathTrie;
|
PathTrie* new_path = new PathTrie;
|
||||||
new_path->character = new_char;
|
new_path->character = new_char;
|
||||||
new_path->timestep = new_timestep;
|
|
||||||
new_path->parent = this;
|
new_path->parent = this;
|
||||||
new_path->dictionary_ = dictionary_;
|
new_path->dictionary_ = dictionary_;
|
||||||
new_path->has_dictionary_ = true;
|
new_path->has_dictionary_ = true;
|
||||||
@ -93,7 +91,6 @@ PathTrie* PathTrie::get_path_trie(unsigned int new_char, unsigned int new_timest
|
|||||||
} else {
|
} else {
|
||||||
PathTrie* new_path = new PathTrie;
|
PathTrie* new_path = new PathTrie;
|
||||||
new_path->character = new_char;
|
new_path->character = new_char;
|
||||||
new_path->timestep = new_timestep;
|
|
||||||
new_path->parent = this;
|
new_path->parent = this;
|
||||||
new_path->log_prob_c = cur_log_prob_c;
|
new_path->log_prob_c = cur_log_prob_c;
|
||||||
children_.push_back(std::make_pair(new_char, new_path));
|
children_.push_back(std::make_pair(new_char, new_path));
|
||||||
@ -102,20 +99,18 @@ PathTrie* PathTrie::get_path_trie(unsigned int new_char, unsigned int new_timest
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void PathTrie::get_path_vec(std::vector<unsigned int>& output, std::vector<unsigned int>& timesteps) {
|
std::vector<unsigned int> PathTrie::get_path_vec() {
|
||||||
// Recursive call: recurse back until stop condition, then append data in
|
if (parent == nullptr) {
|
||||||
// correct order as we walk back down the stack in the lines below.
|
return std::vector<unsigned int>{};
|
||||||
if (parent != nullptr) {
|
|
||||||
parent->get_path_vec(output, timesteps);
|
|
||||||
}
|
}
|
||||||
|
std::vector<unsigned int> output_tokens=parent->get_path_vec();
|
||||||
if (character != ROOT_) {
|
if (character != ROOT_) {
|
||||||
output.push_back(character);
|
output_tokens.push_back(character);
|
||||||
timesteps.push_back(timestep);
|
|
||||||
}
|
}
|
||||||
|
return output_tokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
PathTrie* PathTrie::get_prev_grapheme(std::vector<unsigned int>& output,
|
PathTrie* PathTrie::get_prev_grapheme(std::vector<unsigned int>& output,
|
||||||
std::vector<unsigned int>& timesteps,
|
|
||||||
const Alphabet& alphabet)
|
const Alphabet& alphabet)
|
||||||
{
|
{
|
||||||
PathTrie* stop = this;
|
PathTrie* stop = this;
|
||||||
@ -125,10 +120,9 @@ PathTrie* PathTrie::get_prev_grapheme(std::vector<unsigned int>& output,
|
|||||||
// Recursive call: recurse back until stop condition, then append data in
|
// Recursive call: recurse back until stop condition, then append data in
|
||||||
// correct order as we walk back down the stack in the lines below.
|
// correct order as we walk back down the stack in the lines below.
|
||||||
if (!byte_is_codepoint_boundary(alphabet.DecodeSingle(character)[0])) {
|
if (!byte_is_codepoint_boundary(alphabet.DecodeSingle(character)[0])) {
|
||||||
stop = parent->get_prev_grapheme(output, timesteps, alphabet);
|
stop = parent->get_prev_grapheme(output, alphabet);
|
||||||
}
|
}
|
||||||
output.push_back(character);
|
output.push_back(character);
|
||||||
timesteps.push_back(timestep);
|
|
||||||
return stop;
|
return stop;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -147,7 +141,6 @@ int PathTrie::distance_to_codepoint_boundary(unsigned char *first_byte,
|
|||||||
}
|
}
|
||||||
|
|
||||||
PathTrie* PathTrie::get_prev_word(std::vector<unsigned int>& output,
|
PathTrie* PathTrie::get_prev_word(std::vector<unsigned int>& output,
|
||||||
std::vector<unsigned int>& timesteps,
|
|
||||||
const Alphabet& alphabet)
|
const Alphabet& alphabet)
|
||||||
{
|
{
|
||||||
PathTrie* stop = this;
|
PathTrie* stop = this;
|
||||||
@ -157,10 +150,9 @@ PathTrie* PathTrie::get_prev_word(std::vector<unsigned int>& output,
|
|||||||
// Recursive call: recurse back until stop condition, then append data in
|
// Recursive call: recurse back until stop condition, then append data in
|
||||||
// correct order as we walk back down the stack in the lines below.
|
// correct order as we walk back down the stack in the lines below.
|
||||||
if (parent != nullptr) {
|
if (parent != nullptr) {
|
||||||
stop = parent->get_prev_word(output, timesteps, alphabet);
|
stop = parent->get_prev_word(output, alphabet);
|
||||||
}
|
}
|
||||||
output.push_back(character);
|
output.push_back(character);
|
||||||
timesteps.push_back(timestep);
|
|
||||||
return stop;
|
return stop;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -173,6 +165,10 @@ void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
|
|||||||
log_prob_nb_cur = -NUM_FLT_INF;
|
log_prob_nb_cur = -NUM_FLT_INF;
|
||||||
|
|
||||||
score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev);
|
score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev);
|
||||||
|
|
||||||
|
timesteps = std::move(timesteps_cur);
|
||||||
|
timesteps_cur.clear();
|
||||||
|
|
||||||
output.push_back(this);
|
output.push_back(this);
|
||||||
}
|
}
|
||||||
for (auto child : children_) {
|
for (auto child : children_) {
|
||||||
@ -229,8 +225,8 @@ void PathTrie::print(const Alphabet& a) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
printf("\ntimesteps:\t ");
|
printf("\ntimesteps:\t ");
|
||||||
for (PathTrie* el : chain) {
|
for (unsigned int timestep : timesteps) {
|
||||||
printf("%d ", el->timestep);
|
printf("%d ", timestep);
|
||||||
}
|
}
|
||||||
printf("\n");
|
printf("\n");
|
||||||
printf("transcript:\t %s\n", tr.c_str());
|
printf("transcript:\t %s\n", tr.c_str());
|
||||||
|
@ -21,14 +21,13 @@ public:
|
|||||||
~PathTrie();
|
~PathTrie();
|
||||||
|
|
||||||
// get new prefix after appending new char
|
// get new prefix after appending new char
|
||||||
PathTrie* get_path_trie(unsigned int new_char, unsigned int new_timestep, float log_prob_c, bool reset = true);
|
PathTrie* get_path_trie(unsigned int new_char, float log_prob_c, bool reset = true);
|
||||||
|
|
||||||
// get the prefix data in correct time order from root to current node
|
// get the prefix data in correct time order from root to current node
|
||||||
void get_path_vec(std::vector<unsigned int>& output, std::vector<unsigned int>& timesteps);
|
std::vector<unsigned int> get_path_vec();
|
||||||
|
|
||||||
// get the prefix data in correct time order from beginning of last grapheme to current node
|
// get the prefix data in correct time order from beginning of last grapheme to current node
|
||||||
PathTrie* get_prev_grapheme(std::vector<unsigned int>& output,
|
PathTrie* get_prev_grapheme(std::vector<unsigned int>& output,
|
||||||
std::vector<unsigned int>& timesteps,
|
|
||||||
const Alphabet& alphabet);
|
const Alphabet& alphabet);
|
||||||
|
|
||||||
// get the distance from current node to the first codepoint boundary, and the byte value at the boundary
|
// get the distance from current node to the first codepoint boundary, and the byte value at the boundary
|
||||||
@ -36,7 +35,6 @@ public:
|
|||||||
|
|
||||||
// get the prefix data in correct time order from beginning of last word to current node
|
// get the prefix data in correct time order from beginning of last word to current node
|
||||||
PathTrie* get_prev_word(std::vector<unsigned int>& output,
|
PathTrie* get_prev_word(std::vector<unsigned int>& output,
|
||||||
std::vector<unsigned int>& timesteps,
|
|
||||||
const Alphabet& alphabet);
|
const Alphabet& alphabet);
|
||||||
|
|
||||||
// update log probs
|
// update log probs
|
||||||
@ -65,7 +63,10 @@ public:
|
|||||||
float score;
|
float score;
|
||||||
float approx_ctc;
|
float approx_ctc;
|
||||||
unsigned int character;
|
unsigned int character;
|
||||||
unsigned int timestep;
|
std::vector<unsigned int> timesteps;
|
||||||
|
// `timesteps_cur` is a temporary storage for each decoding step.
|
||||||
|
// At the end of a decoding step, it is moved to `timesteps`.
|
||||||
|
std::vector<unsigned int> timesteps_cur;
|
||||||
PathTrie* parent;
|
PathTrie* parent;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -293,12 +293,12 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix)
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<unsigned int> prefix_vec;
|
std::vector<unsigned int> prefix_vec;
|
||||||
std::vector<unsigned int> prefix_steps;
|
std::vector<unsigned int> prefix_steps = current_node->timesteps;
|
||||||
|
|
||||||
if (is_utf8_mode_) {
|
if (is_utf8_mode_) {
|
||||||
new_node = current_node->get_prev_grapheme(prefix_vec, prefix_steps, alphabet_);
|
new_node = current_node->get_prev_grapheme(prefix_vec, alphabet_);
|
||||||
} else {
|
} else {
|
||||||
new_node = current_node->get_prev_word(prefix_vec, prefix_steps, alphabet_);
|
new_node = current_node->get_prev_word(prefix_vec, alphabet_);
|
||||||
}
|
}
|
||||||
current_node = new_node->parent;
|
current_node = new_node->parent;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user