Merge pull request #1203 from gnardari/master

Transcode input audio to 16 kHz #1022
This commit is contained in:
lissyx 2018-02-05 15:44:34 +01:00 committed by GitHub
commit 3bd917d9e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 21 additions and 3 deletions

View File

@ -4,9 +4,10 @@ from __future__ import absolute_import, division, print_function
from timeit import default_timer as timer from timeit import default_timer as timer
import argparse import argparse
import subprocess
import sys import sys
import scipy.io.wavfile as wav import scipy.io.wavfile as wav
import numpy as np
from deepspeech.model import Model from deepspeech.model import Model
# These constants control the beam search decoder # These constants control the beam search decoder
@ -35,6 +36,23 @@ N_FEATURES = 26
# Size of the context window used for producing timesteps in the input vector # Size of the context window used for producing timesteps in the input vector
N_CONTEXT = 9 N_CONTEXT = 9
def convert_samplerate(audio_path):
sox_cmd = 'sox --norm {} -b 16 -t wav - channels 1 rate 16000'.format(audio_path)
try:
p = subprocess.Popen(sox_cmd.split(),
stderr=subprocess.PIPE, stdout=subprocess.PIPE)
output, err = p.communicate()
if p.returncode:
raise RuntimeError('SoX returned non-zero status')
except OSError as e:
raise OSError('SoX not found, use 16kHz files or install it')
# we already know the header information, get only the data from output
audio = np.fromstring(output.split('data')[1], dtype=np.int16)
return 16000, audio
def main(): def main():
parser = argparse.ArgumentParser(description='Benchmarking tooling for DeepSpeech native_client.') parser = argparse.ArgumentParser(description='Benchmarking tooling for DeepSpeech native_client.')
parser.add_argument('model', type=str, parser.add_argument('model', type=str,
@ -64,9 +82,9 @@ def main():
print('Loaded language model in %0.3fs.' % (lm_load_end), file=sys.stderr) print('Loaded language model in %0.3fs.' % (lm_load_end), file=sys.stderr)
fs, audio = wav.read(args.audio) fs, audio = wav.read(args.audio)
# We can assume 16kHz if fs != 16000:
fs, audio = convert_samplerate(args.audio)
audio_length = len(audio) * ( 1 / 16000) audio_length = len(audio) * ( 1 / 16000)
assert fs == 16000, "Only 16000Hz input WAV files are supported for now!"
print('Running inference.', file=sys.stderr) print('Running inference.', file=sys.stderr)
inference_start = timer() inference_start = timer()