Expose and use model sample rate in Python

This commit is contained in:
Reuben Morais 2019-10-10 21:50:15 +02:00
parent c1ed6d711d
commit afea2b4231
2 changed files with 19 additions and 11 deletions

View File

@ -44,6 +44,15 @@ class Model(object):
deepspeech.impl.FreeModel(self._impl)
self._impl = None
def sampleRate(self):
"""
Return the sample rate expected by the model.
:return: Sample rate.
:type: int
"""
return deepspeech.impl.GetModelSampleRate(self._impl)
def enableDecoderWithLM(self, *args, **kwargs):
"""
Enable decoding using beam scoring with a KenLM language model.

View File

@ -17,9 +17,6 @@ try:
except ImportError:
from pipes import quote
# Define the sample rate for audio
SAMPLE_RATE = 16000
# These constants control the beam search decoder
# Beam width used in the CTC decoder when building candidate transcriptions
@ -32,16 +29,16 @@ LM_ALPHA = 0.75
LM_BETA = 1.85
def convert_samplerate(audio_path):
sox_cmd = 'sox {} --type raw --bits 16 --channels 1 --rate {} --encoding signed-integer --endian little --compression 0.0 --no-dither - '.format(quote(audio_path), SAMPLE_RATE)
def convert_samplerate(audio_path, desired_sample_rate):
sox_cmd = 'sox {} --type raw --bits 16 --channels 1 --rate {} --encoding signed-integer --endian little --compression 0.0 --no-dither - '.format(quote(audio_path), desired_sample_rate)
try:
output = subprocess.check_output(shlex.split(sox_cmd), stderr=subprocess.PIPE)
except subprocess.CalledProcessError as e:
raise RuntimeError('SoX returned non-zero status: {}'.format(e.stderr))
except OSError as e:
raise OSError(e.errno, 'SoX not found, use {}hz files or install it: {}'.format(SAMPLE_RATE, e.strerror))
raise OSError(e.errno, 'SoX not found, use {}hz files or install it: {}'.format(desired_sample_rate, e.strerror))
return SAMPLE_RATE, np.frombuffer(output, np.int16)
return desired_sample_rate, np.frombuffer(output, np.int16)
def metadata_to_string(metadata):
@ -81,6 +78,8 @@ def main():
model_load_end = timer() - model_load_start
print('Loaded model in {:.3}s.'.format(model_load_end), file=sys.stderr)
desired_sample_rate = ds.sampleRate()
if args.lm and args.trie:
print('Loading language model from files {} {}'.format(args.lm, args.trie), file=sys.stderr)
lm_load_start = timer()
@ -90,13 +89,13 @@ def main():
fin = wave.open(args.audio, 'rb')
fs = fin.getframerate()
if fs != SAMPLE_RATE:
print('Warning: original sample rate ({}) is different than {}hz. Resampling might produce erratic speech recognition.'.format(fs, SAMPLE_RATE), file=sys.stderr)
fs, audio = convert_samplerate(args.audio)
if fs != desired_sample_rate:
print('Warning: original sample rate ({}) is different than {}hz. Resampling might produce erratic speech recognition.'.format(fs, desired_sample_rate), file=sys.stderr)
fs, audio = convert_samplerate(args.audio, desired_sample_rate)
else:
audio = np.frombuffer(fin.readframes(fin.getnframes()), np.int16)
audio_length = fin.getnframes() * (1/SAMPLE_RATE)
audio_length = fin.getnframes() * (1/fs)
fin.close()
print('Running inference.', file=sys.stderr)