Disable cache when data augmentation is set

Fixes #2396
This commit is contained in:
Alexandre Lissy 2019-10-08 06:19:10 +02:00
parent fb611efd00
commit c35068f880
2 changed files with 19 additions and 5 deletions

View File

@ -412,10 +412,21 @@ def try_loading(session, saver, checkpoint_filename, caption):
def train(): def train():
do_cache_dataset = True
# pylint: disable=too-many-boolean-expressions
if (FLAGS.data_aug_features_multiplicative > 0 or
FLAGS.data_aug_features_additive > 0 or
FLAGS.augmentation_spec_dropout_keeprate < 1 or
FLAGS.augmentation_freq_and_time_masking or
FLAGS.augmentation_pitch_and_tempo_scaling or
FLAGS.augmentation_speed_up_std > 0):
do_cache_dataset = False
# Create training and validation datasets # Create training and validation datasets
train_set = create_dataset(FLAGS.train_files.split(','), train_set = create_dataset(FLAGS.train_files.split(','),
batch_size=FLAGS.train_batch_size, batch_size=FLAGS.train_batch_size,
cache_path=FLAGS.feature_cache, cache_path=FLAGS.feature_cache if do_cache_dataset else None,
train_phase=True) train_phase=True)
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set), iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),

View File

@ -132,9 +132,12 @@ def create_dataset(csvs, batch_size, cache_path='', train_phase=False):
dataset = (tf.data.Dataset.from_generator(generate_values, dataset = (tf.data.Dataset.from_generator(generate_values,
output_types=(tf.string, (tf.int64, tf.int32, tf.int64))) output_types=(tf.string, (tf.int64, tf.int32, tf.int64)))
.map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) .map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE))
.cache(cache_path)
.window(batch_size, drop_remainder=True).flat_map(batch_fn) if cache_path is not None:
dataset = dataset.cache(cache_path)
dataset = (dataset.window(batch_size, drop_remainder=True).flat_map(batch_fn)
.prefetch(num_gpus)) .prefetch(num_gpus))
return dataset return dataset