diff --git a/util/feeding.py b/util/feeding.py index ba11ebb0..fd8c400d 100644 --- a/util/feeding.py +++ b/util/feeding.py @@ -15,7 +15,7 @@ from tensorflow.python.ops import gen_audio_ops as contrib_audio from util.config import Config from util.logging import log_error from util.text import text_to_char_array - +from util.flags import FLAGS def read_csvs(csv_files): source_data = None @@ -47,6 +47,14 @@ def audiofile_to_features(wav_filename): decoded = contrib_audio.decode_wav(samples, desired_channels=1) features, features_len = samples_to_mfccs(decoded.audio, decoded.sample_rate) + + if FLAGS.data_aug_features_multiplicative > 0: + features = features*tf.random.normal(mean=1, stddev=FLAGS.data_aug_features_multiplicative, shape=tf.shape(features)) + + if FLAGS.data_aug_features_additive > 0: + features = features+tf.random.normal(mean=0.0, stddev=FLAGS.data_aug_features_additive, shape=tf.shape(features)) + + return features, features_len diff --git a/util/flags.py b/util/flags.py index e1eb1788..3dc58fdd 100644 --- a/util/flags.py +++ b/util/flags.py @@ -21,6 +21,13 @@ def create_flags(): f.DEFINE_integer('feature_win_step', 20, 'feature extraction window step length in milliseconds') f.DEFINE_integer('audio_sample_rate', 16000, 'sample rate value expected by model') + # Data Augmentation + # ================ + + f.DEFINE_float('data_aug_features_additive', 0, 'std of the Gaussian additive noise') + f.DEFINE_float('data_aug_features_multiplicative', 0, 'std of normal distribution around 1 for multiplicative noise') + + # Global Constants # ================