adding 'train_phase' to create_dataset. Now we can augment only the training-set.
This commit is contained in:
parent
0cc5ff230f
commit
49c6a9c973
|
@ -415,7 +415,8 @@ def train():
|
||||||
# Create training and validation datasets
|
# Create training and validation datasets
|
||||||
train_set = create_dataset(FLAGS.train_files.split(','),
|
train_set = create_dataset(FLAGS.train_files.split(','),
|
||||||
batch_size=FLAGS.train_batch_size,
|
batch_size=FLAGS.train_batch_size,
|
||||||
cache_path=FLAGS.feature_cache)
|
cache_path=FLAGS.feature_cache,
|
||||||
|
train_phase=True)
|
||||||
|
|
||||||
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
|
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
|
||||||
tfv1.data.get_output_shapes(train_set),
|
tfv1.data.get_output_shapes(train_set),
|
||||||
|
@ -426,7 +427,7 @@ def train():
|
||||||
|
|
||||||
if FLAGS.dev_files:
|
if FLAGS.dev_files:
|
||||||
dev_csvs = FLAGS.dev_files.split(',')
|
dev_csvs = FLAGS.dev_files.split(',')
|
||||||
dev_sets = [create_dataset([csv], batch_size=FLAGS.dev_batch_size) for csv in dev_csvs]
|
dev_sets = [create_dataset([csv], batch_size=FLAGS.dev_batch_size, train_phase=False) for csv in dev_csvs]
|
||||||
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
|
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
|
||||||
|
|
||||||
# Dropout
|
# Dropout
|
||||||
|
|
|
@ -47,7 +47,7 @@ def evaluate(test_csvs, create_model, try_loading):
|
||||||
Config.alphabet)
|
Config.alphabet)
|
||||||
|
|
||||||
test_csvs = FLAGS.test_files.split(',')
|
test_csvs = FLAGS.test_files.split(',')
|
||||||
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size) for csv in test_csvs]
|
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs]
|
||||||
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]),
|
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]),
|
||||||
tfv1.data.get_output_shapes(test_sets[0]),
|
tfv1.data.get_output_shapes(test_sets[0]),
|
||||||
output_classes=tfv1.data.get_output_classes(test_sets[0]))
|
output_classes=tfv1.data.get_output_classes(test_sets[0]))
|
||||||
|
|
|
@ -32,36 +32,38 @@ def read_csvs(csv_files):
|
||||||
return source_data
|
return source_data
|
||||||
|
|
||||||
|
|
||||||
def samples_to_mfccs(samples, sample_rate):
|
def samples_to_mfccs(samples, sample_rate, train_phase=False):
|
||||||
spectrogram = contrib_audio.audio_spectrogram(samples,
|
spectrogram = contrib_audio.audio_spectrogram(samples,
|
||||||
window_size=Config.audio_window_samples,
|
window_size=Config.audio_window_samples,
|
||||||
stride=Config.audio_step_samples,
|
stride=Config.audio_step_samples,
|
||||||
magnitude_squared=True)
|
magnitude_squared=True)
|
||||||
|
|
||||||
if FLAGS.augmention_sparse_deform:
|
# Data Augmentations
|
||||||
spectrogram = augment_sparse_deform(spectrogram,
|
if train_phase:
|
||||||
time_warping_para=FLAGS.augmentation_time_warp_max_warping,
|
if FLAGS.augmention_sparse_deform:
|
||||||
normal_around_warping_std=FLAGS.augmentation_sparse_deform_std_warp)
|
spectrogram = augment_sparse_deform(spectrogram,
|
||||||
|
time_warping_para=FLAGS.augmentation_time_warp_max_warping,
|
||||||
|
normal_around_warping_std=FLAGS.augmentation_sparse_deform_std_warp)
|
||||||
|
|
||||||
if FLAGS.augmentation_spec_dropout_keeprate < 1:
|
if FLAGS.augmentation_spec_dropout_keeprate < 1:
|
||||||
spectrogram = augment_dropout(spectrogram,
|
spectrogram = augment_dropout(spectrogram,
|
||||||
keep_prob=FLAGS.augmentation_spec_dropout_keeprate)
|
keep_prob=FLAGS.augmentation_spec_dropout_keeprate)
|
||||||
|
|
||||||
if FLAGS.augmentation_freq_and_time_masking:
|
if FLAGS.augmentation_freq_and_time_masking:
|
||||||
spectrogram = augment_freq_time_mask(spectrogram,
|
spectrogram = augment_freq_time_mask(spectrogram,
|
||||||
frequency_masking_para=FLAGS.augmentation_freq_and_time_masking_freq_mask_range,
|
frequency_masking_para=FLAGS.augmentation_freq_and_time_masking_freq_mask_range,
|
||||||
time_masking_para=FLAGS.augmentation_freq_and_time_masking_time_mask_range,
|
time_masking_para=FLAGS.augmentation_freq_and_time_masking_time_mask_range,
|
||||||
frequency_mask_num=FLAGS.augmentation_freq_and_time_masking_number_freq_masks,
|
frequency_mask_num=FLAGS.augmentation_freq_and_time_masking_number_freq_masks,
|
||||||
time_mask_num=FLAGS.augmentation_freq_and_time_masking_number_time_masks)
|
time_mask_num=FLAGS.augmentation_freq_and_time_masking_number_time_masks)
|
||||||
|
|
||||||
if FLAGS.augmentation_pitch_and_tempo_scaling:
|
if FLAGS.augmentation_pitch_and_tempo_scaling:
|
||||||
spectrogram = augment_pitch_and_tempo(spectrogram,
|
spectrogram = augment_pitch_and_tempo(spectrogram,
|
||||||
max_tempo=FLAGS.augmentation_pitch_and_tempo_scaling_max_tempo,
|
max_tempo=FLAGS.augmentation_pitch_and_tempo_scaling_max_tempo,
|
||||||
max_pitch=FLAGS.augmentation_pitch_and_tempo_scaling_max_pitch,
|
max_pitch=FLAGS.augmentation_pitch_and_tempo_scaling_max_pitch,
|
||||||
min_pitch=FLAGS.augmentation_pitch_and_tempo_scaling_min_pitch)
|
min_pitch=FLAGS.augmentation_pitch_and_tempo_scaling_min_pitch)
|
||||||
|
|
||||||
if FLAGS.augmentation_speed_up_std > 0:
|
if FLAGS.augmentation_speed_up_std > 0:
|
||||||
spectrogram = augment_speed_up(spectrogram, speed_std=FLAGS.augmentation_speed_up_std)
|
spectrogram = augment_speed_up(spectrogram, speed_std=FLAGS.augmentation_speed_up_std)
|
||||||
|
|
||||||
mfccs = contrib_audio.mfcc(spectrogram, sample_rate, dct_coefficient_count=Config.n_input)
|
mfccs = contrib_audio.mfcc(spectrogram, sample_rate, dct_coefficient_count=Config.n_input)
|
||||||
mfccs = tf.reshape(mfccs, [-1, Config.n_input])
|
mfccs = tf.reshape(mfccs, [-1, Config.n_input])
|
||||||
|
@ -69,25 +71,29 @@ def samples_to_mfccs(samples, sample_rate):
|
||||||
return mfccs, tf.shape(input=mfccs)[0]
|
return mfccs, tf.shape(input=mfccs)[0]
|
||||||
|
|
||||||
|
|
||||||
def audiofile_to_features(wav_filename):
|
def audiofile_to_features(wav_filename, train_phase=False):
|
||||||
samples = tf.io.read_file(wav_filename)
|
samples = tf.io.read_file(wav_filename)
|
||||||
decoded = contrib_audio.decode_wav(samples, desired_channels=1)
|
decoded = contrib_audio.decode_wav(samples, desired_channels=1)
|
||||||
features, features_len = samples_to_mfccs(decoded.audio, decoded.sample_rate)
|
features, features_len = samples_to_mfccs(decoded.audio, decoded.sample_rate, train_phase=train_phase)
|
||||||
|
|
||||||
|
if train_phase:
|
||||||
|
if FLAGS.data_aug_features_multiplicative > 0:
|
||||||
|
features = features*tf.random.normal(mean=1, stddev=FLAGS.data_aug_features_multiplicative, shape=tf.shape(features))
|
||||||
|
|
||||||
if FLAGS.data_aug_features_multiplicative > 0:
|
if FLAGS.data_aug_features_additive > 0:
|
||||||
features = features*tf.random.normal(mean=1, stddev=FLAGS.data_aug_features_multiplicative, shape=tf.shape(features))
|
features = features+tf.random.normal(mean=0.0, stddev=FLAGS.data_aug_features_additive, shape=tf.shape(features))
|
||||||
|
|
||||||
if FLAGS.data_aug_features_additive > 0:
|
|
||||||
features = features+tf.random.normal(mean=0.0, stddev=FLAGS.data_aug_features_additive, shape=tf.shape(features))
|
|
||||||
|
|
||||||
|
|
||||||
return features, features_len
|
return features, features_len
|
||||||
|
|
||||||
|
|
||||||
def entry_to_features(wav_filename, transcript):
|
def entry_to_features_not_augmented(wav_filename, transcript):
|
||||||
# https://bugs.python.org/issue32117
|
# https://bugs.python.org/issue32117
|
||||||
features, features_len = audiofile_to_features(wav_filename)
|
features, features_len = audiofile_to_features(wav_filename, train_phase=False)
|
||||||
|
return wav_filename, features, features_len, tf.SparseTensor(*transcript)
|
||||||
|
|
||||||
|
def entry_to_features_augmented(wav_filename, transcript):
|
||||||
|
# https://bugs.python.org/issue32117
|
||||||
|
features, features_len = audiofile_to_features(wav_filename, train_phase=True)
|
||||||
return wav_filename, features, features_len, tf.SparseTensor(*transcript)
|
return wav_filename, features, features_len, tf.SparseTensor(*transcript)
|
||||||
|
|
||||||
|
|
||||||
|
@ -100,7 +106,7 @@ def to_sparse_tuple(sequence):
|
||||||
return indices, sequence, shape
|
return indices, sequence, shape
|
||||||
|
|
||||||
|
|
||||||
def create_dataset(csvs, batch_size, cache_path=''):
|
def create_dataset(csvs, batch_size, cache_path='', train_phase=True):
|
||||||
df = read_csvs(csvs)
|
df = read_csvs(csvs)
|
||||||
df.sort_values(by='wav_filesize', inplace=True)
|
df.sort_values(by='wav_filesize', inplace=True)
|
||||||
|
|
||||||
|
@ -112,6 +118,7 @@ def create_dataset(csvs, batch_size, cache_path=''):
|
||||||
log_error('While processing {}:\n {}'.format(series['wav_filename'], error_message))
|
log_error('While processing {}:\n {}'.format(series['wav_filename'], error_message))
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
|
entry_to_features = entry_to_features_augmented if train_phase else entry_to_features_not_augmented
|
||||||
def generate_values():
|
def generate_values():
|
||||||
for _, row in df.iterrows():
|
for _, row in df.iterrows():
|
||||||
yield row.wav_filename, to_sparse_tuple(row.transcript)
|
yield row.wav_filename, to_sparse_tuple(row.transcript)
|
||||||
|
|
|
@ -39,7 +39,7 @@ def create_flags():
|
||||||
f.DEFINE_integer('augmentation_freq_and_time_masking_time_mask_range', 2, 'max range of masks in the time domain when performing freqtime-mask augmentation')
|
f.DEFINE_integer('augmentation_freq_and_time_masking_time_mask_range', 2, 'max range of masks in the time domain when performing freqtime-mask augmentation')
|
||||||
f.DEFINE_integer('augmentation_freq_and_time_masking_number_time_masks', 3, 'number of masks in the time domain when performing freqtime-mask augmentation')
|
f.DEFINE_integer('augmentation_freq_and_time_masking_number_time_masks', 3, 'number of masks in the time domain when performing freqtime-mask augmentation')
|
||||||
|
|
||||||
f.DEFINE_float('augmentation_speed_up_std', 0.5, 'std for speeding-up tempo. If std is 0, this augmentation is not performed')
|
f.DEFINE_float('augmentation_speed_up_std', 0, 'std for speeding-up tempo. If std is 0, this augmentation is not performed')
|
||||||
|
|
||||||
f.DEFINE_integer('augmentation_pitch_and_tempo_scaling', 0, 'whether to use spectrogram speed and tempo scaling')
|
f.DEFINE_integer('augmentation_pitch_and_tempo_scaling', 0, 'whether to use spectrogram speed and tempo scaling')
|
||||||
f.DEFINE_float('augmentation_pitch_and_tempo_scaling_min_pitch', 0.95, 'min value of pitch scaling')
|
f.DEFINE_float('augmentation_pitch_and_tempo_scaling_min_pitch', 0.95, 'min value of pitch scaling')
|
||||||
|
|
|
@ -37,7 +37,6 @@ def augment_freq_time_mask(mel_spectrogram,
|
||||||
freq_max = tf.shape(mel_spectrogram)[1]
|
freq_max = tf.shape(mel_spectrogram)[1]
|
||||||
time_max = tf.shape(mel_spectrogram)[2]
|
time_max = tf.shape(mel_spectrogram)[2]
|
||||||
# Frequency masking
|
# Frequency masking
|
||||||
# Testing without loop
|
|
||||||
for _ in range(frequency_mask_num):
|
for _ in range(frequency_mask_num):
|
||||||
f = tf.random.uniform(shape=(), minval=0, maxval=frequency_masking_para, dtype=tf.dtypes.int32)
|
f = tf.random.uniform(shape=(), minval=0, maxval=frequency_masking_para, dtype=tf.dtypes.int32)
|
||||||
f0 = tf.random.uniform(shape=(), minval=0, maxval=freq_max - f, dtype=tf.dtypes.int32)
|
f0 = tf.random.uniform(shape=(), minval=0, maxval=freq_max - f, dtype=tf.dtypes.int32)
|
||||||
|
@ -50,7 +49,6 @@ def augment_freq_time_mask(mel_spectrogram,
|
||||||
mel_spectrogram = mel_spectrogram*freq_mask
|
mel_spectrogram = mel_spectrogram*freq_mask
|
||||||
|
|
||||||
# Time masking
|
# Time masking
|
||||||
# Testing without loop
|
|
||||||
for _ in range(time_mask_num):
|
for _ in range(time_mask_num):
|
||||||
t = tf.random.uniform(shape=(), minval=0, maxval=time_masking_para, dtype=tf.dtypes.int32)
|
t = tf.random.uniform(shape=(), minval=0, maxval=time_masking_para, dtype=tf.dtypes.int32)
|
||||||
t0 = tf.random.uniform(shape=(), minval=0, maxval=time_max - t, dtype=tf.dtypes.int32)
|
t0 = tf.random.uniform(shape=(), minval=0, maxval=time_max - t, dtype=tf.dtypes.int32)
|
||||||
|
|
Loading…
Reference in New Issue