diff --git a/bin/run-ci-graph_augmentations.sh b/bin/run-ci-graph_augmentations.sh index ed01ccb7..85b5661c 100755 --- a/bin/run-ci-graph_augmentations.sh +++ b/bin/run-ci-graph_augmentations.sh @@ -14,7 +14,8 @@ fi; # and when trying to run on multiple devices (like GPUs), this will break export CUDA_VISIBLE_DEVICES=0 -python -u train.py --show_progressbar false --early_stop false \ +python -u train.py --alphabet_config_path "data/alphabet.txt" \ + --show_progressbar false --early_stop false \ --train_files ${ldc93s1_csv} --train_batch_size 1 \ --scorer "" \ --augment dropout \ diff --git a/bin/run-ci-ldc93s1_checkpoint.sh b/bin/run-ci-ldc93s1_checkpoint.sh index 68ebc8bd..a4591215 100755 --- a/bin/run-ci-ldc93s1_checkpoint.sh +++ b/bin/run-ci-ldc93s1_checkpoint.sh @@ -14,7 +14,8 @@ fi; # and when trying to run on multiple devices (like GPUs), this will break export CUDA_VISIBLE_DEVICES=0 -python -u train.py --show_progressbar false --early_stop false \ +python -u train.py --alphabet_config_path "data/alphabet.txt" \ + --show_progressbar false --early_stop false \ --train_files ${ldc93s1_csv} --train_batch_size 1 \ --dev_files ${ldc93s1_csv} --dev_batch_size 1 \ --test_files ${ldc93s1_csv} --test_batch_size 1 \ diff --git a/bin/run-ci-ldc93s1_checkpoint_sdb.sh b/bin/run-ci-ldc93s1_checkpoint_sdb.sh index 30708451..81126575 100755 --- a/bin/run-ci-ldc93s1_checkpoint_sdb.sh +++ b/bin/run-ci-ldc93s1_checkpoint_sdb.sh @@ -20,7 +20,8 @@ fi; # and when trying to run on multiple devices (like GPUs), this will break export CUDA_VISIBLE_DEVICES=0 -python -u train.py --show_progressbar false --early_stop false \ +python -u train.py --alphabet_config_path "data/alphabet.txt" \ + --show_progressbar false --early_stop false \ --train_files ${ldc93s1_sdb} --train_batch_size 1 \ --dev_files ${ldc93s1_sdb} --dev_batch_size 1 \ --test_files ${ldc93s1_sdb} --test_batch_size 1 \ diff --git a/bin/run-ci-ldc93s1_new.sh b/bin/run-ci-ldc93s1_new.sh index f67f2765..a0261257 100755 --- a/bin/run-ci-ldc93s1_new.sh +++ b/bin/run-ci-ldc93s1_new.sh @@ -17,7 +17,8 @@ fi; # and when trying to run on multiple devices (like GPUs), this will break export CUDA_VISIBLE_DEVICES=0 -python -u train.py --show_progressbar false --early_stop false \ +python -u train.py --alphabet_config_path "data/alphabet.txt" \ + --show_progressbar false --early_stop false \ --train_files ${ldc93s1_csv} --train_batch_size 1 \ --feature_cache '/tmp/ldc93s1_cache' \ --dev_files ${ldc93s1_csv} --dev_batch_size 1 \ diff --git a/bin/run-ci-ldc93s1_new_metrics.sh b/bin/run-ci-ldc93s1_new_metrics.sh index ee87c6b5..cf31bf22 100755 --- a/bin/run-ci-ldc93s1_new_metrics.sh +++ b/bin/run-ci-ldc93s1_new_metrics.sh @@ -17,7 +17,8 @@ fi; # and when trying to run on multiple devices (like GPUs), this will break export CUDA_VISIBLE_DEVICES=0 -python -u train.py --show_progressbar false --early_stop false \ +python -u train.py --alphabet_config_path "data/alphabet.txt" \ + --show_progressbar false --early_stop false \ --train_files ${ldc93s1_csv} --train_batch_size 1 \ --dev_files ${ldc93s1_csv} --dev_batch_size 1 \ --test_files ${ldc93s1_csv} --test_batch_size 1 \ diff --git a/bin/run-ci-ldc93s1_new_sdb.sh b/bin/run-ci-ldc93s1_new_sdb.sh index 0e9a5293..aa26e2c8 100755 --- a/bin/run-ci-ldc93s1_new_sdb.sh +++ b/bin/run-ci-ldc93s1_new_sdb.sh @@ -23,7 +23,8 @@ fi; # and when trying to run on multiple devices (like GPUs), this will break export CUDA_VISIBLE_DEVICES=0 -python -u train.py --show_progressbar false --early_stop false \ +python -u train.py --alphabet_config_path "data/alphabet.txt" \ + --show_progressbar false --early_stop false \ --train_files ${ldc93s1_sdb} --train_batch_size 1 \ --dev_files ${ldc93s1_sdb} --dev_batch_size 1 \ --test_files ${ldc93s1_sdb} --test_batch_size 1 \ diff --git a/bin/run-ci-ldc93s1_new_sdb_csv.sh b/bin/run-ci-ldc93s1_new_sdb_csv.sh index ca8cd388..9f9a185b 100755 --- a/bin/run-ci-ldc93s1_new_sdb_csv.sh +++ b/bin/run-ci-ldc93s1_new_sdb_csv.sh @@ -23,7 +23,8 @@ fi; # and when trying to run on multiple devices (like GPUs), this will break export CUDA_VISIBLE_DEVICES=0 -python -u train.py --show_progressbar false --early_stop false \ +python -u train.py --alphabet_config_path "data/alphabet.txt" \ + --show_progressbar false --early_stop false \ --train_files ${ldc93s1_sdb} ${ldc93s1_csv} --train_batch_size 1 \ --feature_cache '/tmp/ldc93s1_cache_sdb_csv' \ --dev_files ${ldc93s1_sdb} ${ldc93s1_csv} --dev_batch_size 1 \ diff --git a/bin/run-ci-ldc93s1_singleshotinference.sh b/bin/run-ci-ldc93s1_singleshotinference.sh index 8aaced54..699b09cb 100755 --- a/bin/run-ci-ldc93s1_singleshotinference.sh +++ b/bin/run-ci-ldc93s1_singleshotinference.sh @@ -14,7 +14,8 @@ fi; # and when trying to run on multiple devices (like GPUs), this will break export CUDA_VISIBLE_DEVICES=0 -python -u train.py --show_progressbar false --early_stop false \ +python -u train.py --alphabet_config_path "data/alphabet.txt" \ + --show_progressbar false --early_stop false \ --train_files ${ldc93s1_csv} --train_batch_size 1 \ --dev_files ${ldc93s1_csv} --dev_batch_size 1 \ --test_files ${ldc93s1_csv} --test_batch_size 1 \ @@ -23,7 +24,7 @@ python -u train.py --show_progressbar false --early_stop false \ --learning_rate 0.001 --dropout_rate 0.05 \ --scorer_path 'data/smoke_test/pruned_lm.scorer' -python -u train.py \ +python -u train.py --alphabet_config_path "data/alphabet.txt" \ --n_hidden 100 \ --checkpoint_dir '/tmp/ckpt' \ --scorer_path 'data/smoke_test/pruned_lm.scorer' \ diff --git a/bin/run-ci-ldc93s1_tflite.sh b/bin/run-ci-ldc93s1_tflite.sh index 0156d969..66342472 100755 --- a/bin/run-ci-ldc93s1_tflite.sh +++ b/bin/run-ci-ldc93s1_tflite.sh @@ -16,7 +16,8 @@ fi; # and when trying to run on multiple devices (like GPUs), this will break export CUDA_VISIBLE_DEVICES=0 -python -u train.py --show_progressbar false \ +python -u train.py --alphabet_config_path "data/alphabet.txt" \ + --show_progressbar false \ --n_hidden 100 \ --checkpoint_dir '/tmp/ckpt' \ --export_dir '/tmp/train_tflite' \ @@ -26,7 +27,8 @@ python -u train.py --show_progressbar false \ mkdir /tmp/train_tflite/en-us -python -u train.py --show_progressbar false \ +python -u train.py --alphabet_config_path "data/alphabet.txt" \ + --show_progressbar false \ --n_hidden 100 \ --checkpoint_dir '/tmp/ckpt' \ --export_dir '/tmp/train_tflite/en-us' \ diff --git a/bin/run-ldc93s1.py b/bin/run-ldc93s1.py new file mode 100755 index 00000000..e266b7a4 --- /dev/null +++ b/bin/run-ldc93s1.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python +import os +from import_ldc93s1 import _download_and_preprocess_data as download_ldc +from coqui_stt_training.util.config import initialize_globals_from_args +from coqui_stt_training.train import train, test, early_training_checks +import tensorflow.compat.v1 as tfv1 + +# only one GPU for only one training sample +os.environ["CUDA_VISIBLE_DEVICES"] = "0" + +download_ldc("data/ldc93s1") + +initialize_globals_from_args( + load_train="init", + alphabet_config_path="data/alphabet.txt", + train_files=["data/ldc93s1/ldc93s1.csv"], + dev_files=["data/ldc93s1/ldc93s1.csv"], + test_files=["data/ldc93s1/ldc93s1.csv"], + augment=["time_mask"], + n_hidden=100, + epochs=200, +) + +early_training_checks() + +train() +tfv1.reset_default_graph() +test() diff --git a/bin/run-ldc93s1.sh b/bin/run-ldc93s1.sh index fdf34609..8fe87e87 100755 --- a/bin/run-ldc93s1.sh +++ b/bin/run-ldc93s1.sh @@ -20,7 +20,8 @@ fi # and when trying to run on multiple devices (like GPUs), this will break export CUDA_VISIBLE_DEVICES=0 -python -u train.py --show_progressbar false \ +python -u train.py --alphabet_config_path "data/alphabet.txt" \ + --show_progressbar false \ --train_files data/ldc93s1/ldc93s1.csv \ --test_files data/ldc93s1/ldc93s1.csv \ --train_batch_size 1 \ diff --git a/ci_scripts/train-extra-tests.sh b/ci_scripts/train-extra-tests.sh index 1f76a0ed..50265afc 100755 --- a/ci_scripts/train-extra-tests.sh +++ b/ci_scripts/train-extra-tests.sh @@ -66,3 +66,6 @@ time ./bin/run-ci-ldc93s1_checkpoint_sdb.sh # Bytes output mode, resuming from checkpoint time ./bin/run-ci-ldc93s1_checkpoint_bytes.sh + +# Training with args set via initialize_globals_from_args() +time python ./bin/run-ldc93s1.py diff --git a/lm_optimizer.py b/lm_optimizer.py index 86100ec4..85ca1fd5 100644 --- a/lm_optimizer.py +++ b/lm_optimizer.py @@ -10,7 +10,7 @@ import tensorflow.compat.v1 as tfv1 from coqui_stt_ctcdecoder import Scorer from coqui_stt_training.evaluate import evaluate from coqui_stt_training.train import create_model -from coqui_stt_training.util.config import Config, initialize_globals +from coqui_stt_training.util.config import Config, initialize_globals_from_cli from coqui_stt_training.util.evaluate_tools import wer_cer_batch from coqui_stt_training.util.flags import FLAGS, create_flags from coqui_stt_training.util.logging import log_error @@ -52,7 +52,7 @@ def objective(trial): def main(_): - initialize_globals() + initialize_globals_from_cli() if not FLAGS.test_files: log_error( diff --git a/training/coqui_stt_training/evaluate.py b/training/coqui_stt_training/evaluate.py index ecff4502..4425582b 100755 --- a/training/coqui_stt_training/evaluate.py +++ b/training/coqui_stt_training/evaluate.py @@ -18,7 +18,7 @@ from .util.checkpoints import load_graph_for_evaluation from .util.config import ( Config, create_progressbar, - initialize_globals, + initialize_globals_from_cli, log_error, log_progress, ) @@ -169,7 +169,7 @@ def evaluate(test_csvs, create_model): def main(): - initialize_globals() + initialize_globals_from_cli() if not Config.test_files: log_error( diff --git a/training/coqui_stt_training/train.py b/training/coqui_stt_training/train.py index 2cf02397..b6cfaec9 100644 --- a/training/coqui_stt_training/train.py +++ b/training/coqui_stt_training/train.py @@ -45,7 +45,7 @@ from .util.checkpoints import ( from .util.config import ( Config, create_progressbar, - initialize_globals, + initialize_globals_from_cli, log_debug, log_error, log_info, @@ -1249,7 +1249,7 @@ def early_training_checks(): def main(): - initialize_globals() + initialize_globals_from_cli() early_training_checks() if Config.train_files: diff --git a/training/coqui_stt_training/util/config.py b/training/coqui_stt_training/util/config.py index e0c925b9..9427c382 100755 --- a/training/coqui_stt_training/util/config.py +++ b/training/coqui_stt_training/util/config.py @@ -36,10 +36,201 @@ Config = _ConfigSingleton() # pylint: disable=invalid-name @dataclass class _SttConfig(Coqpit): + def __post_init__(self): + # Augmentations + self.augmentations = parse_augmentations(self.augment) + if self.augmentations: + print(f"Parsed augmentations: {self.augmentations}") + if self.augmentations and self.feature_cache and self.cache_for_epochs == 0: + print( + "Due to your feature-cache settings, augmentations of " + "the first epoch will be repeated on all following epochs. " + "This may lead to unintended over-fitting. " + "You can use --cache_for_epochs to invalidate " + "the cache after a given number of epochs." + ) + + if self.normalize_sample_rate: + self.augmentations = [NormalizeSampleRate(self.audio_sample_rate)] + self[ + "augmentations" + ] + + # Caching + if self.cache_for_epochs == 1: + print( + "--cache_for_epochs == 1 is (re-)creating the feature cache " + "on every epoch but will never use it. You can either set " + "--cache_for_epochs > 1, or not use feature caching at all." + ) + + # Read-buffer + self.read_buffer = parse_file_size(self.read_buffer) + + # Set default dropout rates + if self.dropout_rate2 < 0: + self.dropout_rate2 = self.dropout_rate + if self.dropout_rate3 < 0: + self.dropout_rate3 = self.dropout_rate + if self.dropout_rate6 < 0: + self.dropout_rate6 = self.dropout_rate + + # Set default checkpoint dir + # If separate save and load flags were not specified, default to load and save + # from the same dir. + + # if save_checkpoint_dir hasn't been set, or checkpoint_dir is new + if (not self.save_checkpoint_dir) or ( + self.save_checkpoint_dir is not self.checkpoint_dir + ): + if not self.checkpoint_dir: + self.checkpoint_dir = xdg.save_data_path( + os.path.join("stt", "checkpoints") + ) + self.save_checkpoint_dir = self.checkpoint_dir + else: + self.save_checkpoint_dir = self.checkpoint_dir + # if load_checkpoint_dir hasn't been set, or checkpoint_dir is new + if (not self.load_checkpoint_dir) or ( + self.load_checkpoint_dir is not self.checkpoint_dir + ): + if not self.checkpoint_dir: + self.checkpoint_dir = xdg.load_data_path( + os.path.join("stt", "checkpoints") + ) + self.load_checkpoint_dir = self.checkpoint_dir + else: + self.load_checkpoint_dir = self.checkpoint_dir + + if self.load_train not in ["last", "best", "init", "auto"]: + self.load_train = "auto" + + if self.load_evaluate not in ["last", "best", "auto"]: + self.load_evaluate = "auto" + + # Set default summary dir + if not self.summary_dir: + self.summary_dir = xdg.save_data_path(os.path.join("stt", "summaries")) + + # Standard session configuration that'll be used for all new sessions. + self.session_config = tfv1.ConfigProto( + allow_soft_placement=True, + log_device_placement=self.log_placement, + inter_op_parallelism_threads=self.inter_op_parallelism_threads, + intra_op_parallelism_threads=self.intra_op_parallelism_threads, + gpu_options=tfv1.GPUOptions(allow_growth=self.use_allow_growth), + ) + + # CPU device + self.cpu_device = "/cpu:0" + + # Available GPU devices + self.available_devices = get_available_gpus(self.session_config) + + # If there is no GPU available, we fall back to CPU based operation + if not self.available_devices: + self.available_devices = [self.cpu_device] + + if self.bytes_output_mode and self.alphabet_config_path: + raise RuntimeError( + "You cannot set --alphabet_config_path *and* --bytes_output_mode" + ) + elif self.bytes_output_mode: + self.alphabet = UTF8Alphabet() + elif self.alphabet_config_path: + self.alphabet = Alphabet(os.path.abspath(self.alphabet_config_path)) + + # Geometric Constants + # =================== + + # For an explanation of the meaning of the geometric constants + # please refer to doc/Geometry.md + + # Number of MFCC features + self.n_input = 26 # TODO: Determine this programmatically from the sample rate + + # The number of frames in the context + self.n_context = ( + 9 # TODO: Determine the optimal value using a validation data set + ) + + # Number of units in hidden layers + self.n_hidden = self.n_hidden + + self.n_hidden_1 = self.n_hidden + + self.n_hidden_2 = self.n_hidden + + self.n_hidden_5 = self.n_hidden + + # LSTM cell state dimension + self.n_cell_dim = self.n_hidden + + # The number of units in the third layer, which feeds in to the LSTM + self.n_hidden_3 = self.n_cell_dim + + # Dims in last layer = number of characters in alphabet plus one + try: + # +1 for CTC blank label + self.n_hidden_6 = self.alphabet.GetSize() + 1 + except: + AttributeError + + # Size of audio window in samples + if (self.feature_win_len * self.audio_sample_rate) % 1000 != 0: + log_error( + "--feature_win_len value ({}) in milliseconds ({}) multiplied " + "by --audio_sample_rate value ({}) must be an integer value. Adjust " + "your --feature_win_len value or resample your audio accordingly." + "".format( + self.feature_win_len, + self.feature_win_len / 1000, + self.audio_sample_rate, + ) + ) + sys.exit(1) + + self.audio_window_samples = self.audio_sample_rate * ( + self.feature_win_len / 1000 + ) + + # Stride for feature computations in samples + if (self.feature_win_step * self.audio_sample_rate) % 1000 != 0: + log_error( + "--feature_win_step value ({}) in milliseconds ({}) multiplied " + "by --audio_sample_rate value ({}) must be an integer value. Adjust " + "your --feature_win_step value or resample your audio accordingly." + "".format( + self.feature_win_step, + self.feature_win_step / 1000, + self.audio_sample_rate, + ) + ) + sys.exit(1) + + self.audio_step_samples = self.audio_sample_rate * ( + self.feature_win_step / 1000 + ) + + if self.one_shot_infer: + if not path_exists_remote(self.one_shot_infer): + log_error("Path specified in --one_shot_infer is not a valid file.") + sys.exit(1) + + if self.train_cudnn and self.load_cudnn: + log_error( + "Trying to use --train_cudnn, but --load_cudnn " + "was also specified. The --load_cudnn flag is only " + "needed when converting a CuDNN RNN checkpoint to " + "a CPU-capable graph. If your system is capable of " + "using CuDNN RNN, you can just specify the CuDNN RNN " + "checkpoint normally with --save_checkpoint_dir." + ) + sys.exit(1) + train_files: List[str] = field( default_factory=list, metadata=dict( - help="space-separated list of files specifying the dataset used for training. Multiple files will get merged. If empty, training will not be run." + help="space-separated list of files specifying the datasets used for training. Multiple files will get merged. If empty, training will not be run." ), ) dev_files: List[str] = field( @@ -472,7 +663,7 @@ class _SttConfig(Coqpit): ), ) alphabet_config_path: str = field( - default="data/alphabet.txt", + default="", metadata=dict( help="path to the configuration file specifying the alphabet used by the network. See the comment in data/alphabet.txt for a description of the format." ), @@ -540,166 +731,17 @@ class _SttConfig(Coqpit): ), ) - def check_values(self): - c = asdict(self) - check_argument("alphabet_config_path", c, is_path=True) - check_argument("one_shot_infer", c, is_path=True) - -def initialize_globals(): +def initialize_globals_from_cli(): c = _SttConfig() c.parse_args(arg_prefix="") + c.__post_init__() + _ConfigSingleton._config = c # pylint: disable=protected-access - # Augmentations - c.augmentations = parse_augmentations(c.augment) - print(f"Parsed augmentations from flags: {c.augmentations}") - if c.augmentations and c.feature_cache and c.cache_for_epochs == 0: - print( - "Due to current feature-cache settings the exact same sample augmentations of the first " - "epoch will be repeated on all following epochs. This could lead to unintended over-fitting. " - "You could use --cache_for_epochs to invalidate the cache after a given number of epochs." - ) - - if c.normalize_sample_rate: - c.augmentations = [NormalizeSampleRate(c.audio_sample_rate)] + c[ - "augmentations" - ] - - # Caching - if c.cache_for_epochs == 1: - print( - "--cache_for_epochs == 1 is (re-)creating the feature cache on every epoch but will never use it." - ) - - # Read-buffer - c.read_buffer = parse_file_size(c.read_buffer) - - # Set default dropout rates - if c.dropout_rate2 < 0: - c.dropout_rate2 = c.dropout_rate - if c.dropout_rate3 < 0: - c.dropout_rate3 = c.dropout_rate - if c.dropout_rate6 < 0: - c.dropout_rate6 = c.dropout_rate - - # Set default checkpoint dir - if not c.checkpoint_dir: - c.checkpoint_dir = xdg.save_data_path(os.path.join("stt", "checkpoints")) - - if c.load_train not in ["last", "best", "init", "auto"]: - c.load_train = "auto" - - if c.load_evaluate not in ["last", "best", "auto"]: - c.load_evaluate = "auto" - - # Set default summary dir - if not c.summary_dir: - c.summary_dir = xdg.save_data_path(os.path.join("stt", "summaries")) - - # Standard session configuration that'll be used for all new sessions. - c.session_config = tfv1.ConfigProto( - allow_soft_placement=True, - log_device_placement=c.log_placement, - inter_op_parallelism_threads=c.inter_op_parallelism_threads, - intra_op_parallelism_threads=c.intra_op_parallelism_threads, - gpu_options=tfv1.GPUOptions(allow_growth=c.use_allow_growth), - ) - - # CPU device - c.cpu_device = "/cpu:0" - - # Available GPU devices - c.available_devices = get_available_gpus(c.session_config) - - # If there is no GPU available, we fall back to CPU based operation - if not c.available_devices: - c.available_devices = [c.cpu_device] - - if c.bytes_output_mode: - c.alphabet = UTF8Alphabet() - else: - c.alphabet = Alphabet(os.path.abspath(c.alphabet_config_path)) - - # Geometric Constants - # =================== - - # For an explanation of the meaning of the geometric constants, please refer to - # doc/Geometry.md - - # Number of MFCC features - c.n_input = 26 # TODO: Determine this programmatically from the sample rate - - # The number of frames in the context - c.n_context = 9 # TODO: Determine the optimal value using a validation data set - - # Number of units in hidden layers - c.n_hidden = c.n_hidden - - c.n_hidden_1 = c.n_hidden - - c.n_hidden_2 = c.n_hidden - - c.n_hidden_5 = c.n_hidden - - # LSTM cell state dimension - c.n_cell_dim = c.n_hidden - - # The number of units in the third layer, which feeds in to the LSTM - c.n_hidden_3 = c.n_cell_dim - - # Units in the sixth layer = number of characters in the target language plus one - c.n_hidden_6 = c.alphabet.GetSize() + 1 # +1 for CTC blank label - - # Size of audio window in samples - if (c.feature_win_len * c.audio_sample_rate) % 1000 != 0: - log_error( - "--feature_win_len value ({}) in milliseconds ({}) multiplied " - "by --audio_sample_rate value ({}) must be an integer value. Adjust " - "your --feature_win_len value or resample your audio accordingly." - "".format(c.feature_win_len, c.feature_win_len / 1000, c.audio_sample_rate) - ) - sys.exit(1) - - c.audio_window_samples = c.audio_sample_rate * (c.feature_win_len / 1000) - - # Stride for feature computations in samples - if (c.feature_win_step * c.audio_sample_rate) % 1000 != 0: - log_error( - "--feature_win_step value ({}) in milliseconds ({}) multiplied " - "by --audio_sample_rate value ({}) must be an integer value. Adjust " - "your --feature_win_step value or resample your audio accordingly." - "".format( - c.feature_win_step, c.feature_win_step / 1000, c.audio_sample_rate - ) - ) - sys.exit(1) - - c.audio_step_samples = c.audio_sample_rate * (c.feature_win_step / 1000) - - if c.one_shot_infer: - if not path_exists_remote(c.one_shot_infer): - log_error("Path specified in --one_shot_infer is not a valid file.") - sys.exit(1) - - if c.train_cudnn and c.load_cudnn: - log_error( - "Trying to use --train_cudnn, but --load_cudnn " - "was also specified. The --load_cudnn flag is only " - "needed when converting a CuDNN RNN checkpoint to " - "a CPU-capable graph. If your system is capable of " - "using CuDNN RNN, you can just specify the CuDNN RNN " - "checkpoint normally with --save_checkpoint_dir." - ) - sys.exit(1) - - # If separate save and load flags were not specified, default to load and save - # from the same dir. - if not c.save_checkpoint_dir: - c.save_checkpoint_dir = c.checkpoint_dir - - if not c.load_checkpoint_dir: - c.load_checkpoint_dir = c.checkpoint_dir +def initialize_globals_from_args(**override_args): + # Update Config with new args + c = _SttConfig(**override_args) _ConfigSingleton._config = c # pylint: disable=protected-access diff --git a/training/coqui_stt_training/util/helpers.py b/training/coqui_stt_training/util/helpers.py index c8aa788a..b897e4a9 100644 --- a/training/coqui_stt_training/util/helpers.py +++ b/training/coqui_stt_training/util/helpers.py @@ -19,14 +19,19 @@ ValueRange = namedtuple("ValueRange", "start end r") def parse_file_size(file_size): - file_size = file_size.lower().strip() - if len(file_size) == 0: - return 0 - n = int(keep_only_digits(file_size)) - if file_size[-1] == "b": - file_size = file_size[:-1] - e = file_size[-1] - return SIZE_PREFIX_LOOKUP[e] * n if e in SIZE_PREFIX_LOOKUP else n + if type(file_size) is str: + file_size = file_size.lower().strip() + if len(file_size) == 0: + return 0 + n = int(keep_only_digits(file_size)) + if file_size[-1] == "b": + file_size = file_size[:-1] + e = file_size[-1] + return SIZE_PREFIX_LOOKUP[e] * n if e in SIZE_PREFIX_LOOKUP else n + elif type(file_size) is int: + return file_size + else: + raise ValueError("file_size not of type 'int' or 'str'") def keep_only_digits(txt): diff --git a/transcribe.py b/transcribe.py index b0492c87..2792ae2f 100755 --- a/transcribe.py +++ b/transcribe.py @@ -20,7 +20,7 @@ from multiprocessing import Process, cpu_count from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch from coqui_stt_training.util.audio import AudioFile -from coqui_stt_training.util.config import Config, initialize_globals +from coqui_stt_training.util.config import Config, initialize_globals_from_cli from coqui_stt_training.util.feeding import split_audio_file from coqui_stt_training.util.flags import FLAGS, create_flags from coqui_stt_training.util.logging import ( @@ -42,7 +42,8 @@ def transcribe_file(audio_path, tlog_path): ) from coqui_stt_training.util.checkpoints import load_graph_for_evaluation - initialize_globals() + initialize_globals_from_cli() + scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet) try: num_processes = cpu_count()