Merge pull request #2406 from lissyx/disable-cache-dataaug
Disable cache when data augmentation is set
This commit is contained in:
commit
031479d88b
|
@ -413,10 +413,21 @@ def try_loading(session, saver, checkpoint_filename, caption):
|
|||
|
||||
|
||||
def train():
|
||||
do_cache_dataset = True
|
||||
|
||||
# pylint: disable=too-many-boolean-expressions
|
||||
if (FLAGS.data_aug_features_multiplicative > 0 or
|
||||
FLAGS.data_aug_features_additive > 0 or
|
||||
FLAGS.augmentation_spec_dropout_keeprate < 1 or
|
||||
FLAGS.augmentation_freq_and_time_masking or
|
||||
FLAGS.augmentation_pitch_and_tempo_scaling or
|
||||
FLAGS.augmentation_speed_up_std > 0):
|
||||
do_cache_dataset = False
|
||||
|
||||
# Create training and validation datasets
|
||||
train_set = create_dataset(FLAGS.train_files.split(','),
|
||||
batch_size=FLAGS.train_batch_size,
|
||||
cache_path=FLAGS.feature_cache,
|
||||
cache_path=FLAGS.feature_cache if do_cache_dataset else None,
|
||||
train_phase=True)
|
||||
|
||||
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
|
||||
|
|
|
@ -132,10 +132,13 @@ def create_dataset(csvs, batch_size, cache_path='', train_phase=False):
|
|||
|
||||
dataset = (tf.data.Dataset.from_generator(generate_values,
|
||||
output_types=(tf.string, (tf.int64, tf.int32, tf.int64)))
|
||||
.map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
||||
.cache(cache_path)
|
||||
.window(batch_size, drop_remainder=True).flat_map(batch_fn)
|
||||
.prefetch(num_gpus))
|
||||
.map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE))
|
||||
|
||||
if cache_path is not None:
|
||||
dataset = dataset.cache(cache_path)
|
||||
|
||||
dataset = (dataset.window(batch_size, drop_remainder=True).flat_map(batch_fn)
|
||||
.prefetch(num_gpus))
|
||||
|
||||
return dataset
|
||||
|
||||
|
|
Loading…
Reference in New Issue