diff --git a/bin/import_ts.py b/bin/import_ts.py index a2d061ca..a6f1ff00 100755 --- a/bin/import_ts.py +++ b/bin/import_ts.py @@ -3,6 +3,7 @@ from __future__ import absolute_import, division, print_function # Make sure we can import stuff from util/ # This script needs to be run from the root of the DeepSpeech repository +import argparse import os import re import sys @@ -11,6 +12,7 @@ import sys sys.path.insert(1, os.path.join(sys.path[0], '..')) import csv +import unidecode import zipfile from os import path @@ -25,7 +27,7 @@ ARCHIVE_DIR_NAME = 'ts_' + ARCHIVE_NAME ARCHIVE_URL = 'https://s3.eu-west-3.amazonaws.com/audiocorp/releases/' + ARCHIVE_NAME + '.zip' -def _download_and_preprocess_data(target_dir): +def _download_and_preprocess_data(target_dir, english_compatible=False): # Making path absolute target_dir = path.abspath(target_dir) # Conditionally download data @@ -33,7 +35,8 @@ def _download_and_preprocess_data(target_dir): # Conditionally extract archive data _maybe_extract(target_dir, ARCHIVE_DIR_NAME, archive_path) # Conditionally convert TrainingSpeech data to DeepSpeech CSVs and wav - _maybe_convert_sets(target_dir, ARCHIVE_DIR_NAME) + _maybe_convert_sets(target_dir, ARCHIVE_DIR_NAME, english_compatible=english_compatible) + def _maybe_extract(target_dir, extracted_data, archive_path): # If target_dir/extracted_data does not exist, extract archive in target_dir @@ -47,7 +50,8 @@ def _maybe_extract(target_dir, extracted_data, archive_path): else: print('Found directory "%s" - not extracting it from archive.' % archive_path) -def _maybe_convert_sets(target_dir, extracted_data): + +def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False): extracted_dir = path.join(target_dir, extracted_data) # override existing CSV with normalized one target_csv_template = os.path.join(target_dir, 'ts_' + ARCHIVE_NAME + '_{}.csv') @@ -70,7 +74,7 @@ def _maybe_convert_sets(target_dir, extracted_data): test_writer.writeheader() for i, item in enumerate(data): - transcript = validate_label(cleanup_transcript(item['text'])) + transcript = validate_label(cleanup_transcript(item['text'], english_compatible=english_compatible)) if not transcript: continue wav_filename = os.path.join(target_dir, extracted_data, item['path']) @@ -92,12 +96,22 @@ PUNCTUATIONS_REG = re.compile(r"[°\-,;!?.()\[\]*…—]") MULTIPLE_SPACES_REG = re.compile(r'\s{2,}') -def cleanup_transcript(text): +def cleanup_transcript(text, english_compatible=False): text = text.replace('’', "'").replace('\u00A0', ' ') text = PUNCTUATIONS_REG.sub(' ', text) text = MULTIPLE_SPACES_REG.sub(' ', text) + if english_compatible: + text = unidecode.unidecode(text) return text.strip().lower() +def handle_args(): + parser = argparse.ArgumentParser(description='Importer for TrainingSpeech dataset.') + parser.add_argument(dest='target_dir') + parser.add_argument('--english-compatible', action='store_true', dest='english_compatible', help='Remove diactrics and other non-ascii chars.') + return parser.parse_args() + + if __name__ == "__main__": - _download_and_preprocess_data(sys.argv[1]) + cli_args = handle_args() + _download_and_preprocess_data(cli_args.target_dir, cli_args.english_compatible)