Fix #2180 - Added wav_filename to WER report
This commit is contained in:
		
							parent
							
								
									84e1fa98b9
								
							
						
					
					
						commit
						007e512c00
					
				@ -214,7 +214,7 @@ def calculate_mean_edit_distance_and_loss(iterator, dropout, batch_size, reuse):
 | 
				
			|||||||
    the decoded result and the batch's original Y.
 | 
					    the decoded result and the batch's original Y.
 | 
				
			||||||
    '''
 | 
					    '''
 | 
				
			||||||
    # Obtain the next batch of data
 | 
					    # Obtain the next batch of data
 | 
				
			||||||
    (batch_x, batch_seq_len), batch_y = iterator.get_next()
 | 
					    _, (batch_x, batch_seq_len), batch_y = iterator.get_next()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if FLAGS.use_cudnn_rnn:
 | 
					    if FLAGS.use_cudnn_rnn:
 | 
				
			||||||
        rnn_impl = rnn_impl_cudnn_rnn
 | 
					        rnn_impl = rnn_impl_cudnn_rnn
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										11
									
								
								evaluate.py
									
									
									
									
									
								
							
							
						
						
									
										11
									
								
								evaluate.py
									
									
									
									
									
								
							@ -52,7 +52,7 @@ def evaluate(test_csvs, create_model, try_loading):
 | 
				
			|||||||
                                                 output_classes=tfv1.data.get_output_classes(test_sets[0]))
 | 
					                                                 output_classes=tfv1.data.get_output_classes(test_sets[0]))
 | 
				
			||||||
    test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets]
 | 
					    test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    (batch_x, batch_x_len), batch_y = iterator.get_next()
 | 
					    batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # One rate per layer
 | 
					    # One rate per layer
 | 
				
			||||||
    no_dropout = [None] * 6
 | 
					    no_dropout = [None] * 6
 | 
				
			||||||
@ -89,6 +89,7 @@ def evaluate(test_csvs, create_model, try_loading):
 | 
				
			|||||||
            exit(1)
 | 
					            exit(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        def run_test(init_op, dataset):
 | 
					        def run_test(init_op, dataset):
 | 
				
			||||||
 | 
					            wav_filenames = []
 | 
				
			||||||
            losses = []
 | 
					            losses = []
 | 
				
			||||||
            predictions = []
 | 
					            predictions = []
 | 
				
			||||||
            ground_truths = []
 | 
					            ground_truths = []
 | 
				
			||||||
@ -105,8 +106,8 @@ def evaluate(test_csvs, create_model, try_loading):
 | 
				
			|||||||
            # First pass, compute losses and transposed logits for decoding
 | 
					            # First pass, compute losses and transposed logits for decoding
 | 
				
			||||||
            while True:
 | 
					            while True:
 | 
				
			||||||
                try:
 | 
					                try:
 | 
				
			||||||
                    batch_logits, batch_loss, batch_lengths, batch_transcripts = \
 | 
					                    batch_wav_filenames, batch_logits, batch_loss, batch_lengths, batch_transcripts = \
 | 
				
			||||||
                        session.run([transposed, loss, batch_x_len, batch_y])
 | 
					                        session.run([batch_wav_filename, transposed, loss, batch_x_len, batch_y])
 | 
				
			||||||
                except tf.errors.OutOfRangeError:
 | 
					                except tf.errors.OutOfRangeError:
 | 
				
			||||||
                    break
 | 
					                    break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -114,6 +115,7 @@ def evaluate(test_csvs, create_model, try_loading):
 | 
				
			|||||||
                                                        num_processes=num_processes, scorer=scorer)
 | 
					                                                        num_processes=num_processes, scorer=scorer)
 | 
				
			||||||
                predictions.extend(d[0][1] for d in decoded)
 | 
					                predictions.extend(d[0][1] for d in decoded)
 | 
				
			||||||
                ground_truths.extend(sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet))
 | 
					                ground_truths.extend(sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet))
 | 
				
			||||||
 | 
					                wav_filenames.extend(wav_filename.decode('UTF-8') for wav_filename in batch_wav_filenames)
 | 
				
			||||||
                losses.extend(batch_loss)
 | 
					                losses.extend(batch_loss)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                step_count += 1
 | 
					                step_count += 1
 | 
				
			||||||
@ -121,7 +123,7 @@ def evaluate(test_csvs, create_model, try_loading):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            bar.finish()
 | 
					            bar.finish()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            wer, cer, samples = calculate_report(ground_truths, predictions, losses)
 | 
					            wer, cer, samples = calculate_report(wav_filenames, ground_truths, predictions, losses)
 | 
				
			||||||
            mean_loss = np.mean(losses)
 | 
					            mean_loss = np.mean(losses)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # Take only the first report_count items
 | 
					            # Take only the first report_count items
 | 
				
			||||||
@ -133,6 +135,7 @@ def evaluate(test_csvs, create_model, try_loading):
 | 
				
			|||||||
            for sample in report_samples:
 | 
					            for sample in report_samples:
 | 
				
			||||||
                print('WER: %f, CER: %f, loss: %f' %
 | 
					                print('WER: %f, CER: %f, loss: %f' %
 | 
				
			||||||
                      (sample.wer, sample.cer, sample.loss))
 | 
					                      (sample.wer, sample.cer, sample.loss))
 | 
				
			||||||
 | 
					                print(' - wav: file://%s' % sample.wav_filename)
 | 
				
			||||||
                print(' - src: "%s"' % sample.src)
 | 
					                print(' - src: "%s"' % sample.src)
 | 
				
			||||||
                print(' - res: "%s"' % sample.res)
 | 
					                print(' - res: "%s"' % sample.res)
 | 
				
			||||||
                print('-' * 80)
 | 
					                print('-' * 80)
 | 
				
			||||||
 | 
				
			|||||||
