Apply a sequence mask for the gradient in ctc_loss_dense.
PiperOrigin-RevId: 295767405 Change-Id: I80a53508288cdc505f876901fde5fa46a7645bca
This commit is contained in:
		
							parent
							
								
									884a14ac9a
								
							
						
					
					
						commit
						b04371bc95
					
				| @ -367,7 +367,8 @@ class CTCLossTestV2(test.TestCase): | ||||
|       batch_size = 8 | ||||
|       num_labels = 6 | ||||
|       label_length = 5 | ||||
|       num_frames = 12 | ||||
|       minimum_logits_length = 10 | ||||
|       num_frames = minimum_logits_length + batch_size | ||||
|       logits = random_ops.random_uniform([num_frames, batch_size, num_labels]) | ||||
|       labels = random_ops.random_uniform( | ||||
|           [batch_size, label_length], minval=1, maxval=num_labels, | ||||
| @ -379,7 +380,7 @@ class CTCLossTestV2(test.TestCase): | ||||
|           label_lengths, maxlen=label_length, dtype=label_lengths.dtype) | ||||
|       labels *= label_mask | ||||
| 
 | ||||
|       logit_lengths = [num_frames] * batch_size | ||||
|       logit_lengths = math_ops.range(batch_size) + minimum_logits_length | ||||
| 
 | ||||
|       ctc_loss = ctc_ops.ctc_loss_dense( | ||||
|           labels=labels, | ||||
| @ -410,8 +411,8 @@ class CTCLossTestV2(test.TestCase): | ||||
|           self.assertAllClose(*self.evaluate([ctc_loss, tf_nn_ctc_loss])) | ||||
|           self.assertAllClose( | ||||
|               *self.evaluate([ctc_loss_grads, tf_nn_ctc_grads]), | ||||
|               rtol=2e-06, | ||||
|               atol=2e-06) | ||||
|               rtol=4e-06, | ||||
|               atol=4e-06) | ||||
| 
 | ||||
|   @test_util.run_v1_only("b/120545219") | ||||
|   def testCtcLossDenseUniqueFastPathIsSameAsCtcLoss(self): | ||||
|  | ||||
| @ -658,6 +658,17 @@ def ctc_loss_and_grad(logits, labels, label_length, logit_length, unique=None): | ||||
|     olabel_log_probs = _state_to_olabel(labels, num_labels, fwd_bwd_log_probs) | ||||
| 
 | ||||
|   grad = math_ops.exp(ilabel_log_probs) - math_ops.exp(olabel_log_probs) | ||||
| 
 | ||||
|   # Applies the sequence mask for the gradient. It is enough to appply the mask | ||||
|   # only for ilabel_log_probs because olabel_log_probs already consider the | ||||
|   # mask. However, it is just safe and clean to apply it for the gradient. | ||||
|   max_logit_length = _get_dim(logits, 0) | ||||
|   logit_mask = array_ops.sequence_mask(logit_length, max_logit_length, | ||||
|                                        dtypes.float32) | ||||
|   logit_mask = array_ops.transpose(logit_mask, perm=[1, 0]) | ||||
|   logit_mask = array_ops.expand_dims(logit_mask, axis=2) | ||||
|   grad *= logit_mask | ||||
| 
 | ||||
|   loss = -log_likelihood | ||||
|   return loss, grad | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user