Merge pull request #3137 from tilmankamp/fix_missing_alphabet

Fix: #3130 - Missing deepspeech_training.util.text.Alphabet
This commit is contained in:
Tilman Kamp 2020-07-07 17:50:07 +02:00 committed by GitHub
commit 6882248ab0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 68 additions and 93 deletions

View File

@ -7,7 +7,6 @@ DeepSpeech.py
Use "python3 import_cv2.py -h" for help Use "python3 import_cv2.py -h" for help
""" """
import csv import csv
import itertools
import os import os
import subprocess import subprocess
import unicodedata import unicodedata
@ -24,27 +23,39 @@ from deepspeech_training.util.importers import (
get_validate_label, get_validate_label,
print_import_report, print_import_report,
) )
from deepspeech_training.util.text import Alphabet from ds_ctcdecoder import Alphabet
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"] FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
MAX_SECS = 10 MAX_SECS = 10
PARAMS = None
FILTER_OBJ = None
def _preprocess_data(tsv_dir, audio_dir, filter_obj, space_after_every_character=False): class LabelFilter:
exclude = [] def __init__(self, normalize, alphabet, validate_fun):
for dataset in ["test", "dev", "train", "validated", "other"]: self.normalize = normalize
set_samples = _maybe_convert_set(dataset, tsv_dir, audio_dir, filter_obj, space_after_every_character) self.alphabet = alphabet
if dataset in ["test", "dev"]: self.validate_fun = validate_fun
exclude += set_samples
if dataset == "validated": def filter(self, label):
_maybe_convert_set("train-all", tsv_dir, audio_dir, filter_obj, space_after_every_character, if self.normalize:
rows=set_samples, exclude=exclude) label = unicodedata.normalize("NFKD", label.strip()).encode("ascii", "ignore").decode("ascii", "ignore")
label = self.validate_fun(label)
if self.alphabet and label and not self.alphabet.CanEncode(label):
label = None
return label
def one_sample(args): def init_worker(params):
global FILTER_OBJ # pylint: disable=global-statement
validate_label = get_validate_label(params)
alphabet = Alphabet(params.filter_alphabet) if params.filter_alphabet else None
FILTER_OBJ = LabelFilter(params.normalize, alphabet, validate_label)
def one_sample(sample):
""" Take an audio file, and optionally convert it to 16kHz WAV """ """ Take an audio file, and optionally convert it to 16kHz WAV """
sample, filter_obj = args
mp3_filename = sample[0] mp3_filename = sample[0]
if not os.path.splitext(mp3_filename.lower())[1] == ".mp3": if not os.path.splitext(mp3_filename.lower())[1] == ".mp3":
mp3_filename += ".mp3" mp3_filename += ".mp3"
@ -60,7 +71,7 @@ def one_sample(args):
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT ["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
) )
) )
label = filter_obj.filter(sample[1]) label = FILTER_OBJ.filter(sample[1])
rows = [] rows = []
counter = get_counter() counter = get_counter()
if file_size == -1: if file_size == -1:
@ -110,10 +121,9 @@ def _maybe_convert_set(dataset, tsv_dir, audio_dir, filter_obj, space_after_ever
num_samples = len(samples) num_samples = len(samples)
print("Importing mp3 files...") print("Importing mp3 files...")
pool = Pool() pool = Pool(initializer=init_worker, initargs=(PARAMS,))
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
samples_with_context = itertools.zip_longest(samples, [], fillvalue=filter_obj) for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1):
for i, processed in enumerate(pool.imap_unordered(one_sample, samples_with_context), start=1):
counter += processed[0] counter += processed[0]
rows += processed[1] rows += processed[1]
bar.update(i) bar.update(i)
@ -155,6 +165,17 @@ def _maybe_convert_set(dataset, tsv_dir, audio_dir, filter_obj, space_after_ever
return rows return rows
def _preprocess_data(tsv_dir, audio_dir, space_after_every_character=False):
exclude = []
for dataset in ["test", "dev", "train", "validated", "other"]:
set_samples = _maybe_convert_set(dataset, tsv_dir, audio_dir, space_after_every_character)
if dataset in ["test", "dev"]:
exclude += set_samples
if dataset == "validated":
_maybe_convert_set("train-all", tsv_dir, audio_dir, space_after_every_character,
rows=set_samples, exclude=exclude)
def _maybe_convert_wav(mp3_filename, wav_filename): def _maybe_convert_wav(mp3_filename, wav_filename):
if not os.path.exists(wav_filename): if not os.path.exists(wav_filename):
transformer = sox.Transformer() transformer = sox.Transformer()
@ -164,28 +185,8 @@ def _maybe_convert_wav(mp3_filename, wav_filename):
except sox.core.SoxError: except sox.core.SoxError:
pass pass
class LabelFilter:
def __init__(self, normalize, alphabet, validate_fun):
self.normalize = normalize
self.alphabet = alphabet
self.validate_fun = validate_fun
def filter(self, label): def parse_args():
if self.normalize:
label = (
unicodedata.normalize("NFKD", label.strip())
.encode("ascii", "ignore")
.decode("ascii", "ignore")
)
label = self.validate_fun(label)
if self.alphabet and label:
try:
self.alphabet.encode(label)
except KeyError:
label = None
return label
def main():
parser = get_importers_parser(description="Import CommonVoice v2.0 corpora") parser = get_importers_parser(description="Import CommonVoice v2.0 corpora")
parser.add_argument("tsv_dir", help="Directory containing tsv files") parser.add_argument("tsv_dir", help="Directory containing tsv files")
parser.add_argument( parser.add_argument(
@ -206,18 +207,14 @@ def main():
action="store_true", action="store_true",
help="To help transcript join by white space", help="To help transcript join by white space",
) )
return parser.parse_args()
params = parser.parse_args()
validate_label = get_validate_label(params)
audio_dir = ( def main():
params.audio_dir if params.audio_dir else os.path.join(params.tsv_dir, "clips") audio_dir = PARAMS.audio_dir if PARAMS.audio_dir else os.path.join(PARAMS.tsv_dir, "clips")
) _preprocess_data(PARAMS.tsv_dir, audio_dir, PARAMS.space_after_every_character)
alphabet = Alphabet(params.filter_alphabet) if params.filter_alphabet else None
filter_obj = LabelFilter(params.normalize, alphabet, validate_label)
_preprocess_data(params.tsv_dir, audio_dir, filter_obj,
params.space_after_every_character)
if __name__ == "__main__": if __name__ == "__main__":
PARAMS = parse_args()
main() main()

View File

@ -20,7 +20,7 @@ from deepspeech_training.util.importers import (
get_validate_label, get_validate_label,
print_import_report, print_import_report,
) )
from deepspeech_training.util.text import Alphabet from ds_ctcdecoder import Alphabet
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"] FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
@ -198,7 +198,7 @@ def handle_args():
"--iso639-3", type=str, required=True, help="ISO639-3 language code" "--iso639-3", type=str, required=True, help="ISO639-3 language code"
) )
parser.add_argument( parser.add_argument(
"--english-name", type=str, required=True, help="Enligh name of the language" "--english-name", type=str, required=True, help="English name of the language"
) )
parser.add_argument( parser.add_argument(
"--filter_alphabet", "--filter_alphabet",
@ -242,11 +242,8 @@ if __name__ == "__main__":
.decode("ascii", "ignore") .decode("ascii", "ignore")
) )
label = validate_label(label) label = validate_label(label)
if ALPHABET and label: if ALPHABET and label and not ALPHABET.CanEncode(label):
try: label = None
ALPHABET.encode(label)
except KeyError:
label = None
return label return label
ARCHIVE_NAME = ARCHIVE_NAME.format( ARCHIVE_NAME = ARCHIVE_NAME.format(

View File

@ -18,7 +18,7 @@ from deepspeech_training.util.importers import (
get_validate_label, get_validate_label,
print_import_report, print_import_report,
) )
from deepspeech_training.util.text import Alphabet from ds_ctcdecoder import Alphabet
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"] FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
@ -215,11 +215,8 @@ if __name__ == "__main__":
.decode("ascii", "ignore") .decode("ascii", "ignore")
) )
label = validate_label(label) label = validate_label(label)
if ALPHABET and label: if ALPHABET and label and not ALPHABET.CanEncode(label):
try: label = None
ALPHABET.encode(label)
except KeyError:
label = None
return label return label
ARCHIVE_DIR_NAME = ARCHIVE_DIR_NAME.format(language=CLI_ARGS.language) ARCHIVE_DIR_NAME = ARCHIVE_DIR_NAME.format(language=CLI_ARGS.language)

