Merge pull request #1967 from coqui-ai/batch-shuffling

Add support for shuffling batches after N epochs (Fixes #1901)
This commit is contained in:
Reuben Morais 2021-09-16 11:05:06 +02:00 committed by GitHub
commit 835d657648
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 39 additions and 14 deletions

View File

@ -20,7 +20,8 @@ fi
# and when trying to run on multiple devices (like GPUs), this will break # and when trying to run on multiple devices (like GPUs), this will break
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
python -u train.py --alphabet_config_path "data/alphabet.txt" \ python -m coqui_stt_training.train \
--alphabet_config_path "data/alphabet.txt" \
--show_progressbar false \ --show_progressbar false \
--train_files data/ldc93s1/ldc93s1.csv \ --train_files data/ldc93s1/ldc93s1.csv \
--test_files data/ldc93s1/ldc93s1.csv \ --test_files data/ldc93s1/ldc93s1.csv \

View File

@ -266,11 +266,9 @@ def early_training_checks():
) )
def create_training_datasets() -> ( def create_training_datasets(
tf.data.Dataset, epoch_ph: tf.Tensor = None,
[tf.data.Dataset], ) -> (tf.data.Dataset, [tf.data.Dataset], [tf.data.Dataset],):
[tf.data.Dataset],
):
"""Creates training datasets from input flags. """Creates training datasets from input flags.
Returns a single training dataset and two lists of datasets for validation Returns a single training dataset and two lists of datasets for validation
@ -288,6 +286,7 @@ def create_training_datasets() -> (
reverse=Config.reverse_train, reverse=Config.reverse_train,
limit=Config.limit_train, limit=Config.limit_train,
buffering=Config.read_buffer, buffering=Config.read_buffer,
epoch_ph=epoch_ph,
) )
dev_sets = [] dev_sets = []
@ -331,7 +330,8 @@ def train():
tfv1.reset_default_graph() tfv1.reset_default_graph()
tfv1.set_random_seed(Config.random_seed) tfv1.set_random_seed(Config.random_seed)
train_set, dev_sets, metrics_sets = create_training_datasets() epoch_ph = tf.placeholder(tf.int64, name="epoch_ph")
train_set, dev_sets, metrics_sets = create_training_datasets(epoch_ph)
iterator = tfv1.data.Iterator.from_structure( iterator = tfv1.data.Iterator.from_structure(
tfv1.data.get_output_types(train_set), tfv1.data.get_output_types(train_set),
@ -488,7 +488,7 @@ def train():
).start() ).start()
# Initialize iterator to the appropriate dataset # Initialize iterator to the appropriate dataset
session.run(init_op) session.run(init_op, {epoch_ph: epoch})
# Batch loop # Batch loop
while True: while True:
@ -507,7 +507,7 @@ def train():
non_finite_files, non_finite_files,
step_summaries_op, step_summaries_op,
], ],
feed_dict=feed_dict, feed_dict={**feed_dict, **{epoch_ph: epoch}},
) )
except tf.errors.OutOfRangeError: except tf.errors.OutOfRangeError:
break break

View File

@ -340,6 +340,22 @@ class _SttConfig(Coqpit):
help='after how many epochs the feature cache is invalidated again - 0 for "never"' help='after how many epochs the feature cache is invalidated again - 0 for "never"'
), ),
) )
shuffle_batches: bool = field(
default=False,
metadata=dict(
help="reshuffle batches every epoch, starting after N epochs, where N is set by the shuffle_start flag."
),
)
shuffle_start: int = field(
default=1,
metadata=dict(help="epoch to start shuffling batches from (zero-based)."),
)
shuffle_buffer: int = field(
default=1000,
metadata=dict(
help="how many batches to keep in shuffle buffer when shuffling batches."
),
)
feature_win_len: int = field( feature_win_len: int = field(
default=32, default=32,

View File

@ -140,6 +140,7 @@ def create_dataset(
limit=0, limit=0,
process_ahead=None, process_ahead=None,
buffering=1 * MEGABYTE, buffering=1 * MEGABYTE,
epoch_ph=None,
): ):
epoch_counter = Counter() # survives restarts of the dataset and its generator epoch_counter = Counter() # survives restarts of the dataset and its generator
@ -207,11 +208,18 @@ def create_dataset(
).map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) ).map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
if cache_path: if cache_path:
dataset = dataset.cache(cache_path) dataset = dataset.cache(cache_path)
dataset = ( dataset = dataset.window(batch_size, drop_remainder=train_phase).flat_map(batch_fn)
dataset.window(batch_size, drop_remainder=train_phase)
.flat_map(batch_fn) if Config.shuffle_batches and epoch_ph is not None:
.prefetch(len(Config.available_devices)) with tf.control_dependencies([tf.print("epoch:", epoch_ph)]):
epoch_buffer_size = tf.cond(
tf.less(epoch_ph, Config.shuffle_start),
lambda: tf.constant(1, tf.int64),
lambda: tf.constant(Config.shuffle_buffer, tf.int64),
) )
dataset = dataset.shuffle(epoch_buffer_size, seed=epoch_ph)
dataset = dataset.prefetch(len(Config.available_devices))
return dataset return dataset