Merge pull request #2554 from mozilla/disable-cache-to-memory

Disable caching features to memory
This commit is contained in:
Reuben Morais 2019-11-29 12:17:21 +00:00 committed by GitHub
commit b6bc46f3fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 3 deletions

View File

@ -433,7 +433,8 @@ def train():
# Create training and validation datasets
train_set = create_dataset(FLAGS.train_files.split(','),
batch_size=FLAGS.train_batch_size,
cache_path=FLAGS.feature_cache if do_cache_dataset else None,
enable_cache=FLAGS.feature_cache and do_cache_dataset,
cache_path=FLAGS.feature_cache,
train_phase=True)
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),

View File

@ -94,7 +94,7 @@ def to_sparse_tuple(sequence):
return indices, sequence, shape
def create_dataset(csvs, batch_size, cache_path='', train_phase=False):
def create_dataset(csvs, batch_size, enable_cache=False, cache_path=None, train_phase=False):
df = read_csvs(csvs)
df.sort_values(by='wav_filesize', inplace=True)
@ -126,7 +126,7 @@ def create_dataset(csvs, batch_size, cache_path='', train_phase=False):
output_types=(tf.string, (tf.int64, tf.int32, tf.int64)))
.map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE))
if cache_path is not None:
if enable_cache:
dataset = dataset.cache(cache_path)
dataset = (dataset.window(batch_size, drop_remainder=True).flat_map(batch_fn)