View File

@ -1,16 +1,13 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import csv import csv
import os import os
import re
import subprocess import subprocess
import tarfile import tarfile
import unicodedata import unicodedata
import zipfile
from glob import glob from glob import glob
from multiprocessing import Pool from multiprocessing import Pool
import progressbar import progressbar
import sox
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
from deepspeech_training.util.importers import ( from deepspeech_training.util.importers import (
@ -20,7 +17,7 @@ from deepspeech_training.util.importers import (
get_validate_label, get_validate_label,
print_import_report, print_import_report,
) )
from deepspeech_training.util.text import Alphabet from ds_ctcdecoder import Alphabet
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"] FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
@ -227,11 +224,8 @@ if __name__ == "__main__":
.decode("ascii", "ignore") .decode("ascii", "ignore")
) )
label = validate_label(label) label = validate_label(label)
if ALPHABET and label: if ALPHABET and label and not ALPHABET.CanEncode(label):
try: label = None
ALPHABET.encode(label)
except KeyError:
label = None
return label return label
_download_and_preprocess_data(target_dir=CLI_ARGS.target_dir) _download_and_preprocess_data(target_dir=CLI_ARGS.target_dir)

View File

@ -24,7 +24,7 @@ import sox
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
from deepspeech_training.util.importers import validate_label_eng as validate_label from deepspeech_training.util.importers import validate_label_eng as validate_label
from deepspeech_training.util.text import Alphabet from ds_ctcdecoder import Alphabet
SWC_URL = "https://www2.informatik.uni-hamburg.de/nats/pub/SWC/SWC_{language}.tar" SWC_URL = "https://www2.informatik.uni-hamburg.de/nats/pub/SWC/SWC_{language}.tar"
SWC_ARCHIVE = "SWC_{language}.tar" SWC_ARCHIVE = "SWC_{language}.tar"
@ -170,7 +170,8 @@ def read_token(token):
def in_alphabet(alphabet, c): def in_alphabet(alphabet, c):
return True if alphabet is None else alphabet.has_char(c) return alphabet.CanEncode(c) if alphabet else True
ALPHABETS = {} ALPHABETS = {}
@ -201,16 +202,8 @@ def label_filter(label, language):
dont_normalize = DONT_NORMALIZE[language] if language in DONT_NORMALIZE else "" dont_normalize = DONT_NORMALIZE[language] if language in DONT_NORMALIZE else ""
alphabet = get_alphabet(language) alphabet = get_alphabet(language)
for c in label: for c in label:
if ( if CLI_ARGS.normalize and c not in dont_normalize and not in_alphabet(alphabet, c):
CLI_ARGS.normalize c = unicodedata.normalize("NFKD", c).encode("ascii", "ignore").decode("ascii", "ignore")
and c not in dont_normalize
and not in_alphabet(alphabet, c)
):
c = (
unicodedata.normalize("NFKD", c)
.encode("ascii", "ignore")
.decode("ascii", "ignore")
)
for sc in c: for sc in c:
if not in_alphabet(alphabet, sc): if not in_alphabet(alphabet, sc):
return None, "illegal character" return None, "illegal character"

