diff --git a/DeepSpeech.py b/DeepSpeech.py index 5183407e..c9344ea1 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -415,7 +415,8 @@ def train(): # Create training and validation datasets train_set = create_dataset(FLAGS.train_files.split(','), batch_size=FLAGS.train_batch_size, - cache_path=FLAGS.feature_cache) + cache_path=FLAGS.feature_cache, + train_phase=True) iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set), tfv1.data.get_output_shapes(train_set), @@ -426,7 +427,7 @@ def train(): if FLAGS.dev_files: dev_csvs = FLAGS.dev_files.split(',') - dev_sets = [create_dataset([csv], batch_size=FLAGS.dev_batch_size) for csv in dev_csvs] + dev_sets = [create_dataset([csv], batch_size=FLAGS.dev_batch_size, train_phase=False) for csv in dev_csvs] dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets] # Dropout diff --git a/evaluate.py b/evaluate.py index c86ebc1e..32c45367 100755 --- a/evaluate.py +++ b/evaluate.py @@ -47,7 +47,7 @@ def evaluate(test_csvs, create_model, try_loading): Config.alphabet) test_csvs = FLAGS.test_files.split(',') - test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size) for csv in test_csvs] + test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs] iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]), tfv1.data.get_output_shapes(test_sets[0]), output_classes=tfv1.data.get_output_classes(test_sets[0])) diff --git a/util/feeding.py b/util/feeding.py index c65590ee..817338cf 100644 --- a/util/feeding.py +++ b/util/feeding.py @@ -32,36 +32,38 @@ def read_csvs(csv_files): return source_data -def samples_to_mfccs(samples, sample_rate): +def samples_to_mfccs(samples, sample_rate, train_phase=False): spectrogram = contrib_audio.audio_spectrogram(samples, window_size=Config.audio_window_samples, stride=Config.audio_step_samples, magnitude_squared=True) - if FLAGS.augmention_sparse_deform: - spectrogram = augment_sparse_deform(spectrogram, - time_warping_para=FLAGS.augmentation_time_warp_max_warping, - normal_around_warping_std=FLAGS.augmentation_sparse_deform_std_warp) + # Data Augmentations + if train_phase: + if FLAGS.augmention_sparse_deform: + spectrogram = augment_sparse_deform(spectrogram, + time_warping_para=FLAGS.augmentation_time_warp_max_warping, + normal_around_warping_std=FLAGS.augmentation_sparse_deform_std_warp) - if FLAGS.augmentation_spec_dropout_keeprate < 1: - spectrogram = augment_dropout(spectrogram, - keep_prob=FLAGS.augmentation_spec_dropout_keeprate) + if FLAGS.augmentation_spec_dropout_keeprate < 1: + spectrogram = augment_dropout(spectrogram, + keep_prob=FLAGS.augmentation_spec_dropout_keeprate) - if FLAGS.augmentation_freq_and_time_masking: - spectrogram = augment_freq_time_mask(spectrogram, - frequency_masking_para=FLAGS.augmentation_freq_and_time_masking_freq_mask_range, - time_masking_para=FLAGS.augmentation_freq_and_time_masking_time_mask_range, - frequency_mask_num=FLAGS.augmentation_freq_and_time_masking_number_freq_masks, - time_mask_num=FLAGS.augmentation_freq_and_time_masking_number_time_masks) + if FLAGS.augmentation_freq_and_time_masking: + spectrogram = augment_freq_time_mask(spectrogram, + frequency_masking_para=FLAGS.augmentation_freq_and_time_masking_freq_mask_range, + time_masking_para=FLAGS.augmentation_freq_and_time_masking_time_mask_range, + frequency_mask_num=FLAGS.augmentation_freq_and_time_masking_number_freq_masks, + time_mask_num=FLAGS.augmentation_freq_and_time_masking_number_time_masks) - if FLAGS.augmentation_pitch_and_tempo_scaling: - spectrogram = augment_pitch_and_tempo(spectrogram, - max_tempo=FLAGS.augmentation_pitch_and_tempo_scaling_max_tempo, - max_pitch=FLAGS.augmentation_pitch_and_tempo_scaling_max_pitch, - min_pitch=FLAGS.augmentation_pitch_and_tempo_scaling_min_pitch) + if FLAGS.augmentation_pitch_and_tempo_scaling: + spectrogram = augment_pitch_and_tempo(spectrogram, + max_tempo=FLAGS.augmentation_pitch_and_tempo_scaling_max_tempo, + max_pitch=FLAGS.augmentation_pitch_and_tempo_scaling_max_pitch, + min_pitch=FLAGS.augmentation_pitch_and_tempo_scaling_min_pitch) - if FLAGS.augmentation_speed_up_std > 0: - spectrogram = augment_speed_up(spectrogram, speed_std=FLAGS.augmentation_speed_up_std) + if FLAGS.augmentation_speed_up_std > 0: + spectrogram = augment_speed_up(spectrogram, speed_std=FLAGS.augmentation_speed_up_std) mfccs = contrib_audio.mfcc(spectrogram, sample_rate, dct_coefficient_count=Config.n_input) mfccs = tf.reshape(mfccs, [-1, Config.n_input]) @@ -69,25 +71,29 @@ def samples_to_mfccs(samples, sample_rate): return mfccs, tf.shape(input=mfccs)[0] -def audiofile_to_features(wav_filename): +def audiofile_to_features(wav_filename, train_phase=False): samples = tf.io.read_file(wav_filename) decoded = contrib_audio.decode_wav(samples, desired_channels=1) - features, features_len = samples_to_mfccs(decoded.audio, decoded.sample_rate) + features, features_len = samples_to_mfccs(decoded.audio, decoded.sample_rate, train_phase=train_phase) + if train_phase: + if FLAGS.data_aug_features_multiplicative > 0: + features = features*tf.random.normal(mean=1, stddev=FLAGS.data_aug_features_multiplicative, shape=tf.shape(features)) - if FLAGS.data_aug_features_multiplicative > 0: - features = features*tf.random.normal(mean=1, stddev=FLAGS.data_aug_features_multiplicative, shape=tf.shape(features)) - - if FLAGS.data_aug_features_additive > 0: - features = features+tf.random.normal(mean=0.0, stddev=FLAGS.data_aug_features_additive, shape=tf.shape(features)) - + if FLAGS.data_aug_features_additive > 0: + features = features+tf.random.normal(mean=0.0, stddev=FLAGS.data_aug_features_additive, shape=tf.shape(features)) return features, features_len -def entry_to_features(wav_filename, transcript): +def entry_to_features_not_augmented(wav_filename, transcript): # https://bugs.python.org/issue32117 - features, features_len = audiofile_to_features(wav_filename) + features, features_len = audiofile_to_features(wav_filename, train_phase=False) + return wav_filename, features, features_len, tf.SparseTensor(*transcript) + +def entry_to_features_augmented(wav_filename, transcript): + # https://bugs.python.org/issue32117 + features, features_len = audiofile_to_features(wav_filename, train_phase=True) return wav_filename, features, features_len, tf.SparseTensor(*transcript) @@ -100,7 +106,7 @@ def to_sparse_tuple(sequence): return indices, sequence, shape -def create_dataset(csvs, batch_size, cache_path=''): +def create_dataset(csvs, batch_size, cache_path='', train_phase=True): df = read_csvs(csvs) df.sort_values(by='wav_filesize', inplace=True) @@ -112,6 +118,7 @@ def create_dataset(csvs, batch_size, cache_path=''): log_error('While processing {}:\n {}'.format(series['wav_filename'], error_message)) exit(1) + entry_to_features = entry_to_features_augmented if train_phase else entry_to_features_not_augmented def generate_values(): for _, row in df.iterrows(): yield row.wav_filename, to_sparse_tuple(row.transcript) diff --git a/util/flags.py b/util/flags.py index 7119e4a1..68c7ca7d 100644 --- a/util/flags.py +++ b/util/flags.py @@ -39,7 +39,7 @@ def create_flags(): f.DEFINE_integer('augmentation_freq_and_time_masking_time_mask_range', 2, 'max range of masks in the time domain when performing freqtime-mask augmentation') f.DEFINE_integer('augmentation_freq_and_time_masking_number_time_masks', 3, 'number of masks in the time domain when performing freqtime-mask augmentation') - f.DEFINE_float('augmentation_speed_up_std', 0.5, 'std for speeding-up tempo. If std is 0, this augmentation is not performed') + f.DEFINE_float('augmentation_speed_up_std', 0, 'std for speeding-up tempo. If std is 0, this augmentation is not performed') f.DEFINE_integer('augmentation_pitch_and_tempo_scaling', 0, 'whether to use spectrogram speed and tempo scaling') f.DEFINE_float('augmentation_pitch_and_tempo_scaling_min_pitch', 0.95, 'min value of pitch scaling') diff --git a/util/spectrogram_augmentations.py b/util/spectrogram_augmentations.py index 012c1bd2..fb47f713 100644 --- a/util/spectrogram_augmentations.py +++ b/util/spectrogram_augmentations.py @@ -37,7 +37,6 @@ def augment_freq_time_mask(mel_spectrogram, freq_max = tf.shape(mel_spectrogram)[1] time_max = tf.shape(mel_spectrogram)[2] # Frequency masking - # Testing without loop for _ in range(frequency_mask_num): f = tf.random.uniform(shape=(), minval=0, maxval=frequency_masking_para, dtype=tf.dtypes.int32) f0 = tf.random.uniform(shape=(), minval=0, maxval=freq_max - f, dtype=tf.dtypes.int32) @@ -50,7 +49,6 @@ def augment_freq_time_mask(mel_spectrogram, mel_spectrogram = mel_spectrogram*freq_mask # Time masking - # Testing without loop for _ in range(time_mask_num): t = tf.random.uniform(shape=(), minval=0, maxval=time_masking_para, dtype=tf.dtypes.int32) t0 = tf.random.uniform(shape=(), minval=0, maxval=time_max - t, dtype=tf.dtypes.int32)