PR #3279 - avoid unnecessary copies of timesteps vectors
This commit is contained in:
parent
363121235e
commit
1f89bef5f0
@ -106,7 +106,8 @@ DecoderState::next(const double *probs,
|
||||
// combine current path with previous ones with the same prefix
|
||||
// the blank label comes last, so we can compare log_prob_nb_cur with log_p
|
||||
if (prefix->log_prob_nb_cur < log_p) {
|
||||
prefix->timesteps_cur = prefix->timesteps;
|
||||
// keep current timesteps
|
||||
prefix->previous_timesteps = nullptr;
|
||||
}
|
||||
prefix->log_prob_b_cur =
|
||||
log_sum_exp(prefix->log_prob_b_cur, log_p);
|
||||
@ -120,7 +121,8 @@ DecoderState::next(const double *probs,
|
||||
|
||||
// combine current path with previous ones with the same prefix
|
||||
if (prefix->log_prob_nb_cur < log_p) {
|
||||
prefix->timesteps_cur = prefix->timesteps;
|
||||
// keep current timesteps
|
||||
prefix->previous_timesteps = nullptr;
|
||||
}
|
||||
prefix->log_prob_nb_cur = log_sum_exp(
|
||||
prefix->log_prob_nb_cur, log_p);
|
||||
@ -130,10 +132,6 @@ DecoderState::next(const double *probs,
|
||||
auto prefix_new = prefix->get_path_trie(c, log_prob_c);
|
||||
|
||||
if (prefix_new != nullptr) {
|
||||
// compute timesteps of current path
|
||||
std::vector<unsigned int> timesteps_new=prefix->timesteps;
|
||||
timesteps_new.push_back(abs_time_step_);
|
||||
|
||||
// compute probability of current path
|
||||
float log_p = -NUM_FLT_INF;
|
||||
|
||||
@ -167,7 +165,10 @@ DecoderState::next(const double *probs,
|
||||
|
||||
// combine current path with previous ones with the same prefix
|
||||
if (prefix_new->log_prob_nb_cur < log_p) {
|
||||
prefix_new->timesteps_cur = timesteps_new;
|
||||
// record data needed to update timesteps
|
||||
// the actual update will be done if nothing better is found
|
||||
prefix_new->previous_timesteps = &prefix->timesteps;
|
||||
prefix_new->new_timestep = abs_time_step_;
|
||||
}
|
||||
prefix_new->log_prob_nb_cur =
|
||||
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
|
||||
|
@ -157,6 +157,11 @@ PathTrie* PathTrie::get_prev_word(std::vector<unsigned int>& output,
|
||||
}
|
||||
|
||||
void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
|
||||
// previous_timesteps might point to ancestors' timesteps
|
||||
// therefore, children must be uptaded first
|
||||
for (auto child : children_) {
|
||||
child.second->iterate_to_vec(output);
|
||||
}
|
||||
if (exists_) {
|
||||
log_prob_b_prev = log_prob_b_cur;
|
||||
log_prob_nb_prev = log_prob_nb_cur;
|
||||
@ -166,14 +171,14 @@ void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
|
||||
|
||||
score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev);
|
||||
|
||||
timesteps = std::move(timesteps_cur);
|
||||
timesteps_cur.clear();
|
||||
if (previous_timesteps != nullptr) {
|
||||
timesteps = *previous_timesteps;
|
||||
timesteps.push_back(new_timestep);
|
||||
}
|
||||
previous_timesteps=nullptr;
|
||||
|
||||
output.push_back(this);
|
||||
}
|
||||
for (auto child : children_) {
|
||||
child.second->iterate_to_vec(output);
|
||||
}
|
||||
}
|
||||
|
||||
void PathTrie::remove() {
|
||||
|
@ -64,9 +64,11 @@ public:
|
||||
float approx_ctc;
|
||||
unsigned int character;
|
||||
std::vector<unsigned int> timesteps;
|
||||
// `timesteps_cur` is a temporary storage for each decoding step.
|
||||
// At the end of a decoding step, it is moved to `timesteps`.
|
||||
std::vector<unsigned int> timesteps_cur;
|
||||
|
||||
// timestep temporary storage for each decoding step.
|
||||
std::vector<unsigned int>* previous_timesteps=nullptr;
|
||||
unsigned int new_timestep;
|
||||
|
||||
PathTrie* parent;
|
||||
|
||||
private:
|
||||
|
Loading…
Reference in New Issue
Block a user