@ -34,12 +34,13 @@ def wer_cer_batch(samples):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def process_decode_result(item):
 | 
					def process_decode_result(item):
 | 
				
			||||||
    ground_truth, prediction, loss = item
 | 
					    wav_filename, ground_truth, prediction, loss = item
 | 
				
			||||||
    char_distance = levenshtein(ground_truth, prediction)
 | 
					    char_distance = levenshtein(ground_truth, prediction)
 | 
				
			||||||
    char_length = len(ground_truth)
 | 
					    char_length = len(ground_truth)
 | 
				
			||||||
    word_distance = levenshtein(ground_truth.split(), prediction.split())
 | 
					    word_distance = levenshtein(ground_truth.split(), prediction.split())
 | 
				
			||||||
    word_length = len(ground_truth.split())
 | 
					    word_length = len(ground_truth.split())
 | 
				
			||||||
    return AttrDict({
 | 
					    return AttrDict({
 | 
				
			||||||
 | 
					        'wav_filename': wav_filename,
 | 
				
			||||||
        'src': ground_truth,
 | 
					        'src': ground_truth,
 | 
				
			||||||
        'res': prediction,
 | 
					        'res': prediction,
 | 
				
			||||||
        'loss': loss,
 | 
					        'loss': loss,
 | 
				
			||||||
@ -52,13 +53,13 @@ def process_decode_result(item):
 | 
				
			|||||||
    })
 | 
					    })
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def calculate_report(labels, decodings, losses):
 | 
					def calculate_report(wav_filenames, labels, decodings, losses):
 | 
				
			||||||
    r'''
 | 
					    r'''
 | 
				
			||||||
    This routine will calculate a WER report.
 | 
					    This routine will calculate a WER report.
 | 
				
			||||||
    It'll compute the `mean` WER and create ``Sample`` objects of the ``report_count`` top lowest
 | 
					    It'll compute the `mean` WER and create ``Sample`` objects of the ``report_count`` top lowest
 | 
				
			||||||
    loss items from the provided WER results tuple (only items with WER!=0 and ordered by their WER).
 | 
					    loss items from the provided WER results tuple (only items with WER!=0 and ordered by their WER).
 | 
				
			||||||
    '''
 | 
					    '''
 | 
				
			||||||
    samples = pmap(process_decode_result, zip(labels, decodings, losses))
 | 
					    samples = pmap(process_decode_result, zip(wav_filenames, labels, decodings, losses))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Getting the WER and CER from the accumulated edit distances and lengths
 | 
					    # Getting the WER and CER from the accumulated edit distances and lengths
 | 
				
			||||||
    samples_wer, samples_cer = wer_cer_batch(samples)
 | 
					    samples_wer, samples_cer = wer_cer_batch(samples)
 | 
				
			||||||
 | 
				
			|||||||
@ -52,7 +52,7 @@ def audiofile_to_features(wav_filename):
 | 
				
			|||||||
def entry_to_features(wav_filename, transcript):
 | 
					def entry_to_features(wav_filename, transcript):
 | 
				
			||||||
    # https://bugs.python.org/issue32117
 | 
					    # https://bugs.python.org/issue32117
 | 
				
			||||||
    features, features_len = audiofile_to_features(wav_filename)
 | 
					    features, features_len = audiofile_to_features(wav_filename)
 | 
				
			||||||
    return features, features_len, tf.SparseTensor(*transcript)
 | 
					    return wav_filename, features, features_len, tf.SparseTensor(*transcript)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def to_sparse_tuple(sequence):
 | 
					def to_sparse_tuple(sequence):
 | 
				
			||||||
@ -82,12 +82,13 @@ def create_dataset(csvs, batch_size, cache_path=''):
 | 
				
			|||||||
        shape = sparse.dense_shape
 | 
					        shape = sparse.dense_shape
 | 
				
			||||||
        return tf.sparse.reshape(sparse, [shape[0], shape[2]])
 | 
					        return tf.sparse.reshape(sparse, [shape[0], shape[2]])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def batch_fn(features, features_len, transcripts):
 | 
					    def batch_fn(wav_filenames, features, features_len, transcripts):
 | 
				
			||||||
        features = tf.data.Dataset.zip((features, features_len))
 | 
					        features = tf.data.Dataset.zip((features, features_len))
 | 
				
			||||||
        features = features.padded_batch(batch_size,
 | 
					        features = features.padded_batch(batch_size,
 | 
				
			||||||
                                         padded_shapes=([None, Config.n_input], []))
 | 
					                                         padded_shapes=([None, Config.n_input], []))
 | 
				
			||||||
        transcripts = transcripts.batch(batch_size).map(sparse_reshape)
 | 
					        transcripts = transcripts.batch(batch_size).map(sparse_reshape)
 | 
				
			||||||
        return tf.data.Dataset.zip((features, transcripts))
 | 
					        wav_filenames = wav_filenames.batch(batch_size)
 | 
				
			||||||
 | 
					        return tf.data.Dataset.zip((wav_filenames, features, transcripts))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    num_gpus = len(Config.available_devices)
 | 
					    num_gpus = len(Config.available_devices)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user