View File

@ -16,7 +16,7 @@ import progressbar
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
from deepspeech_training.util.importers import validate_label_eng as validate_label from deepspeech_training.util.importers import validate_label_eng as validate_label
from deepspeech_training.util.text import Alphabet from ds_ctcdecoder import Alphabet
TUDA_VERSION = "v2" TUDA_VERSION = "v2"
TUDA_PACKAGE = "german-speechdata-package-{}".format(TUDA_VERSION) TUDA_PACKAGE = "german-speechdata-package-{}".format(TUDA_VERSION)
@ -46,22 +46,18 @@ def maybe_extract(archive):
return extracted return extracted
def in_alphabet(c):
return ALPHABET.CanEncode(c) if ALPHABET else True
def check_and_prepare_sentence(sentence): def check_and_prepare_sentence(sentence):
sentence = sentence.lower().replace("co2", "c o zwei") sentence = sentence.lower().replace("co2", "c o zwei")
chars = [] chars = []
for c in sentence: for c in sentence:
if ( if CLI_ARGS.normalize and c not in "äöüß" and not in_alphabet(c):
CLI_ARGS.normalize c = unicodedata.normalize("NFKD", c).encode("ascii", "ignore").decode("ascii", "ignore")
and c not in "äöüß"
and (ALPHABET is None or not ALPHABET.has_char(c))
):
c = (
unicodedata.normalize("NFKD", c)
.encode("ascii", "ignore")
.decode("ascii", "ignore")
)
for sc in c: for sc in c:
if ALPHABET is not None and not ALPHABET.has_char(c): if not in_alphabet(c):
return None return None
chars.append(sc) chars.append(sc)
return validate_label("".join(chars)) return validate_label("".join(chars))
@ -122,6 +118,7 @@ def write_csvs(extracted):
sentence = list(meta.iter("cleaned_sentence"))[0].text sentence = list(meta.iter("cleaned_sentence"))[0].text
sentence = check_and_prepare_sentence(sentence) sentence = check_and_prepare_sentence(sentence)
if sentence is None: if sentence is None:
reasons['alphabet filter'] += 1
continue continue
for wav_name in wav_names: for wav_name in wav_names:
sample_counter += 1 sample_counter += 1