diff --git a/util/feeding.py b/util/feeding.py index 829d2ffe..a041503a 100644 --- a/util/feeding.py +++ b/util/feeding.py @@ -81,14 +81,9 @@ def audiofile_to_features(wav_filename, train_phase=False): return features, features_len -def entry_to_features_not_augmented(wav_filename, transcript): +def entry_to_features(wav_filename, transcript, train_phase): # https://bugs.python.org/issue32117 - features, features_len = audiofile_to_features(wav_filename, train_phase=False) - return wav_filename, features, features_len, tf.SparseTensor(*transcript) - -def entry_to_features_augmented(wav_filename, transcript): - # https://bugs.python.org/issue32117 - features, features_len = audiofile_to_features(wav_filename, train_phase=True) + features, features_len = audiofile_to_features(wav_filename, train_phase=train_phase) return wav_filename, features, features_len, tf.SparseTensor(*transcript) @@ -101,7 +96,7 @@ def to_sparse_tuple(sequence): return indices, sequence, shape -def create_dataset(csvs, batch_size, cache_path='', train_phase=True): +def create_dataset(csvs, batch_size, cache_path='', train_phase=False): df = read_csvs(csvs) df.sort_values(by='wav_filesize', inplace=True) @@ -113,7 +108,6 @@ def create_dataset(csvs, batch_size, cache_path='', train_phase=True): log_error('While processing {}:\n {}'.format(series['wav_filename'], error_message)) exit(1) - entry_to_features = entry_to_features_augmented if train_phase else entry_to_features_not_augmented def generate_values(): for _, row in df.iterrows(): yield row.wav_filename, to_sparse_tuple(row.transcript) @@ -134,10 +128,11 @@ def create_dataset(csvs, batch_size, cache_path='', train_phase=True): return tf.data.Dataset.zip((wav_filenames, features, transcripts)) num_gpus = len(Config.available_devices) + process_fn = partial(entry_to_features, train_phase=train_phase) dataset = (tf.data.Dataset.from_generator(generate_values, output_types=(tf.string, (tf.int64, tf.int32, tf.int64))) - .map(entry_to_features, num_parallel_calls=tf.data.experimental.AUTOTUNE) + .map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) .cache(cache_path) .window(batch_size, drop_remainder=True).flat_map(batch_fn) .prefetch(num_gpus))