From 87f0a371b1c0c903f2cae54039753196bf70cce4 Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Mon, 23 Aug 2021 16:02:13 +0200 Subject: [PATCH 1/3] Serialize alphabet alongside checkpoint --- native_client/alphabet.cc | 18 ++++++++++++++++++ native_client/alphabet.h | 3 +++ training/coqui_stt_training/train.py | 5 +++++ 3 files changed, 26 insertions(+) diff --git a/native_client/alphabet.cc b/native_client/alphabet.cc index 9abc65a5..bd231a45 100644 --- a/native_client/alphabet.cc +++ b/native_client/alphabet.cc @@ -69,6 +69,24 @@ Alphabet::init(const char *config_file) return 0; } +std::string +Alphabet::SerializeText() +{ + std::stringstream out; + + out << "# Each line in this file represents the Unicode codepoint (UTF-8 encoded)\n" + << "# associated with a numeric label.\n" + << "# A line that starts with # is a comment. You can escape it with \\# if you wish\n" + << "# to use '#' as a label.\n"; + + for (int label = 0; label < size_; ++label) { + out << label_to_str_[label] << "\n"; + } + + out << "# The last (non-comment) line needs to end with a newline.\n"; + return out.str(); +} + std::string Alphabet::Serialize() { diff --git a/native_client/alphabet.h b/native_client/alphabet.h index f402cc0d..3eb26790 100644 --- a/native_client/alphabet.h +++ b/native_client/alphabet.h @@ -22,6 +22,9 @@ public: // Serialize alphabet into a binary buffer. std::string Serialize(); + // Serialize alphabet into a text representation (ie. config file read by `init`) + std::string SerializeText(); + // Deserialize alphabet from a binary buffer. int Deserialize(const char* buffer, const int buffer_size); diff --git a/training/coqui_stt_training/train.py b/training/coqui_stt_training/train.py index 6032184c..67079474 100644 --- a/training/coqui_stt_training/train.py +++ b/training/coqui_stt_training/train.py @@ -429,6 +429,11 @@ def train(): with open_remote(flags_file, "w") as fout: json.dump(Config.serialize(), fout, indent=2) + # Serialize alphabet alongside checkpoint + preserved_alphabet_file = os.path.join(Config.save_checkpoint_dir, "alphabet.txt") + with open_remote(preserved_alphabet_file, "wb") as fout: + fout.write(Config.alphabet.SerializeText()) + with tfv1.Session(config=Config.session_config) as session: log_debug("Session opened.") From 2b5a844c0503bb0fdb46794f9631fb810ce0750b Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Mon, 23 Aug 2021 17:33:02 +0200 Subject: [PATCH 2/3] Load alphabet alongside checkpoint if present, some config fixes/cleanup --- training/coqui_stt_training/train.py | 3 -- training/coqui_stt_training/util/config.py | 43 +++++++++++++--------- 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/training/coqui_stt_training/train.py b/training/coqui_stt_training/train.py index 67079474..b62dbb8a 100644 --- a/training/coqui_stt_training/train.py +++ b/training/coqui_stt_training/train.py @@ -688,9 +688,6 @@ def early_training_checks(): "for loading and saving." ) - if not Config.alphabet_config_path and not Config.bytes_output_mode: - raise RuntimeError("Missing --alphabet_config_path flag, can't continue") - def main(): initialize_globals_from_cli() diff --git a/training/coqui_stt_training/util/config.py b/training/coqui_stt_training/util/config.py index 96114e25..f954ff0c 100755 --- a/training/coqui_stt_training/util/config.py +++ b/training/coqui_stt_training/util/config.py @@ -118,6 +118,12 @@ class _SttConfig(Coqpit): if not self.available_devices: self.available_devices = [self.cpu_device] + # If neither `--alphabet_config_path` nor `--bytes_output_mode` were specified, + # look for alphabet file alongside loaded checkpoint. + checkpoint_alphabet_file = os.path.join( + self.load_checkpoint_dir, "alphabet.txt" + ) + if self.bytes_output_mode and self.alphabet_config_path: raise RuntimeError( "You cannot set --alphabet_config_path *and* --bytes_output_mode" @@ -126,6 +132,15 @@ class _SttConfig(Coqpit): self.alphabet = UTF8Alphabet() elif self.alphabet_config_path: self.alphabet = Alphabet(os.path.abspath(self.alphabet_config_path)) + elif os.path.exists(checkpoint_alphabet_file): + print( + "I --alphabet_config_path not specified, but found an alphabet file " + f"alongside specified checkpoint ({checkpoint_alphabet_file}).\n" + "I Will use this alphabet file for this run." + ) + self.alphabet = Alphabet(checkpoint_alphabet_file) + else: + raise RuntimeError("Missing --alphabet_config_path flag, can't continue") # Geometric Constants # =================== @@ -157,15 +172,12 @@ class _SttConfig(Coqpit): self.n_hidden_3 = self.n_cell_dim # Dims in last layer = number of characters in alphabet plus one - try: - # +1 for CTC blank label - self.n_hidden_6 = self.alphabet.GetSize() + 1 - except: - AttributeError + # +1 for CTC blank label + self.n_hidden_6 = self.alphabet.GetSize() + 1 # Size of audio window in samples if (self.feature_win_len * self.audio_sample_rate) % 1000 != 0: - log_error( + raise RuntimeError( "--feature_win_len value ({}) in milliseconds ({}) multiplied " "by --audio_sample_rate value ({}) must be an integer value. Adjust " "your --feature_win_len value or resample your audio accordingly." @@ -175,7 +187,6 @@ class _SttConfig(Coqpit): self.audio_sample_rate, ) ) - sys.exit(1) self.audio_window_samples = self.audio_sample_rate * ( self.feature_win_len / 1000 @@ -183,7 +194,7 @@ class _SttConfig(Coqpit): # Stride for feature computations in samples if (self.feature_win_step * self.audio_sample_rate) % 1000 != 0: - log_error( + raise RuntimeError( "--feature_win_step value ({}) in milliseconds ({}) multiplied " "by --audio_sample_rate value ({}) must be an integer value. Adjust " "your --feature_win_step value or resample your audio accordingly." @@ -193,19 +204,18 @@ class _SttConfig(Coqpit): self.audio_sample_rate, ) ) - sys.exit(1) self.audio_step_samples = self.audio_sample_rate * ( self.feature_win_step / 1000 ) - if self.one_shot_infer: - if not path_exists_remote(self.one_shot_infer): - log_error("Path specified in --one_shot_infer is not a valid file.") - sys.exit(1) + if self.one_shot_infer and not path_exists_remote(self.one_shot_infer): + raise RuntimeError( + "Path specified in --one_shot_infer is not a valid file." + ) if self.train_cudnn and self.load_cudnn: - log_error( + raise RuntimeError( "Trying to use --train_cudnn, but --load_cudnn " "was also specified. The --load_cudnn flag is only " "needed when converting a CuDNN RNN checkpoint to " @@ -213,7 +223,6 @@ class _SttConfig(Coqpit): "using CuDNN RNN, you can just specify the CuDNN RNN " "checkpoint normally with --save_checkpoint_dir." ) - sys.exit(1) # sphinx-doc: training_ref_flags_start train_files: List[str] = field( @@ -727,9 +736,7 @@ class _SttConfig(Coqpit): def initialize_globals_from_cli(): - c = _SttConfig() - c.parse_args(arg_prefix="") - c.__post_init__() + c = _SttConfig.init_from_argparse(arg_prefix="") _ConfigSingleton._config = c # pylint: disable=protected-access From 02adea2d50556e0784dbc1a628ce9de009eb0855 Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Mon, 23 Aug 2021 18:24:47 +0200 Subject: [PATCH 3/3] Generate and save alphabet automatically if dataset is fully specified --- native_client/alphabet.cc | 15 ++++++++ native_client/alphabet.h | 3 ++ native_client/ctcdecode/__init__.py | 16 +++++--- training/coqui_stt_training/util/config.py | 44 +++++++++++++++++++--- 4 files changed, 66 insertions(+), 12 deletions(-) diff --git a/native_client/alphabet.cc b/native_client/alphabet.cc index bd231a45..a2a00dc0 100644 --- a/native_client/alphabet.cc +++ b/native_client/alphabet.cc @@ -69,6 +69,21 @@ Alphabet::init(const char *config_file) return 0; } +void +Alphabet::InitFromLabels(const std::vector& labels) +{ + space_label_ = -2; + size_ = labels.size(); + for (int i = 0; i < size_; ++i) { + const std::string& label = labels[i]; + if (label == " ") { + space_label_ = i; + } + label_to_str_[i] = label; + str_to_label_[label] = i; + } +} + std::string Alphabet::SerializeText() { diff --git a/native_client/alphabet.h b/native_client/alphabet.h index 3eb26790..ad75dfc1 100644 --- a/native_client/alphabet.h +++ b/native_client/alphabet.h @@ -19,6 +19,9 @@ public: virtual int init(const char *config_file); + // Initialize directly from sequence of labels. + void InitFromLabels(const std::vector& labels); + // Serialize alphabet into a binary buffer. std::string Serialize(); diff --git a/native_client/ctcdecode/__init__.py b/native_client/ctcdecode/__init__.py index 93365c80..82cdd308 100644 --- a/native_client/ctcdecode/__init__.py +++ b/native_client/ctcdecode/__init__.py @@ -45,13 +45,17 @@ class Scorer(swigwrapper.Scorer): class Alphabet(swigwrapper.Alphabet): """Convenience wrapper for Alphabet which calls init in the constructor""" - def __init__(self, config_path): + def __init__(self, config_path=None): super(Alphabet, self).__init__() - err = self.init(config_path.encode("utf-8")) - if err != 0: - raise ValueError( - "Alphabet initialization failed with error code 0x{:X}".format(err) - ) + if config_path: + err = self.init(config_path.encode("utf-8")) + if err != 0: + raise ValueError( + "Alphabet initialization failed with error code 0x{:X}".format(err) + ) + + def InitFromLabels(self, data): + return super(Alphabet, self).InitFromLabels([c.encode("utf-8") for c in data]) def CanEncodeSingle(self, input): """ diff --git a/training/coqui_stt_training/util/config.py b/training/coqui_stt_training/util/config.py index f954ff0c..52619d35 100755 --- a/training/coqui_stt_training/util/config.py +++ b/training/coqui_stt_training/util/config.py @@ -11,11 +11,13 @@ from attrdict import AttrDict from coqpit import MISSING, Coqpit, check_argument from coqui_stt_ctcdecoder import Alphabet, UTF8Alphabet from xdg import BaseDirectory as xdg +from tqdm import tqdm from .augmentations import NormalizeSampleRate, parse_augmentations from .gpu import get_available_gpus from .helpers import parse_file_size from .io import path_exists_remote +from .sample_collections import samples_from_sources class _ConfigSingleton: @@ -120,9 +122,12 @@ class _SttConfig(Coqpit): # If neither `--alphabet_config_path` nor `--bytes_output_mode` were specified, # look for alphabet file alongside loaded checkpoint. - checkpoint_alphabet_file = os.path.join( + loaded_checkpoint_alphabet_file = os.path.join( self.load_checkpoint_dir, "alphabet.txt" ) + saved_checkpoint_alphabet_file = os.path.join( + self.save_checkpoint_dir, "alphabet.txt" + ) if self.bytes_output_mode and self.alphabet_config_path: raise RuntimeError( @@ -132,15 +137,42 @@ class _SttConfig(Coqpit): self.alphabet = UTF8Alphabet() elif self.alphabet_config_path: self.alphabet = Alphabet(os.path.abspath(self.alphabet_config_path)) - elif os.path.exists(checkpoint_alphabet_file): + elif os.path.exists(loaded_checkpoint_alphabet_file): print( "I --alphabet_config_path not specified, but found an alphabet file " - f"alongside specified checkpoint ({checkpoint_alphabet_file}).\n" - "I Will use this alphabet file for this run." + f"alongside specified checkpoint ({loaded_checkpoint_alphabet_file}). " + "Will use this alphabet file for this run." ) - self.alphabet = Alphabet(checkpoint_alphabet_file) + self.alphabet = Alphabet(loaded_checkpoint_alphabet_file) + elif self.train_files and self.dev_files and self.test_files: + # Generate alphabet automatically from input dataset, but only if + # fully specified, to avoid confusion in case a missing set has extra + # characters. + print( + "I --alphabet_config_path not specified, but all input datasets are " + "present (--train_files, --dev_files, --test_files). An alphabet " + "will be generated automatically from the data and placed alongside " + f"the checkpoint ({saved_checkpoint_alphabet_file})." + ) + characters = set() + for sample in tqdm( + samples_from_sources( + self.train_files + self.dev_files + self.test_files + ) + ): + characters |= set(sample.transcript) + characters = list(sorted(characters)) + print(f"I Generated alphabet characters: {characters}.") + self.alphabet = Alphabet() + self.alphabet.InitFromLabels(characters) else: - raise RuntimeError("Missing --alphabet_config_path flag, can't continue") + raise RuntimeError( + "Missing --alphabet_config_path flag. Couldn't find an alphabet file\n" + "alongside checkpoint, and input datasets are not fully specified\n" + "(--train_files, --dev_files, --test_files), so can't generate an alphabet.\n" + "Either specify an alphabet file or fully specify the dataset, so one will\n" + "be generated automatically." + ) # Geometric Constants # ===================