Transcode input audio to 16 kHz when necessary

This commit is contained in:
Guilherme Nardari 2018-02-02 16:27:01 -02:00
parent 1aca3a32d9
commit 1abcedfc99
1 changed files with 21 additions and 3 deletions

View File

@ -4,9 +4,10 @@ from __future__ import absolute_import, division, print_function
from timeit import default_timer as timer from timeit import default_timer as timer
import argparse import argparse
import subprocess
import sys import sys
import scipy.io.wavfile as wav import scipy.io.wavfile as wav
import numpy as np
from deepspeech.model import Model from deepspeech.model import Model
# These constants control the beam search decoder # These constants control the beam search decoder
@ -35,6 +36,23 @@ N_FEATURES = 26
# Size of the context window used for producing timesteps in the input vector # Size of the context window used for producing timesteps in the input vector
N_CONTEXT = 9 N_CONTEXT = 9
def convert_samplerate(audio_path):
sox_cmd = 'sox --norm {} -b 16 -t wav - channels 1 rate 16000'.format(audio_path)
try:
p = subprocess.Popen(sox_cmd.split(),
stderr=subprocess.PIPE, stdout=subprocess.PIPE)
output, err = p.communicate()
if p.returncode:
raise RuntimeError('SoX returned non-zero status')
except OSError as e:
raise OSError('SoX not found, use 16kHz files or install it')
# we already know the header information, get only the data from output
audio = np.fromstring(output.split('data')[1], dtype=np.int16)
return 16000, audio
def main(): def main():
parser = argparse.ArgumentParser(description='Benchmarking tooling for DeepSpeech native_client.') parser = argparse.ArgumentParser(description='Benchmarking tooling for DeepSpeech native_client.')
parser.add_argument('model', type=str, parser.add_argument('model', type=str,
@ -64,9 +82,9 @@ def main():
print('Loaded language model in %0.3fs.' % (lm_load_end), file=sys.stderr) print('Loaded language model in %0.3fs.' % (lm_load_end), file=sys.stderr)
fs, audio = wav.read(args.audio) fs, audio = wav.read(args.audio)
# We can assume 16kHz if fs != 16000:
fs, audio = convert_samplerate(args.audio)
audio_length = len(audio) * ( 1 / 16000) audio_length = len(audio) * ( 1 / 16000)
assert fs == 16000, "Only 16000Hz input WAV files are supported for now!"
print('Running inference.', file=sys.stderr) print('Running inference.', file=sys.stderr)
inference_start = timer() inference_start = timer()