Expose and use model sample rate in Python
This commit is contained in:
parent
c1ed6d711d
commit
afea2b4231
@ -44,6 +44,15 @@ class Model(object):
|
|||||||
deepspeech.impl.FreeModel(self._impl)
|
deepspeech.impl.FreeModel(self._impl)
|
||||||
self._impl = None
|
self._impl = None
|
||||||
|
|
||||||
|
def sampleRate(self):
|
||||||
|
"""
|
||||||
|
Return the sample rate expected by the model.
|
||||||
|
|
||||||
|
:return: Sample rate.
|
||||||
|
:type: int
|
||||||
|
"""
|
||||||
|
return deepspeech.impl.GetModelSampleRate(self._impl)
|
||||||
|
|
||||||
def enableDecoderWithLM(self, *args, **kwargs):
|
def enableDecoderWithLM(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Enable decoding using beam scoring with a KenLM language model.
|
Enable decoding using beam scoring with a KenLM language model.
|
||||||
|
@ -17,9 +17,6 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
from pipes import quote
|
from pipes import quote
|
||||||
|
|
||||||
# Define the sample rate for audio
|
|
||||||
|
|
||||||
SAMPLE_RATE = 16000
|
|
||||||
# These constants control the beam search decoder
|
# These constants control the beam search decoder
|
||||||
|
|
||||||
# Beam width used in the CTC decoder when building candidate transcriptions
|
# Beam width used in the CTC decoder when building candidate transcriptions
|
||||||
@ -32,16 +29,16 @@ LM_ALPHA = 0.75
|
|||||||
LM_BETA = 1.85
|
LM_BETA = 1.85
|
||||||
|
|
||||||
|
|
||||||
def convert_samplerate(audio_path):
|
def convert_samplerate(audio_path, desired_sample_rate):
|
||||||
sox_cmd = 'sox {} --type raw --bits 16 --channels 1 --rate {} --encoding signed-integer --endian little --compression 0.0 --no-dither - '.format(quote(audio_path), SAMPLE_RATE)
|
sox_cmd = 'sox {} --type raw --bits 16 --channels 1 --rate {} --encoding signed-integer --endian little --compression 0.0 --no-dither - '.format(quote(audio_path), desired_sample_rate)
|
||||||
try:
|
try:
|
||||||
output = subprocess.check_output(shlex.split(sox_cmd), stderr=subprocess.PIPE)
|
output = subprocess.check_output(shlex.split(sox_cmd), stderr=subprocess.PIPE)
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
raise RuntimeError('SoX returned non-zero status: {}'.format(e.stderr))
|
raise RuntimeError('SoX returned non-zero status: {}'.format(e.stderr))
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
raise OSError(e.errno, 'SoX not found, use {}hz files or install it: {}'.format(SAMPLE_RATE, e.strerror))
|
raise OSError(e.errno, 'SoX not found, use {}hz files or install it: {}'.format(desired_sample_rate, e.strerror))
|
||||||
|
|
||||||
return SAMPLE_RATE, np.frombuffer(output, np.int16)
|
return desired_sample_rate, np.frombuffer(output, np.int16)
|
||||||
|
|
||||||
|
|
||||||
def metadata_to_string(metadata):
|
def metadata_to_string(metadata):
|
||||||
@ -81,6 +78,8 @@ def main():
|
|||||||
model_load_end = timer() - model_load_start
|
model_load_end = timer() - model_load_start
|
||||||
print('Loaded model in {:.3}s.'.format(model_load_end), file=sys.stderr)
|
print('Loaded model in {:.3}s.'.format(model_load_end), file=sys.stderr)
|
||||||
|
|
||||||
|
desired_sample_rate = ds.sampleRate()
|
||||||
|
|
||||||
if args.lm and args.trie:
|
if args.lm and args.trie:
|
||||||
print('Loading language model from files {} {}'.format(args.lm, args.trie), file=sys.stderr)
|
print('Loading language model from files {} {}'.format(args.lm, args.trie), file=sys.stderr)
|
||||||
lm_load_start = timer()
|
lm_load_start = timer()
|
||||||
@ -90,13 +89,13 @@ def main():
|
|||||||
|
|
||||||
fin = wave.open(args.audio, 'rb')
|
fin = wave.open(args.audio, 'rb')
|
||||||
fs = fin.getframerate()
|
fs = fin.getframerate()
|
||||||
if fs != SAMPLE_RATE:
|
if fs != desired_sample_rate:
|
||||||
print('Warning: original sample rate ({}) is different than {}hz. Resampling might produce erratic speech recognition.'.format(fs, SAMPLE_RATE), file=sys.stderr)
|
print('Warning: original sample rate ({}) is different than {}hz. Resampling might produce erratic speech recognition.'.format(fs, desired_sample_rate), file=sys.stderr)
|
||||||
fs, audio = convert_samplerate(args.audio)
|
fs, audio = convert_samplerate(args.audio, desired_sample_rate)
|
||||||
else:
|
else:
|
||||||
audio = np.frombuffer(fin.readframes(fin.getnframes()), np.int16)
|
audio = np.frombuffer(fin.readframes(fin.getnframes()), np.int16)
|
||||||
|
|
||||||
audio_length = fin.getnframes() * (1/SAMPLE_RATE)
|
audio_length = fin.getnframes() * (1/fs)
|
||||||
fin.close()
|
fin.close()
|
||||||
|
|
||||||
print('Running inference.', file=sys.stderr)
|
print('Running inference.', file=sys.stderr)
|
||||||
|
Loading…
Reference in New Issue
Block a user