Merge pull request #1967 from coqui-ai/batch-shuffling

Add support for shuffling batches after N epochs (Fixes #1901)
This commit is contained in:
Reuben Morais 2021-09-16 11:05:06 +02:00 committed by GitHub
commit 835d657648
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 39 additions and 14 deletions

View File

@ -20,7 +20,8 @@ fi
# and when trying to run on multiple devices (like GPUs), this will break
export CUDA_VISIBLE_DEVICES=0
python -u train.py --alphabet_config_path "data/alphabet.txt" \
python -m coqui_stt_training.train \
--alphabet_config_path "data/alphabet.txt" \
--show_progressbar false \
--train_files data/ldc93s1/ldc93s1.csv \
--test_files data/ldc93s1/ldc93s1.csv \

View File

@ -266,11 +266,9 @@ def early_training_checks():
)
def create_training_datasets() -> (
tf.data.Dataset,
[tf.data.Dataset],
[tf.data.Dataset],
):
def create_training_datasets(
epoch_ph: tf.Tensor = None,
) -> (tf.data.Dataset, [tf.data.Dataset], [tf.data.Dataset],):
"""Creates training datasets from input flags.
Returns a single training dataset and two lists of datasets for validation
@ -288,6 +286,7 @@ def create_training_datasets() -> (
reverse=Config.reverse_train,
limit=Config.limit_train,
buffering=Config.read_buffer,
epoch_ph=epoch_ph,
)
dev_sets = []
@ -331,7 +330,8 @@ def train():
tfv1.reset_default_graph()
tfv1.set_random_seed(Config.random_seed)
train_set, dev_sets, metrics_sets = create_training_datasets()
epoch_ph = tf.placeholder(tf.int64, name="epoch_ph")
train_set, dev_sets, metrics_sets = create_training_datasets(epoch_ph)
iterator = tfv1.data.Iterator.from_structure(
tfv1.data.get_output_types(train_set),
@ -488,7 +488,7 @@ def train():
).start()
# Initialize iterator to the appropriate dataset
session.run(init_op)
session.run(init_op, {epoch_ph: epoch})
# Batch loop
while True:
@ -507,7 +507,7 @@ def train():
non_finite_files,
step_summaries_op,
],
feed_dict=feed_dict,
feed_dict={**feed_dict, **{epoch_ph: epoch}},
)
except tf.errors.OutOfRangeError:
break

View File

@ -340,6 +340,22 @@ class _SttConfig(Coqpit):
help='after how many epochs the feature cache is invalidated again - 0 for "never"'
),
)
shuffle_batches: bool = field(
default=False,
metadata=dict(
help="reshuffle batches every epoch, starting after N epochs, where N is set by the shuffle_start flag."
),
)
shuffle_start: int = field(
default=1,
metadata=dict(help="epoch to start shuffling batches from (zero-based)."),
)
shuffle_buffer: int = field(
default=1000,
metadata=dict(
help="how many batches to keep in shuffle buffer when shuffling batches."
),
)
feature_win_len: int = field(
default=32,

View File

@ -140,6 +140,7 @@ def create_dataset(
limit=0,
process_ahead=None,
buffering=1 * MEGABYTE,
epoch_ph=None,
):
epoch_counter = Counter() # survives restarts of the dataset and its generator
@ -207,11 +208,18 @@ def create_dataset(
).map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
if cache_path:
dataset = dataset.cache(cache_path)
dataset = (
dataset.window(batch_size, drop_remainder=train_phase)
.flat_map(batch_fn)
.prefetch(len(Config.available_devices))
)
dataset = dataset.window(batch_size, drop_remainder=train_phase).flat_map(batch_fn)
if Config.shuffle_batches and epoch_ph is not None:
with tf.control_dependencies([tf.print("epoch:", epoch_ph)]):
epoch_buffer_size = tf.cond(
tf.less(epoch_ph, Config.shuffle_start),
lambda: tf.constant(1, tf.int64),
lambda: tf.constant(Config.shuffle_buffer, tf.int64),
)
dataset = dataset.shuffle(epoch_buffer_size, seed=epoch_ph)
dataset = dataset.prefetch(len(Config.available_devices))
return dataset