From 4dc565beca8b0ef7d1d54115dff2f73c298f9897 Mon Sep 17 00:00:00 2001 From: Josh Meyer Date: Wed, 21 Jul 2021 05:00:05 -0400 Subject: [PATCH] Move checking logic into __post_init__() --- lm_optimizer.py | 5 +- training/coqui_stt_training/evaluate.py | 6 +- training/coqui_stt_training/train.py | 8 +- training/coqui_stt_training/util/config.py | 323 +++++++++++---------- transcribe.py | 6 +- 5 files changed, 174 insertions(+), 174 deletions(-) diff --git a/lm_optimizer.py b/lm_optimizer.py index d4787137..85ca1fd5 100644 --- a/lm_optimizer.py +++ b/lm_optimizer.py @@ -10,7 +10,7 @@ import tensorflow.compat.v1 as tfv1 from coqui_stt_ctcdecoder import Scorer from coqui_stt_training.evaluate import evaluate from coqui_stt_training.train import create_model -from coqui_stt_training.util.config import Config, _SttConfig, initialize_config_globals +from coqui_stt_training.util.config import Config, initialize_globals_from_cli from coqui_stt_training.util.evaluate_tools import wer_cer_batch from coqui_stt_training.util.flags import FLAGS, create_flags from coqui_stt_training.util.logging import log_error @@ -52,8 +52,7 @@ def objective(trial): def main(_): - Config = _SttConfig() - initialize_config_globals(Config) + initialize_globals_from_cli() if not FLAGS.test_files: log_error( diff --git a/training/coqui_stt_training/evaluate.py b/training/coqui_stt_training/evaluate.py index cff3dc1e..4425582b 100755 --- a/training/coqui_stt_training/evaluate.py +++ b/training/coqui_stt_training/evaluate.py @@ -17,9 +17,8 @@ from .util.augmentations import NormalizeSampleRate from .util.checkpoints import load_graph_for_evaluation from .util.config import ( Config, - _SttConfig, create_progressbar, - initialize_config_globals, + initialize_globals_from_cli, log_error, log_progress, ) @@ -170,8 +169,7 @@ def evaluate(test_csvs, create_model): def main(): - Config = _SttConfig() - initialize_config_globals(Config) + initialize_globals_from_cli() if not Config.test_files: log_error( diff --git a/training/coqui_stt_training/train.py b/training/coqui_stt_training/train.py index 619b2d1d..b6cfaec9 100644 --- a/training/coqui_stt_training/train.py +++ b/training/coqui_stt_training/train.py @@ -44,9 +44,8 @@ from .util.checkpoints import ( ) from .util.config import ( Config, - _SttConfig, create_progressbar, - initialize_config_globals, + initialize_globals_from_cli, log_debug, log_error, log_info, @@ -1250,10 +1249,7 @@ def early_training_checks(): def main(): - Config = _SttConfig() - Config.parse_args(arg_prefix="") # parse CLI args - initialize_config_globals(Config) - + initialize_globals_from_cli() early_training_checks() if Config.train_files: diff --git a/training/coqui_stt_training/util/config.py b/training/coqui_stt_training/util/config.py index 1896eb7c..0a04ad77 100755 --- a/training/coqui_stt_training/util/config.py +++ b/training/coqui_stt_training/util/config.py @@ -36,6 +36,163 @@ Config = _ConfigSingleton() # pylint: disable=invalid-name @dataclass class _SttConfig(Coqpit): + def __post_init__(self): + # Augmentations + self.augmentations = parse_augmentations(self.augment) + if self.augmentations: + print(f"Parsed augmentations: {self.augmentations}") + if self.augmentations and self.feature_cache and self.cache_for_epochs == 0: + print( + "Due to current feature-cache settings the exact same sample augmentations of the first " + "epoch will be repeated on all following epochs. This could lead to unintended over-fitting. " + "You could use --cache_for_epochs to invalidate the cache after a given number of epochs." + ) + + if self.normalize_sample_rate: + self.augmentations = [NormalizeSampleRate(self.audio_sample_rate)] + self[ + "augmentations" + ] + + # Caching + if self.cache_for_epochs == 1: + print( + "--cache_for_epochs == 1 is (re-)creating the feature cache on every epoch but will never use it." + "You can either set --cache_for_epochs > 1, or not use feature caching at all." + ) + + # Read-buffer + self.read_buffer = parse_file_size(self.read_buffer) + + # Set default dropout rates + if self.dropout_rate2 < 0: + self.dropout_rate2 = self.dropout_rate + if self.dropout_rate3 < 0: + self.dropout_rate3 = self.dropout_rate + if self.dropout_rate6 < 0: + self.dropout_rate6 = self.dropout_rate + + # Set default checkpoint dir + if not self.checkpoint_dir: + self.checkpoint_dir = xdg.save_data_path(os.path.join("stt", "checkpoints")) + + if self.load_train not in ["last", "best", "init", "auto"]: + self.load_train = "auto" + + if self.load_evaluate not in ["last", "best", "auto"]: + self.load_evaluate = "auto" + + # Set default summary dir + if not self.summary_dir: + self.summary_dir = xdg.save_data_path(os.path.join("stt", "summaries")) + + # Standard session configuration that'll be used for all new sessions. + self.session_config = tfv1.ConfigProto( + allow_soft_placement=True, + log_device_placement=self.log_placement, + inter_op_parallelism_threads=self.inter_op_parallelism_threads, + intra_op_parallelism_threads=self.intra_op_parallelism_threads, + gpu_options=tfv1.GPUOptions(allow_growth=self.use_allow_growth), + ) + + # CPU device + self.cpu_device = "/cpu:0" + + # Available GPU devices + self.available_devices = get_available_gpus(self.session_config) + + # If there is no GPU available, we fall back to CPU based operation + if not self.available_devices: + self.available_devices = [self.cpu_device] + + if self.bytes_output_mode: + self.alphabet = UTF8Alphabet() + elif self.alphabet_config_path: + self.alphabet = Alphabet(os.path.abspath(self.alphabet_config_path)) + + # Geometric Constants + # =================== + + # For an explanation of the meaning of the geometric constants + # please refer to doc/Geometry.md + + # Number of MFCC features + self.n_input = 26 # TODO: Determine this programmatically from the sample rate + + # The number of frames in the context + self.n_context = 9 # TODO: Determine the optimal value using a validation data set + + # Number of units in hidden layers + self.n_hidden = self.n_hidden + + self.n_hidden_1 = self.n_hidden + + self.n_hidden_2 = self.n_hidden + + self.n_hidden_5 = self.n_hidden + + # LSTM cell state dimension + self.n_cell_dim = self.n_hidden + + # The number of units in the third layer, which feeds in to the LSTM + self.n_hidden_3 = self.n_cell_dim + + # Dims in last layer = number of characters in alphabet plus one + try: + # +1 for CTC blank label + self.n_hidden_6 = self.alphabet.GetSize() + 1 + except: + AttributeError + + # Size of audio window in samples + if (self.feature_win_len * self.audio_sample_rate) % 1000 != 0: + log_error( + "--feature_win_len value ({}) in milliseconds ({}) multiplied " + "by --audio_sample_rate value ({}) must be an integer value. Adjust " + "your --feature_win_len value or resample your audio accordingly." + "".format(self.feature_win_len, self.feature_win_len / 1000, self.audio_sample_rate) + ) + sys.exit(1) + + self.audio_window_samples = self.audio_sample_rate * (self.feature_win_len / 1000) + + # Stride for feature computations in samples + if (self.feature_win_step * self.audio_sample_rate) % 1000 != 0: + log_error( + "--feature_win_step value ({}) in milliseconds ({}) multiplied " + "by --audio_sample_rate value ({}) must be an integer value. Adjust " + "your --feature_win_step value or resample your audio accordingly." + "".format( + self.feature_win_step, self.feature_win_step / 1000, self.audio_sample_rate + ) + ) + sys.exit(1) + + self.audio_step_samples = self.audio_sample_rate * (self.feature_win_step / 1000) + + if self.one_shot_infer: + if not path_exists_remote(self.one_shot_infer): + log_error("Path specified in --one_shot_infer is not a valid file.") + sys.exit(1) + + if self.train_cudnn and self.load_cudnn: + log_error( + "Trying to use --train_cudnn, but --load_cudnn " + "was also specified. The --load_cudnn flag is only " + "needed when converting a CuDNN RNN checkpoint to " + "a CPU-capable graph. If your system is capable of " + "using CuDNN RNN, you can just specify the CuDNN RNN " + "checkpoint normally with --save_checkpoint_dir." + ) + sys.exit(1) + + # If separate save and load flags were not specified, default to load and save + # from the same dir. + if not self.save_checkpoint_dir: + self.save_checkpoint_dir = self.checkpoint_dir + + if not self.load_checkpoint_dir: + self.load_checkpoint_dir = self.checkpoint_dir + train_files: List[str] = field( default_factory=list, metadata=dict( @@ -541,165 +698,15 @@ class _SttConfig(Coqpit): ) -def initialize_config_globals(c): - """ - input: config class object (i.e. coqpit.Coqpit) - """ - - # Augmentations - c.augmentations = parse_augmentations(c.augment) - print(f"Parsed augmentations from flags: {c.augmentations}") - if c.augmentations and c.feature_cache and c.cache_for_epochs == 0: - print( - "Due to current feature-cache settings the exact same sample augmentations of the first " - "epoch will be repeated on all following epochs. This could lead to unintended over-fitting. " - "You could use --cache_for_epochs to invalidate the cache after a given number of epochs." - ) - - if c.normalize_sample_rate: - c.augmentations = [NormalizeSampleRate(c.audio_sample_rate)] + c[ - "augmentations" - ] - - # Caching - if c.cache_for_epochs == 1: - print( - "--cache_for_epochs == 1 is (re-)creating the feature cache on every epoch but will never use it." - ) - - # Read-buffer - c.read_buffer = parse_file_size(c.read_buffer) - - # Set default dropout rates - if c.dropout_rate2 < 0: - c.dropout_rate2 = c.dropout_rate - if c.dropout_rate3 < 0: - c.dropout_rate3 = c.dropout_rate - if c.dropout_rate6 < 0: - c.dropout_rate6 = c.dropout_rate - - # Set default checkpoint dir - if not c.checkpoint_dir: - c.checkpoint_dir = xdg.save_data_path(os.path.join("stt", "checkpoints")) - - if c.load_train not in ["last", "best", "init", "auto"]: - c.load_train = "auto" - - if c.load_evaluate not in ["last", "best", "auto"]: - c.load_evaluate = "auto" - - # Set default summary dir - if not c.summary_dir: - c.summary_dir = xdg.save_data_path(os.path.join("stt", "summaries")) - - # Standard session configuration that'll be used for all new sessions. - c.session_config = tfv1.ConfigProto( - allow_soft_placement=True, - log_device_placement=c.log_placement, - inter_op_parallelism_threads=c.inter_op_parallelism_threads, - intra_op_parallelism_threads=c.intra_op_parallelism_threads, - gpu_options=tfv1.GPUOptions(allow_growth=c.use_allow_growth), - ) - - # CPU device - c.cpu_device = "/cpu:0" - - # Available GPU devices - c.available_devices = get_available_gpus(c.session_config) - - # If there is no GPU available, we fall back to CPU based operation - if not c.available_devices: - c.available_devices = [c.cpu_device] - - if c.bytes_output_mode: - c.alphabet = UTF8Alphabet() - elif c.alphabet_config_path: - c.alphabet = Alphabet(os.path.abspath(c.alphabet_config_path)) - - # Geometric Constants - # =================== - - # For an explanation of the meaning of the geometric constants, please refer to - # doc/Geometry.md - - # Number of MFCC features - c.n_input = 26 # TODO: Determine this programmatically from the sample rate - - # The number of frames in the context - c.n_context = 9 # TODO: Determine the optimal value using a validation data set - - # Number of units in hidden layers - c.n_hidden = c.n_hidden - - c.n_hidden_1 = c.n_hidden - - c.n_hidden_2 = c.n_hidden - - c.n_hidden_5 = c.n_hidden - - # LSTM cell state dimension - c.n_cell_dim = c.n_hidden - - # The number of units in the third layer, which feeds in to the LSTM - c.n_hidden_3 = c.n_cell_dim - - # Units in the last layer = number of characters in the alphabet plus one - try: - # +1 for CTC blank label - c.n_hidden_6 = c.alphabet.GetSize() + 1 - except: - AttributeError - - # Size of audio window in samples - if (c.feature_win_len * c.audio_sample_rate) % 1000 != 0: - log_error( - "--feature_win_len value ({}) in milliseconds ({}) multiplied " - "by --audio_sample_rate value ({}) must be an integer value. Adjust " - "your --feature_win_len value or resample your audio accordingly." - "".format(c.feature_win_len, c.feature_win_len / 1000, c.audio_sample_rate) - ) - sys.exit(1) - - c.audio_window_samples = c.audio_sample_rate * (c.feature_win_len / 1000) - - # Stride for feature computations in samples - if (c.feature_win_step * c.audio_sample_rate) % 1000 != 0: - log_error( - "--feature_win_step value ({}) in milliseconds ({}) multiplied " - "by --audio_sample_rate value ({}) must be an integer value. Adjust " - "your --feature_win_step value or resample your audio accordingly." - "".format( - c.feature_win_step, c.feature_win_step / 1000, c.audio_sample_rate - ) - ) - sys.exit(1) - - c.audio_step_samples = c.audio_sample_rate * (c.feature_win_step / 1000) - - if c.one_shot_infer: - if not path_exists_remote(c.one_shot_infer): - log_error("Path specified in --one_shot_infer is not a valid file.") - sys.exit(1) - - if c.train_cudnn and c.load_cudnn: - log_error( - "Trying to use --train_cudnn, but --load_cudnn " - "was also specified. The --load_cudnn flag is only " - "needed when converting a CuDNN RNN checkpoint to " - "a CPU-capable graph. If your system is capable of " - "using CuDNN RNN, you can just specify the CuDNN RNN " - "checkpoint normally with --save_checkpoint_dir." - ) - sys.exit(1) - - # If separate save and load flags were not specified, default to load and save - # from the same dir. - if not c.save_checkpoint_dir: - c.save_checkpoint_dir = c.checkpoint_dir - - if not c.load_checkpoint_dir: - c.load_checkpoint_dir = c.checkpoint_dir +def initialize_globals_from_cli(): + c = _SttConfig() + c.parse_args(arg_prefix="") + c.__post_init__() + _ConfigSingleton._config = c # pylint: disable=protected-access +def initialize_globals_from_args(**override_args): + # Update Config with new args + c = _SttConfig(**override_args) _ConfigSingleton._config = c # pylint: disable=protected-access diff --git a/transcribe.py b/transcribe.py index b6f5af9a..2792ae2f 100755 --- a/transcribe.py +++ b/transcribe.py @@ -20,7 +20,7 @@ from multiprocessing import Process, cpu_count from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch from coqui_stt_training.util.audio import AudioFile -from coqui_stt_training.util.config import Config, _SttConfig, initialize_config_globals +from coqui_stt_training.util.config import Config, initialize_globals_from_cli from coqui_stt_training.util.feeding import split_audio_file from coqui_stt_training.util.flags import FLAGS, create_flags from coqui_stt_training.util.logging import ( @@ -42,8 +42,8 @@ def transcribe_file(audio_path, tlog_path): ) from coqui_stt_training.util.checkpoints import load_graph_for_evaluation - Config = _SttConfig() - initialize_config_globals(Config) + initialize_globals_from_cli() + scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet) try: num_processes = cpu_count()