adding 'train_phase' to create_dataset. Now we can augment only the training-set.
This commit is contained in:
parent
0cc5ff230f
commit
49c6a9c973
|
@ -415,7 +415,8 @@ def train():
|
|||
# Create training and validation datasets
|
||||
train_set = create_dataset(FLAGS.train_files.split(','),
|
||||
batch_size=FLAGS.train_batch_size,
|
||||
cache_path=FLAGS.feature_cache)
|
||||
cache_path=FLAGS.feature_cache,
|
||||
train_phase=True)
|
||||
|
||||
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
|
||||
tfv1.data.get_output_shapes(train_set),
|
||||
|
@ -426,7 +427,7 @@ def train():
|
|||
|
||||
if FLAGS.dev_files:
|
||||
dev_csvs = FLAGS.dev_files.split(',')
|
||||
dev_sets = [create_dataset([csv], batch_size=FLAGS.dev_batch_size) for csv in dev_csvs]
|
||||
dev_sets = [create_dataset([csv], batch_size=FLAGS.dev_batch_size, train_phase=False) for csv in dev_csvs]
|
||||
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
|
||||
|
||||
# Dropout
|
||||
|
|
|
@ -47,7 +47,7 @@ def evaluate(test_csvs, create_model, try_loading):
|
|||
Config.alphabet)
|
||||
|
||||
test_csvs = FLAGS.test_files.split(',')
|
||||
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size) for csv in test_csvs]
|
||||
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs]
|
||||
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]),
|
||||
tfv1.data.get_output_shapes(test_sets[0]),
|
||||
output_classes=tfv1.data.get_output_classes(test_sets[0]))
|
||||
|
|
|
@ -32,36 +32,38 @@ def read_csvs(csv_files):
|
|||
return source_data
|
||||
|
||||
|
||||
def samples_to_mfccs(samples, sample_rate):
|
||||
def samples_to_mfccs(samples, sample_rate, train_phase=False):
|
||||
spectrogram = contrib_audio.audio_spectrogram(samples,
|
||||
window_size=Config.audio_window_samples,
|
||||
stride=Config.audio_step_samples,
|
||||
magnitude_squared=True)
|
||||
|
||||
if FLAGS.augmention_sparse_deform:
|
||||
spectrogram = augment_sparse_deform(spectrogram,
|
||||
time_warping_para=FLAGS.augmentation_time_warp_max_warping,
|
||||
normal_around_warping_std=FLAGS.augmentation_sparse_deform_std_warp)
|
||||
# Data Augmentations
|
||||
if train_phase:
|
||||
if FLAGS.augmention_sparse_deform:
|
||||
spectrogram = augment_sparse_deform(spectrogram,
|
||||
time_warping_para=FLAGS.augmentation_time_warp_max_warping,
|
||||
normal_around_warping_std=FLAGS.augmentation_sparse_deform_std_warp)
|
||||
|
||||
if FLAGS.augmentation_spec_dropout_keeprate < 1:
|
||||
spectrogram = augment_dropout(spectrogram,
|
||||
keep_prob=FLAGS.augmentation_spec_dropout_keeprate)
|
||||
if FLAGS.augmentation_spec_dropout_keeprate < 1:
|
||||
spectrogram = augment_dropout(spectrogram,
|
||||
keep_prob=FLAGS.augmentation_spec_dropout_keeprate)
|
||||
|
||||
if FLAGS.augmentation_freq_and_time_masking:
|
||||
spectrogram = augment_freq_time_mask(spectrogram,
|
||||
frequency_masking_para=FLAGS.augmentation_freq_and_time_masking_freq_mask_range,
|
||||
time_masking_para=FLAGS.augmentation_freq_and_time_masking_time_mask_range,
|
||||
frequency_mask_num=FLAGS.augmentation_freq_and_time_masking_number_freq_masks,
|
||||
time_mask_num=FLAGS.augmentation_freq_and_time_masking_number_time_masks)
|
||||
if FLAGS.augmentation_freq_and_time_masking:
|
||||
spectrogram = augment_freq_time_mask(spectrogram,
|
||||
frequency_masking_para=FLAGS.augmentation_freq_and_time_masking_freq_mask_range,
|
||||
time_masking_para=FLAGS.augmentation_freq_and_time_masking_time_mask_range,
|
||||
frequency_mask_num=FLAGS.augmentation_freq_and_time_masking_number_freq_masks,
|
||||
time_mask_num=FLAGS.augmentation_freq_and_time_masking_number_time_masks)
|
||||
|
||||
if FLAGS.augmentation_pitch_and_tempo_scaling:
|
||||
spectrogram = augment_pitch_and_tempo(spectrogram,
|
||||
max_tempo=FLAGS.augmentation_pitch_and_tempo_scaling_max_tempo,
|
||||
max_pitch=FLAGS.augmentation_pitch_and_tempo_scaling_max_pitch,
|
||||
min_pitch=FLAGS.augmentation_pitch_and_tempo_scaling_min_pitch)
|
||||
if FLAGS.augmentation_pitch_and_tempo_scaling:
|
||||
spectrogram = augment_pitch_and_tempo(spectrogram,
|
||||
max_tempo=FLAGS.augmentation_pitch_and_tempo_scaling_max_tempo,
|
||||
max_pitch=FLAGS.augmentation_pitch_and_tempo_scaling_max_pitch,
|
||||
min_pitch=FLAGS.augmentation_pitch_and_tempo_scaling_min_pitch)
|
||||
|
||||
if FLAGS.augmentation_speed_up_std > 0:
|
||||
spectrogram = augment_speed_up(spectrogram, speed_std=FLAGS.augmentation_speed_up_std)
|
||||
if FLAGS.augmentation_speed_up_std > 0:
|
||||
spectrogram = augment_speed_up(spectrogram, speed_std=FLAGS.augmentation_speed_up_std)
|
||||
|
||||
mfccs = contrib_audio.mfcc(spectrogram, sample_rate, dct_coefficient_count=Config.n_input)
|
||||
mfccs = tf.reshape(mfccs, [-1, Config.n_input])
|
||||
|
@ -69,25 +71,29 @@ def samples_to_mfccs(samples, sample_rate):
|
|||
return mfccs, tf.shape(input=mfccs)[0]
|
||||
|
||||
|
||||
def audiofile_to_features(wav_filename):
|
||||
def audiofile_to_features(wav_filename, train_phase=False):
|
||||
samples = tf.io.read_file(wav_filename)
|
||||
decoded = contrib_audio.decode_wav(samples, desired_channels=1)
|
||||
features, features_len = samples_to_mfccs(decoded.audio, decoded.sample_rate)
|
||||
features, features_len = samples_to_mfccs(decoded.audio, decoded.sample_rate, train_phase=train_phase)
|
||||
|
||||
if train_phase:
|
||||
if FLAGS.data_aug_features_multiplicative > 0:
|
||||
features = features*tf.random.normal(mean=1, stddev=FLAGS.data_aug_features_multiplicative, shape=tf.shape(features))
|
||||
|
||||
if FLAGS.data_aug_features_multiplicative > 0:
|
||||
features = features*tf.random.normal(mean=1, stddev=FLAGS.data_aug_features_multiplicative, shape=tf.shape(features))
|
||||
|
||||
if FLAGS.data_aug_features_additive > 0:
|
||||
features = features+tf.random.normal(mean=0.0, stddev=FLAGS.data_aug_features_additive, shape=tf.shape(features))
|
||||
|
||||
if FLAGS.data_aug_features_additive > 0:
|
||||
features = features+tf.random.normal(mean=0.0, stddev=FLAGS.data_aug_features_additive, shape=tf.shape(features))
|
||||
|
||||
return features, features_len
|
||||
|
||||
|
||||
def entry_to_features(wav_filename, transcript):
|
||||
def entry_to_features_not_augmented(wav_filename, transcript):
|
||||
# https://bugs.python.org/issue32117
|
||||
features, features_len = audiofile_to_features(wav_filename)
|
||||
features, features_len = audiofile_to_features(wav_filename, train_phase=False)
|
||||
return wav_filename, features, features_len, tf.SparseTensor(*transcript)
|
||||
|
||||
def entry_to_features_augmented(wav_filename, transcript):
|
||||
# https://bugs.python.org/issue32117
|
||||
features, features_len = audiofile_to_features(wav_filename, train_phase=True)
|
||||
return wav_filename, features, features_len, tf.SparseTensor(*transcript)
|
||||
|
||||
|
||||
|
@ -100,7 +106,7 @@ def to_sparse_tuple(sequence):
|
|||
return indices, sequence, shape
|
||||
|
||||
|
||||
def create_dataset(csvs, batch_size, cache_path=''):
|
||||
def create_dataset(csvs, batch_size, cache_path='', train_phase=True):
|
||||
df = read_csvs(csvs)
|
||||
df.sort_values(by='wav_filesize', inplace=True)
|
||||
|
||||
|
@ -112,6 +118,7 @@ def create_dataset(csvs, batch_size, cache_path=''):
|
|||
log_error('While processing {}:\n {}'.format(series['wav_filename'], error_message))
|
||||
exit(1)
|
||||
|
||||
entry_to_features = entry_to_features_augmented if train_phase else entry_to_features_not_augmented
|
||||
def generate_values():
|
||||
for _, row in df.iterrows():
|
||||
yield row.wav_filename, to_sparse_tuple(row.transcript)
|
||||
|
|
|
@ -39,7 +39,7 @@ def create_flags():
|
|||
f.DEFINE_integer('augmentation_freq_and_time_masking_time_mask_range', 2, 'max range of masks in the time domain when performing freqtime-mask augmentation')
|
||||
f.DEFINE_integer('augmentation_freq_and_time_masking_number_time_masks', 3, 'number of masks in the time domain when performing freqtime-mask augmentation')
|
||||
|
||||
f.DEFINE_float('augmentation_speed_up_std', 0.5, 'std for speeding-up tempo. If std is 0, this augmentation is not performed')
|
||||
f.DEFINE_float('augmentation_speed_up_std', 0, 'std for speeding-up tempo. If std is 0, this augmentation is not performed')
|
||||
|
||||
f.DEFINE_integer('augmentation_pitch_and_tempo_scaling', 0, 'whether to use spectrogram speed and tempo scaling')
|
||||
f.DEFINE_float('augmentation_pitch_and_tempo_scaling_min_pitch', 0.95, 'min value of pitch scaling')
|
||||
|
|
|
@ -37,7 +37,6 @@ def augment_freq_time_mask(mel_spectrogram,
|
|||
freq_max = tf.shape(mel_spectrogram)[1]
|
||||
time_max = tf.shape(mel_spectrogram)[2]
|
||||
# Frequency masking
|
||||
# Testing without loop
|
||||
for _ in range(frequency_mask_num):
|
||||
f = tf.random.uniform(shape=(), minval=0, maxval=frequency_masking_para, dtype=tf.dtypes.int32)
|
||||
f0 = tf.random.uniform(shape=(), minval=0, maxval=freq_max - f, dtype=tf.dtypes.int32)
|
||||
|
@ -50,7 +49,6 @@ def augment_freq_time_mask(mel_spectrogram,
|
|||
mel_spectrogram = mel_spectrogram*freq_mask
|
||||
|
||||
# Time masking
|
||||
# Testing without loop
|
||||
for _ in range(time_mask_num):
|
||||
t = tf.random.uniform(shape=(), minval=0, maxval=time_masking_para, dtype=tf.dtypes.int32)
|
||||
t0 = tf.random.uniform(shape=(), minval=0, maxval=time_max - t, dtype=tf.dtypes.int32)
|
||||
|
|
Loading…
Reference in New Issue