diff --git a/DeepSpeech.py b/DeepSpeech.py index cea43e2a..871e395c 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -413,10 +413,21 @@ def try_loading(session, saver, checkpoint_filename, caption): def train(): + do_cache_dataset = True + + # pylint: disable=too-many-boolean-expressions + if (FLAGS.data_aug_features_multiplicative > 0 or + FLAGS.data_aug_features_additive > 0 or + FLAGS.augmentation_spec_dropout_keeprate < 1 or + FLAGS.augmentation_freq_and_time_masking or + FLAGS.augmentation_pitch_and_tempo_scaling or + FLAGS.augmentation_speed_up_std > 0): + do_cache_dataset = False + # Create training and validation datasets train_set = create_dataset(FLAGS.train_files.split(','), batch_size=FLAGS.train_batch_size, - cache_path=FLAGS.feature_cache, + cache_path=FLAGS.feature_cache if do_cache_dataset else None, train_phase=True) iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set), diff --git a/util/feeding.py b/util/feeding.py index a041503a..c5c50a1b 100644 --- a/util/feeding.py +++ b/util/feeding.py @@ -132,10 +132,13 @@ def create_dataset(csvs, batch_size, cache_path='', train_phase=False): dataset = (tf.data.Dataset.from_generator(generate_values, output_types=(tf.string, (tf.int64, tf.int32, tf.int64))) - .map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) - .cache(cache_path) - .window(batch_size, drop_remainder=True).flat_map(batch_fn) - .prefetch(num_gpus)) + .map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)) + + if cache_path is not None: + dataset = dataset.cache(cache_path) + + dataset = (dataset.window(batch_size, drop_remainder=True).flat_map(batch_fn) + .prefetch(num_gpus)) return dataset