Remove some duplicated code
This commit is contained in:
parent
d051d4fd0e
commit
b6af8c5dc7
|
@ -81,14 +81,9 @@ def audiofile_to_features(wav_filename, train_phase=False):
|
||||||
return features, features_len
|
return features, features_len
|
||||||
|
|
||||||
|
|
||||||
def entry_to_features_not_augmented(wav_filename, transcript):
|
def entry_to_features(wav_filename, transcript, train_phase):
|
||||||
# https://bugs.python.org/issue32117
|
# https://bugs.python.org/issue32117
|
||||||
features, features_len = audiofile_to_features(wav_filename, train_phase=False)
|
features, features_len = audiofile_to_features(wav_filename, train_phase=train_phase)
|
||||||
return wav_filename, features, features_len, tf.SparseTensor(*transcript)
|
|
||||||
|
|
||||||
def entry_to_features_augmented(wav_filename, transcript):
|
|
||||||
# https://bugs.python.org/issue32117
|
|
||||||
features, features_len = audiofile_to_features(wav_filename, train_phase=True)
|
|
||||||
return wav_filename, features, features_len, tf.SparseTensor(*transcript)
|
return wav_filename, features, features_len, tf.SparseTensor(*transcript)
|
||||||
|
|
||||||
|
|
||||||
|
@ -101,7 +96,7 @@ def to_sparse_tuple(sequence):
|
||||||
return indices, sequence, shape
|
return indices, sequence, shape
|
||||||
|
|
||||||
|
|
||||||
def create_dataset(csvs, batch_size, cache_path='', train_phase=True):
|
def create_dataset(csvs, batch_size, cache_path='', train_phase=False):
|
||||||
df = read_csvs(csvs)
|
df = read_csvs(csvs)
|
||||||
df.sort_values(by='wav_filesize', inplace=True)
|
df.sort_values(by='wav_filesize', inplace=True)
|
||||||
|
|
||||||
|
@ -113,7 +108,6 @@ def create_dataset(csvs, batch_size, cache_path='', train_phase=True):
|
||||||
log_error('While processing {}:\n {}'.format(series['wav_filename'], error_message))
|
log_error('While processing {}:\n {}'.format(series['wav_filename'], error_message))
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
entry_to_features = entry_to_features_augmented if train_phase else entry_to_features_not_augmented
|
|
||||||
def generate_values():
|
def generate_values():
|
||||||
for _, row in df.iterrows():
|
for _, row in df.iterrows():
|
||||||
yield row.wav_filename, to_sparse_tuple(row.transcript)
|
yield row.wav_filename, to_sparse_tuple(row.transcript)
|
||||||
|
@ -134,10 +128,11 @@ def create_dataset(csvs, batch_size, cache_path='', train_phase=True):
|
||||||
return tf.data.Dataset.zip((wav_filenames, features, transcripts))
|
return tf.data.Dataset.zip((wav_filenames, features, transcripts))
|
||||||
|
|
||||||
num_gpus = len(Config.available_devices)
|
num_gpus = len(Config.available_devices)
|
||||||
|
process_fn = partial(entry_to_features, train_phase=train_phase)
|
||||||
|
|
||||||
dataset = (tf.data.Dataset.from_generator(generate_values,
|
dataset = (tf.data.Dataset.from_generator(generate_values,
|
||||||
output_types=(tf.string, (tf.int64, tf.int32, tf.int64)))
|
output_types=(tf.string, (tf.int64, tf.int32, tf.int64)))
|
||||||
.map(entry_to_features, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
.map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
||||||
.cache(cache_path)
|
.cache(cache_path)
|
||||||
.window(batch_size, drop_remainder=True).flat_map(batch_fn)
|
.window(batch_size, drop_remainder=True).flat_map(batch_fn)
|
||||||
.prefetch(num_gpus))
|
.prefetch(num_gpus))
|
||||||
|
|
Loading…
Reference in New Issue