TUDA importer
This commit is contained in:
		
							parent
							
								
									44a605c8b7
								
							
						
					
					
						commit
						2cdfcff4c6
					
				
							
								
								
									
										165
									
								
								bin/import_tuda.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										165
									
								
								bin/import_tuda.py
									
									
									
									
									
										Executable file
									
								
							| @ -0,0 +1,165 @@ | ||||
| #!/usr/bin/env python | ||||
| ''' | ||||
| Downloads and prepares (parts of) the "German Distant Speech" corpus (TUDA) for DeepSpeech.py | ||||
| Use "python3 import_tuda.py -h" for help | ||||
| ''' | ||||
| from __future__ import absolute_import, division, print_function | ||||
| 
 | ||||
| # Make sure we can import stuff from util/ | ||||
| # This script needs to be run from the root of the DeepSpeech repository | ||||
| import os | ||||
| import sys | ||||
| sys.path.insert(1, os.path.join(sys.path[0], '..')) | ||||
| 
 | ||||
| import csv | ||||
| import wave | ||||
| import tarfile | ||||
| import argparse | ||||
| import progressbar | ||||
| import unicodedata | ||||
| import xml.etree.cElementTree as ET | ||||
| 
 | ||||
| from os import path | ||||
| from collections import Counter | ||||
| from util.text import Alphabet, validate_label | ||||
| from util.downloader import maybe_download, SIMPLE_BAR | ||||
| 
 | ||||
| TUDA_VERSION = 'v2' | ||||
| TUDA_PACKAGE = 'german-speechdata-package-{}'.format(TUDA_VERSION) | ||||
| TUDA_URL = 'http://ltdata1.informatik.uni-hamburg.de/kaldi_tuda_de/{}.tar.gz'.format(TUDA_PACKAGE) | ||||
| TUDA_ARCHIVE = '{}.tar.gz'.format(TUDA_PACKAGE) | ||||
| 
 | ||||
| CHANNELS = 1 | ||||
| SAMPLE_WIDTH = 2 | ||||
| SAMPLE_RATE = 16000 | ||||
| 
 | ||||
| FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] | ||||
| 
 | ||||
| 
 | ||||
| def maybe_extract(archive): | ||||
|     extracted = path.join(CLI_ARGS.base_dir, TUDA_PACKAGE) | ||||
|     if path.isdir(extracted): | ||||
|         print('Found directory "{}" - not extracting.'.format(extracted)) | ||||
|     else: | ||||
|         print('Extracting "{}"...'.format(archive)) | ||||
|         with tarfile.open(archive) as tar: | ||||
|             members = tar.getmembers() | ||||
|             bar = progressbar.ProgressBar(max_value=len(members), widgets=SIMPLE_BAR) | ||||
|             for member in bar(members): | ||||
|                 tar.extract(member=member, path=CLI_ARGS.base_dir) | ||||
|     return extracted | ||||
| 
 | ||||
| 
 | ||||
| def check_and_prepare_sentence(sentence): | ||||
|     sentence = sentence.lower().replace('co2', 'c o zwei') | ||||
|     chars = [] | ||||
|     for c in sentence: | ||||
|         if CLI_ARGS.normalize and c not in 'äöüß' and (ALPHABET is None or not ALPHABET.has_char(c)): | ||||
|             c = unicodedata.normalize("NFKD", c).encode("ascii", "ignore").decode("ascii", "ignore") | ||||
|         for sc in c: | ||||
|             if ALPHABET is not None and not ALPHABET.has_char(c): | ||||
|                 return None | ||||
|             chars.append(sc) | ||||
|     return validate_label(''.join(chars)) | ||||
| 
 | ||||
| 
 | ||||
| def check_wav_file(wav_path, sentence):  # pylint: disable=too-many-return-statements | ||||
|     try: | ||||
|         with wave.open(wav_path, 'r') as src_wav_file: | ||||
|             rate = src_wav_file.getframerate() | ||||
|             channels = src_wav_file.getnchannels() | ||||
|             sample_width = src_wav_file.getsampwidth() | ||||
|             milliseconds = int(src_wav_file.getnframes() * 1000 / rate) | ||||
|         if rate != SAMPLE_RATE: | ||||
|             return False, 'wrong sample rate' | ||||
|         if channels != CHANNELS: | ||||
|             return False, 'wrong number of channels' | ||||
|         if sample_width != SAMPLE_WIDTH: | ||||
|             return False, 'wrong sample width' | ||||
|         if milliseconds / len(sentence) < 20: | ||||
|             return False, 'too short' | ||||
|         if milliseconds > CLI_ARGS.max_duration > 0: | ||||
|             return False, 'too long' | ||||
|     except wave.Error: | ||||
|         return False, 'invalid wav file' | ||||
|     except EOFError: | ||||
|         return False, 'premature EOF' | ||||
|     return True, 'OK' | ||||
| 
 | ||||
