Generate and save alphabet automatically if dataset is fully specified

This commit is contained in:
Reuben Morais 2021-08-23 18:24:47 +02:00
parent 2b5a844c05
commit 02adea2d50
4 changed files with 66 additions and 12 deletions

View File

@ -69,6 +69,21 @@ Alphabet::init(const char *config_file)
return 0; return 0;
} }
void
Alphabet::InitFromLabels(const std::vector<std::string>& labels)
{
space_label_ = -2;
size_ = labels.size();
for (int i = 0; i < size_; ++i) {
const std::string& label = labels[i];
if (label == " ") {
space_label_ = i;
}
label_to_str_[i] = label;
str_to_label_[label] = i;
}
}
std::string std::string
Alphabet::SerializeText() Alphabet::SerializeText()
{ {

View File

@ -19,6 +19,9 @@ public:
virtual int init(const char *config_file); virtual int init(const char *config_file);
// Initialize directly from sequence of labels.
void InitFromLabels(const std::vector<std::string>& labels);
// Serialize alphabet into a binary buffer. // Serialize alphabet into a binary buffer.
std::string Serialize(); std::string Serialize();

View File

@ -45,14 +45,18 @@ class Scorer(swigwrapper.Scorer):
class Alphabet(swigwrapper.Alphabet): class Alphabet(swigwrapper.Alphabet):
"""Convenience wrapper for Alphabet which calls init in the constructor""" """Convenience wrapper for Alphabet which calls init in the constructor"""
def __init__(self, config_path): def __init__(self, config_path=None):
super(Alphabet, self).__init__() super(Alphabet, self).__init__()
if config_path:
err = self.init(config_path.encode("utf-8")) err = self.init(config_path.encode("utf-8"))
if err != 0: if err != 0:
raise ValueError( raise ValueError(
"Alphabet initialization failed with error code 0x{:X}".format(err) "Alphabet initialization failed with error code 0x{:X}".format(err)
) )
def InitFromLabels(self, data):
return super(Alphabet, self).InitFromLabels([c.encode("utf-8") for c in data])
def CanEncodeSingle(self, input): def CanEncodeSingle(self, input):
""" """
Returns true if the single character/output class has a corresponding label Returns true if the single character/output class has a corresponding label

View File

@ -11,11 +11,13 @@ from attrdict import AttrDict
from coqpit import MISSING, Coqpit, check_argument from coqpit import MISSING, Coqpit, check_argument
from coqui_stt_ctcdecoder import Alphabet, UTF8Alphabet from coqui_stt_ctcdecoder import Alphabet, UTF8Alphabet
from xdg import BaseDirectory as xdg from xdg import BaseDirectory as xdg
from tqdm import tqdm
from .augmentations import NormalizeSampleRate, parse_augmentations from .augmentations import NormalizeSampleRate, parse_augmentations
from .gpu import get_available_gpus from .gpu import get_available_gpus
from .helpers import parse_file_size from .helpers import parse_file_size
from .io import path_exists_remote from .io import path_exists_remote
from .sample_collections import samples_from_sources
class _ConfigSingleton: class _ConfigSingleton:
@ -120,9 +122,12 @@ class _SttConfig(Coqpit):
# If neither `--alphabet_config_path` nor `--bytes_output_mode` were specified, # If neither `--alphabet_config_path` nor `--bytes_output_mode` were specified,
# look for alphabet file alongside loaded checkpoint. # look for alphabet file alongside loaded checkpoint.
checkpoint_alphabet_file = os.path.join( loaded_checkpoint_alphabet_file = os.path.join(
self.load_checkpoint_dir, "alphabet.txt" self.load_checkpoint_dir, "alphabet.txt"
) )
saved_checkpoint_alphabet_file = os.path.join(
self.save_checkpoint_dir, "alphabet.txt"
)
if self.bytes_output_mode and self.alphabet_config_path: if self.bytes_output_mode and self.alphabet_config_path:
raise RuntimeError( raise RuntimeError(
@ -132,15 +137,42 @@ class _SttConfig(Coqpit):
self.alphabet = UTF8Alphabet() self.alphabet = UTF8Alphabet()
elif self.alphabet_config_path: elif self.alphabet_config_path:
self.alphabet = Alphabet(os.path.abspath(self.alphabet_config_path)) self.alphabet = Alphabet(os.path.abspath(self.alphabet_config_path))
elif os.path.exists(checkpoint_alphabet_file): elif os.path.exists(loaded_checkpoint_alphabet_file):
print( print(
"I --alphabet_config_path not specified, but found an alphabet file " "I --alphabet_config_path not specified, but found an alphabet file "
f"alongside specified checkpoint ({checkpoint_alphabet_file}).\n" f"alongside specified checkpoint ({loaded_checkpoint_alphabet_file}). "
"I Will use this alphabet file for this run." "Will use this alphabet file for this run."
) )
self.alphabet = Alphabet(checkpoint_alphabet_file) self.alphabet = Alphabet(loaded_checkpoint_alphabet_file)
elif self.train_files and self.dev_files and self.test_files:
# Generate alphabet automatically from input dataset, but only if
# fully specified, to avoid confusion in case a missing set has extra
# characters.
print(
"I --alphabet_config_path not specified, but all input datasets are "
"present (--train_files, --dev_files, --test_files). An alphabet "
"will be generated automatically from the data and placed alongside "
f"the checkpoint ({saved_checkpoint_alphabet_file})."
)
characters = set()
for sample in tqdm(
samples_from_sources(
self.train_files + self.dev_files + self.test_files
)
):
characters |= set(sample.transcript)
characters = list(sorted(characters))
print(f"I Generated alphabet characters: {characters}.")
self.alphabet = Alphabet()
self.alphabet.InitFromLabels(characters)
else: else:
raise RuntimeError("Missing --alphabet_config_path flag, can't continue") raise RuntimeError(
"Missing --alphabet_config_path flag. Couldn't find an alphabet file\n"
"alongside checkpoint, and input datasets are not fully specified\n"
"(--train_files, --dev_files, --test_files), so can't generate an alphabet.\n"
"Either specify an alphabet file or fully specify the dataset, so one will\n"
"be generated automatically."
)
# Geometric Constants # Geometric Constants
# =================== # ===================