158 lines
4.9 KiB
Python
158 lines
4.9 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
from __future__ import absolute_import, division, print_function
|
|
|
|
import argparse
|
|
import csv
|
|
import os
|
|
import sys
|
|
import wave
|
|
from functools import partial
|
|
from multiprocessing import JoinableQueue, Manager, Process, cpu_count
|
|
|
|
import numpy as np
|
|
from coqui_stt_training.util.evaluate_tools import calculate_and_print_report
|
|
from coqui_stt_training.util.config import Config, initialize_globals_from_args
|
|
from six.moves import range, zip
|
|
from stt import Model
|
|
|
|
r"""
|
|
This module should be self-contained:
|
|
- build libstt.so with TFLite:
|
|
- bazel build [...] --define=runtime=tflite [...] //native_client:libstt.so
|
|
- make -C native_client/python/ TFDIR=... bindings
|
|
- setup a virtualenv
|
|
- pip install native_client/python/dist/*.whl
|
|
- pip install -r requirements_eval_tflite.txt
|
|
|
|
Then run with a TFLite model, a scorer and a CSV test file
|
|
"""
|
|
|
|
|
|
def tflite_worker(model, scorer, queue_in, queue_out, gpu_mask):
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_mask)
|
|
ds = Model(model)
|
|
ds.enableExternalScorer(scorer)
|
|
|
|
while True:
|
|
try:
|
|
msg = queue_in.get()
|
|
|
|
filename = msg["filename"]
|
|
fin = wave.open(filename, "rb")
|
|
audio = np.frombuffer(fin.readframes(fin.getnframes()), np.int16)
|
|
fin.close()
|
|
|
|
decoded = ds.stt(audio)
|
|
|
|
queue_out.put(
|
|
{
|
|
"wav": filename,
|
|
"prediction": decoded,
|
|
"ground_truth": msg["transcript"],
|
|
}
|
|
)
|
|
except FileNotFoundError as ex:
|
|
print("FileNotFoundError: ", ex)
|
|
|
|
print(queue_out.qsize(), end="\r") # Update the current progress
|
|
queue_in.task_done()
|
|
|
|
|
|
def main(args):
|
|
initialize_globals_from_args()
|
|
manager = Manager()
|
|
work_todo = JoinableQueue() # this is where we are going to store input data
|
|
work_done = manager.Queue() # this where we are gonna push them out
|
|
|
|
processes = []
|
|
for i in range(args.proc):
|
|
worker_process = Process(
|
|
target=tflite_worker,
|
|
args=(args.model, args.scorer, work_todo, work_done, i),
|
|
daemon=True,
|
|
name="tflite_process_{}".format(i),
|
|
)
|
|
worker_process.start() # Launch reader() as a separate python process
|
|
processes.append(worker_process)
|
|
|
|
print([x.name for x in processes])
|
|
|
|
wavlist = []
|
|
ground_truths = []
|
|
predictions = []
|
|
losses = []
|
|
wav_filenames = []
|
|
|
|
with open(args.csv, "r") as csvfile:
|
|
csvreader = csv.DictReader(csvfile)
|
|
count = 0
|
|
for row in csvreader:
|
|
count += 1
|
|
# Relative paths are relative to the folder the CSV file is in
|
|
if not os.path.isabs(row["wav_filename"]):
|
|
row["wav_filename"] = os.path.join(
|
|
os.path.dirname(args.csv), row["wav_filename"]
|
|
)
|
|
work_todo.put(
|
|
{"filename": row["wav_filename"], "transcript": row["transcript"]}
|
|
)
|
|
wav_filenames.extend(row["wav_filename"])
|
|
|
|
print("Totally %d wav entries found in csv\n" % count)
|
|
work_todo.join()
|
|
print("\nTotally %d wav file transcripted" % work_done.qsize())
|
|
|
|
while not work_done.empty():
|
|
msg = work_done.get()
|
|
losses.append(0.0)
|
|
ground_truths.append(msg["ground_truth"])
|
|
predictions.append(msg["prediction"])
|
|
wavlist.append(msg["wav"])
|
|
|
|
# Print test summary
|
|
_ = calculate_and_print_report(
|
|
wav_filenames, ground_truths, predictions, losses, args.csv
|
|
)
|
|
|
|
if args.dump:
|
|
with open(args.dump + ".txt", "w") as ftxt, open(
|
|
args.dump + ".out", "w"
|
|
) as fout:
|
|
for wav, txt, out in zip(wavlist, ground_truths, predictions):
|
|
ftxt.write("%s %s\n" % (wav, txt))
|
|
fout.write("%s %s\n" % (wav, out))
|
|
print("Reference texts dumped to %s.txt" % args.dump)
|
|
print("Transcription dumped to %s.out" % args.dump)
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Computing TFLite accuracy")
|
|
parser.add_argument(
|
|
"--model", required=True, help="Path to the model (protocol buffer binary file)"
|
|
)
|
|
parser.add_argument(
|
|
"--scorer", required=True, help="Path to the external scorer file"
|
|
)
|
|
parser.add_argument("--csv", required=True, help="Path to the CSV source file")
|
|
parser.add_argument(
|
|
"--proc",
|
|
required=False,
|
|
default=cpu_count(),
|
|
type=int,
|
|
help="Number of processes to spawn, defaulting to number of CPUs",
|
|
)
|
|
parser.add_argument(
|
|
"--dump",
|
|
required=False,
|
|
help='Path to dump the results as text file, with one line for each wav: "wav transcription".',
|
|
)
|
|
args, unknown = parser.parse_known_args()
|
|
# Reconstruct argv for absl.flags
|
|
sys.argv = [sys.argv[0]] + unknown
|
|
return args
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main(parse_args())
|