evaluate_tflite: Fix shared Queue
Also dump output to a file Fixed some trivial pylint issues at the same time Signed-off-by: Li Li <eggonlea@msn.com>
This commit is contained in:
parent
94df405ec4
commit
863c5544ca
@ -6,14 +6,13 @@ import argparse
|
||||
import numpy as np
|
||||
import wave
|
||||
import csv
|
||||
import sys
|
||||
import os
|
||||
|
||||
from six.moves import zip, range
|
||||
from multiprocessing import JoinableQueue, Pool, Process, Queue, cpu_count
|
||||
from multiprocessing import JoinableQueue, Process, cpu_count, Manager
|
||||
from deepspeech import Model
|
||||
|
||||
from util.evaluate_tools import process_decode_result, calculate_report
|
||||
from util.evaluate_tools import calculate_report
|
||||
|
||||
r'''
|
||||
This module should be self-contained:
|
||||
@ -41,15 +40,17 @@ def tflite_worker(model, alphabet, lm, trie, queue_in, queue_out, gpu_mask):
|
||||
while True:
|
||||
msg = queue_in.get()
|
||||
|
||||
fin = wave.open(msg['filename'], 'rb')
|
||||
filename = msg['filename']
|
||||
wavname = os.path.splitext(os.path.basename(filename))[0]
|
||||
fin = wave.open(filename, 'rb')
|
||||
fs = fin.getframerate()
|
||||
audio = np.frombuffer(fin.readframes(fin.getnframes()), np.int16)
|
||||
audio_length = fin.getnframes() * (1/16000)
|
||||
fin.close()
|
||||
|
||||
|
||||
decoded = ds.stt(audio, fs)
|
||||
|
||||
queue_out.put({'prediction': decoded, 'ground_truth': msg['transcript']})
|
||||
|
||||
queue_out.put({'wav': wavname, 'prediction': decoded, 'ground_truth': msg['transcript']})
|
||||
print(queue_out.qsize(), end='\r') # Update the current progress
|
||||
queue_in.task_done()
|
||||
|
||||
def main():
|
||||
@ -66,10 +67,13 @@ def main():
|
||||
help='Path to the CSV source file')
|
||||
parser.add_argument('--proc', required=False, default=cpu_count(), type=int,
|
||||
help='Number of processes to spawn, defaulting to number of CPUs')
|
||||
parser.add_argument('--dump', required=False, action='store_true', default=False,
|
||||
help='Dump the results as text file, with one line for each wav: "wav transcription"')
|
||||
args = parser.parse_args()
|
||||
|
||||
manager = Manager()
|
||||
work_todo = JoinableQueue() # this is where we are going to store input data
|
||||
work_done = Queue() # this where we are gonna push them out
|
||||
work_done = manager.Queue() # this where we are gonna push them out
|
||||
|
||||
processes = []
|
||||
for i in range(args.proc):
|
||||
@ -79,27 +83,41 @@ def main():
|
||||
|
||||
print([x.name for x in processes])
|
||||
|
||||
wavlist = []
|
||||
ground_truths = []
|
||||
predictions = []
|
||||
losses = []
|
||||
|
||||
with open(args.csv, 'r') as csvfile:
|
||||
csvreader = csv.DictReader(csvfile)
|
||||
count = 0
|
||||
for row in csvreader:
|
||||
count += 1
|
||||
work_todo.put({'filename': row['wav_filename'], 'transcript': row['transcript']})
|
||||
print('Totally %d wav entries found in csv\n' % count)
|
||||
work_todo.join()
|
||||
print('\nTotally %d wav file transcripted' % work_done.qsize())
|
||||
|
||||
while (not work_done.empty()):
|
||||
while not work_done.empty():
|
||||
msg = work_done.get()
|
||||
losses.append(0.0)
|
||||
ground_truths.append(msg['ground_truth'])
|
||||
predictions.append(msg['prediction'])
|
||||
wavlist.append(msg['wav'])
|
||||
|
||||
wer, cer, samples = calculate_report(ground_truths, predictions, losses)
|
||||
wer, cer, _ = calculate_report(ground_truths, predictions, losses)
|
||||
mean_loss = np.mean(losses)
|
||||
|
||||
print('Test - WER: %f, CER: %f, loss: %f' %
|
||||
(wer, cer, mean_loss))
|
||||
|
||||
if args.dump:
|
||||
with open(args.csv + '.txt', 'w') as ftxt, open(args.csv + '.out', 'w') as fout:
|
||||
for wav, txt, out in zip(wavlist, ground_truths, predictions):
|
||||
ftxt.write('%s %s\n' % (wav, txt))
|
||||
fout.write('%s %s\n' % (wav, out))
|
||||
print('Reference texts dumped to %s.txt' % args.csv)
|
||||
print('Transcription dumped to %s.out' % args.csv)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user