Abstract sample rate / update beam width (#2199)

* Abstract sample rate
This commit is contained in:
Richard Hamnett 2019-07-04 15:58:45 +01:00 committed by Reuben Morais
parent 06056949bb
commit a5c61ea588

View File

@ -17,6 +17,9 @@ try:
except ImportError:
from pipes import quote
# Define the sample rate for audio
SAMPLE_RATE = 16000
# These constants control the beam search decoder
# Beam width used in the CTC decoder when building candidate transcriptions
@ -41,15 +44,15 @@ N_CONTEXT = 9
def convert_samplerate(audio_path):
sox_cmd = 'sox {} --type raw --bits 16 --channels 1 --rate 16000 --encoding signed-integer --endian little --compression 0.0 --no-dither - '.format(quote(audio_path))
sox_cmd = 'sox {} --type raw --bits 16 --channels 1 --rate {} --encoding signed-integer --endian little --compression 0.0 --no-dither - '.format(quote(audio_path), SAMPLE_RATE)
try:
output = subprocess.check_output(shlex.split(sox_cmd), stderr=subprocess.PIPE)
except subprocess.CalledProcessError as e:
raise RuntimeError('SoX returned non-zero status: {}'.format(e.stderr))
except OSError as e:
raise OSError(e.errno, 'SoX not found, use 16kHz files or install it: {}'.format(e.strerror))
raise OSError(e.errno, 'SoX not found, use {}hz files or install it: {}'.format(SAMPLE_RATE, e.strerror))
return 16000, np.frombuffer(output, np.int16)
return SAMPLE_RATE, np.frombuffer(output, np.int16)
def metadata_to_string(metadata):
@ -98,13 +101,13 @@ def main():
fin = wave.open(args.audio, 'rb')
fs = fin.getframerate()
if fs != 16000:
print('Warning: original sample rate ({}) is different than 16kHz. Resampling might produce erratic speech recognition.'.format(fs), file=sys.stderr)
if fs != SAMPLE_RATE:
print('Warning: original sample rate ({}) is different than {}hz. Resampling might produce erratic speech recognition.'.format(fs, SAMPLE_RATE), file=sys.stderr)
fs, audio = convert_samplerate(args.audio)
else:
audio = np.frombuffer(fin.readframes(fin.getnframes()), np.int16)
audio_length = fin.getnframes() * (1/16000)
audio_length = fin.getnframes() * (1/SAMPLE_RATE)
fin.close()
print('Running inference.', file=sys.stderr)