adding 'train_phase' to create_dataset. Now we can augment only the training-set.

This commit is contained in:
Bernardo Henz 2019-08-01 22:09:06 -03:00 committed by Reuben Morais
parent 0cc5ff230f
commit 49c6a9c973
5 changed files with 44 additions and 38 deletions

View File

@ -415,7 +415,8 @@ def train():
# Create training and validation datasets
train_set = create_dataset(FLAGS.train_files.split(','),
batch_size=FLAGS.train_batch_size,
cache_path=FLAGS.feature_cache)
cache_path=FLAGS.feature_cache,
train_phase=True)
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
tfv1.data.get_output_shapes(train_set),
@ -426,7 +427,7 @@ def train():
if FLAGS.dev_files:
dev_csvs = FLAGS.dev_files.split(',')
dev_sets = [create_dataset([csv], batch_size=FLAGS.dev_batch_size) for csv in dev_csvs]
dev_sets = [create_dataset([csv], batch_size=FLAGS.dev_batch_size, train_phase=False) for csv in dev_csvs]
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
# Dropout

View File

@ -47,7 +47,7 @@ def evaluate(test_csvs, create_model, try_loading):
Config.alphabet)
test_csvs = FLAGS.test_files.split(',')
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size) for csv in test_csvs]
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs]
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]),
tfv1.data.get_output_shapes(test_sets[0]),
output_classes=tfv1.data.get_output_classes(test_sets[0]))

View File

@ -32,36 +32,38 @@ def read_csvs(csv_files):
return source_data
def samples_to_mfccs(samples, sample_rate):
def samples_to_mfccs(samples, sample_rate, train_phase=False):
spectrogram = contrib_audio.audio_spectrogram(samples,
window_size=Config.audio_window_samples,
stride=Config.audio_step_samples,
magnitude_squared=True)
if FLAGS.augmention_sparse_deform:
spectrogram = augment_sparse_deform(spectrogram,
time_warping_para=FLAGS.augmentation_time_warp_max_warping,
normal_around_warping_std=FLAGS.augmentation_sparse_deform_std_warp)
# Data Augmentations
if train_phase:
if FLAGS.augmention_sparse_deform:
spectrogram = augment_sparse_deform(spectrogram,
time_warping_para=FLAGS.augmentation_time_warp_max_warping,
normal_around_warping_std=FLAGS.augmentation_sparse_deform_std_warp)
if FLAGS.augmentation_spec_dropout_keeprate < 1:
spectrogram = augment_dropout(spectrogram,
keep_prob=FLAGS.augmentation_spec_dropout_keeprate)
if FLAGS.augmentation_spec_dropout_keeprate < 1:
spectrogram = augment_dropout(spectrogram,
keep_prob=FLAGS.augmentation_spec_dropout_keeprate)
if FLAGS.augmentation_freq_and_time_masking:
spectrogram = augment_freq_time_mask(spectrogram,
frequency_masking_para=FLAGS.augmentation_freq_and_time_masking_freq_mask_range,
time_masking_para=FLAGS.augmentation_freq_and_time_masking_time_mask_range,
frequency_mask_num=FLAGS.augmentation_freq_and_time_masking_number_freq_masks,
time_mask_num=FLAGS.augmentation_freq_and_time_masking_number_time_masks)
if FLAGS.augmentation_freq_and_time_masking:
spectrogram = augment_freq_time_mask(spectrogram,
frequency_masking_para=FLAGS.augmentation_freq_and_time_masking_freq_mask_range,
time_masking_para=FLAGS.augmentation_freq_and_time_masking_time_mask_range,
frequency_mask_num=FLAGS.augmentation_freq_and_time_masking_number_freq_masks,
time_mask_num=FLAGS.augmentation_freq_and_time_masking_number_time_masks)
if FLAGS.augmentation_pitch_and_tempo_scaling:
spectrogram = augment_pitch_and_tempo(spectrogram,
max_tempo=FLAGS.augmentation_pitch_and_tempo_scaling_max_tempo,
max_pitch=FLAGS.augmentation_pitch_and_tempo_scaling_max_pitch,
min_pitch=FLAGS.augmentation_pitch_and_tempo_scaling_min_pitch)
if FLAGS.augmentation_pitch_and_tempo_scaling:
spectrogram = augment_pitch_and_tempo(spectrogram,
max_tempo=FLAGS.augmentation_pitch_and_tempo_scaling_max_tempo,
max_pitch=FLAGS.augmentation_pitch_and_tempo_scaling_max_pitch,
min_pitch=FLAGS.augmentation_pitch_and_tempo_scaling_min_pitch)
if FLAGS.augmentation_speed_up_std > 0:
spectrogram = augment_speed_up(spectrogram, speed_std=FLAGS.augmentation_speed_up_std)
if FLAGS.augmentation_speed_up_std > 0:
spectrogram = augment_speed_up(spectrogram, speed_std=FLAGS.augmentation_speed_up_std)
mfccs = contrib_audio.mfcc(spectrogram, sample_rate, dct_coefficient_count=Config.n_input)
mfccs = tf.reshape(mfccs, [-1, Config.n_input])
@ -69,25 +71,29 @@ def samples_to_mfccs(samples, sample_rate):
return mfccs, tf.shape(input=mfccs)[0]
def audiofile_to_features(wav_filename):
def audiofile_to_features(wav_filename, train_phase=False):
samples = tf.io.read_file(wav_filename)
decoded = contrib_audio.decode_wav(samples, desired_channels=1)
features, features_len = samples_to_mfccs(decoded.audio, decoded.sample_rate)
features, features_len = samples_to_mfccs(decoded.audio, decoded.sample_rate, train_phase=train_phase)
if train_phase:
if FLAGS.data_aug_features_multiplicative > 0:
features = features*tf.random.normal(mean=1, stddev=FLAGS.data_aug_features_multiplicative, shape=tf.shape(features))
if FLAGS.data_aug_features_multiplicative > 0:
features = features*tf.random.normal(mean=1, stddev=FLAGS.data_aug_features_multiplicative, shape=tf.shape(features))
if FLAGS.data_aug_features_additive > 0:
features = features+tf.random.normal(mean=0.0, stddev=FLAGS.data_aug_features_additive, shape=tf.shape(features))
if FLAGS.data_aug_features_additive > 0:
features = features+tf.random.normal(mean=0.0, stddev=FLAGS.data_aug_features_additive, shape=tf.shape(features))
return features, features_len
def entry_to_features(wav_filename, transcript):
def entry_to_features_not_augmented(wav_filename, transcript):
# https://bugs.python.org/issue32117
features, features_len = audiofile_to_features(wav_filename)
features, features_len = audiofile_to_features(wav_filename, train_phase=False)
return wav_filename, features, features_len, tf.SparseTensor(*transcript)
def entry_to_features_augmented(wav_filename, transcript):
# https://bugs.python.org/issue32117
features, features_len = audiofile_to_features(wav_filename, train_phase=True)
return wav_filename, features, features_len, tf.SparseTensor(*transcript)
@ -100,7 +106,7 @@ def to_sparse_tuple(sequence):
return indices, sequence, shape
def create_dataset(csvs, batch_size, cache_path=''):
def create_dataset(csvs, batch_size, cache_path='', train_phase=True):
df = read_csvs(csvs)
df.sort_values(by='wav_filesize', inplace=True)
@ -112,6 +118,7 @@ def create_dataset(csvs, batch_size, cache_path=''):
log_error('While processing {}:\n {}'.format(series['wav_filename'], error_message))
exit(1)
entry_to_features = entry_to_features_augmented if train_phase else entry_to_features_not_augmented
def generate_values():
for _, row in df.iterrows():
yield row.wav_filename, to_sparse_tuple(row.transcript)

