Merge pull request #2268 from tilmankamp/reportfilename
Fix #2180 - Added wav_filename to WER report
This commit is contained in:
commit
daa6167829
@ -214,7 +214,7 @@ def calculate_mean_edit_distance_and_loss(iterator, dropout, batch_size, reuse):
|
||||
the decoded result and the batch's original Y.
|
||||
'''
|
||||
# Obtain the next batch of data
|
||||
(batch_x, batch_seq_len), batch_y = iterator.get_next()
|
||||
_, (batch_x, batch_seq_len), batch_y = iterator.get_next()
|
||||
|
||||
if FLAGS.use_cudnn_rnn:
|
||||
rnn_impl = rnn_impl_cudnn_rnn
|
||||
|
11
evaluate.py
11
evaluate.py
@ -52,7 +52,7 @@ def evaluate(test_csvs, create_model, try_loading):
|
||||
output_classes=tfv1.data.get_output_classes(test_sets[0]))
|
||||
test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets]
|
||||
|
||||
(batch_x, batch_x_len), batch_y = iterator.get_next()
|
||||
batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next()
|
||||
|
||||
# One rate per layer
|
||||
no_dropout = [None] * 6
|
||||
@ -89,6 +89,7 @@ def evaluate(test_csvs, create_model, try_loading):
|
||||
exit(1)
|
||||
|
||||
def run_test(init_op, dataset):
|
||||
wav_filenames = []
|
||||
losses = []
|
||||
predictions = []
|
||||
ground_truths = []
|
||||
@ -105,8 +106,8 @@ def evaluate(test_csvs, create_model, try_loading):
|
||||
# First pass, compute losses and transposed logits for decoding
|
||||
while True:
|
||||
try:
|
||||
batch_logits, batch_loss, batch_lengths, batch_transcripts = \
|
||||
session.run([transposed, loss, batch_x_len, batch_y])
|
||||
batch_wav_filenames, batch_logits, batch_loss, batch_lengths, batch_transcripts = \
|
||||
session.run([batch_wav_filename, transposed, loss, batch_x_len, batch_y])
|
||||
except tf.errors.OutOfRangeError:
|
||||
break
|
||||
|
||||
@ -114,6 +115,7 @@ def evaluate(test_csvs, create_model, try_loading):
|
||||
num_processes=num_processes, scorer=scorer)
|
||||
predictions.extend(d[0][1] for d in decoded)
|
||||
ground_truths.extend(sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet))
|
||||
wav_filenames.extend(wav_filename.decode('UTF-8') for wav_filename in batch_wav_filenames)
|
||||
losses.extend(batch_loss)
|
||||
|
||||
step_count += 1
|
||||
@ -121,7 +123,7 @@ def evaluate(test_csvs, create_model, try_loading):
|
||||
|
||||
bar.finish()
|
||||
|
||||
wer, cer, samples = calculate_report(ground_truths, predictions, losses)
|
||||
wer, cer, samples = calculate_report(wav_filenames, ground_truths, predictions, losses)
|
||||
mean_loss = np.mean(losses)
|
||||
|
||||
# Take only the first report_count items
|
||||
@ -133,6 +135,7 @@ def evaluate(test_csvs, create_model, try_loading):
|
||||
for sample in report_samples:
|
||||
print('WER: %f, CER: %f, loss: %f' %
|
||||
(sample.wer, sample.cer, sample.loss))
|
||||
print(' - wav: file://%s' % sample.wav_filename)
|
||||
print(' - src: "%s"' % sample.src)
|
||||
print(' - res: "%s"' % sample.res)
|
||||
print('-' * 80)
|
||||
|
@ -34,12 +34,13 @@ def wer_cer_batch(samples):
|
||||
|
||||
|
||||
def process_decode_result(item):
|
||||
ground_truth, prediction, loss = item
|
||||
wav_filename, ground_truth, prediction, loss = item
|
||||
char_distance = levenshtein(ground_truth, prediction)
|
||||
char_length = len(ground_truth)
|
||||
word_distance = levenshtein(ground_truth.split(), prediction.split())
|
||||
word_length = len(ground_truth.split())
|
||||
return AttrDict({
|
||||
'wav_filename': wav_filename,
|
||||
'src': ground_truth,
|
||||
'res': prediction,
|
||||
'loss': loss,
|
||||
@ -52,13 +53,13 @@ def process_decode_result(item):
|
||||
})
|
||||
|
||||
|
||||
def calculate_report(labels, decodings, losses):
|
||||
def calculate_report(wav_filenames, labels, decodings, losses):
|
||||
r'''
|
||||
This routine will calculate a WER report.
|
||||
It'll compute the `mean` WER and create ``Sample`` objects of the ``report_count`` top lowest
|
||||
loss items from the provided WER results tuple (only items with WER!=0 and ordered by their WER).
|
||||
'''
|
||||
samples = pmap(process_decode_result, zip(labels, decodings, losses))
|
||||
samples = pmap(process_decode_result, zip(wav_filenames, labels, decodings, losses))
|
||||
|
||||
# Getting the WER and CER from the accumulated edit distances and lengths
|
||||
samples_wer, samples_cer = wer_cer_batch(samples)
|
||||
|
@ -52,7 +52,7 @@ def audiofile_to_features(wav_filename):
|
||||
def entry_to_features(wav_filename, transcript):
|
||||
# https://bugs.python.org/issue32117
|
||||
features, features_len = audiofile_to_features(wav_filename)
|
||||
return features, features_len, tf.SparseTensor(*transcript)
|
||||
return wav_filename, features, features_len, tf.SparseTensor(*transcript)
|
||||
|
||||
|
||||
def to_sparse_tuple(sequence):
|
||||
@ -82,12 +82,13 @@ def create_dataset(csvs, batch_size, cache_path=''):
|
||||
shape = sparse.dense_shape
|
||||
return tf.sparse.reshape(sparse, [shape[0], shape[2]])
|
||||
|
||||
def batch_fn(features, features_len, transcripts):
|
||||
def batch_fn(wav_filenames, features, features_len, transcripts):
|
||||
features = tf.data.Dataset.zip((features, features_len))
|
||||
features = features.padded_batch(batch_size,
|
||||
padded_shapes=([None, Config.n_input], []))
|
||||
transcripts = transcripts.batch(batch_size).map(sparse_reshape)
|
||||
return tf.data.Dataset.zip((features, transcripts))
|
||||
wav_filenames = wav_filenames.batch(batch_size)
|
||||
return tf.data.Dataset.zip((wav_filenames, features, transcripts))
|
||||
|
||||
num_gpus = len(Config.available_devices)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user