Merge pull request #1967 from coqui-ai/batch-shuffling
Add support for shuffling batches after N epochs (Fixes #1901)
This commit is contained in:
commit
835d657648
|
@ -20,7 +20,8 @@ fi
|
||||||
# and when trying to run on multiple devices (like GPUs), this will break
|
# and when trying to run on multiple devices (like GPUs), this will break
|
||||||
export CUDA_VISIBLE_DEVICES=0
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
|
|
||||||
python -u train.py --alphabet_config_path "data/alphabet.txt" \
|
python -m coqui_stt_training.train \
|
||||||
|
--alphabet_config_path "data/alphabet.txt" \
|
||||||
--show_progressbar false \
|
--show_progressbar false \
|
||||||
--train_files data/ldc93s1/ldc93s1.csv \
|
--train_files data/ldc93s1/ldc93s1.csv \
|
||||||
--test_files data/ldc93s1/ldc93s1.csv \
|
--test_files data/ldc93s1/ldc93s1.csv \
|
||||||
|
|
|
@ -266,11 +266,9 @@ def early_training_checks():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_training_datasets() -> (
|
def create_training_datasets(
|
||||||
tf.data.Dataset,
|
epoch_ph: tf.Tensor = None,
|
||||||
[tf.data.Dataset],
|
) -> (tf.data.Dataset, [tf.data.Dataset], [tf.data.Dataset],):
|
||||||
[tf.data.Dataset],
|
|
||||||
):
|
|
||||||
"""Creates training datasets from input flags.
|
"""Creates training datasets from input flags.
|
||||||
|
|
||||||
Returns a single training dataset and two lists of datasets for validation
|
Returns a single training dataset and two lists of datasets for validation
|
||||||
|
@ -288,6 +286,7 @@ def create_training_datasets() -> (
|
||||||
reverse=Config.reverse_train,
|
reverse=Config.reverse_train,
|
||||||
limit=Config.limit_train,
|
limit=Config.limit_train,
|
||||||
buffering=Config.read_buffer,
|
buffering=Config.read_buffer,
|
||||||
|
epoch_ph=epoch_ph,
|
||||||
)
|
)
|
||||||
|
|
||||||
dev_sets = []
|
dev_sets = []
|
||||||
|
@ -331,7 +330,8 @@ def train():
|
||||||
tfv1.reset_default_graph()
|
tfv1.reset_default_graph()
|
||||||
tfv1.set_random_seed(Config.random_seed)
|
tfv1.set_random_seed(Config.random_seed)
|
||||||
|
|
||||||
train_set, dev_sets, metrics_sets = create_training_datasets()
|
epoch_ph = tf.placeholder(tf.int64, name="epoch_ph")
|
||||||
|
train_set, dev_sets, metrics_sets = create_training_datasets(epoch_ph)
|
||||||
|
|
||||||
iterator = tfv1.data.Iterator.from_structure(
|
iterator = tfv1.data.Iterator.from_structure(
|
||||||
tfv1.data.get_output_types(train_set),
|
tfv1.data.get_output_types(train_set),
|
||||||
|
@ -488,7 +488,7 @@ def train():
|
||||||
).start()
|
).start()
|
||||||
|
|
||||||
# Initialize iterator to the appropriate dataset
|
# Initialize iterator to the appropriate dataset
|
||||||
session.run(init_op)
|
session.run(init_op, {epoch_ph: epoch})
|
||||||
|
|
||||||
# Batch loop
|
# Batch loop
|
||||||
while True:
|
while True:
|
||||||
|
@ -507,7 +507,7 @@ def train():
|
||||||
non_finite_files,
|
non_finite_files,
|
||||||
step_summaries_op,
|
step_summaries_op,
|
||||||
],
|
],
|
||||||
feed_dict=feed_dict,
|
feed_dict={**feed_dict, **{epoch_ph: epoch}},
|
||||||
)
|
)
|
||||||
except tf.errors.OutOfRangeError:
|
except tf.errors.OutOfRangeError:
|
||||||
break
|
break
|
||||||
|
|
|
@ -340,6 +340,22 @@ class _SttConfig(Coqpit):
|
||||||
help='after how many epochs the feature cache is invalidated again - 0 for "never"'
|
help='after how many epochs the feature cache is invalidated again - 0 for "never"'
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
shuffle_batches: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata=dict(
|
||||||
|
help="reshuffle batches every epoch, starting after N epochs, where N is set by the shuffle_start flag."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
shuffle_start: int = field(
|
||||||
|
default=1,
|
||||||
|
metadata=dict(help="epoch to start shuffling batches from (zero-based)."),
|
||||||
|
)
|
||||||
|
shuffle_buffer: int = field(
|
||||||
|
default=1000,
|
||||||
|
metadata=dict(
|
||||||
|
help="how many batches to keep in shuffle buffer when shuffling batches."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
feature_win_len: int = field(
|
feature_win_len: int = field(
|
||||||
default=32,
|
default=32,
|
||||||
|
|
|
@ -140,6 +140,7 @@ def create_dataset(
|
||||||
limit=0,
|
limit=0,
|
||||||
process_ahead=None,
|
process_ahead=None,
|
||||||
buffering=1 * MEGABYTE,
|
buffering=1 * MEGABYTE,
|
||||||
|
epoch_ph=None,
|
||||||
):
|
):
|
||||||
epoch_counter = Counter() # survives restarts of the dataset and its generator
|
epoch_counter = Counter() # survives restarts of the dataset and its generator
|
||||||
|
|
||||||
|
@ -207,11 +208,18 @@ def create_dataset(
|
||||||
).map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
).map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
||||||
if cache_path:
|
if cache_path:
|
||||||
dataset = dataset.cache(cache_path)
|
dataset = dataset.cache(cache_path)
|
||||||
dataset = (
|
dataset = dataset.window(batch_size, drop_remainder=train_phase).flat_map(batch_fn)
|
||||||
dataset.window(batch_size, drop_remainder=train_phase)
|
|
||||||
.flat_map(batch_fn)
|
if Config.shuffle_batches and epoch_ph is not None:
|
||||||
.prefetch(len(Config.available_devices))
|
with tf.control_dependencies([tf.print("epoch:", epoch_ph)]):
|
||||||
|
epoch_buffer_size = tf.cond(
|
||||||
|
tf.less(epoch_ph, Config.shuffle_start),
|
||||||
|
lambda: tf.constant(1, tf.int64),
|
||||||
|
lambda: tf.constant(Config.shuffle_buffer, tf.int64),
|
||||||
)
|
)
|
||||||
|
dataset = dataset.shuffle(epoch_buffer_size, seed=epoch_ph)
|
||||||
|
|
||||||
|
dataset = dataset.prefetch(len(Config.available_devices))
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue