Introduce --auto_input_dataset flag for input formatting

Automatically split data into sets and generate alphabet.
This commit is contained in:
Reuben Morais 2021-08-26 16:29:38 +02:00
parent 8458352255
commit 412de47623
4 changed files with 296 additions and 48 deletions

View File

@ -15,6 +15,7 @@ import json
import shutil import shutil
import time import time
from datetime import datetime from datetime import datetime
from pathlib import Path
import numpy as np import numpy as np
import progressbar import progressbar
@ -265,14 +266,14 @@ def early_training_checks():
) )
def train(): def create_training_datasets(
early_training_checks() exception_box,
) -> (tf.data.Dataset, [tf.data.Dataset], [tf.data.Dataset],):
tfv1.reset_default_graph() """Creates training datasets from input flags.
tfv1.set_random_seed(Config.random_seed)
exception_box = ExceptionBox()
Returns a single training dataset and two lists of datasets for validation
and metrics tracking.
"""
# Create training and validation datasets # Create training and validation datasets
train_set = create_dataset( train_set = create_dataset(
Config.train_files, Config.train_files,
@ -288,17 +289,8 @@ def train():
buffering=Config.read_buffer, buffering=Config.read_buffer,
) )
iterator = tfv1.data.Iterator.from_structure( dev_sets = []
tfv1.data.get_output_types(train_set),
tfv1.data.get_output_shapes(train_set),
output_classes=tfv1.data.get_output_classes(train_set),
)
# Make initialization ops for switching between the two sets
train_init_op = iterator.make_initializer(train_set)
if Config.dev_files: if Config.dev_files:
dev_sources = Config.dev_files
dev_sets = [ dev_sets = [
create_dataset( create_dataset(
[source], [source],
@ -311,12 +303,11 @@ def train():
limit=Config.limit_dev, limit=Config.limit_dev,
buffering=Config.read_buffer, buffering=Config.read_buffer,
) )
for source in dev_sources for source in Config.dev_files
] ]
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
metrics_sets = []
if Config.metrics_files: if Config.metrics_files:
metrics_sources = Config.metrics_files
metrics_sets = [ metrics_sets = [
create_dataset( create_dataset(
[source], [source],
@ -329,12 +320,35 @@ def train():
limit=Config.limit_dev, limit=Config.limit_dev,
buffering=Config.read_buffer, buffering=Config.read_buffer,
) )
for source in metrics_sources for source in Config.metrics_files
]
metrics_init_ops = [
iterator.make_initializer(metrics_set) for metrics_set in metrics_sets
] ]
return train_set, dev_sets, metrics_sets
def train():
early_training_checks()
tfv1.reset_default_graph()
tfv1.set_random_seed(Config.random_seed)
exception_box = ExceptionBox()
train_set, dev_sets, metrics_sets = create_training_datasets(exception_box)
iterator = tfv1.data.Iterator.from_structure(
tfv1.data.get_output_types(train_set),
tfv1.data.get_output_shapes(train_set),
output_classes=tfv1.data.get_output_classes(train_set),
)
# Make initialization ops for switching between the two sets
train_init_op = iterator.make_initializer(train_set)
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
metrics_init_ops = [
iterator.make_initializer(metrics_set) for metrics_set in metrics_sets
]
# Dropout # Dropout
dropout_rates = [ dropout_rates = [
tfv1.placeholder(tf.float32, name="dropout_{}".format(i)) for i in range(6) tfv1.placeholder(tf.float32, name="dropout_{}".format(i)) for i in range(6)
@ -550,7 +564,7 @@ def train():
# Validation # Validation
dev_loss = 0.0 dev_loss = 0.0
total_steps = 0 total_steps = 0
for source, init_op in zip(dev_sources, dev_init_ops): for source, init_op in zip(Config.dev_files, dev_init_ops):
log_progress("Validating epoch %d on %s..." % (epoch, source)) log_progress("Validating epoch %d on %s..." % (epoch, source))
set_loss, steps = run_set("dev", epoch, init_op, dataset=source) set_loss, steps = run_set("dev", epoch, init_op, dataset=source)
dev_loss += set_loss * steps dev_loss += set_loss * steps
@ -630,7 +644,7 @@ def train():
if Config.metrics_files: if Config.metrics_files:
# Read only metrics, not affecting best validation loss tracking # Read only metrics, not affecting best validation loss tracking
for source, init_op in zip(metrics_sources, metrics_init_ops): for source, init_op in zip(Config.metrics_files, metrics_init_ops):
log_progress("Metrics for epoch %d on %s..." % (epoch, source)) log_progress("Metrics for epoch %d on %s..." % (epoch, source))
set_loss, _ = run_set("metrics", epoch, init_op, dataset=source) set_loss, _ = run_set("metrics", epoch, init_op, dataset=source)
log_progress( log_progress(

View File

@ -0,0 +1,194 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from pathlib import Path
from typing import Optional
import pandas
from tqdm import tqdm
from .io import open_remote
from .sample_collections import samples_from_sources
from coqui_stt_ctcdecoder import Alphabet
def create_alphabet_from_sources(sources: [str]) -> ([str], Alphabet):
"""Generate an Alphabet from characters in given sources.
sources: List of paths to input sources (CSV, SDB).
Returns a 2-tuple with list of characters and Alphabet instance.
"""
characters = set()
for sample in tqdm(samples_from_sources(sources)):
characters |= set(sample.transcript)
characters = list(sorted(characters))
alphabet = Alphabet()
alphabet.InitFromLabels(characters)
return characters, alphabet
def _get_sample_size(population_size):
"""calculates the sample size for a 99% confidence and 1% margin of error"""
margin_of_error = 0.01
fraction_picking = 0.50
z_score = 2.58 # Corresponds to confidence level 99%
numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
margin_of_error ** 2
)
sample_size = 0
for train_size in range(population_size, 0, -1):
denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
margin_of_error ** 2 * train_size
)
sample_size = int(numerator / denominator)
if 2 * sample_size + train_size <= population_size:
break
return sample_size
def _split_sets(samples: pandas.DataFrame, sample_size):
"""
randomply split the datasets into train, validation, and test sets where the size of the
validation and test sets are determined by the `get_sample_size` function.
"""
samples = samples.sample(frac=1).reset_index(drop=True)
train_beg = 0
train_end = len(samples) - 2 * sample_size
dev_beg = train_end
dev_end = train_end + sample_size
test_beg = dev_end
test_end = len(samples)
return (
samples[train_beg:train_end],
samples[dev_beg:dev_end],
samples[test_beg:test_end],
)
def create_datasets_from_auto_input(
auto_input_dataset: Path, alphabet_config_path: Optional[Path]
) -> (Path, Path, Path, Path):
"""Creates training datasets from --auto_input_dataset flag.
auto_input_dataset: Path to input CSV or folder containing CSV.
Returns paths to generated train set, dev set and test set, and the path
to the alphabet file, either generated from the data, existing alongside
data, or specified manually by the user.
"""
if auto_input_dataset.is_dir():
auto_input_dir = auto_input_dataset
all_csvs = list(auto_input_dataset.glob("*.csv"))
if not all_csvs:
raise RuntimeError(
"--auto_input_dataset is a directory but no CSV file was found "
"inside of it. Either make sure a CSV file is in the directory "
"or specify the file it directly."
)
non_subsets = [f for f in all_csvs if f.stem not in ("train", "dev", "test")]
if len(non_subsets) == 1:
auto_input_csv = non_subsets[0]
elif len(non_subsets) > 1:
non_subsets_fmt = ", ".join(str(s) for s in non_subsets)
raise RuntimeError(
"--auto_input_dataset is a directory but there are multiple CSV "
f"files not matching a subset name (train/dev/test): {non_subsets_fmt}. "
"Either remove extraneous CSV files or specify the correct file "
"to use for dataset formatting directly instead of the directory."
)
# else (empty) -> fall through, sets already present and get picked up below
else:
auto_input_dir = auto_input_dataset.parent
auto_input_csv = auto_input_dataset
train_set_path = auto_input_dir / "train.csv"
dev_set_path = auto_input_dir / "dev.csv"
test_set_path = auto_input_dir / "test.csv"
if train_set_path.exists() != dev_set_path.exists() != test_set_path.exists():
raise RuntimeError(
"Specifying --auto_input_dataset with some generated files present "
"and some missing. Either all three sets (train.csv, dev.csv, test.csv) "
"should exist alongside {auto_input_csv} (in which case they will be used), "
"or none of those files should exist (in which case they will be generated.)"
)
print(f"I Processing --auto_input_dataset input: {auto_input_csv}...")
df = pandas.read_csv(auto_input_csv)
if set(df.columns) < set(("wav_filename", "wav_filesize", "transcript")):
raise RuntimeError(
"Missing columns in --auto_input_dataset CSV. STT training inputs "
"require wav_filename, wav_filesize, and transcript columns."
)
dev_test_size = _get_sample_size(len(df))
if dev_test_size == 0:
if len(df) >= 2:
dev_test_size = 1
else:
raise RuntimeError(
"--auto_input_dataset dataset is too small for automatic splitting "
"into sets. Specify a larger input dataset or split it manually."
)
data_characters = sorted(list(set("".join(df["transcript"].values))))
alphabet_alongside_data_path = auto_input_dir / "alphabet.txt"
if alphabet_config_path:
alphabet = Alphabet(str(alphabet_config_path))
if not alphabet.CanEncode("".join(data_characters)):
raise RuntimeError(
"--alphabet_config_path was specified alongside --auto_input_dataset, "
"but alphabet contents don't match dataset transcripts. Make sure the "
"alphabet covers all transcripts or leave --alphabet_config_path "
"unspecified so that one will be generated automatically."
)
print(f"I Using specified --alphabet_config_path: {alphabet_config_path}")
generated_alphabet_path = alphabet_config_path
elif alphabet_alongside_data_path.exists():
alphabet = Alphabet(str(alphabet_alongside_data_path))
if not alphabet.CanEncode("".join(data_characters)):
raise RuntimeError(
"alphabet.txt exists alongside --auto_input_dataset file, but "
"alphabet contents don't match dataset transcripts. Make sure the "
"alphabet covers all transcripts or remove alphabet.txt file "
"from the data folderso that one will be generated automatically."
)
generated_alphabet_path = alphabet_alongside_data_path
print(f"I Using existing alphabet file: {alphabet_alongside_data_path}")
else:
alphabet = Alphabet()
alphabet.InitFromLabels(data_characters)
generated_alphabet_path = auto_input_dir / "alphabet.txt"
print(
f"I Saved generated alphabet with characters ({data_characters}) into {generated_alphabet_path}"
)
with open_remote(str(generated_alphabet_path), "wb") as fout:
fout.write(alphabet.SerializeText())
# If splits don't already exist, generate and save them.
# We check above that all three splits either exist or don't exist together,
# so we can check a single one for existence here.
if not train_set_path.exists():
train_set, dev_set, test_set = _split_sets(df, dev_test_size)
print(f"I Generated train set size: {len(train_set)} samples.")
print(f"I Generated validation set size: {len(dev_set)} samples.")
print(f"I Generated test set size: {len(test_set)} samples.")
print(f"I Writing train set to {train_set_path}")
train_set.to_csv(train_set_path, index=False)
print(f"I Writing dev set to {dev_set_path}")
dev_set.to_csv(dev_set_path, index=False)
print(f"I Writing test set to {test_set_path}")
test_set.to_csv(test_set_path, index=False)
else:
print("I Generated splits found alongside --auto_input_dataset, using them.")
return train_set_path, dev_set_path, test_set_path, generated_alphabet_path

84
training/coqui_stt_training/util/config.py Executable file → Normal file
View File

@ -3,6 +3,7 @@ from __future__ import absolute_import, division, print_function
import os import os
import sys import sys
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import List from typing import List
import progressbar import progressbar
@ -11,13 +12,12 @@ from attrdict import AttrDict
from coqpit import MISSING, Coqpit, check_argument from coqpit import MISSING, Coqpit, check_argument
from coqui_stt_ctcdecoder import Alphabet, UTF8Alphabet from coqui_stt_ctcdecoder import Alphabet, UTF8Alphabet
from xdg import BaseDirectory as xdg from xdg import BaseDirectory as xdg
from tqdm import tqdm
from .augmentations import NormalizeSampleRate, parse_augmentations from .augmentations import NormalizeSampleRate, parse_augmentations
from .auto_input import create_alphabet_from_sources, create_datasets_from_auto_input
from .gpu import get_available_gpus from .gpu import get_available_gpus
from .helpers import parse_file_size from .helpers import parse_file_size
from .io import path_exists_remote from .io import path_exists_remote
from .sample_collections import samples_from_sources
class _ConfigSingleton: class _ConfigSingleton:
@ -129,6 +129,30 @@ class _SttConfig(Coqpit):
self.save_checkpoint_dir, "alphabet.txt" self.save_checkpoint_dir, "alphabet.txt"
) )
if not (
bool(self.auto_input_dataset)
!= (self.train_files or self.dev_files or self.test_files)
):
raise RuntimeError(
"When using --auto_input_dataset, do not specify --train_files, "
"--dev_files, or --test_files."
)
if self.auto_input_dataset:
(
gen_train,
gen_dev,
gen_test,
gen_alphabet,
) = create_datasets_from_auto_input(
Path(self.auto_input_dataset),
Path(self.alphabet_config_path) if self.alphabet_config_path else None,
)
self.train_files = [str(gen_train)]
self.dev_files = [str(gen_dev)]
self.test_files = [str(gen_test)]
self.alphabet_config_path = str(gen_alphabet)
if self.bytes_output_mode and self.alphabet_config_path: if self.bytes_output_mode and self.alphabet_config_path:
raise RuntimeError( raise RuntimeError(
"You cannot set --alphabet_config_path *and* --bytes_output_mode" "You cannot set --alphabet_config_path *and* --bytes_output_mode"
@ -136,7 +160,7 @@ class _SttConfig(Coqpit):
elif self.bytes_output_mode: elif self.bytes_output_mode:
self.alphabet = UTF8Alphabet() self.alphabet = UTF8Alphabet()
elif self.alphabet_config_path: elif self.alphabet_config_path:
self.alphabet = Alphabet(os.path.abspath(self.alphabet_config_path)) self.alphabet = Alphabet(self.alphabet_config_path)
elif os.path.exists(loaded_checkpoint_alphabet_file): elif os.path.exists(loaded_checkpoint_alphabet_file):
print( print(
"I --alphabet_config_path not specified, but found an alphabet file " "I --alphabet_config_path not specified, but found an alphabet file "
@ -145,26 +169,36 @@ class _SttConfig(Coqpit):
) )
self.alphabet = Alphabet(loaded_checkpoint_alphabet_file) self.alphabet = Alphabet(loaded_checkpoint_alphabet_file)
elif self.train_files and self.dev_files and self.test_files: elif self.train_files and self.dev_files and self.test_files:
# Generate alphabet automatically from input dataset, but only if # If all subsets are in the same folder and there's an alphabet file
# fully specified, to avoid confusion in case a missing set has extra # alongside them, use it.
# characters. self.alphabet = None
print( sources = self.train_files + self.dev_files + self.test_files
"I --alphabet_config_path not specified, but all input datasets are " parents = set(Path(p).parent for p in sources)
"present (--train_files, --dev_files, --test_files). An alphabet " if len(parents) == 1:
"will be generated automatically from the data and placed alongside " possible_alphabet = list(parents)[0] / "alphabet.txt"
f"the checkpoint ({saved_checkpoint_alphabet_file})." if possible_alphabet.exists():
) print(
characters = set() "I --alphabet_config_path not specified, but all input "
for sample in tqdm( "datasets are present and in the same folder (--train_files, "
samples_from_sources( "--dev_files and --test_files), and an alphabet.txt file "
self.train_files + self.dev_files + self.test_files f"was found alongside the sets ({possible_alphabet}). "
"Will use this alphabet file for this run."
)
self.alphabet = Alphabet(str(possible_alphabet))
if not self.alphabet:
# Generate alphabet automatically from input dataset, but only if
# fully specified, to avoid confusion in case a missing set has extra
# characters.
print(
"I --alphabet_config_path not specified, but all input datasets are "
"present (--train_files, --dev_files, --test_files). An alphabet "
"will be generated automatically from the data and placed alongside "
f"the checkpoint ({saved_checkpoint_alphabet_file})."
) )
): characters, alphabet = create_alphabet_from_sources(sources)
characters |= set(sample.transcript) print(f"I Generated alphabet characters: {characters}.")
characters = list(sorted(characters)) self.alphabet = alphabet
print(f"I Generated alphabet characters: {characters}.")
self.alphabet = Alphabet()
self.alphabet.InitFromLabels(characters)
else: else:
raise RuntimeError( raise RuntimeError(
"Missing --alphabet_config_path flag. Couldn't find an alphabet file\n" "Missing --alphabet_config_path flag. Couldn't find an alphabet file\n"
@ -281,6 +315,12 @@ class _SttConfig(Coqpit):
help="space-separated list of files specifying the datasets used for tracking of metrics (after validation step). Currently the only metric is the CTC loss but without affecting the tracking of best validation loss. Multiple files will get reported separately. If empty, metrics will not be computed." help="space-separated list of files specifying the datasets used for tracking of metrics (after validation step). Currently the only metric is the CTC loss but without affecting the tracking of best validation loss. Multiple files will get reported separately. If empty, metrics will not be computed."
), ),
) )
auto_input_dataset: str = field(
default="",
metadata=dict(
help="path to a single CSV file to use for training. Cannot be specified alongside --train_files, --dev_files, --test_files. Training/validation/testing subsets will be automatically generated from the input, alongside with an alphabet file, if not already present.",
),
)
read_buffer: str = field( read_buffer: str = field(
default="1MB", default="1MB",

0
training/coqui_stt_training/util/gpu.py Executable file → Normal file
View File