Add support for shuffling batches after N epochs
This commit is contained in:
parent
feeb2a222d
commit
72599be9d4
|
@ -20,7 +20,8 @@ fi
|
|||
# and when trying to run on multiple devices (like GPUs), this will break
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
python -u train.py --alphabet_config_path "data/alphabet.txt" \
|
||||
python -m coqui_stt_training.train \
|
||||
--alphabet_config_path "data/alphabet.txt" \
|
||||
--show_progressbar false \
|
||||
--train_files data/ldc93s1/ldc93s1.csv \
|
||||
--test_files data/ldc93s1/ldc93s1.csv \
|
||||
|
|
|
@ -266,11 +266,9 @@ def early_training_checks():
|
|||
)
|
||||
|
||||
|
||||
def create_training_datasets() -> (
|
||||
tf.data.Dataset,
|
||||
[tf.data.Dataset],
|
||||
[tf.data.Dataset],
|
||||
):
|
||||
def create_training_datasets(
|
||||
epoch_ph: tf.Tensor = None,
|
||||
) -> (tf.data.Dataset, [tf.data.Dataset], [tf.data.Dataset],):
|
||||
"""Creates training datasets from input flags.
|
||||
|
||||
Returns a single training dataset and two lists of datasets for validation
|
||||
|
@ -288,6 +286,7 @@ def create_training_datasets() -> (
|
|||
reverse=Config.reverse_train,
|
||||
limit=Config.limit_train,
|
||||
buffering=Config.read_buffer,
|
||||
epoch_ph=epoch_ph,
|
||||
)
|
||||
|
||||
dev_sets = []
|
||||
|
@ -331,7 +330,8 @@ def train():
|
|||
tfv1.reset_default_graph()
|
||||
tfv1.set_random_seed(Config.random_seed)
|
||||
|
||||
train_set, dev_sets, metrics_sets = create_training_datasets()
|
||||
epoch_ph = tf.placeholder(tf.int64, name="epoch_ph")
|
||||
train_set, dev_sets, metrics_sets = create_training_datasets(epoch_ph)
|
||||
|
||||
iterator = tfv1.data.Iterator.from_structure(
|
||||
tfv1.data.get_output_types(train_set),
|
||||
|
@ -488,7 +488,7 @@ def train():
|
|||
).start()
|
||||
|
||||
# Initialize iterator to the appropriate dataset
|
||||
session.run(init_op)
|
||||
session.run(init_op, {epoch_ph: epoch})
|
||||
|
||||
# Batch loop
|
||||
while True:
|
||||
|
@ -507,7 +507,7 @@ def train():
|
|||
non_finite_files,
|
||||
step_summaries_op,
|
||||
],
|
||||
feed_dict=feed_dict,
|
||||
feed_dict={**feed_dict, **{epoch_ph: epoch}},
|
||||
)
|
||||
except tf.errors.OutOfRangeError:
|
||||
break
|
||||
|
|
|
@ -340,6 +340,22 @@ class _SttConfig(Coqpit):
|
|||
help='after how many epochs the feature cache is invalidated again - 0 for "never"'
|
||||
),
|
||||
)
|
||||
shuffle_batches: bool = field(
|
||||
default=False,
|
||||
metadata=dict(
|
||||
help="reshuffle batches every epoch, starting after N epochs, where N is set by the shuffle_start flag."
|
||||
),
|
||||
)
|
||||
shuffle_start: int = field(
|
||||
default=1,
|
||||
metadata=dict(help="epoch to start shuffling batches from (zero-based)."),
|
||||
)
|
||||
shuffle_buffer: int = field(
|
||||
default=1000,
|
||||
metadata=dict(
|
||||
help="how many batches to keep in shuffle buffer when shuffling batches."
|
||||
),
|
||||
)
|
||||
|
||||
feature_win_len: int = field(
|
||||
default=32,
|
||||
|
|
|
@ -140,6 +140,7 @@ def create_dataset(
|
|||
limit=0,
|
||||
process_ahead=None,
|
||||
buffering=1 * MEGABYTE,
|
||||
epoch_ph=None,
|
||||
):
|
||||
epoch_counter = Counter() # survives restarts of the dataset and its generator
|
||||
|
||||
|
@ -207,11 +208,18 @@ def create_dataset(
|
|||
).map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
||||
if cache_path:
|
||||
dataset = dataset.cache(cache_path)
|
||||
dataset = (
|
||||
dataset.window(batch_size, drop_remainder=train_phase)
|
||||
.flat_map(batch_fn)
|
||||
.prefetch(len(Config.available_devices))
|
||||
)
|
||||
dataset = dataset.window(batch_size, drop_remainder=train_phase).flat_map(batch_fn)
|
||||
|
||||
if Config.shuffle_batches and epoch_ph is not None:
|
||||
with tf.control_dependencies([tf.print("epoch:", epoch_ph)]):
|
||||
epoch_buffer_size = tf.cond(
|
||||
tf.less(epoch_ph, Config.shuffle_start),
|
||||
lambda: tf.constant(1, tf.int64),
|
||||
lambda: tf.constant(Config.shuffle_buffer, tf.int64),
|
||||
)
|
||||
dataset = dataset.shuffle(epoch_buffer_size, seed=epoch_ph)
|
||||
|
||||
dataset = dataset.prefetch(len(Config.available_devices))
|
||||
return dataset
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue