Abstract sample rate / update beam width (#2199)
* Abstract sample rate
This commit is contained in:
parent
06056949bb
commit
a5c61ea588
@ -17,6 +17,9 @@ try:
|
||||
except ImportError:
|
||||
from pipes import quote
|
||||
|
||||
# Define the sample rate for audio
|
||||
|
||||
SAMPLE_RATE = 16000
|
||||
# These constants control the beam search decoder
|
||||
|
||||
# Beam width used in the CTC decoder when building candidate transcriptions
|
||||
@ -41,15 +44,15 @@ N_CONTEXT = 9
|
||||
|
||||
|
||||
def convert_samplerate(audio_path):
|
||||
sox_cmd = 'sox {} --type raw --bits 16 --channels 1 --rate 16000 --encoding signed-integer --endian little --compression 0.0 --no-dither - '.format(quote(audio_path))
|
||||
sox_cmd = 'sox {} --type raw --bits 16 --channels 1 --rate {} --encoding signed-integer --endian little --compression 0.0 --no-dither - '.format(quote(audio_path), SAMPLE_RATE)
|
||||
try:
|
||||
output = subprocess.check_output(shlex.split(sox_cmd), stderr=subprocess.PIPE)
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise RuntimeError('SoX returned non-zero status: {}'.format(e.stderr))
|
||||
except OSError as e:
|
||||
raise OSError(e.errno, 'SoX not found, use 16kHz files or install it: {}'.format(e.strerror))
|
||||
raise OSError(e.errno, 'SoX not found, use {}hz files or install it: {}'.format(SAMPLE_RATE, e.strerror))
|
||||
|
||||
return 16000, np.frombuffer(output, np.int16)
|
||||
return SAMPLE_RATE, np.frombuffer(output, np.int16)
|
||||
|
||||
|
||||
def metadata_to_string(metadata):
|
||||
@ -98,13 +101,13 @@ def main():
|
||||
|
||||
fin = wave.open(args.audio, 'rb')
|
||||
fs = fin.getframerate()
|
||||
if fs != 16000:
|
||||
print('Warning: original sample rate ({}) is different than 16kHz. Resampling might produce erratic speech recognition.'.format(fs), file=sys.stderr)
|
||||
if fs != SAMPLE_RATE:
|
||||
print('Warning: original sample rate ({}) is different than {}hz. Resampling might produce erratic speech recognition.'.format(fs, SAMPLE_RATE), file=sys.stderr)
|
||||
fs, audio = convert_samplerate(args.audio)
|
||||
else:
|
||||
audio = np.frombuffer(fin.readframes(fin.getnframes()), np.int16)
|
||||
|
||||
audio_length = fin.getnframes() * (1/16000)
|
||||
audio_length = fin.getnframes() * (1/SAMPLE_RATE)
|
||||
fin.close()
|
||||
|
||||
print('Running inference.', file=sys.stderr)
|
||||
|
Loading…
x
Reference in New Issue
Block a user