Merge pull request #2165 from mozilla/keep-stream-absolute-timestep
Keep absolute per-stream time step in DecoderState (Fixes #2163)
This commit is contained in:
		
						commit
						6f8c902f25
					
				| @ -25,6 +25,7 @@ DecoderState* decoder_init(const Alphabet &alphabet, | |||||||
| 
 | 
 | ||||||
|   // assign special ids
 |   // assign special ids
 | ||||||
|   DecoderState *state = new DecoderState; |   DecoderState *state = new DecoderState; | ||||||
|  |   state->time_step = 0; | ||||||
|   state->space_id = alphabet.GetSpaceLabel(); |   state->space_id = alphabet.GetSpaceLabel(); | ||||||
|   state->blank_id = alphabet.GetSize(); |   state->blank_id = alphabet.GetSize(); | ||||||
| 
 | 
 | ||||||
| @ -57,8 +58,8 @@ void decoder_next(const double *probs, | |||||||
|                   Scorer *ext_scorer) { |                   Scorer *ext_scorer) { | ||||||
| 
 | 
 | ||||||
|   // prefix search over time 
 |   // prefix search over time 
 | ||||||
|   for (size_t time_step = 0; time_step < time_dim; ++time_step) { |   for (size_t rel_time_step = 0; rel_time_step < time_dim; ++rel_time_step, ++state->time_step) { | ||||||
|     auto *prob = &probs[time_step*class_dim]; |     auto *prob = &probs[rel_time_step*class_dim]; | ||||||
| 
 | 
 | ||||||
|     float min_cutoff = -NUM_FLT_INF; |     float min_cutoff = -NUM_FLT_INF; | ||||||
|     bool full_beam = false; |     bool full_beam = false; | ||||||
| @ -99,7 +100,7 @@ void decoder_next(const double *probs, | |||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         // get new prefix
 |         // get new prefix
 | ||||||
|         auto prefix_new = prefix->get_path_trie(c, time_step, log_prob_c); |         auto prefix_new = prefix->get_path_trie(c, state->time_step, log_prob_c); | ||||||
| 
 | 
 | ||||||
|         if (prefix_new != nullptr) { |         if (prefix_new != nullptr) { | ||||||
|           float log_p = -NUM_FLT_INF; |           float log_p = -NUM_FLT_INF; | ||||||
|  | |||||||
| @ -6,6 +6,7 @@ | |||||||
| /* Struct for the state of the decoder, containing the prefixes and initial root prefix plus state variables. */ | /* Struct for the state of the decoder, containing the prefixes and initial root prefix plus state variables. */ | ||||||
| 
 | 
 | ||||||
| struct DecoderState { | struct DecoderState { | ||||||
|  |   int time_step; | ||||||
|   int space_id; |   int space_id; | ||||||
|   int blank_id; |   int blank_id; | ||||||
|   std::vector<PathTrie*> prefixes; |   std::vector<PathTrie*> prefixes; | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user