Make evaluate_tflite.py work with v0.6.1 calculate_report

This commit is contained in:
Reuben Morais 2020-01-12 11:01:33 +01:00
parent 42fe30d572
commit 33e725bb25

View File

@ -2,17 +2,21 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
import absl.app
import argparse
import numpy as np
import wave
import csv
import os
import sys
from functools import partial
from six.moves import zip, range
from multiprocessing import JoinableQueue, Process, cpu_count, Manager
from deepspeech import Model
from util.evaluate_tools import calculate_report
from util.flags import create_flags
r'''
This module should be self-contained:
@ -54,22 +58,7 @@ def tflite_worker(model, lm, trie, queue_in, queue_out, gpu_mask):
print(queue_out.qsize(), end='\r') # Update the current progress
queue_in.task_done()
def main():
parser = argparse.ArgumentParser(description='Computing TFLite accuracy')
parser.add_argument('--model', required=True,
help='Path to the model (protocol buffer binary file)')
parser.add_argument('--lm', required=True,
help='Path to the language model binary file')
parser.add_argument('--trie', required=True,
help='Path to the language model trie file created with native_client/generate_trie')
parser.add_argument('--csv', required=True,
help='Path to the CSV source file')
parser.add_argument('--proc', required=False, default=cpu_count(), type=int,
help='Number of processes to spawn, defaulting to number of CPUs')
parser.add_argument('--dump', required=False, action='store_true', default=False,
help='Dump the results as text file, with one line for each wav: "wav transcription"')
args = parser.parse_args()
def main(args, _):
manager = Manager()
work_todo = JoinableQueue() # this is where we are going to store input data
work_done = manager.Queue() # this where we are gonna push them out
@ -117,12 +106,29 @@ def main():
(wer, cer, mean_loss))
if args.dump:
with open(args.csv + '.txt', 'w') as ftxt, open(args.csv + '.out', 'w') as fout:
with open(args.dump + '.txt', 'w') as ftxt, open(args.dump + '.out', 'w') as fout:
for wav, txt, out in zip(wavlist, ground_truths, predictions):
ftxt.write('%s %s\n' % (wav, txt))
fout.write('%s %s\n' % (wav, out))
print('Reference texts dumped to %s.txt' % args.csv)
print('Transcription dumped to %s.out' % args.csv)
print('Reference texts dumped to %s.txt' % args.dump)
print('Transcription dumped to %s.out' % args.dump)
if __name__ == '__main__':
main()
parser = argparse.ArgumentParser(description='Computing TFLite accuracy')
parser.add_argument('--model', required=True,
help='Path to the model (protocol buffer binary file)')
parser.add_argument('--lm', required=True,
help='Path to the language model binary file')
parser.add_argument('--trie', required=True,
help='Path to the language model trie file created with native_client/generate_trie')
parser.add_argument('--csv', required=True,
help='Path to the CSV source file')
parser.add_argument('--proc', required=False, default=cpu_count(), type=int,
help='Number of processes to spawn, defaulting to number of CPUs')
parser.add_argument('--dump', required=False,
help='Path to dump the results as text file, with one line for each wav: "wav transcription".')
args, unknown = parser.parse_known_args()
# Reconstruct argv for absl.flags
sys.argv = [sys.argv[0]] + unknown
create_flags()
absl.app.run(partial(main, args))