Update evaluate_tflite with wav_filename
This commit is contained in:
parent
f3694efbca
commit
f3240bffbc
@ -36,18 +36,21 @@ def tflite_worker(model, alphabet, lm, trie, queue_in, queue_out, gpu_mask):
|
|||||||
ds.enableDecoderWithLM(lm, trie, LM_ALPHA, LM_BETA)
|
ds.enableDecoderWithLM(lm, trie, LM_ALPHA, LM_BETA)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
msg = queue_in.get()
|
try:
|
||||||
|
msg = queue_in.get()
|
||||||
|
|
||||||
filename = msg['filename']
|
filename = msg['filename']
|
||||||
wavname = os.path.splitext(os.path.basename(filename))[0]
|
wavname = os.path.splitext(os.path.basename(filename))[0]
|
||||||
fin = wave.open(filename, 'rb')
|
fin = wave.open(filename, 'rb')
|
||||||
fs = fin.getframerate()
|
audio = np.frombuffer(fin.readframes(fin.getnframes()), np.int16)
|
||||||
audio = np.frombuffer(fin.readframes(fin.getnframes()), np.int16)
|
fin.close()
|
||||||
fin.close()
|
|
||||||
|
|
||||||
decoded = ds.stt(audio)
|
decoded = ds.stt(audio)
|
||||||
|
|
||||||
|
queue_out.put({'wav': wavname, 'prediction': decoded, 'ground_truth': msg['transcript']})
|
||||||
|
except FileNotFoundError as ex:
|
||||||
|
print('FileNotFoundError: ', ex)
|
||||||
|
|
||||||
queue_out.put({'wav': wavname, 'prediction': decoded, 'ground_truth': msg['transcript']})
|
|
||||||
print(queue_out.qsize(), end='\r') # Update the current progress
|
print(queue_out.qsize(), end='\r') # Update the current progress
|
||||||
queue_in.task_done()
|
queue_in.task_done()
|
||||||
|
|
||||||
@ -85,6 +88,7 @@ def main():
|
|||||||
ground_truths = []
|
ground_truths = []
|
||||||
predictions = []
|
predictions = []
|
||||||
losses = []
|
losses = []
|
||||||
|
wav_filenames = []
|
||||||
|
|
||||||
with open(args.csv, 'r') as csvfile:
|
with open(args.csv, 'r') as csvfile:
|
||||||
csvreader = csv.DictReader(csvfile)
|
csvreader = csv.DictReader(csvfile)
|
||||||
@ -92,6 +96,8 @@ def main():
|
|||||||
for row in csvreader:
|
for row in csvreader:
|
||||||
count += 1
|
count += 1
|
||||||
work_todo.put({'filename': row['wav_filename'], 'transcript': row['transcript']})
|
work_todo.put({'filename': row['wav_filename'], 'transcript': row['transcript']})
|
||||||
|
wav_filenames.extend(row['wav_filename'])
|
||||||
|
|
||||||
print('Totally %d wav entries found in csv\n' % count)
|
print('Totally %d wav entries found in csv\n' % count)
|
||||||
work_todo.join()
|
work_todo.join()
|
||||||
print('\nTotally %d wav file transcripted' % work_done.qsize())
|
print('\nTotally %d wav file transcripted' % work_done.qsize())
|
||||||
@ -103,7 +109,7 @@ def main():
|
|||||||
predictions.append(msg['prediction'])
|
predictions.append(msg['prediction'])
|
||||||
wavlist.append(msg['wav'])
|
wavlist.append(msg['wav'])
|
||||||
|
|
||||||
wer, cer, _ = calculate_report(ground_truths, predictions, losses)
|
wer, cer, samples = calculate_report(wav_filenames, ground_truths, predictions, losses)
|
||||||
mean_loss = np.mean(losses)
|
mean_loss = np.mean(losses)
|
||||||
|
|
||||||
print('Test - WER: %f, CER: %f, loss: %f' %
|
print('Test - WER: %f, CER: %f, loss: %f' %
|
||||||
|
Loading…
x
Reference in New Issue
Block a user