Merge pull request #3137 from tilmankamp/fix_missing_alphabet
Fix: #3130 - Missing deepspeech_training.util.text.Alphabet
This commit is contained in:
commit
6882248ab0
|
@ -7,7 +7,6 @@ DeepSpeech.py
|
|||
Use "python3 import_cv2.py -h" for help
|
||||
"""
|
||||
import csv
|
||||
import itertools
|
||||
import os
|
||||
import subprocess
|
||||
import unicodedata
|
||||
|
@ -24,27 +23,39 @@ from deepspeech_training.util.importers import (
|
|||
get_validate_label,
|
||||
print_import_report,
|
||||
)
|
||||
from deepspeech_training.util.text import Alphabet
|
||||
from ds_ctcdecoder import Alphabet
|
||||
|
||||
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
SAMPLE_RATE = 16000
|
||||
MAX_SECS = 10
|
||||
PARAMS = None
|
||||
FILTER_OBJ = None
|
||||
|
||||
|
||||
def _preprocess_data(tsv_dir, audio_dir, filter_obj, space_after_every_character=False):
|
||||
exclude = []
|
||||
for dataset in ["test", "dev", "train", "validated", "other"]:
|
||||
set_samples = _maybe_convert_set(dataset, tsv_dir, audio_dir, filter_obj, space_after_every_character)
|
||||
if dataset in ["test", "dev"]:
|
||||
exclude += set_samples
|
||||
if dataset == "validated":
|
||||
_maybe_convert_set("train-all", tsv_dir, audio_dir, filter_obj, space_after_every_character,
|
||||
rows=set_samples, exclude=exclude)
|
||||
class LabelFilter:
|
||||
def __init__(self, normalize, alphabet, validate_fun):
|
||||
self.normalize = normalize
|
||||
self.alphabet = alphabet
|
||||
self.validate_fun = validate_fun
|
||||
|
||||
def filter(self, label):
|
||||
if self.normalize:
|
||||
label = unicodedata.normalize("NFKD", label.strip()).encode("ascii", "ignore").decode("ascii", "ignore")
|
||||
label = self.validate_fun(label)
|
||||
if self.alphabet and label and not self.alphabet.CanEncode(label):
|
||||
label = None
|
||||
return label
|
||||
|
||||
|
||||
def one_sample(args):
|
||||
def init_worker(params):
|
||||
global FILTER_OBJ # pylint: disable=global-statement
|
||||
validate_label = get_validate_label(params)
|
||||
alphabet = Alphabet(params.filter_alphabet) if params.filter_alphabet else None
|
||||
FILTER_OBJ = LabelFilter(params.normalize, alphabet, validate_label)
|
||||
|
||||
|
||||
def one_sample(sample):
|
||||
""" Take an audio file, and optionally convert it to 16kHz WAV """
|
||||
sample, filter_obj = args
|
||||
mp3_filename = sample[0]
|
||||
if not os.path.splitext(mp3_filename.lower())[1] == ".mp3":
|
||||
mp3_filename += ".mp3"
|
||||
|
@ -60,7 +71,7 @@ def one_sample(args):
|
|||
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
|
||||
)
|
||||
)
|
||||
label = filter_obj.filter(sample[1])
|
||||
label = FILTER_OBJ.filter(sample[1])
|
||||
rows = []
|
||||
counter = get_counter()
|
||||
if file_size == -1:
|
||||
|
@ -110,10 +121,9 @@ def _maybe_convert_set(dataset, tsv_dir, audio_dir, filter_obj, space_after_ever
|
|||
num_samples = len(samples)
|
||||
|
||||
print("Importing mp3 files...")
|
||||
pool = Pool()
|
||||
pool = Pool(initializer=init_worker, initargs=(PARAMS,))
|
||||
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
|
||||
samples_with_context = itertools.zip_longest(samples, [], fillvalue=filter_obj)
|
||||
for i, processed in enumerate(pool.imap_unordered(one_sample, samples_with_context), start=1):
|
||||
for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1):
|
||||
counter += processed[0]
|
||||
rows += processed[1]
|
||||
bar.update(i)
|
||||
|
@ -155,6 +165,17 @@ def _maybe_convert_set(dataset, tsv_dir, audio_dir, filter_obj, space_after_ever
|
|||
return rows
|
||||
|
||||
|
||||
def _preprocess_data(tsv_dir, audio_dir, space_after_every_character=False):
|
||||
exclude = []
|
||||
for dataset in ["test", "dev", "train", "validated", "other"]:
|
||||
set_samples = _maybe_convert_set(dataset, tsv_dir, audio_dir, space_after_every_character)
|
||||
if dataset in ["test", "dev"]:
|
||||
exclude += set_samples
|
||||
if dataset == "validated":
|
||||
_maybe_convert_set("train-all", tsv_dir, audio_dir, space_after_every_character,
|
||||
rows=set_samples, exclude=exclude)
|
||||
|
||||
|
||||
def _maybe_convert_wav(mp3_filename, wav_filename):
|
||||
if not os.path.exists(wav_filename):
|
||||
transformer = sox.Transformer()
|
||||
|
@ -164,28 +185,8 @@ def _maybe_convert_wav(mp3_filename, wav_filename):
|
|||
except sox.core.SoxError:
|
||||
pass
|
||||
|
||||
class LabelFilter:
|
||||
def __init__(self, normalize, alphabet, validate_fun):
|
||||
self.normalize = normalize
|
||||
self.alphabet = alphabet
|
||||
self.validate_fun = validate_fun
|
||||
|
||||
def filter(self, label):
|
||||
if self.normalize:
|
||||
label = (
|
||||
unicodedata.normalize("NFKD", label.strip())
|
||||
.encode("ascii", "ignore")
|
||||
.decode("ascii", "ignore")
|
||||
)
|
||||
label = self.validate_fun(label)
|
||||
if self.alphabet and label:
|
||||
try:
|
||||
self.alphabet.encode(label)
|
||||
except KeyError:
|
||||
label = None
|
||||
return label
|
||||
|
||||
def main():
|
||||
def parse_args():
|
||||
parser = get_importers_parser(description="Import CommonVoice v2.0 corpora")
|
||||
parser.add_argument("tsv_dir", help="Directory containing tsv files")
|
||||
parser.add_argument(
|
||||
|
@ -206,18 +207,14 @@ def main():
|
|||
action="store_true",
|
||||
help="To help transcript join by white space",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
params = parser.parse_args()
|
||||
validate_label = get_validate_label(params)
|
||||
|
||||
audio_dir = (
|
||||
params.audio_dir if params.audio_dir else os.path.join(params.tsv_dir, "clips")
|
||||
)
|
||||
alphabet = Alphabet(params.filter_alphabet) if params.filter_alphabet else None
|
||||
def main():
|
||||
audio_dir = PARAMS.audio_dir if PARAMS.audio_dir else os.path.join(PARAMS.tsv_dir, "clips")
|
||||
_preprocess_data(PARAMS.tsv_dir, audio_dir, PARAMS.space_after_every_character)
|
||||
|
||||
filter_obj = LabelFilter(params.normalize, alphabet, validate_label)
|
||||
_preprocess_data(params.tsv_dir, audio_dir, filter_obj,
|
||||
params.space_after_every_character)
|
||||
|
||||
if __name__ == "__main__":
|
||||
PARAMS = parse_args()
|
||||
main()
|
||||
|
|
|
@ -20,7 +20,7 @@ from deepspeech_training.util.importers import (
|
|||
get_validate_label,
|
||||
print_import_report,
|
||||
)
|
||||
from deepspeech_training.util.text import Alphabet
|
||||
from ds_ctcdecoder import Alphabet
|
||||
|
||||
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
SAMPLE_RATE = 16000
|
||||
|
@ -198,7 +198,7 @@ def handle_args():
|
|||
"--iso639-3", type=str, required=True, help="ISO639-3 language code"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--english-name", type=str, required=True, help="Enligh name of the language"
|
||||
"--english-name", type=str, required=True, help="English name of the language"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--filter_alphabet",
|
||||
|
@ -242,11 +242,8 @@ if __name__ == "__main__":
|
|||
.decode("ascii", "ignore")
|
||||
)
|
||||
label = validate_label(label)
|
||||
if ALPHABET and label:
|
||||
try:
|
||||
ALPHABET.encode(label)
|
||||
except KeyError:
|
||||
label = None
|
||||
if ALPHABET and label and not ALPHABET.CanEncode(label):
|
||||
label = None
|
||||
return label
|
||||
|
||||
ARCHIVE_NAME = ARCHIVE_NAME.format(
|
||||
|
|
|
@ -18,7 +18,7 @@ from deepspeech_training.util.importers import (
|
|||
get_validate_label,
|
||||
print_import_report,
|
||||
)
|
||||
from deepspeech_training.util.text import Alphabet
|
||||
from ds_ctcdecoder import Alphabet
|
||||
|
||||
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
SAMPLE_RATE = 16000
|
||||
|
@ -215,11 +215,8 @@ if __name__ == "__main__":
|
|||
.decode("ascii", "ignore")
|
||||
)
|
||||
label = validate_label(label)
|
||||
if ALPHABET and label:
|
||||
try:
|
||||
ALPHABET.encode(label)
|
||||
except KeyError:
|
||||
label = None
|
||||
if ALPHABET and label and not ALPHABET.CanEncode(label):
|
||||
label = None
|
||||
return label
|
||||
|
||||
ARCHIVE_DIR_NAME = ARCHIVE_DIR_NAME.format(language=CLI_ARGS.language)
|
||||
|
|
|
@ -1,16 +1,13 @@
|
|||
#!/usr/bin/env python3
|
||||
import csv
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import tarfile
|
||||
import unicodedata
|
||||
import zipfile
|
||||
from glob import glob
|
||||
from multiprocessing import Pool
|
||||
|
||||
import progressbar
|
||||
import sox
|
||||
|
||||
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
|
||||
from deepspeech_training.util.importers import (
|
||||
|
@ -20,7 +17,7 @@ from deepspeech_training.util.importers import (
|
|||
get_validate_label,
|
||||
print_import_report,
|
||||
)
|
||||
from deepspeech_training.util.text import Alphabet
|
||||
from ds_ctcdecoder import Alphabet
|
||||
|
||||
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
SAMPLE_RATE = 16000
|
||||
|
@ -227,11 +224,8 @@ if __name__ == "__main__":
|
|||
.decode("ascii", "ignore")
|
||||
)
|
||||
label = validate_label(label)
|
||||
if ALPHABET and label:
|
||||
try:
|
||||
ALPHABET.encode(label)
|
||||
except KeyError:
|
||||
label = None
|
||||
if ALPHABET and label and not ALPHABET.CanEncode(label):
|
||||
label = None
|
||||
return label
|
||||
|
||||
_download_and_preprocess_data(target_dir=CLI_ARGS.target_dir)
|
||||
|
|
|
@ -24,7 +24,7 @@ import sox
|
|||
|
||||
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
|
||||
from deepspeech_training.util.importers import validate_label_eng as validate_label
|
||||
from deepspeech_training.util.text import Alphabet
|
||||
from ds_ctcdecoder import Alphabet
|
||||
|
||||
SWC_URL = "https://www2.informatik.uni-hamburg.de/nats/pub/SWC/SWC_{language}.tar"
|
||||
SWC_ARCHIVE = "SWC_{language}.tar"
|
||||
|
@ -170,7 +170,8 @@ def read_token(token):
|
|||
|
||||
|
||||
def in_alphabet(alphabet, c):
|
||||
return True if alphabet is None else alphabet.has_char(c)
|
||||
return alphabet.CanEncode(c) if alphabet else True
|
||||
|
||||
|
||||
|
||||
ALPHABETS = {}
|
||||
|
@ -201,16 +202,8 @@ def label_filter(label, language):
|
|||
dont_normalize = DONT_NORMALIZE[language] if language in DONT_NORMALIZE else ""
|
||||
alphabet = get_alphabet(language)
|
||||
for c in label:
|
||||
if (
|
||||
CLI_ARGS.normalize
|
||||
and c not in dont_normalize
|
||||
and not in_alphabet(alphabet, c)
|
||||
):
|
||||
c = (
|
||||
unicodedata.normalize("NFKD", c)
|
||||
.encode("ascii", "ignore")
|
||||
.decode("ascii", "ignore")
|
||||
)
|
||||
if CLI_ARGS.normalize and c not in dont_normalize and not in_alphabet(alphabet, c):
|
||||
c = unicodedata.normalize("NFKD", c).encode("ascii", "ignore").decode("ascii", "ignore")
|
||||
for sc in c:
|
||||
if not in_alphabet(alphabet, sc):
|
||||
return None, "illegal character"
|
||||
|
|
|
@ -16,7 +16,7 @@ import progressbar
|
|||
|
||||
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
|
||||
from deepspeech_training.util.importers import validate_label_eng as validate_label
|
||||
from deepspeech_training.util.text import Alphabet
|
||||
from ds_ctcdecoder import Alphabet
|
||||
|
||||
TUDA_VERSION = "v2"
|
||||
TUDA_PACKAGE = "german-speechdata-package-{}".format(TUDA_VERSION)
|
||||
|
@ -46,22 +46,18 @@ def maybe_extract(archive):
|
|||
return extracted
|
||||
|
||||
|
||||
def in_alphabet(c):
|
||||
return ALPHABET.CanEncode(c) if ALPHABET else True
|
||||
|
||||
|
||||
def check_and_prepare_sentence(sentence):
|
||||
sentence = sentence.lower().replace("co2", "c o zwei")
|
||||
chars = []
|
||||
for c in sentence:
|
||||
if (
|
||||
CLI_ARGS.normalize
|
||||
and c not in "äöüß"
|
||||
and (ALPHABET is None or not ALPHABET.has_char(c))
|
||||
):
|
||||
c = (
|
||||
unicodedata.normalize("NFKD", c)
|
||||
.encode("ascii", "ignore")
|
||||
.decode("ascii", "ignore")
|
||||
)
|
||||
if CLI_ARGS.normalize and c not in "äöüß" and not in_alphabet(c):
|
||||
c = unicodedata.normalize("NFKD", c).encode("ascii", "ignore").decode("ascii", "ignore")
|
||||
for sc in c:
|
||||
if ALPHABET is not None and not ALPHABET.has_char(c):
|
||||
if not in_alphabet(c):
|
||||
return None
|
||||
chars.append(sc)
|
||||
return validate_label("".join(chars))
|
||||
|
@ -122,6 +118,7 @@ def write_csvs(extracted):
|
|||
sentence = list(meta.iter("cleaned_sentence"))[0].text
|
||||
sentence = check_and_prepare_sentence(sentence)
|
||||
if sentence is None:
|
||||
reasons['alphabet filter'] += 1
|
||||
continue
|
||||
for wav_name in wav_names:
|
||||
sample_counter += 1
|
||||
|
|
Loading…
Reference in New Issue