Remove some duplicated code

This commit is contained in:
Reuben Morais 2019-09-09 12:20:16 +02:00
parent d051d4fd0e
commit b6af8c5dc7
1 changed files with 5 additions and 10 deletions

View File

@ -81,14 +81,9 @@ def audiofile_to_features(wav_filename, train_phase=False):
return features, features_len
def entry_to_features_not_augmented(wav_filename, transcript):
def entry_to_features(wav_filename, transcript, train_phase):
# https://bugs.python.org/issue32117
features, features_len = audiofile_to_features(wav_filename, train_phase=False)
return wav_filename, features, features_len, tf.SparseTensor(*transcript)
def entry_to_features_augmented(wav_filename, transcript):
# https://bugs.python.org/issue32117
features, features_len = audiofile_to_features(wav_filename, train_phase=True)
features, features_len = audiofile_to_features(wav_filename, train_phase=train_phase)
return wav_filename, features, features_len, tf.SparseTensor(*transcript)
@ -101,7 +96,7 @@ def to_sparse_tuple(sequence):
return indices, sequence, shape
def create_dataset(csvs, batch_size, cache_path='', train_phase=True):
def create_dataset(csvs, batch_size, cache_path='', train_phase=False):
df = read_csvs(csvs)
df.sort_values(by='wav_filesize', inplace=True)
@ -113,7 +108,6 @@ def create_dataset(csvs, batch_size, cache_path='', train_phase=True):
log_error('While processing {}:\n {}'.format(series['wav_filename'], error_message))
exit(1)
entry_to_features = entry_to_features_augmented if train_phase else entry_to_features_not_augmented
def generate_values():
for _, row in df.iterrows():
yield row.wav_filename, to_sparse_tuple(row.transcript)
@ -134,10 +128,11 @@ def create_dataset(csvs, batch_size, cache_path='', train_phase=True):
return tf.data.Dataset.zip((wav_filenames, features, transcripts))
num_gpus = len(Config.available_devices)
process_fn = partial(entry_to_features, train_phase=train_phase)
dataset = (tf.data.Dataset.from_generator(generate_values,
output_types=(tf.string, (tf.int64, tf.int32, tf.int64)))
.map(entry_to_features, num_parallel_calls=tf.data.experimental.AUTOTUNE)
.map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
.cache(cache_path)
.window(batch_size, drop_remainder=True).flat_map(batch_fn)
.prefetch(num_gpus))