Remove some duplicated code
This commit is contained in:
parent
d051d4fd0e
commit
b6af8c5dc7
|
@ -81,14 +81,9 @@ def audiofile_to_features(wav_filename, train_phase=False):
|
|||
return features, features_len
|
||||
|
||||
|
||||
def entry_to_features_not_augmented(wav_filename, transcript):
|
||||
def entry_to_features(wav_filename, transcript, train_phase):
|
||||
# https://bugs.python.org/issue32117
|
||||
features, features_len = audiofile_to_features(wav_filename, train_phase=False)
|
||||
return wav_filename, features, features_len, tf.SparseTensor(*transcript)
|
||||
|
||||
def entry_to_features_augmented(wav_filename, transcript):
|
||||
# https://bugs.python.org/issue32117
|
||||
features, features_len = audiofile_to_features(wav_filename, train_phase=True)
|
||||
features, features_len = audiofile_to_features(wav_filename, train_phase=train_phase)
|
||||
return wav_filename, features, features_len, tf.SparseTensor(*transcript)
|
||||
|
||||
|
||||
|
@ -101,7 +96,7 @@ def to_sparse_tuple(sequence):
|
|||
return indices, sequence, shape
|
||||
|
||||
|
||||
def create_dataset(csvs, batch_size, cache_path='', train_phase=True):
|
||||
def create_dataset(csvs, batch_size, cache_path='', train_phase=False):
|
||||
df = read_csvs(csvs)
|
||||
df.sort_values(by='wav_filesize', inplace=True)
|
||||
|
||||
|
@ -113,7 +108,6 @@ def create_dataset(csvs, batch_size, cache_path='', train_phase=True):
|
|||
log_error('While processing {}:\n {}'.format(series['wav_filename'], error_message))
|
||||
exit(1)
|
||||
|
||||
entry_to_features = entry_to_features_augmented if train_phase else entry_to_features_not_augmented
|
||||
def generate_values():
|
||||
for _, row in df.iterrows():
|
||||
yield row.wav_filename, to_sparse_tuple(row.transcript)
|
||||
|
@ -134,10 +128,11 @@ def create_dataset(csvs, batch_size, cache_path='', train_phase=True):
|
|||
return tf.data.Dataset.zip((wav_filenames, features, transcripts))
|
||||
|
||||
num_gpus = len(Config.available_devices)
|
||||
process_fn = partial(entry_to_features, train_phase=train_phase)
|
||||
|
||||
dataset = (tf.data.Dataset.from_generator(generate_values,
|
||||
output_types=(tf.string, (tf.int64, tf.int32, tf.int64)))
|
||||
.map(entry_to_features, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
||||
.map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
||||
.cache(cache_path)
|
||||
.window(batch_size, drop_remainder=True).flat_map(batch_fn)
|
||||
.prefetch(num_gpus))
|
||||
|
|
Loading…
Reference in New Issue