PR #3279 - avoid unnecessary copies of timesteps vectors
This commit is contained in:
parent
363121235e
commit
1f89bef5f0
@ -106,7 +106,8 @@ DecoderState::next(const double *probs,
|
|||||||
// combine current path with previous ones with the same prefix
|
// combine current path with previous ones with the same prefix
|
||||||
// the blank label comes last, so we can compare log_prob_nb_cur with log_p
|
// the blank label comes last, so we can compare log_prob_nb_cur with log_p
|
||||||
if (prefix->log_prob_nb_cur < log_p) {
|
if (prefix->log_prob_nb_cur < log_p) {
|
||||||
prefix->timesteps_cur = prefix->timesteps;
|
// keep current timesteps
|
||||||
|
prefix->previous_timesteps = nullptr;
|
||||||
}
|
}
|
||||||
prefix->log_prob_b_cur =
|
prefix->log_prob_b_cur =
|
||||||
log_sum_exp(prefix->log_prob_b_cur, log_p);
|
log_sum_exp(prefix->log_prob_b_cur, log_p);
|
||||||
@ -120,7 +121,8 @@ DecoderState::next(const double *probs,
|
|||||||
|
|
||||||
// combine current path with previous ones with the same prefix
|
// combine current path with previous ones with the same prefix
|
||||||
if (prefix->log_prob_nb_cur < log_p) {
|
if (prefix->log_prob_nb_cur < log_p) {
|
||||||
prefix->timesteps_cur = prefix->timesteps;
|
// keep current timesteps
|
||||||
|
prefix->previous_timesteps = nullptr;
|
||||||
}
|
}
|
||||||
prefix->log_prob_nb_cur = log_sum_exp(
|
prefix->log_prob_nb_cur = log_sum_exp(
|
||||||
prefix->log_prob_nb_cur, log_p);
|
prefix->log_prob_nb_cur, log_p);
|
||||||
@ -130,10 +132,6 @@ DecoderState::next(const double *probs,
|
|||||||
auto prefix_new = prefix->get_path_trie(c, log_prob_c);
|
auto prefix_new = prefix->get_path_trie(c, log_prob_c);
|
||||||
|
|
||||||
if (prefix_new != nullptr) {
|
if (prefix_new != nullptr) {
|
||||||
// compute timesteps of current path
|
|
||||||
std::vector<unsigned int> timesteps_new=prefix->timesteps;
|
|
||||||
timesteps_new.push_back(abs_time_step_);
|
|
||||||
|
|
||||||
// compute probability of current path
|
// compute probability of current path
|
||||||
float log_p = -NUM_FLT_INF;
|
float log_p = -NUM_FLT_INF;
|
||||||
|
|
||||||
@ -167,7 +165,10 @@ DecoderState::next(const double *probs,
|
|||||||
|
|
||||||
// combine current path with previous ones with the same prefix
|
// combine current path with previous ones with the same prefix
|
||||||
if (prefix_new->log_prob_nb_cur < log_p) {
|
if (prefix_new->log_prob_nb_cur < log_p) {
|
||||||
prefix_new->timesteps_cur = timesteps_new;
|
// record data needed to update timesteps
|
||||||
|
// the actual update will be done if nothing better is found
|
||||||
|
prefix_new->previous_timesteps = &prefix->timesteps;
|
||||||
|
prefix_new->new_timestep = abs_time_step_;
|
||||||
}
|
}
|
||||||
prefix_new->log_prob_nb_cur =
|
prefix_new->log_prob_nb_cur =
|
||||||
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
|
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
|
||||||
|
@ -157,6 +157,11 @@ PathTrie* PathTrie::get_prev_word(std::vector<unsigned int>& output,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
|
void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
|
||||||
|
// previous_timesteps might point to ancestors' timesteps
|
||||||
|
// therefore, children must be uptaded first
|
||||||
|
for (auto child : children_) {
|
||||||
|
child.second->iterate_to_vec(output);
|
||||||
|
}
|
||||||
if (exists_) {
|
if (exists_) {
|
||||||
log_prob_b_prev = log_prob_b_cur;
|
log_prob_b_prev = log_prob_b_cur;
|
||||||
log_prob_nb_prev = log_prob_nb_cur;
|
log_prob_nb_prev = log_prob_nb_cur;
|
||||||
@ -166,14 +171,14 @@ void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
|
|||||||
|
|
||||||
score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev);
|
score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev);
|
||||||
|
|
||||||
timesteps = std::move(timesteps_cur);
|
if (previous_timesteps != nullptr) {
|
||||||
timesteps_cur.clear();
|
timesteps = *previous_timesteps;
|
||||||
|
timesteps.push_back(new_timestep);
|
||||||
|
}
|
||||||
|
previous_timesteps=nullptr;
|
||||||
|
|
||||||
output.push_back(this);
|
output.push_back(this);
|
||||||
}
|
}
|
||||||
for (auto child : children_) {
|
|
||||||
child.second->iterate_to_vec(output);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void PathTrie::remove() {
|
void PathTrie::remove() {
|
||||||
|
@ -64,9 +64,11 @@ public:
|
|||||||
float approx_ctc;
|
float approx_ctc;
|
||||||
unsigned int character;
|
unsigned int character;
|
||||||
std::vector<unsigned int> timesteps;
|
std::vector<unsigned int> timesteps;
|
||||||
// `timesteps_cur` is a temporary storage for each decoding step.
|
|
||||||
// At the end of a decoding step, it is moved to `timesteps`.
|
// timestep temporary storage for each decoding step.
|
||||||
std::vector<unsigned int> timesteps_cur;
|
std::vector<unsigned int>* previous_timesteps=nullptr;
|
||||||
|
unsigned int new_timestep;
|
||||||
|
|
||||||
PathTrie* parent;
|
PathTrie* parent;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
Loading…
Reference in New Issue
Block a user