From 1abcedfc993919b8baf2500aade5c595ed458781 Mon Sep 17 00:00:00 2001 From: Guilherme Nardari Date: Fri, 2 Feb 2018 16:27:01 -0200 Subject: [PATCH] Transcode input audio to 16 kHz when necessary --- native_client/python/client.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/native_client/python/client.py b/native_client/python/client.py index dff5fc81..a6ba5a6c 100644 --- a/native_client/python/client.py +++ b/native_client/python/client.py @@ -4,9 +4,10 @@ from __future__ import absolute_import, division, print_function from timeit import default_timer as timer import argparse +import subprocess import sys import scipy.io.wavfile as wav - +import numpy as np from deepspeech.model import Model # These constants control the beam search decoder @@ -35,6 +36,23 @@ N_FEATURES = 26 # Size of the context window used for producing timesteps in the input vector N_CONTEXT = 9 +def convert_samplerate(audio_path): + sox_cmd = 'sox --norm {} -b 16 -t wav - channels 1 rate 16000'.format(audio_path) + try: + p = subprocess.Popen(sox_cmd.split(), + stderr=subprocess.PIPE, stdout=subprocess.PIPE) + output, err = p.communicate() + + if p.returncode: + raise RuntimeError('SoX returned non-zero status') + + except OSError as e: + raise OSError('SoX not found, use 16kHz files or install it') + + # we already know the header information, get only the data from output + audio = np.fromstring(output.split('data')[1], dtype=np.int16) + return 16000, audio + def main(): parser = argparse.ArgumentParser(description='Benchmarking tooling for DeepSpeech native_client.') parser.add_argument('model', type=str, @@ -64,9 +82,9 @@ def main(): print('Loaded language model in %0.3fs.' % (lm_load_end), file=sys.stderr) fs, audio = wav.read(args.audio) - # We can assume 16kHz + if fs != 16000: + fs, audio = convert_samplerate(args.audio) audio_length = len(audio) * ( 1 / 16000) - assert fs == 16000, "Only 16000Hz input WAV files are supported for now!" print('Running inference.', file=sys.stderr) inference_start = timer()