diff --git a/bin/run-ldc93s1.py b/bin/run-ldc93s1.py index 610d20ba..e266b7a4 100755 --- a/bin/run-ldc93s1.py +++ b/bin/run-ldc93s1.py @@ -6,7 +6,7 @@ from coqui_stt_training.train import train, test, early_training_checks import tensorflow.compat.v1 as tfv1 # only one GPU for only one training sample -os.environ['CUDA_VISIBLE_DEVICES']='0' +os.environ["CUDA_VISIBLE_DEVICES"] = "0" download_ldc("data/ldc93s1") @@ -18,7 +18,7 @@ initialize_globals_from_args( test_files=["data/ldc93s1/ldc93s1.csv"], augment=["time_mask"], n_hidden=100, - epochs=200 + epochs=200, ) early_training_checks() diff --git a/training/coqui_stt_training/util/config.py b/training/coqui_stt_training/util/config.py index d7a19ffa..9427c382 100755 --- a/training/coqui_stt_training/util/config.py +++ b/training/coqui_stt_training/util/config.py @@ -79,16 +79,24 @@ class _SttConfig(Coqpit): # from the same dir. # if save_checkpoint_dir hasn't been set, or checkpoint_dir is new - if (not self.save_checkpoint_dir) or (self.save_checkpoint_dir is not self.checkpoint_dir): + if (not self.save_checkpoint_dir) or ( + self.save_checkpoint_dir is not self.checkpoint_dir + ): if not self.checkpoint_dir: - self.checkpoint_dir = xdg.save_data_path(os.path.join("stt", "checkpoints")) + self.checkpoint_dir = xdg.save_data_path( + os.path.join("stt", "checkpoints") + ) self.save_checkpoint_dir = self.checkpoint_dir else: self.save_checkpoint_dir = self.checkpoint_dir # if load_checkpoint_dir hasn't been set, or checkpoint_dir is new - if (not self.load_checkpoint_dir) or (self.load_checkpoint_dir is not self.checkpoint_dir): + if (not self.load_checkpoint_dir) or ( + self.load_checkpoint_dir is not self.checkpoint_dir + ): if not self.checkpoint_dir: - self.checkpoint_dir = xdg.load_data_path(os.path.join("stt", "checkpoints")) + self.checkpoint_dir = xdg.load_data_path( + os.path.join("stt", "checkpoints") + ) self.load_checkpoint_dir = self.checkpoint_dir else: self.load_checkpoint_dir = self.checkpoint_dir @@ -138,10 +146,12 @@ class _SttConfig(Coqpit): # please refer to doc/Geometry.md # Number of MFCC features - self.n_input = 26 # TODO: Determine this programmatically from the sample rate + self.n_input = 26 # TODO: Determine this programmatically from the sample rate # The number of frames in the context - self.n_context = 9 # TODO: Determine the optimal value using a validation data set + self.n_context = ( + 9 # TODO: Determine the optimal value using a validation data set + ) # Number of units in hidden layers self.n_hidden = self.n_hidden @@ -171,11 +181,17 @@ class _SttConfig(Coqpit): "--feature_win_len value ({}) in milliseconds ({}) multiplied " "by --audio_sample_rate value ({}) must be an integer value. Adjust " "your --feature_win_len value or resample your audio accordingly." - "".format(self.feature_win_len, self.feature_win_len / 1000, self.audio_sample_rate) + "".format( + self.feature_win_len, + self.feature_win_len / 1000, + self.audio_sample_rate, + ) ) sys.exit(1) - self.audio_window_samples = self.audio_sample_rate * (self.feature_win_len / 1000) + self.audio_window_samples = self.audio_sample_rate * ( + self.feature_win_len / 1000 + ) # Stride for feature computations in samples if (self.feature_win_step * self.audio_sample_rate) % 1000 != 0: @@ -184,12 +200,16 @@ class _SttConfig(Coqpit): "by --audio_sample_rate value ({}) must be an integer value. Adjust " "your --feature_win_step value or resample your audio accordingly." "".format( - self.feature_win_step, self.feature_win_step / 1000, self.audio_sample_rate + self.feature_win_step, + self.feature_win_step / 1000, + self.audio_sample_rate, ) ) sys.exit(1) - self.audio_step_samples = self.audio_sample_rate * (self.feature_win_step / 1000) + self.audio_step_samples = self.audio_sample_rate * ( + self.feature_win_step / 1000 + ) if self.one_shot_infer: if not path_exists_remote(self.one_shot_infer): @@ -718,6 +738,7 @@ def initialize_globals_from_cli(): c.__post_init__() _ConfigSingleton._config = c # pylint: disable=protected-access + def initialize_globals_from_args(**override_args): # Update Config with new args c = _SttConfig(**override_args)