diff --git a/DeepSpeech.py b/DeepSpeech.py index 19e16d3b..84a784d5 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -214,7 +214,7 @@ def calculate_mean_edit_distance_and_loss(iterator, dropout, batch_size, reuse): the decoded result and the batch's original Y. ''' # Obtain the next batch of data - (batch_x, batch_seq_len), batch_y = iterator.get_next() + _, (batch_x, batch_seq_len), batch_y = iterator.get_next() if FLAGS.use_cudnn_rnn: rnn_impl = rnn_impl_cudnn_rnn diff --git a/evaluate.py b/evaluate.py index a8de7dc7..95289e61 100755 --- a/evaluate.py +++ b/evaluate.py @@ -52,7 +52,7 @@ def evaluate(test_csvs, create_model, try_loading): output_classes=tfv1.data.get_output_classes(test_sets[0])) test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets] - (batch_x, batch_x_len), batch_y = iterator.get_next() + batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next() # One rate per layer no_dropout = [None] * 6 @@ -89,6 +89,7 @@ def evaluate(test_csvs, create_model, try_loading): exit(1) def run_test(init_op, dataset): + wav_filenames = [] losses = [] predictions = [] ground_truths = [] @@ -105,8 +106,8 @@ def evaluate(test_csvs, create_model, try_loading): # First pass, compute losses and transposed logits for decoding while True: try: - batch_logits, batch_loss, batch_lengths, batch_transcripts = \ - session.run([transposed, loss, batch_x_len, batch_y]) + batch_wav_filenames, batch_logits, batch_loss, batch_lengths, batch_transcripts = \ + session.run([batch_wav_filename, transposed, loss, batch_x_len, batch_y]) except tf.errors.OutOfRangeError: break @@ -114,6 +115,7 @@ def evaluate(test_csvs, create_model, try_loading): num_processes=num_processes, scorer=scorer) predictions.extend(d[0][1] for d in decoded) ground_truths.extend(sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet)) + wav_filenames.extend(wav_filename.decode('UTF-8') for wav_filename in batch_wav_filenames) losses.extend(batch_loss) step_count += 1 @@ -121,7 +123,7 @@ def evaluate(test_csvs, create_model, try_loading): bar.finish() - wer, cer, samples = calculate_report(ground_truths, predictions, losses) + wer, cer, samples = calculate_report(wav_filenames, ground_truths, predictions, losses) mean_loss = np.mean(losses) # Take only the first report_count items @@ -133,6 +135,7 @@ def evaluate(test_csvs, create_model, try_loading): for sample in report_samples: print('WER: %f, CER: %f, loss: %f' % (sample.wer, sample.cer, sample.loss)) + print(' - wav: file://%s' % sample.wav_filename) print(' - src: "%s"' % sample.src) print(' - res: "%s"' % sample.res) print('-' * 80) diff --git a/util/evaluate_tools.py b/util/evaluate_tools.py index 1ad91f46..ed69578a 100644 --- a/util/evaluate_tools.py +++ b/util/evaluate_tools.py @@ -34,12 +34,13 @@ def wer_cer_batch(samples): def process_decode_result(item): - ground_truth, prediction, loss = item + wav_filename, ground_truth, prediction, loss = item char_distance = levenshtein(ground_truth, prediction) char_length = len(ground_truth) word_distance = levenshtein(ground_truth.split(), prediction.split()) word_length = len(ground_truth.split()) return AttrDict({ + 'wav_filename': wav_filename, 'src': ground_truth, 'res': prediction, 'loss': loss, @@ -52,13 +53,13 @@ def process_decode_result(item): }) -def calculate_report(labels, decodings, losses): +def calculate_report(wav_filenames, labels, decodings, losses): r''' This routine will calculate a WER report. It'll compute the `mean` WER and create ``Sample`` objects of the ``report_count`` top lowest loss items from the provided WER results tuple (only items with WER!=0 and ordered by their WER). ''' - samples = pmap(process_decode_result, zip(labels, decodings, losses)) + samples = pmap(process_decode_result, zip(wav_filenames, labels, decodings, losses)) # Getting the WER and CER from the accumulated edit distances and lengths samples_wer, samples_cer = wer_cer_batch(samples) diff --git a/util/feeding.py b/util/feeding.py index a0d045de..67c6cd93 100644 --- a/util/feeding.py +++ b/util/feeding.py @@ -52,7 +52,7 @@ def audiofile_to_features(wav_filename): def entry_to_features(wav_filename, transcript): # https://bugs.python.org/issue32117 features, features_len = audiofile_to_features(wav_filename) - return features, features_len, tf.SparseTensor(*transcript) + return wav_filename, features, features_len, tf.SparseTensor(*transcript) def to_sparse_tuple(sequence): @@ -82,12 +82,13 @@ def create_dataset(csvs, batch_size, cache_path=''): shape = sparse.dense_shape return tf.sparse.reshape(sparse, [shape[0], shape[2]]) - def batch_fn(features, features_len, transcripts): + def batch_fn(wav_filenames, features, features_len, transcripts): features = tf.data.Dataset.zip((features, features_len)) features = features.padded_batch(batch_size, padded_shapes=([None, Config.n_input], [])) transcripts = transcripts.batch(batch_size).map(sparse_reshape) - return tf.data.Dataset.zip((features, transcripts)) + wav_filenames = wav_filenames.batch(batch_size) + return tf.data.Dataset.zip((wav_filenames, features, transcripts)) num_gpus = len(Config.available_devices)