Merge pull request #2406 from lissyx/disable-cache-dataaug

Disable cache when data augmentation is set
This commit is contained in:
lissyx 2019-10-08 18:10:32 +02:00 committed by GitHub
commit 031479d88b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 5 deletions

View File

@ -413,10 +413,21 @@ def try_loading(session, saver, checkpoint_filename, caption):
def train():
do_cache_dataset = True
# pylint: disable=too-many-boolean-expressions
if (FLAGS.data_aug_features_multiplicative > 0 or
FLAGS.data_aug_features_additive > 0 or
FLAGS.augmentation_spec_dropout_keeprate < 1 or
FLAGS.augmentation_freq_and_time_masking or
FLAGS.augmentation_pitch_and_tempo_scaling or
FLAGS.augmentation_speed_up_std > 0):
do_cache_dataset = False
# Create training and validation datasets
train_set = create_dataset(FLAGS.train_files.split(','),
batch_size=FLAGS.train_batch_size,
cache_path=FLAGS.feature_cache,
cache_path=FLAGS.feature_cache if do_cache_dataset else None,
train_phase=True)
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),

View File

@ -132,10 +132,13 @@ def create_dataset(csvs, batch_size, cache_path='', train_phase=False):
dataset = (tf.data.Dataset.from_generator(generate_values,
output_types=(tf.string, (tf.int64, tf.int32, tf.int64)))
.map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
.cache(cache_path)
.window(batch_size, drop_remainder=True).flat_map(batch_fn)
.prefetch(num_gpus))
.map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE))
if cache_path is not None:
dataset = dataset.cache(cache_path)
dataset = (dataset.window(batch_size, drop_remainder=True).flat_map(batch_fn)
.prefetch(num_gpus))
return dataset