diff --git a/train-ldc.ipynb b/train-ldc.ipynb index 78a94c2f..5f7a1168 100644 --- a/train-ldc.ipynb +++ b/train-ldc.ipynb @@ -1,3 +1,5 @@ +# Download LDC data + import os import sys @@ -30,17 +32,190 @@ def download_and_preprocess_data(data_dir): download_and_preprocess_data('/home/STT/data') +# Train -from STT.training.coqui_stt_training.train import train, early_training_checks -from STT.training.coqui_stt_training.util.config import initialize_globals +from STT.training.coqui_stt_training.util.config import _SttConfig, _ConfigSingleton +from STT.training.coqui_stt_training.util.augmentations import parse_augmentations, NormalizeSampleRate +from STT.training.coqui_stt_training.util.helpers import parse_file_size +from STT.training.coqui_stt_training.util.gpu import get_available_gpus +from coqui_stt_ctcdecoder import Alphabet +from xdg import BaseDirectory as xdg +import tensorflow.compat.v1 as tfv1 -#Config.train_files=['/home/STT/data/ldc.csv'] -#Config.dev_files=['/home/STT/data/ldc.csv'] -#Config.test_files=['/home/STT/data/ldc.csv'] +def initialize_globals(c): -#Config.alphabet_config_path='/home/STT/data/alphabet.txt' -initialize_globals() + # Augmentations + c.augmentations = parse_augmentations(c.augment) + print(f"Parsed augmentations from flags: {c.augmentations}") + if c.augmentations and c.feature_cache and c.cache_for_epochs == 0: + print( + "Due to current feature-cache settings the exact same sample augmentations of the first " + "epoch will be repeated on all following epochs. This could lead to unintended over-fitting. " + "You could use --cache_for_epochs to invalidate the cache after a given number of epochs." + ) + if c.normalize_sample_rate: + c.augmentations = [NormalizeSampleRate(c.audio_sample_rate)] + c[ + "augmentations" + ] + + # Caching + if c.cache_for_epochs == 1: + print( + "--cache_for_epochs == 1 is (re-)creating the feature cache on every epoch but will never use it." + ) + + # Read-buffer + c.read_buffer = parse_file_size(c.read_buffer) + + # Set default dropout rates + if c.dropout_rate2 < 0: + c.dropout_rate2 = c.dropout_rate + if c.dropout_rate3 < 0: + c.dropout_rate3 = c.dropout_rate + if c.dropout_rate6 < 0: + c.dropout_rate6 = c.dropout_rate + + # Set default checkpoint dir + if not c.checkpoint_dir: + c.checkpoint_dir = xdg.save_data_path(os.path.join("stt", "checkpoints")) + + if c.load_train not in ["last", "best", "init", "auto"]: + c.load_train = "auto" + + if c.load_evaluate not in ["last", "best", "auto"]: + c.load_evaluate = "auto" + + # Set default summary dir + if not c.summary_dir: + c.summary_dir = xdg.save_data_path(os.path.join("stt", "summaries")) + + # Standard session configuration that'll be used for all new sessions. + c.session_config = tfv1.ConfigProto( + allow_soft_placement=True, + log_device_placement=c.log_placement, + inter_op_parallelism_threads=c.inter_op_parallelism_threads, + intra_op_parallelism_threads=c.intra_op_parallelism_threads, + gpu_options=tfv1.GPUOptions(allow_growth=c.use_allow_growth), + ) + + # CPU device + c.cpu_device = "/cpu:0" + + # Available GPU devices + c.available_devices = get_available_gpus(c.session_config) + + # If there is no GPU available, we fall back to CPU based operation + if not c.available_devices: + c.available_devices = [c.cpu_device] + + c.alphabet_config_path="" + + if c.bytes_output_mode: + c.alphabet = UTF8Alphabet() + elif c.alphabet_config_path: + c.alphabet = Alphabet(os.path.abspath(c.alphabet_config_path)) + + # Geometric Constants + # =================== + + # For an explanation of the meaning of the geometric constants, please refer to + # doc/Geometry.md + + # Number of MFCC features + c.n_input = 26 # TODO: Determine this programmatically from the sample rate + + # The number of frames in the context + c.n_context = 9 # TODO: Determine the optimal value using a validation data set + + # Number of units in hidden layers + c.n_hidden = c.n_hidden + + c.n_hidden_1 = c.n_hidden + + c.n_hidden_2 = c.n_hidden + + c.n_hidden_5 = c.n_hidden + + # LSTM cell state dimension + c.n_cell_dim = c.n_hidden + + # The number of units in the third layer, which feeds in to the LSTM + c.n_hidden_3 = c.n_cell_dim + + # Units in the sixth layer = number of characters in the target language plus one + try: + c.n_hidden_6 = c.alphabet.GetSize() + 1 # +1 for CTC blank label + except: + AttributeError + + # Size of audio window in samples + if (c.feature_win_len * c.audio_sample_rate) % 1000 != 0: + log_error( + "--feature_win_len value ({}) in milliseconds ({}) multiplied " + "by --audio_sample_rate value ({}) must be an integer value. Adjust " + "your --feature_win_len value or resample your audio accordingly." + "".format(c.feature_win_len, c.feature_win_len / 1000, c.audio_sample_rate) + ) + sys.exit(1) + + c.audio_window_samples = c.audio_sample_rate * (c.feature_win_len / 1000) + + # Stride for feature computations in samples + if (c.feature_win_step * c.audio_sample_rate) % 1000 != 0: + log_error( + "--feature_win_step value ({}) in milliseconds ({}) multiplied " + "by --audio_sample_rate value ({}) must be an integer value. Adjust " + "your --feature_win_step value or resample your audio accordingly." + "".format( + c.feature_win_step, c.feature_win_step / 1000, c.audio_sample_rate + ) + ) + sys.exit(1) + + c.audio_step_samples = c.audio_sample_rate * (c.feature_win_step / 1000) + + if c.one_shot_infer: + if not path_exists_remote(c.one_shot_infer): + log_error("Path specified in --one_shot_infer is not a valid file.") + sys.exit(1) + + if c.train_cudnn and c.load_cudnn: + log_error( + "Trying to use --train_cudnn, but --load_cudnn " + "was also specified. The --load_cudnn flag is only " + "needed when converting a CuDNN RNN checkpoint to " + "a CPU-capable graph. If your system is capable of " + "using CuDNN RNN, you can just specify the CuDNN RNN " + "checkpoint normally with --save_checkpoint_dir." + ) + sys.exit(1) + + # If separate save and load flags were not specified, default to load and save + # from the same dir. + if not c.save_checkpoint_dir: + c.save_checkpoint_dir = c.checkpoint_dir + + if not c.load_checkpoint_dir: + c.load_checkpoint_dir = c.checkpoint_dir + + _ConfigSingleton._config = c # pylint: disable=protected-access + +from STT.training.coqui_stt_training.train import train, test, early_training_checks + +Config = _SttConfig() + +Config.alphabet = Alphabet('/home/STT/data/alphabet.txt') +Config.train_files=['/home/STT/data/ldc93s1.csv'] +Config.dev_files=['/home/STT/data/ldc93s1.csv'] +Config.test_files=['/home/STT/data/ldc93s1.csv'] +Config.n_hidden=100 +Config.epochs=200 + +initialize_globals(Config) + +#print(Config.to_json()) early_training_checks() - train() +tfv1.reset_default_graph() +test()