View File

@ -39,7 +39,7 @@ def create_flags():
f.DEFINE_integer('augmentation_freq_and_time_masking_time_mask_range', 2, 'max range of masks in the time domain when performing freqtime-mask augmentation')
f.DEFINE_integer('augmentation_freq_and_time_masking_number_time_masks', 3, 'number of masks in the time domain when performing freqtime-mask augmentation')
f.DEFINE_float('augmentation_speed_up_std', 0.5, 'std for speeding-up tempo. If std is 0, this augmentation is not performed')
f.DEFINE_float('augmentation_speed_up_std', 0, 'std for speeding-up tempo. If std is 0, this augmentation is not performed')
f.DEFINE_integer('augmentation_pitch_and_tempo_scaling', 0, 'whether to use spectrogram speed and tempo scaling')
f.DEFINE_float('augmentation_pitch_and_tempo_scaling_min_pitch', 0.95, 'min value of pitch scaling')

View File

@ -37,7 +37,6 @@ def augment_freq_time_mask(mel_spectrogram,
freq_max = tf.shape(mel_spectrogram)[1]
time_max = tf.shape(mel_spectrogram)[2]
# Frequency masking
# Testing without loop
for _ in range(frequency_mask_num):
f = tf.random.uniform(shape=(), minval=0, maxval=frequency_masking_para, dtype=tf.dtypes.int32)
f0 = tf.random.uniform(shape=(), minval=0, maxval=freq_max - f, dtype=tf.dtypes.int32)
@ -50,7 +49,6 @@ def augment_freq_time_mask(mel_spectrogram,
mel_spectrogram = mel_spectrogram*freq_mask
# Time masking
# Testing without loop
for _ in range(time_mask_num):
t = tf.random.uniform(shape=(), minval=0, maxval=time_masking_para, dtype=tf.dtypes.int32)
t0 = tf.random.uniform(shape=(), minval=0, maxval=time_max - t, dtype=tf.dtypes.int32)