diff --git a/training/coqui_stt_training/train.py b/training/coqui_stt_training/train.py index 67079474..b62dbb8a 100644 --- a/training/coqui_stt_training/train.py +++ b/training/coqui_stt_training/train.py @@ -688,9 +688,6 @@ def early_training_checks(): "for loading and saving." ) - if not Config.alphabet_config_path and not Config.bytes_output_mode: - raise RuntimeError("Missing --alphabet_config_path flag, can't continue") - def main(): initialize_globals_from_cli() diff --git a/training/coqui_stt_training/util/config.py b/training/coqui_stt_training/util/config.py index 96114e25..f954ff0c 100755 --- a/training/coqui_stt_training/util/config.py +++ b/training/coqui_stt_training/util/config.py @@ -118,6 +118,12 @@ class _SttConfig(Coqpit): if not self.available_devices: self.available_devices = [self.cpu_device] + # If neither `--alphabet_config_path` nor `--bytes_output_mode` were specified, + # look for alphabet file alongside loaded checkpoint. + checkpoint_alphabet_file = os.path.join( + self.load_checkpoint_dir, "alphabet.txt" + ) + if self.bytes_output_mode and self.alphabet_config_path: raise RuntimeError( "You cannot set --alphabet_config_path *and* --bytes_output_mode" @@ -126,6 +132,15 @@ class _SttConfig(Coqpit): self.alphabet = UTF8Alphabet() elif self.alphabet_config_path: self.alphabet = Alphabet(os.path.abspath(self.alphabet_config_path)) + elif os.path.exists(checkpoint_alphabet_file): + print( + "I --alphabet_config_path not specified, but found an alphabet file " + f"alongside specified checkpoint ({checkpoint_alphabet_file}).\n" + "I Will use this alphabet file for this run." + ) + self.alphabet = Alphabet(checkpoint_alphabet_file) + else: + raise RuntimeError("Missing --alphabet_config_path flag, can't continue") # Geometric Constants # =================== @@ -157,15 +172,12 @@ class _SttConfig(Coqpit): self.n_hidden_3 = self.n_cell_dim # Dims in last layer = number of characters in alphabet plus one - try: - # +1 for CTC blank label - self.n_hidden_6 = self.alphabet.GetSize() + 1 - except: - AttributeError + # +1 for CTC blank label + self.n_hidden_6 = self.alphabet.GetSize() + 1 # Size of audio window in samples if (self.feature_win_len * self.audio_sample_rate) % 1000 != 0: - log_error( + raise RuntimeError( "--feature_win_len value ({}) in milliseconds ({}) multiplied " "by --audio_sample_rate value ({}) must be an integer value. Adjust " "your --feature_win_len value or resample your audio accordingly." @@ -175,7 +187,6 @@ class _SttConfig(Coqpit): self.audio_sample_rate, ) ) - sys.exit(1) self.audio_window_samples = self.audio_sample_rate * ( self.feature_win_len / 1000 @@ -183,7 +194,7 @@ class _SttConfig(Coqpit): # Stride for feature computations in samples if (self.feature_win_step * self.audio_sample_rate) % 1000 != 0: - log_error( + raise RuntimeError( "--feature_win_step value ({}) in milliseconds ({}) multiplied " "by --audio_sample_rate value ({}) must be an integer value. Adjust " "your --feature_win_step value or resample your audio accordingly." @@ -193,19 +204,18 @@ class _SttConfig(Coqpit): self.audio_sample_rate, ) ) - sys.exit(1) self.audio_step_samples = self.audio_sample_rate * ( self.feature_win_step / 1000 ) - if self.one_shot_infer: - if not path_exists_remote(self.one_shot_infer): - log_error("Path specified in --one_shot_infer is not a valid file.") - sys.exit(1) + if self.one_shot_infer and not path_exists_remote(self.one_shot_infer): + raise RuntimeError( + "Path specified in --one_shot_infer is not a valid file." + ) if self.train_cudnn and self.load_cudnn: - log_error( + raise RuntimeError( "Trying to use --train_cudnn, but --load_cudnn " "was also specified. The --load_cudnn flag is only " "needed when converting a CuDNN RNN checkpoint to " @@ -213,7 +223,6 @@ class _SttConfig(Coqpit): "using CuDNN RNN, you can just specify the CuDNN RNN " "checkpoint normally with --save_checkpoint_dir." ) - sys.exit(1) # sphinx-doc: training_ref_flags_start train_files: List[str] = field( @@ -727,9 +736,7 @@ class _SttConfig(Coqpit): def initialize_globals_from_cli(): - c = _SttConfig() - c.parse_args(arg_prefix="") - c.__post_init__() + c = _SttConfig.init_from_argparse(arg_prefix="") _ConfigSingleton._config = c # pylint: disable=protected-access