| 
 | ||||
| def write_csvs(extracted): | ||||
|     sample_counter = 0 | ||||
|     reasons = Counter() | ||||
|     for sub_set in ['train', 'dev', 'test']: | ||||
|         set_path = path.join(extracted, sub_set) | ||||
|         set_files = os.listdir(set_path) | ||||
|         recordings = {} | ||||
|         for file in set_files: | ||||
|             if file.endswith('.xml'): | ||||
|                 recordings[file[:-4]] = [] | ||||
|         for file in set_files: | ||||
|             if file.endswith('.wav') and '_' in file: | ||||
|                 prefix = file.split('_')[0] | ||||
|                 if prefix in recordings: | ||||
|                     recordings[prefix].append(file) | ||||
|         recordings = recordings.items() | ||||
|         csv_path = path.join(CLI_ARGS.base_dir, 'tuda-{}-{}.csv'.format(TUDA_VERSION, sub_set)) | ||||
|         print('Writing "{}"...'.format(csv_path)) | ||||
|         with open(csv_path, 'w') as csv_file: | ||||
|             writer = csv.DictWriter(csv_file, fieldnames=FIELDNAMES) | ||||
|             writer.writeheader() | ||||
|             set_dir = path.join(extracted, sub_set) | ||||
|             bar = progressbar.ProgressBar(max_value=len(recordings), widgets=SIMPLE_BAR) | ||||
|             for prefix, wav_names in bar(recordings): | ||||
|                 xml_path = path.join(set_dir, prefix + '.xml') | ||||
|                 meta = ET.parse(xml_path).getroot() | ||||
|                 sentence = list(meta.iter('cleaned_sentence'))[0].text | ||||
|                 sentence = check_and_prepare_sentence(sentence) | ||||
|                 if sentence is None: | ||||
|                     continue | ||||
|                 for wav_name in wav_names: | ||||
|                     sample_counter += 1 | ||||
|                     wav_path = path.join(set_path, wav_name) | ||||
|                     keep, reason = check_wav_file(wav_path, sentence) | ||||
|                     if keep: | ||||
|                         writer.writerow({ | ||||
|                             'wav_filename': path.relpath(wav_path, CLI_ARGS.base_dir), | ||||
|                             'wav_filesize': path.getsize(wav_path), | ||||
|                             'transcript': sentence.lower() | ||||
|                         }) | ||||
|                     else: | ||||
|                         reasons[reason] += 1 | ||||
|     if len(reasons.keys()) > 0: | ||||
|         print('Excluded samples:') | ||||
|         for reason, n in reasons.most_common(): | ||||
|             print(' - "{}": {} ({:.2f}%)'.format(reason, n, n * 100 / sample_counter)) | ||||
| 
 | ||||
| 
 | ||||
| def cleanup(archive): | ||||
|     if not CLI_ARGS.keep_archive: | ||||
|         print('Removing archive "{}"...'.format(archive)) | ||||
|         os.remove(archive) | ||||
| 
 | ||||
| 
 | ||||
| def download_and_prepare(): | ||||
|     archive = maybe_download(TUDA_ARCHIVE, CLI_ARGS.base_dir, TUDA_URL) | ||||
|     extracted = maybe_extract(archive) | ||||
|     write_csvs(extracted) | ||||
|     cleanup(archive) | ||||
| 
 | ||||
| 
 | ||||
| def handle_args(): | ||||
|     parser = argparse.ArgumentParser(description='Import German Distant Speech (TUDA)') | ||||
|     parser.add_argument('base_dir', help='Directory containing all data') | ||||
|     parser.add_argument('--max_duration', type=int, default=10000, help='Maximum sample duration in milliseconds') | ||||
|     parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones') | ||||
|     parser.add_argument('--alphabet', help='Exclude samples with characters not in provided alphabet file') | ||||
|     parser.add_argument('--keep_archive', type=bool, default=True, | ||||
|                         help='If downloaded archives should be kept') | ||||
|     return parser.parse_args() | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     CLI_ARGS = handle_args() | ||||
|     ALPHABET = Alphabet(CLI_ARGS.alphabet) if CLI_ARGS.alphabet else None | ||||
|     download_and_prepare() | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user