diff --git a/bin/run-ldc93s1.sh b/bin/run-ldc93s1.sh index 8fe87e87..2bd80c59 100755 --- a/bin/run-ldc93s1.sh +++ b/bin/run-ldc93s1.sh @@ -20,7 +20,8 @@ fi # and when trying to run on multiple devices (like GPUs), this will break export CUDA_VISIBLE_DEVICES=0 -python -u train.py --alphabet_config_path "data/alphabet.txt" \ +python -m coqui_stt_training.train \ + --alphabet_config_path "data/alphabet.txt" \ --show_progressbar false \ --train_files data/ldc93s1/ldc93s1.csv \ --test_files data/ldc93s1/ldc93s1.csv \ diff --git a/training/coqui_stt_training/train.py b/training/coqui_stt_training/train.py index 98f9e407..d396fd7f 100644 --- a/training/coqui_stt_training/train.py +++ b/training/coqui_stt_training/train.py @@ -266,11 +266,9 @@ def early_training_checks(): ) -def create_training_datasets() -> ( - tf.data.Dataset, - [tf.data.Dataset], - [tf.data.Dataset], -): +def create_training_datasets( + epoch_ph: tf.Tensor = None, +) -> (tf.data.Dataset, [tf.data.Dataset], [tf.data.Dataset],): """Creates training datasets from input flags. Returns a single training dataset and two lists of datasets for validation @@ -288,6 +286,7 @@ def create_training_datasets() -> ( reverse=Config.reverse_train, limit=Config.limit_train, buffering=Config.read_buffer, + epoch_ph=epoch_ph, ) dev_sets = [] @@ -331,7 +330,8 @@ def train(): tfv1.reset_default_graph() tfv1.set_random_seed(Config.random_seed) - train_set, dev_sets, metrics_sets = create_training_datasets() + epoch_ph = tf.placeholder(tf.int64, name="epoch_ph") + train_set, dev_sets, metrics_sets = create_training_datasets(epoch_ph) iterator = tfv1.data.Iterator.from_structure( tfv1.data.get_output_types(train_set), @@ -488,7 +488,7 @@ def train(): ).start() # Initialize iterator to the appropriate dataset - session.run(init_op) + session.run(init_op, {epoch_ph: epoch}) # Batch loop while True: @@ -507,7 +507,7 @@ def train(): non_finite_files, step_summaries_op, ], - feed_dict=feed_dict, + feed_dict={**feed_dict, **{epoch_ph: epoch}}, ) except tf.errors.OutOfRangeError: break diff --git a/training/coqui_stt_training/util/config.py b/training/coqui_stt_training/util/config.py index 85493d4b..4ddc62b2 100644 --- a/training/coqui_stt_training/util/config.py +++ b/training/coqui_stt_training/util/config.py @@ -340,6 +340,22 @@ class _SttConfig(Coqpit): help='after how many epochs the feature cache is invalidated again - 0 for "never"' ), ) + shuffle_batches: bool = field( + default=False, + metadata=dict( + help="reshuffle batches every epoch, starting after N epochs, where N is set by the shuffle_start flag." + ), + ) + shuffle_start: int = field( + default=1, + metadata=dict(help="epoch to start shuffling batches from (zero-based)."), + ) + shuffle_buffer: int = field( + default=1000, + metadata=dict( + help="how many batches to keep in shuffle buffer when shuffling batches." + ), + ) feature_win_len: int = field( default=32, diff --git a/training/coqui_stt_training/util/feeding.py b/training/coqui_stt_training/util/feeding.py index 80ff0c20..bf506375 100644 --- a/training/coqui_stt_training/util/feeding.py +++ b/training/coqui_stt_training/util/feeding.py @@ -140,6 +140,7 @@ def create_dataset( limit=0, process_ahead=None, buffering=1 * MEGABYTE, + epoch_ph=None, ): epoch_counter = Counter() # survives restarts of the dataset and its generator @@ -207,11 +208,18 @@ def create_dataset( ).map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) if cache_path: dataset = dataset.cache(cache_path) - dataset = ( - dataset.window(batch_size, drop_remainder=train_phase) - .flat_map(batch_fn) - .prefetch(len(Config.available_devices)) - ) + dataset = dataset.window(batch_size, drop_remainder=train_phase).flat_map(batch_fn) + + if Config.shuffle_batches and epoch_ph is not None: + with tf.control_dependencies([tf.print("epoch:", epoch_ph)]): + epoch_buffer_size = tf.cond( + tf.less(epoch_ph, Config.shuffle_start), + lambda: tf.constant(1, tf.int64), + lambda: tf.constant(Config.shuffle_buffer, tf.int64), + ) + dataset = dataset.shuffle(epoch_buffer_size, seed=epoch_ph) + + dataset = dataset.prefetch(len(Config.available_devices)) return dataset