Merge pull request #2352 from mozilla/data-augmentation-pr
Data augmentation PR
This commit is contained in:
commit
889f069b1c
|
@ -415,7 +415,8 @@ def train():
|
|||
# Create training and validation datasets
|
||||
train_set = create_dataset(FLAGS.train_files.split(','),
|
||||
batch_size=FLAGS.train_batch_size,
|
||||
cache_path=FLAGS.feature_cache)
|
||||
cache_path=FLAGS.feature_cache,
|
||||
train_phase=True)
|
||||
|
||||
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
|
||||
tfv1.data.get_output_shapes(train_set),
|
||||
|
@ -426,7 +427,7 @@ def train():
|
|||
|
||||
if FLAGS.dev_files:
|
||||
dev_csvs = FLAGS.dev_files.split(',')
|
||||
dev_sets = [create_dataset([csv], batch_size=FLAGS.dev_batch_size) for csv in dev_csvs]
|
||||
dev_sets = [create_dataset([csv], batch_size=FLAGS.dev_batch_size, train_phase=False) for csv in dev_csvs]
|
||||
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
|
||||
|
||||
# Dropout
|
||||
|
|
|
@ -47,7 +47,7 @@ def evaluate(test_csvs, create_model, try_loading):
|
|||
Config.alphabet)
|
||||
|
||||
test_csvs = FLAGS.test_files.split(',')
|
||||
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size) for csv in test_csvs]
|
||||
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs]
|
||||
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]),
|
||||
tfv1.data.get_output_shapes(test_sets[0]),
|
||||
output_classes=tfv1.data.get_output_classes(test_sets[0]))
|
||||
|
|
|
@ -15,7 +15,8 @@ from tensorflow.python.ops import gen_audio_ops as contrib_audio
|
|||
from util.config import Config
|
||||
from util.logging import log_error
|
||||
from util.text import text_to_char_array
|
||||
|
||||
from util.flags import FLAGS
|
||||
from util.spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up
|
||||
|
||||
def read_csvs(csv_files):
|
||||
source_data = None
|
||||
|
@ -31,28 +32,58 @@ def read_csvs(csv_files):
|
|||
return source_data
|
||||
|
||||
|
||||
def samples_to_mfccs(samples, sample_rate):
|
||||
def samples_to_mfccs(samples, sample_rate, train_phase=False):
|
||||
spectrogram = contrib_audio.audio_spectrogram(samples,
|
||||
window_size=Config.audio_window_samples,
|
||||
stride=Config.audio_step_samples,
|
||||
magnitude_squared=True)
|
||||
|
||||
# Data Augmentations
|
||||
if train_phase:
|
||||
if FLAGS.augmentation_spec_dropout_keeprate < 1:
|
||||
spectrogram = augment_dropout(spectrogram,
|
||||
keep_prob=FLAGS.augmentation_spec_dropout_keeprate)
|
||||
|
||||
if FLAGS.augmentation_freq_and_time_masking:
|
||||
spectrogram = augment_freq_time_mask(spectrogram,
|
||||
frequency_masking_para=FLAGS.augmentation_freq_and_time_masking_freq_mask_range,
|
||||
time_masking_para=FLAGS.augmentation_freq_and_time_masking_time_mask_range,
|
||||
frequency_mask_num=FLAGS.augmentation_freq_and_time_masking_number_freq_masks,
|
||||
time_mask_num=FLAGS.augmentation_freq_and_time_masking_number_time_masks)
|
||||
|
||||
if FLAGS.augmentation_pitch_and_tempo_scaling:
|
||||
spectrogram = augment_pitch_and_tempo(spectrogram,
|
||||
max_tempo=FLAGS.augmentation_pitch_and_tempo_scaling_max_tempo,
|
||||
max_pitch=FLAGS.augmentation_pitch_and_tempo_scaling_max_pitch,
|
||||
min_pitch=FLAGS.augmentation_pitch_and_tempo_scaling_min_pitch)
|
||||
|
||||
if FLAGS.augmentation_speed_up_std > 0:
|
||||
spectrogram = augment_speed_up(spectrogram, speed_std=FLAGS.augmentation_speed_up_std)
|
||||
|
||||
mfccs = contrib_audio.mfcc(spectrogram, sample_rate, dct_coefficient_count=Config.n_input)
|
||||
mfccs = tf.reshape(mfccs, [-1, Config.n_input])
|
||||
|
||||
return mfccs, tf.shape(input=mfccs)[0]
|
||||
|
||||
|
||||
def audiofile_to_features(wav_filename):
|
||||
def audiofile_to_features(wav_filename, train_phase=False):
|
||||
samples = tf.io.read_file(wav_filename)
|
||||
decoded = contrib_audio.decode_wav(samples, desired_channels=1)
|
||||
features, features_len = samples_to_mfccs(decoded.audio, decoded.sample_rate)
|
||||
features, features_len = samples_to_mfccs(decoded.audio, decoded.sample_rate, train_phase=train_phase)
|
||||
|
||||
if train_phase:
|
||||
if FLAGS.data_aug_features_multiplicative > 0:
|
||||
features = features*tf.random.normal(mean=1, stddev=FLAGS.data_aug_features_multiplicative, shape=tf.shape(features))
|
||||
|
||||
if FLAGS.data_aug_features_additive > 0:
|
||||
features = features+tf.random.normal(mean=0.0, stddev=FLAGS.data_aug_features_additive, shape=tf.shape(features))
|
||||
|
||||
return features, features_len
|
||||
|
||||
|
||||
def entry_to_features(wav_filename, transcript):
|
||||
def entry_to_features(wav_filename, transcript, train_phase):
|
||||
# https://bugs.python.org/issue32117
|
||||
features, features_len = audiofile_to_features(wav_filename)
|
||||
features, features_len = audiofile_to_features(wav_filename, train_phase=train_phase)
|
||||
return wav_filename, features, features_len, tf.SparseTensor(*transcript)
|
||||
|
||||
|
||||
|
@ -65,7 +96,7 @@ def to_sparse_tuple(sequence):
|
|||
return indices, sequence, shape
|
||||
|
||||
|
||||
def create_dataset(csvs, batch_size, cache_path=''):
|
||||
def create_dataset(csvs, batch_size, cache_path='', train_phase=False):
|
||||
df = read_csvs(csvs)
|
||||
df.sort_values(by='wav_filesize', inplace=True)
|
||||
|
||||
|
@ -97,10 +128,11 @@ def create_dataset(csvs, batch_size, cache_path=''):
|
|||
return tf.data.Dataset.zip((wav_filenames, features, transcripts))
|
||||
|
||||
num_gpus = len(Config.available_devices)
|
||||
process_fn = partial(entry_to_features, train_phase=train_phase)
|
||||
|
||||
dataset = (tf.data.Dataset.from_generator(generate_values,
|
||||
output_types=(tf.string, (tf.int64, tf.int32, tf.int64)))
|
||||
.map(entry_to_features, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
||||
.map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
||||
.cache(cache_path)
|
||||
.window(batch_size, drop_remainder=True).flat_map(batch_fn)
|
||||
.prefetch(num_gpus))
|
||||
|
|
|
@ -21,6 +21,28 @@ def create_flags():
|
|||
f.DEFINE_integer('feature_win_step', 20, 'feature extraction window step length in milliseconds')
|
||||
f.DEFINE_integer('audio_sample_rate', 16000, 'sample rate value expected by model')
|
||||
|
||||
# Data Augmentation
|
||||
# ================
|
||||
|
||||
f.DEFINE_float('data_aug_features_additive', 0, 'std of the Gaussian additive noise')
|
||||
f.DEFINE_float('data_aug_features_multiplicative', 0, 'std of normal distribution around 1 for multiplicative noise')
|
||||
|
||||
f.DEFINE_float('augmentation_spec_dropout_keeprate', 1, 'keep rate of dropout augmentation on spectrogram (if 1, no dropout will be performed on spectrogram)')
|
||||
|
||||
f.DEFINE_boolean('augmentation_freq_and_time_masking', False, 'whether to use frequency and time masking augmentation')
|
||||
f.DEFINE_integer('augmentation_freq_and_time_masking_freq_mask_range', 5, 'max range of masks in the frequency domain when performing freqtime-mask augmentation')
|
||||
f.DEFINE_integer('augmentation_freq_and_time_masking_number_freq_masks', 3, 'number of masks in the frequency domain when performing freqtime-mask augmentation')
|
||||
f.DEFINE_integer('augmentation_freq_and_time_masking_time_mask_range', 2, 'max range of masks in the time domain when performing freqtime-mask augmentation')
|
||||
f.DEFINE_integer('augmentation_freq_and_time_masking_number_time_masks', 3, 'number of masks in the time domain when performing freqtime-mask augmentation')
|
||||
|
||||
f.DEFINE_float('augmentation_speed_up_std', 0, 'std for speeding-up tempo. If std is 0, this augmentation is not performed')
|
||||
|
||||
f.DEFINE_boolean('augmentation_pitch_and_tempo_scaling', False, 'whether to use spectrogram speed and tempo scaling')
|
||||
f.DEFINE_float('augmentation_pitch_and_tempo_scaling_min_pitch', 0.95, 'min value of pitch scaling')
|
||||
f.DEFINE_float('augmentation_pitch_and_tempo_scaling_max_pitch', 1.2, 'max value of pitch scaling')
|
||||
f.DEFINE_float('augmentation_pitch_and_tempo_scaling_max_tempo', 1.2, 'max vlaue of tempo scaling')
|
||||
|
||||
|
||||
# Global Constants
|
||||
# ================
|
||||
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
import tensorflow as tf
|
||||
|
||||
def augment_freq_time_mask(mel_spectrogram,
|
||||
frequency_masking_para=30,
|
||||
time_masking_para=10,
|
||||
frequency_mask_num=3,
|
||||
time_mask_num=3):
|
||||
freq_max = tf.shape(mel_spectrogram)[1]
|
||||
time_max = tf.shape(mel_spectrogram)[2]
|
||||
# Frequency masking
|
||||
for _ in range(frequency_mask_num):
|
||||
f = tf.random.uniform(shape=(), minval=0, maxval=frequency_masking_para, dtype=tf.dtypes.int32)
|
||||
f0 = tf.random.uniform(shape=(), minval=0, maxval=freq_max - f, dtype=tf.dtypes.int32)
|
||||
value_ones_freq_prev = tf.ones(shape=[1, f0, time_max])
|
||||
value_zeros_freq = tf.zeros(shape=[1, f, time_max])
|
||||
value_ones_freq_next = tf.ones(shape=[1, freq_max-(f0+f), time_max])
|
||||
freq_mask = tf.concat([value_ones_freq_prev, value_zeros_freq, value_ones_freq_next], axis=1)
|
||||
#mel_spectrogram[:, f0:f0 + f, :] = 0 #can't assign to tensor
|
||||
#mel_spectrogram[:, f0:f0 + f, :] = value_zeros_freq #can't assign to tensor
|
||||
mel_spectrogram = mel_spectrogram*freq_mask
|
||||
|
||||
# Time masking
|
||||
for _ in range(time_mask_num):
|
||||
t = tf.random.uniform(shape=(), minval=0, maxval=time_masking_para, dtype=tf.dtypes.int32)
|
||||
t0 = tf.random.uniform(shape=(), minval=0, maxval=time_max - t, dtype=tf.dtypes.int32)
|
||||
value_zeros_time_prev = tf.ones(shape=[1, freq_max, t0])
|
||||
value_zeros_time = tf.zeros(shape=[1, freq_max, t])
|
||||
value_zeros_time_next = tf.ones(shape=[1, freq_max, time_max-(t0+t)])
|
||||
time_mask = tf.concat([value_zeros_time_prev, value_zeros_time, value_zeros_time_next], axis=2)
|
||||
#mel_spectrogram[:, :, t0:t0 + t] = 0 #can't assign to tensor
|
||||
#mel_spectrogram[:, :, t0:t0 + t] = value_zeros_time #can't assign to tensor
|
||||
mel_spectrogram = mel_spectrogram*time_mask
|
||||
|
||||
return mel_spectrogram
|
||||
|
||||
def augment_pitch_and_tempo(spectrogram,
|
||||
max_tempo=1.2,
|
||||
max_pitch=1.1,
|
||||
min_pitch=0.95):
|
||||
original_shape = tf.shape(spectrogram)
|
||||
choosen_pitch = tf.random.uniform(shape=(), minval=min_pitch, maxval=max_pitch)
|
||||
choosen_tempo = tf.random.uniform(shape=(), minval=1, maxval=max_tempo)
|
||||
new_height = tf.cast(tf.cast(original_shape[1], tf.float32)*choosen_pitch, tf.int32)
|
||||
new_width = tf.cast(tf.cast(original_shape[2], tf.float32)/(choosen_tempo), tf.int32)
|
||||
spectrogram_aug = tf.image.resize_bilinear(tf.expand_dims(spectrogram, -1), [new_height, new_width])
|
||||
spectrogram_aug = tf.image.crop_to_bounding_box(spectrogram_aug, offset_height=0, offset_width=0, target_height=tf.minimum(original_shape[1], new_height), target_width=tf.shape(spectrogram_aug)[2])
|
||||
spectrogram_aug = tf.cond(choosen_pitch < 1,
|
||||
lambda: tf.image.pad_to_bounding_box(spectrogram_aug, offset_height=0, offset_width=0,
|
||||
target_height=original_shape[1], target_width=tf.shape(spectrogram_aug)[2]),
|
||||
lambda: spectrogram_aug)
|
||||
return spectrogram_aug[:, :, :, 0]
|
||||
|
||||
|
||||
def augment_speed_up(spectrogram,
|
||||
speed_std=0.1):
|
||||
original_shape = tf.shape(spectrogram)
|
||||
choosen_speed = tf.math.abs(tf.random.normal(shape=(), stddev=speed_std)) # abs makes sure the augmention will only speed up
|
||||
choosen_speed = 1 + choosen_speed
|
||||
new_height = tf.cast(tf.cast(original_shape[1], tf.float32), tf.int32)
|
||||
new_width = tf.cast(tf.cast(original_shape[2], tf.float32)/(choosen_speed), tf.int32)
|
||||
spectrogram_aug = tf.image.resize_bilinear(tf.expand_dims(spectrogram, -1), [new_height, new_width])
|
||||
return spectrogram_aug[:, :, :, 0]
|
||||
|
||||
def augment_dropout(spectrogram,
|
||||
keep_prob=0.95):
|
||||
return tf.nn.dropout(spectrogram, rate=1-keep_prob)
|
Loading…
Reference in New Issue