Linter induced changes
This commit is contained in:
parent
010f24578f
commit
122a007d33
@ -54,7 +54,7 @@ SUBSTITUTIONS = {
|
|||||||
(re.compile(r'eins punkt null null null'), 'ein tausend'),
|
(re.compile(r'eins punkt null null null'), 'ein tausend'),
|
||||||
(re.compile(r'punkt null null null'), 'tausend'),
|
(re.compile(r'punkt null null null'), 'tausend'),
|
||||||
(re.compile(r'punkt null'), None)
|
(re.compile(r'punkt null'), None)
|
||||||
] # TODO: Add Dutch and English
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
DONT_NORMALIZE = {
|
DONT_NORMALIZE = {
|
||||||
@ -77,7 +77,7 @@ class Sample:
|
|||||||
|
|
||||||
def fail(message):
|
def fail(message):
|
||||||
print(message)
|
print(message)
|
||||||
exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def group(lst, get_key):
|
def group(lst, get_key):
|
||||||
@ -95,14 +95,11 @@ def get_sample_size(population_size):
|
|||||||
margin_of_error = 0.01
|
margin_of_error = 0.01
|
||||||
fraction_picking = 0.50
|
fraction_picking = 0.50
|
||||||
z_score = 2.58 # Corresponds to confidence level 99%
|
z_score = 2.58 # Corresponds to confidence level 99%
|
||||||
numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
|
numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (margin_of_error ** 2)
|
||||||
margin_of_error ** 2
|
|
||||||
)
|
|
||||||
sample_size = 0
|
sample_size = 0
|
||||||
for train_size in range(population_size, 0, -1):
|
for train_size in range(population_size, 0, -1):
|
||||||
denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
|
denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / \
|
||||||
margin_of_error ** 2 * train_size
|
(margin_of_error ** 2 * train_size)
|
||||||
)
|
|
||||||
sample_size = int(numerator / denominator)
|
sample_size = int(numerator / denominator)
|
||||||
if 2 * sample_size + train_size <= population_size:
|
if 2 * sample_size + train_size <= population_size:
|
||||||
break
|
break
|
||||||
@ -162,14 +159,13 @@ def in_alphabet(alphabet, c):
|
|||||||
return True if alphabet is None else alphabet.has_char(c)
|
return True if alphabet is None else alphabet.has_char(c)
|
||||||
|
|
||||||
|
|
||||||
alphabets = {}
|
ALPHABETS = {}
|
||||||
def get_alphabet(language):
|
def get_alphabet(language):
|
||||||
global alphabets
|
if language in ALPHABETS:
|
||||||
if language in alphabets:
|
return ALPHABETS[language]
|
||||||
return alphabets[language]
|
|
||||||
alphabet_path = getattr(CLI_ARGS, language + '_alphabet')
|
alphabet_path = getattr(CLI_ARGS, language + '_alphabet')
|
||||||
alphabet = Alphabet(alphabet_path) if alphabet_path else None
|
alphabet = Alphabet(alphabet_path) if alphabet_path else None
|
||||||
alphabets[language] = alphabet
|
ALPHABETS[language] = alphabet
|
||||||
return alphabet
|
return alphabet
|
||||||
|
|
||||||
|
|
||||||
@ -190,9 +186,7 @@ def label_filter(label, language):
|
|||||||
alphabet = get_alphabet(language)
|
alphabet = get_alphabet(language)
|
||||||
for c in label:
|
for c in label:
|
||||||
if CLI_ARGS.normalize and c not in dont_normalize and not in_alphabet(alphabet, c):
|
if CLI_ARGS.normalize and c not in dont_normalize and not in_alphabet(alphabet, c):
|
||||||
c = (unicodedata.normalize("NFKD", c)
|
c = unicodedata.normalize("NFKD", c).encode("ascii", "ignore").decode("ascii", "ignore")
|
||||||
.encode("ascii", "ignore")
|
|
||||||
.decode("ascii", "ignore"))
|
|
||||||
for sc in c:
|
for sc in c:
|
||||||
if not in_alphabet(alphabet, sc):
|
if not in_alphabet(alphabet, sc):
|
||||||
return None, 'illegal character'
|
return None, 'illegal character'
|
||||||
@ -204,11 +198,38 @@ def label_filter(label, language):
|
|||||||
|
|
||||||
def collect_samples(base_dir, language):
|
def collect_samples(base_dir, language):
|
||||||
roots = []
|
roots = []
|
||||||
for root, dirs, files in os.walk(base_dir):
|
for root, _, files in os.walk(base_dir):
|
||||||
if ALIGNED_NAME in files and WAV_NAME in files:
|
if ALIGNED_NAME in files and WAV_NAME in files:
|
||||||
roots.append(root)
|
roots.append(root)
|
||||||
samples = []
|
samples = []
|
||||||
reasons = Counter()
|
reasons = Counter()
|
||||||
|
|
||||||
|
def add_sample(p_wav_path, p_speaker, p_start, p_end, p_text, p_reason='complete'):
|
||||||
|
if p_start is not None and p_end is not None and p_text is not None:
|
||||||
|
duration = p_end - p_start
|
||||||
|
text, filter_reason = label_filter(p_text, language)
|
||||||
|
skip = False
|
||||||
|
if filter_reason is not None:
|
||||||
|
skip = True
|
||||||
|
p_reason = filter_reason
|
||||||
|
elif duration > CLI_ARGS.max_duration > 0 and CLI_ARGS.ignore_too_long:
|
||||||
|
skip = True
|
||||||
|
p_reason = 'exceeded duration'
|
||||||
|
elif int(duration / 20) < len(text):
|
||||||
|
skip = True
|
||||||
|
p_reason = 'too short to decode'
|
||||||
|
elif duration / len(text) < 10:
|
||||||
|
skip = True
|
||||||
|
p_reason = 'length duration ratio'
|
||||||
|
if skip:
|
||||||
|
reasons[p_reason] += 1
|
||||||
|
else:
|
||||||
|
samples.append(Sample(p_wav_path, p_start, p_end, text, p_speaker))
|
||||||
|
elif p_start is None or p_end is None:
|
||||||
|
reasons['missing timestamps'] += 1
|
||||||
|
else:
|
||||||
|
reasons['missing text'] += 1
|
||||||
|
|
||||||
print('Collecting samples...')
|
print('Collecting samples...')
|
||||||
bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR)
|
bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR)
|
||||||
for root in bar(roots):
|
for root in bar(roots):
|
||||||
@ -221,57 +242,34 @@ def collect_samples(base_dir, language):
|
|||||||
speaker = attributes['value']
|
speaker = attributes['value']
|
||||||
break
|
break
|
||||||
for sentence in aligned.iter('s'):
|
for sentence in aligned.iter('s'):
|
||||||
def add_sample(start, end, text, reason='complete'):
|
|
||||||
if start is not None and end is not None and text is not None:
|
|
||||||
duration = end - start
|
|
||||||
text, filter_reason = label_filter(text, language)
|
|
||||||
skip = False
|
|
||||||
if filter_reason is not None:
|
|
||||||
skip = True
|
|
||||||
reason = filter_reason
|
|
||||||
elif duration > CLI_ARGS.max_duration > 0 and CLI_ARGS.ignore_too_long:
|
|
||||||
skip = True
|
|
||||||
reason = 'exceeded duration'
|
|
||||||
elif int(duration / 20) < len(text):
|
|
||||||
skip = True
|
|
||||||
reason = 'too short to decode'
|
|
||||||
elif duration / len(text) < 10:
|
|
||||||
skip = True
|
|
||||||
reason = 'length duration ratio'
|
|
||||||
if skip:
|
|
||||||
reasons[reason] += 1
|
|
||||||
else:
|
|
||||||
samples.append(Sample(wav_path, start, end, text, speaker))
|
|
||||||
elif start is None or end is None:
|
|
||||||
reasons['missing timestamps'] += 1
|
|
||||||
else:
|
|
||||||
reasons['missing text'] += 1
|
|
||||||
if ignored(sentence):
|
if ignored(sentence):
|
||||||
continue
|
continue
|
||||||
split = False
|
split = False
|
||||||
tokens = list(map(lambda token: read_token(token), sentence.findall('t')))
|
tokens = list(map(read_token, sentence.findall('t')))
|
||||||
sample_start, sample_end, token_texts, sample_texts = None, None, [], []
|
sample_start, sample_end, token_texts, sample_texts = None, None, [], []
|
||||||
for token_start, token_end, token_text in tokens:
|
for token_start, token_end, token_text in tokens:
|
||||||
if CLI_ARGS.exclude_numbers and any(c.isdigit() for c in token_text):
|
if CLI_ARGS.exclude_numbers and any(c.isdigit() for c in token_text):
|
||||||
add_sample(sample_start, sample_end, ' '.join(sample_texts), reason='has numbers')
|
add_sample(wav_path, speaker, sample_start, sample_end, ' '.join(sample_texts),
|
||||||
|
p_reason='has numbers')
|
||||||
sample_start, sample_end, token_texts, sample_texts = None, None, [], []
|
sample_start, sample_end, token_texts, sample_texts = None, None, [], []
|
||||||
continue
|
continue
|
||||||
if sample_start is None:
|
if sample_start is None:
|
||||||
sample_start = token_start
|
sample_start = token_start
|
||||||
if sample_start is None:
|
if sample_start is None:
|
||||||
continue
|
continue
|
||||||
else:
|
token_texts.append(token_text)
|
||||||
token_texts.append(token_text)
|
|
||||||
if token_end is not None:
|
if token_end is not None:
|
||||||
if token_start != sample_start and token_end - sample_start > CLI_ARGS.max_duration > 0:
|
if token_start != sample_start and token_end - sample_start > CLI_ARGS.max_duration > 0:
|
||||||
add_sample(sample_start, sample_end, ' '.join(sample_texts), reason='split')
|
add_sample(wav_path, speaker, sample_start, sample_end, ' '.join(sample_texts),
|
||||||
|
p_reason='split')
|
||||||
sample_start = sample_end
|
sample_start = sample_end
|
||||||
sample_texts = []
|
sample_texts = []
|
||||||
split = True
|
split = True
|
||||||
sample_end = token_end
|
sample_end = token_end
|
||||||
sample_texts.extend(token_texts)
|
sample_texts.extend(token_texts)
|
||||||
token_texts = []
|
token_texts = []
|
||||||
add_sample(sample_start, sample_end, ' '.join(sample_texts), reason='split' if split else 'complete')
|
add_sample(wav_path, speaker, sample_start, sample_end, ' '.join(sample_texts),
|
||||||
|
p_reason='split' if split else 'complete')
|
||||||
print('Skipped samples:')
|
print('Skipped samples:')
|
||||||
for reason, n in reasons.most_common():
|
for reason, n in reasons.most_common():
|
||||||
print(' - {}: {}'.format(reason, n))
|
print(' - {}: {}'.format(reason, n))
|
||||||
@ -279,7 +277,7 @@ def collect_samples(base_dir, language):
|
|||||||
|
|
||||||
|
|
||||||
def maybe_convert_one_to_wav(entry):
|
def maybe_convert_one_to_wav(entry):
|
||||||
root, dirs, files = entry
|
root, _, files = entry
|
||||||
transformer = sox.Transformer()
|
transformer = sox.Transformer()
|
||||||
transformer.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS)
|
transformer.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS)
|
||||||
combiner = sox.Combiner()
|
combiner = sox.Combiner()
|
||||||
@ -317,7 +315,7 @@ def maybe_convert_to_wav(base_dir):
|
|||||||
def assign_sub_sets(samples):
|
def assign_sub_sets(samples):
|
||||||
sample_size = get_sample_size(len(samples))
|
sample_size = get_sample_size(len(samples))
|
||||||
speakers = group(samples, lambda sample: sample.speaker).values()
|
speakers = group(samples, lambda sample: sample.speaker).values()
|
||||||
speakers = list(sorted(speakers, key=lambda speaker_samples: len(speaker_samples)))
|
speakers = list(sorted(speakers, key=len))
|
||||||
sample_sets = [[], []]
|
sample_sets = [[], []]
|
||||||
while any(map(lambda s: len(s) < sample_size, sample_sets)) and len(speakers) > 0:
|
while any(map(lambda s: len(s) < sample_size, sample_sets)) and len(speakers) > 0:
|
||||||
for sample_set in sample_sets:
|
for sample_set in sample_sets:
|
||||||
@ -441,8 +439,8 @@ def handle_args():
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
CLI_ARGS = handle_args()
|
CLI_ARGS = handle_args()
|
||||||
if CLI_ARGS.language == 'all':
|
if CLI_ARGS.language == 'all':
|
||||||
for language in LANGUAGES:
|
for lang in LANGUAGES:
|
||||||
prepare_language(language)
|
prepare_language(lang)
|
||||||
elif CLI_ARGS.language in LANGUAGES:
|
elif CLI_ARGS.language in LANGUAGES:
|
||||||
prepare_language(CLI_ARGS.language)
|
prepare_language(CLI_ARGS.language)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user