Merge pull request #1203 from gnardari/master
Transcode input audio to 16 kHz #1022
This commit is contained in:
		
						commit
						3bd917d9e3
					
				@ -4,9 +4,10 @@ from __future__ import absolute_import, division, print_function
 | 
				
			|||||||
from timeit import default_timer as timer
 | 
					from timeit import default_timer as timer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					import subprocess
 | 
				
			||||||
import sys
 | 
					import sys
 | 
				
			||||||
import scipy.io.wavfile as wav
 | 
					import scipy.io.wavfile as wav
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
from deepspeech.model import Model
 | 
					from deepspeech.model import Model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# These constants control the beam search decoder
 | 
					# These constants control the beam search decoder
 | 
				
			||||||
@ -35,6 +36,23 @@ N_FEATURES = 26
 | 
				
			|||||||
# Size of the context window used for producing timesteps in the input vector
 | 
					# Size of the context window used for producing timesteps in the input vector
 | 
				
			||||||
N_CONTEXT = 9
 | 
					N_CONTEXT = 9
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def convert_samplerate(audio_path):
 | 
				
			||||||
 | 
					    sox_cmd = 'sox --norm {} -b 16 -t wav - channels 1 rate 16000'.format(audio_path)
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        p = subprocess.Popen(sox_cmd.split(),
 | 
				
			||||||
 | 
					                             stderr=subprocess.PIPE, stdout=subprocess.PIPE)
 | 
				
			||||||
 | 
					        output, err = p.communicate()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if p.returncode:
 | 
				
			||||||
 | 
					            raise RuntimeError('SoX returned non-zero status')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    except OSError as e:
 | 
				
			||||||
 | 
					        raise OSError('SoX not found, use 16kHz files or install it')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # we already know the header information, get only the data from output
 | 
				
			||||||
 | 
					    audio = np.fromstring(output.split('data')[1], dtype=np.int16)
 | 
				
			||||||
 | 
					    return 16000, audio
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def main():
 | 
					def main():
 | 
				
			||||||
    parser = argparse.ArgumentParser(description='Benchmarking tooling for DeepSpeech native_client.')
 | 
					    parser = argparse.ArgumentParser(description='Benchmarking tooling for DeepSpeech native_client.')
 | 
				
			||||||
    parser.add_argument('model', type=str,
 | 
					    parser.add_argument('model', type=str,
 | 
				
			||||||
@ -64,9 +82,9 @@ def main():
 | 
				
			|||||||
        print('Loaded language model in %0.3fs.' % (lm_load_end), file=sys.stderr)
 | 
					        print('Loaded language model in %0.3fs.' % (lm_load_end), file=sys.stderr)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    fs, audio = wav.read(args.audio)
 | 
					    fs, audio = wav.read(args.audio)
 | 
				
			||||||
    # We can assume 16kHz
 | 
					    if fs != 16000:
 | 
				
			||||||
 | 
					        fs, audio = convert_samplerate(args.audio)
 | 
				
			||||||
    audio_length = len(audio) * ( 1 / 16000)
 | 
					    audio_length = len(audio) * ( 1 / 16000)
 | 
				
			||||||
    assert fs == 16000, "Only 16000Hz input WAV files are supported for now!"
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    print('Running inference.', file=sys.stderr)
 | 
					    print('Running inference.', file=sys.stderr)
 | 
				
			||||||
    inference_start = timer()
 | 
					    inference_start = timer()
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user