From e3b1b5fd42a843bff7308c39b1c7122b915ec198 Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Thu, 28 Nov 2019 13:51:33 +0100 Subject: [PATCH] Disable caching features to memory --- DeepSpeech.py | 3 ++- util/feeding.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/DeepSpeech.py b/DeepSpeech.py index 42c1d782..85569ebb 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -433,7 +433,8 @@ def train(): # Create training and validation datasets train_set = create_dataset(FLAGS.train_files.split(','), batch_size=FLAGS.train_batch_size, - cache_path=FLAGS.feature_cache if do_cache_dataset else None, + enable_cache=FLAGS.feature_cache and do_cache_dataset, + cache_path=FLAGS.feature_cache, train_phase=True) iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set), diff --git a/util/feeding.py b/util/feeding.py index e7704b01..16d0e312 100644 --- a/util/feeding.py +++ b/util/feeding.py @@ -94,7 +94,7 @@ def to_sparse_tuple(sequence): return indices, sequence, shape -def create_dataset(csvs, batch_size, cache_path='', train_phase=False): +def create_dataset(csvs, batch_size, enable_cache=False, cache_path=None, train_phase=False): df = read_csvs(csvs) df.sort_values(by='wav_filesize', inplace=True) @@ -126,7 +126,7 @@ def create_dataset(csvs, batch_size, cache_path='', train_phase=False): output_types=(tf.string, (tf.int64, tf.int32, tf.int64))) .map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)) - if cache_path is not None: + if enable_cache: dataset = dataset.cache(cache_path) dataset = (dataset.window(batch_size, drop_remainder=True).flat_map(batch_fn)