Merge pull request #2406 from lissyx/disable-cache-dataaug
Disable cache when data augmentation is set
This commit is contained in:
commit
031479d88b
|
@ -413,10 +413,21 @@ def try_loading(session, saver, checkpoint_filename, caption):
|
||||||
|
|
||||||
|
|
||||||
def train():
|
def train():
|
||||||
|
do_cache_dataset = True
|
||||||
|
|
||||||
|
# pylint: disable=too-many-boolean-expressions
|
||||||
|
if (FLAGS.data_aug_features_multiplicative > 0 or
|
||||||
|
FLAGS.data_aug_features_additive > 0 or
|
||||||
|
FLAGS.augmentation_spec_dropout_keeprate < 1 or
|
||||||
|
FLAGS.augmentation_freq_and_time_masking or
|
||||||
|
FLAGS.augmentation_pitch_and_tempo_scaling or
|
||||||
|
FLAGS.augmentation_speed_up_std > 0):
|
||||||
|
do_cache_dataset = False
|
||||||
|
|
||||||
# Create training and validation datasets
|
# Create training and validation datasets
|
||||||
train_set = create_dataset(FLAGS.train_files.split(','),
|
train_set = create_dataset(FLAGS.train_files.split(','),
|
||||||
batch_size=FLAGS.train_batch_size,
|
batch_size=FLAGS.train_batch_size,
|
||||||
cache_path=FLAGS.feature_cache,
|
cache_path=FLAGS.feature_cache if do_cache_dataset else None,
|
||||||
train_phase=True)
|
train_phase=True)
|
||||||
|
|
||||||
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
|
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
|
||||||
|
|
|
@ -132,9 +132,12 @@ def create_dataset(csvs, batch_size, cache_path='', train_phase=False):
|
||||||
|
|
||||||
dataset = (tf.data.Dataset.from_generator(generate_values,
|
dataset = (tf.data.Dataset.from_generator(generate_values,
|
||||||
output_types=(tf.string, (tf.int64, tf.int32, tf.int64)))
|
output_types=(tf.string, (tf.int64, tf.int32, tf.int64)))
|
||||||
.map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
.map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE))
|
||||||
.cache(cache_path)
|
|
||||||
.window(batch_size, drop_remainder=True).flat_map(batch_fn)
|
if cache_path is not None:
|
||||||
|
dataset = dataset.cache(cache_path)
|
||||||
|
|
||||||
|
dataset = (dataset.window(batch_size, drop_remainder=True).flat_map(batch_fn)
|
||||||
.prefetch(num_gpus))
|
.prefetch(num_gpus))
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
Loading…
Reference in New Issue