Merge pull request #1945 from coqui-ai/alphabet-loading-generation

Convenience features for alphabet loading/saving/generation
This commit is contained in:
Reuben Morais 2021-08-25 20:35:09 +02:00 committed by GitHub
commit 66b8a56454
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 111 additions and 27 deletions

View File

@ -69,6 +69,39 @@ Alphabet::init(const char *config_file)
return 0;
}
void
Alphabet::InitFromLabels(const std::vector<std::string>& labels)
{
space_label_ = -2;
size_ = labels.size();
for (int i = 0; i < size_; ++i) {
const std::string& label = labels[i];
if (label == " ") {
space_label_ = i;
}
label_to_str_[i] = label;
str_to_label_[label] = i;
}
}
std::string
Alphabet::SerializeText()
{
std::stringstream out;
out << "# Each line in this file represents the Unicode codepoint (UTF-8 encoded)\n"
<< "# associated with a numeric label.\n"
<< "# A line that starts with # is a comment. You can escape it with \\# if you wish\n"
<< "# to use '#' as a label.\n";
for (int label = 0; label < size_; ++label) {
out << label_to_str_[label] << "\n";
}
out << "# The last (non-comment) line needs to end with a newline.\n";
return out.str();
}
std::string
Alphabet::Serialize()
{

View File

@ -19,9 +19,15 @@ public:
virtual int init(const char *config_file);
// Initialize directly from sequence of labels.
void InitFromLabels(const std::vector<std::string>& labels);
// Serialize alphabet into a binary buffer.
std::string Serialize();
// Serialize alphabet into a text representation (ie. config file read by `init`)
std::string SerializeText();
// Deserialize alphabet from a binary buffer.
int Deserialize(const char* buffer, const int buffer_size);

View File

@ -45,14 +45,18 @@ class Scorer(swigwrapper.Scorer):
class Alphabet(swigwrapper.Alphabet):
"""Convenience wrapper for Alphabet which calls init in the constructor"""
def __init__(self, config_path):
def __init__(self, config_path=None):
super(Alphabet, self).__init__()
if config_path:
err = self.init(config_path.encode("utf-8"))
if err != 0:
raise ValueError(
"Alphabet initialization failed with error code 0x{:X}".format(err)
)
def InitFromLabels(self, data):
return super(Alphabet, self).InitFromLabels([c.encode("utf-8") for c in data])
def CanEncodeSingle(self, input):
"""
Returns true if the single character/output class has a corresponding label

View File

@ -429,6 +429,11 @@ def train():
with open_remote(flags_file, "w") as fout:
json.dump(Config.serialize(), fout, indent=2)
# Serialize alphabet alongside checkpoint
preserved_alphabet_file = os.path.join(Config.save_checkpoint_dir, "alphabet.txt")
with open_remote(preserved_alphabet_file, "wb") as fout:
fout.write(Config.alphabet.SerializeText())
with tfv1.Session(config=Config.session_config) as session:
log_debug("Session opened.")
@ -683,9 +688,6 @@ def early_training_checks():
"for loading and saving."
)
if not Config.alphabet_config_path and not Config.bytes_output_mode:
raise RuntimeError("Missing --alphabet_config_path flag, can't continue")
def main():
initialize_globals_from_cli()

View File

@ -11,11 +11,13 @@ from attrdict import AttrDict
from coqpit import MISSING, Coqpit, check_argument
from coqui_stt_ctcdecoder import Alphabet, UTF8Alphabet
from xdg import BaseDirectory as xdg
from tqdm import tqdm
from .augmentations import NormalizeSampleRate, parse_augmentations
from .gpu import get_available_gpus
from .helpers import parse_file_size
from .io import path_exists_remote
from .sample_collections import samples_from_sources
class _ConfigSingleton:
@ -118,6 +120,15 @@ class _SttConfig(Coqpit):
if not self.available_devices:
self.available_devices = [self.cpu_device]
# If neither `--alphabet_config_path` nor `--bytes_output_mode` were specified,
# look for alphabet file alongside loaded checkpoint.
loaded_checkpoint_alphabet_file = os.path.join(
self.load_checkpoint_dir, "alphabet.txt"
)
saved_checkpoint_alphabet_file = os.path.join(
self.save_checkpoint_dir, "alphabet.txt"
)
if self.bytes_output_mode and self.alphabet_config_path:
raise RuntimeError(
"You cannot set --alphabet_config_path *and* --bytes_output_mode"
@ -126,6 +137,42 @@ class _SttConfig(Coqpit):
self.alphabet = UTF8Alphabet()
elif self.alphabet_config_path:
self.alphabet = Alphabet(os.path.abspath(self.alphabet_config_path))
elif os.path.exists(loaded_checkpoint_alphabet_file):
print(
"I --alphabet_config_path not specified, but found an alphabet file "
f"alongside specified checkpoint ({loaded_checkpoint_alphabet_file}). "
"Will use this alphabet file for this run."
)
self.alphabet = Alphabet(loaded_checkpoint_alphabet_file)
elif self.train_files and self.dev_files and self.test_files:
# Generate alphabet automatically from input dataset, but only if
# fully specified, to avoid confusion in case a missing set has extra
# characters.
print(
"I --alphabet_config_path not specified, but all input datasets are "
"present (--train_files, --dev_files, --test_files). An alphabet "
"will be generated automatically from the data and placed alongside "
f"the checkpoint ({saved_checkpoint_alphabet_file})."
)
characters = set()
for sample in tqdm(
samples_from_sources(
self.train_files + self.dev_files + self.test_files
)
):
characters |= set(sample.transcript)
characters = list(sorted(characters))
print(f"I Generated alphabet characters: {characters}.")
self.alphabet = Alphabet()
self.alphabet.InitFromLabels(characters)
else:
raise RuntimeError(
"Missing --alphabet_config_path flag. Couldn't find an alphabet file\n"
"alongside checkpoint, and input datasets are not fully specified\n"
"(--train_files, --dev_files, --test_files), so can't generate an alphabet.\n"
"Either specify an alphabet file or fully specify the dataset, so one will\n"
"be generated automatically."
)
# Geometric Constants
# ===================
@ -157,15 +204,12 @@ class _SttConfig(Coqpit):
self.n_hidden_3 = self.n_cell_dim
# Dims in last layer = number of characters in alphabet plus one
try:
# +1 for CTC blank label
self.n_hidden_6 = self.alphabet.GetSize() + 1
except:
AttributeError
# Size of audio window in samples
if (self.feature_win_len * self.audio_sample_rate) % 1000 != 0:
log_error(
raise RuntimeError(
"--feature_win_len value ({}) in milliseconds ({}) multiplied "
"by --audio_sample_rate value ({}) must be an integer value. Adjust "
"your --feature_win_len value or resample your audio accordingly."
@ -175,7 +219,6 @@ class _SttConfig(Coqpit):
self.audio_sample_rate,
)
)
sys.exit(1)
self.audio_window_samples = self.audio_sample_rate * (
self.feature_win_len / 1000
@ -183,7 +226,7 @@ class _SttConfig(Coqpit):
# Stride for feature computations in samples
if (self.feature_win_step * self.audio_sample_rate) % 1000 != 0:
log_error(
raise RuntimeError(
"--feature_win_step value ({}) in milliseconds ({}) multiplied "
"by --audio_sample_rate value ({}) must be an integer value. Adjust "
"your --feature_win_step value or resample your audio accordingly."
@ -193,19 +236,18 @@ class _SttConfig(Coqpit):
self.audio_sample_rate,
)
)
sys.exit(1)
self.audio_step_samples = self.audio_sample_rate * (
self.feature_win_step / 1000
)
if self.one_shot_infer:
if not path_exists_remote(self.one_shot_infer):
log_error("Path specified in --one_shot_infer is not a valid file.")
sys.exit(1)
if self.one_shot_infer and not path_exists_remote(self.one_shot_infer):
raise RuntimeError(
"Path specified in --one_shot_infer is not a valid file."
)
if self.train_cudnn and self.load_cudnn:
log_error(
raise RuntimeError(
"Trying to use --train_cudnn, but --load_cudnn "
"was also specified. The --load_cudnn flag is only "
"needed when converting a CuDNN RNN checkpoint to "
@ -213,7 +255,6 @@ class _SttConfig(Coqpit):
"using CuDNN RNN, you can just specify the CuDNN RNN "
"checkpoint normally with --save_checkpoint_dir."
)
sys.exit(1)
# sphinx-doc: training_ref_flags_start
train_files: List[str] = field(
@ -727,9 +768,7 @@ class _SttConfig(Coqpit):
def initialize_globals_from_cli():
c = _SttConfig()
c.parse_args(arg_prefix="")
c.__post_init__()
c = _SttConfig.init_from_argparse(arg_prefix="")
_ConfigSingleton._config = c # pylint: disable=protected-access