Transcode input audio to 16 kHz when necessary
This commit is contained in:
parent
1aca3a32d9
commit
1abcedfc99
|
@ -4,9 +4,10 @@ from __future__ import absolute_import, division, print_function
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import scipy.io.wavfile as wav
|
import scipy.io.wavfile as wav
|
||||||
|
import numpy as np
|
||||||
from deepspeech.model import Model
|
from deepspeech.model import Model
|
||||||
|
|
||||||
# These constants control the beam search decoder
|
# These constants control the beam search decoder
|
||||||
|
@ -35,6 +36,23 @@ N_FEATURES = 26
|
||||||
# Size of the context window used for producing timesteps in the input vector
|
# Size of the context window used for producing timesteps in the input vector
|
||||||
N_CONTEXT = 9
|
N_CONTEXT = 9
|
||||||
|
|
||||||
|
def convert_samplerate(audio_path):
|
||||||
|
sox_cmd = 'sox --norm {} -b 16 -t wav - channels 1 rate 16000'.format(audio_path)
|
||||||
|
try:
|
||||||
|
p = subprocess.Popen(sox_cmd.split(),
|
||||||
|
stderr=subprocess.PIPE, stdout=subprocess.PIPE)
|
||||||
|
output, err = p.communicate()
|
||||||
|
|
||||||
|
if p.returncode:
|
||||||
|
raise RuntimeError('SoX returned non-zero status')
|
||||||
|
|
||||||
|
except OSError as e:
|
||||||
|
raise OSError('SoX not found, use 16kHz files or install it')
|
||||||
|
|
||||||
|
# we already know the header information, get only the data from output
|
||||||
|
audio = np.fromstring(output.split('data')[1], dtype=np.int16)
|
||||||
|
return 16000, audio
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description='Benchmarking tooling for DeepSpeech native_client.')
|
parser = argparse.ArgumentParser(description='Benchmarking tooling for DeepSpeech native_client.')
|
||||||
parser.add_argument('model', type=str,
|
parser.add_argument('model', type=str,
|
||||||
|
@ -64,9 +82,9 @@ def main():
|
||||||
print('Loaded language model in %0.3fs.' % (lm_load_end), file=sys.stderr)
|
print('Loaded language model in %0.3fs.' % (lm_load_end), file=sys.stderr)
|
||||||
|
|
||||||
fs, audio = wav.read(args.audio)
|
fs, audio = wav.read(args.audio)
|
||||||
# We can assume 16kHz
|
if fs != 16000:
|
||||||
|
fs, audio = convert_samplerate(args.audio)
|
||||||
audio_length = len(audio) * ( 1 / 16000)
|
audio_length = len(audio) * ( 1 / 16000)
|
||||||
assert fs == 16000, "Only 16000Hz input WAV files are supported for now!"
|
|
||||||
|
|
||||||
print('Running inference.', file=sys.stderr)
|
print('Running inference.', file=sys.stderr)
|
||||||
inference_start = timer()
|
inference_start = timer()
|
||||||
|
|
Loading…
Reference in New Issue