PR #3279 - avoid unnecessary copies of timesteps vectors

This commit is contained in:
godeffroy 2020-08-31 19:01:47 +02:00
parent 363121235e
commit 1f89bef5f0
3 changed files with 23 additions and 15 deletions

View File

@ -106,7 +106,8 @@ DecoderState::next(const double *probs,
// combine current path with previous ones with the same prefix
// the blank label comes last, so we can compare log_prob_nb_cur with log_p
if (prefix->log_prob_nb_cur < log_p) {
prefix->timesteps_cur = prefix->timesteps;
// keep current timesteps
prefix->previous_timesteps = nullptr;
}
prefix->log_prob_b_cur =
log_sum_exp(prefix->log_prob_b_cur, log_p);
@ -120,7 +121,8 @@ DecoderState::next(const double *probs,
// combine current path with previous ones with the same prefix
if (prefix->log_prob_nb_cur < log_p) {
prefix->timesteps_cur = prefix->timesteps;
// keep current timesteps
prefix->previous_timesteps = nullptr;
}
prefix->log_prob_nb_cur = log_sum_exp(
prefix->log_prob_nb_cur, log_p);
@ -130,10 +132,6 @@ DecoderState::next(const double *probs,
auto prefix_new = prefix->get_path_trie(c, log_prob_c);
if (prefix_new != nullptr) {
// compute timesteps of current path
std::vector<unsigned int> timesteps_new=prefix->timesteps;
timesteps_new.push_back(abs_time_step_);
// compute probability of current path
float log_p = -NUM_FLT_INF;
@ -167,7 +165,10 @@ DecoderState::next(const double *probs,
// combine current path with previous ones with the same prefix
if (prefix_new->log_prob_nb_cur < log_p) {
prefix_new->timesteps_cur = timesteps_new;
// record data needed to update timesteps
// the actual update will be done if nothing better is found
prefix_new->previous_timesteps = &prefix->timesteps;
prefix_new->new_timestep = abs_time_step_;
}
prefix_new->log_prob_nb_cur =
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);

View File

@ -157,6 +157,11 @@ PathTrie* PathTrie::get_prev_word(std::vector<unsigned int>& output,
}
void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
// previous_timesteps might point to ancestors' timesteps
// therefore, children must be uptaded first
for (auto child : children_) {
child.second->iterate_to_vec(output);
}
if (exists_) {
log_prob_b_prev = log_prob_b_cur;
log_prob_nb_prev = log_prob_nb_cur;
@ -166,14 +171,14 @@ void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev);
timesteps = std::move(timesteps_cur);
timesteps_cur.clear();
if (previous_timesteps != nullptr) {
timesteps = *previous_timesteps;
timesteps.push_back(new_timestep);
}
previous_timesteps=nullptr;
output.push_back(this);
}
for (auto child : children_) {
child.second->iterate_to_vec(output);
}
}
void PathTrie::remove() {

View File

@ -64,9 +64,11 @@ public:
float approx_ctc;
unsigned int character;
std::vector<unsigned int> timesteps;
// `timesteps_cur` is a temporary storage for each decoding step.
// At the end of a decoding step, it is moved to `timesteps`.
std::vector<unsigned int> timesteps_cur;
// timestep temporary storage for each decoding step.
std::vector<unsigned int>* previous_timesteps=nullptr;
unsigned int new_timestep;
PathTrie* parent;
private: