Make evaluate_tflite.py work with v0.6.1 calculate_report
This commit is contained in:
parent
42fe30d572
commit
33e725bb25
@ -2,17 +2,21 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import absl.app
|
||||
import argparse
|
||||
import numpy as np
|
||||
import wave
|
||||
import csv
|
||||
import os
|
||||
import sys
|
||||
|
||||
from functools import partial
|
||||
from six.moves import zip, range
|
||||
from multiprocessing import JoinableQueue, Process, cpu_count, Manager
|
||||
from deepspeech import Model
|
||||
|
||||
from util.evaluate_tools import calculate_report
|
||||
from util.flags import create_flags
|
||||
|
||||
r'''
|
||||
This module should be self-contained:
|
||||
@ -54,22 +58,7 @@ def tflite_worker(model, lm, trie, queue_in, queue_out, gpu_mask):
|
||||
print(queue_out.qsize(), end='\r') # Update the current progress
|
||||
queue_in.task_done()
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Computing TFLite accuracy')
|
||||
parser.add_argument('--model', required=True,
|
||||
help='Path to the model (protocol buffer binary file)')
|
||||
parser.add_argument('--lm', required=True,
|
||||
help='Path to the language model binary file')
|
||||
parser.add_argument('--trie', required=True,
|
||||
help='Path to the language model trie file created with native_client/generate_trie')
|
||||
parser.add_argument('--csv', required=True,
|
||||
help='Path to the CSV source file')
|
||||
parser.add_argument('--proc', required=False, default=cpu_count(), type=int,
|
||||
help='Number of processes to spawn, defaulting to number of CPUs')
|
||||
parser.add_argument('--dump', required=False, action='store_true', default=False,
|
||||
help='Dump the results as text file, with one line for each wav: "wav transcription"')
|
||||
args = parser.parse_args()
|
||||
|
||||
def main(args, _):
|
||||
manager = Manager()
|
||||
work_todo = JoinableQueue() # this is where we are going to store input data
|
||||
work_done = manager.Queue() # this where we are gonna push them out
|
||||
@ -117,12 +106,29 @@ def main():
|
||||
(wer, cer, mean_loss))
|
||||
|
||||
if args.dump:
|
||||
with open(args.csv + '.txt', 'w') as ftxt, open(args.csv + '.out', 'w') as fout:
|
||||
with open(args.dump + '.txt', 'w') as ftxt, open(args.dump + '.out', 'w') as fout:
|
||||
for wav, txt, out in zip(wavlist, ground_truths, predictions):
|
||||
ftxt.write('%s %s\n' % (wav, txt))
|
||||
fout.write('%s %s\n' % (wav, out))
|
||||
print('Reference texts dumped to %s.txt' % args.csv)
|
||||
print('Transcription dumped to %s.out' % args.csv)
|
||||
print('Reference texts dumped to %s.txt' % args.dump)
|
||||
print('Transcription dumped to %s.out' % args.dump)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
parser = argparse.ArgumentParser(description='Computing TFLite accuracy')
|
||||
parser.add_argument('--model', required=True,
|
||||
help='Path to the model (protocol buffer binary file)')
|
||||
parser.add_argument('--lm', required=True,
|
||||
help='Path to the language model binary file')
|
||||
parser.add_argument('--trie', required=True,
|
||||
help='Path to the language model trie file created with native_client/generate_trie')
|
||||
parser.add_argument('--csv', required=True,
|
||||
help='Path to the CSV source file')
|
||||
parser.add_argument('--proc', required=False, default=cpu_count(), type=int,
|
||||
help='Number of processes to spawn, defaulting to number of CPUs')
|
||||
parser.add_argument('--dump', required=False,
|
||||
help='Path to dump the results as text file, with one line for each wav: "wav transcription".')
|
||||
args, unknown = parser.parse_known_args()
|
||||
# Reconstruct argv for absl.flags
|
||||
sys.argv = [sys.argv[0]] + unknown
|
||||
create_flags()
|
||||
absl.app.run(partial(main, args))
|
||||
|
Loading…
x
Reference in New Issue
Block a user