diff --git a/DeepSpeech.py b/DeepSpeech.py index 42c1d782..161ae45b 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -427,7 +427,8 @@ def train(): FLAGS.augmentation_spec_dropout_keeprate < 1 or FLAGS.augmentation_freq_and_time_masking or FLAGS.augmentation_pitch_and_tempo_scaling or - FLAGS.augmentation_speed_up_std > 0): + FLAGS.augmentation_speed_up_std > 0 or + FLAGS.augmentation_sparse_warp): do_cache_dataset = False # Create training and validation datasets diff --git a/util/feeding.py b/util/feeding.py index 612a940b..ac06f871 100644 --- a/util/feeding.py +++ b/util/feeding.py @@ -48,7 +48,8 @@ def samples_to_mfccs(samples, sample_rate, train_phase=False): time_warping_para=FLAGS.augmentation_sparse_warp_time_warping_para, interpolation_order=FLAGS.augmentation_sparse_warp_interpolation_order, regularization_weight=FLAGS.augmentation_sparse_warp_regularization_weight, - num_boundary_points=FLAGS.augmentation_sparse_warp_num_boundary_points) + num_boundary_points=FLAGS.augmentation_sparse_warp_num_boundary_points, + num_control_points=FLAGS.augmentation_sparse_warp_num_control_points) if FLAGS.augmentation_freq_and_time_masking: spectrogram = augment_freq_time_mask(spectrogram, diff --git a/util/flags.py b/util/flags.py index d6eecb31..dba3cbff 100644 --- a/util/flags.py +++ b/util/flags.py @@ -30,6 +30,7 @@ def create_flags(): f.DEFINE_float('augmentation_spec_dropout_keeprate', 1, 'keep rate of dropout augmentation on spectrogram (if 1, no dropout will be performed on spectrogram)') f.DEFINE_boolean('augmentation_sparse_warp', False, 'whether to use spectrogram sparse warp') + f.DEFINE_integer('augmentation_sparse_warp_num_control_points', 1, 'specify number of control points') f.DEFINE_integer('augmentation_sparse_warp_time_warping_para', 80, 'time_warping_para') f.DEFINE_integer('augmentation_sparse_warp_interpolation_order', 2, 'sparse_warp_interpolation_order') f.DEFINE_float('augmentation_sparse_warp_regularization_weight', 0.0, 'sparse_warp_regularization_weight') diff --git a/util/spectrogram_augmentations.py b/util/spectrogram_augmentations.py index fce975cf..24337f54 100644 --- a/util/spectrogram_augmentations.py +++ b/util/spectrogram_augmentations.py @@ -68,7 +68,7 @@ def augment_dropout(spectrogram, return tf.nn.dropout(spectrogram, rate=1-keep_prob) -def augment_sparse_warp(spectrogram, time_warping_para=80, interpolation_order=2, regularization_weight=0.0, num_boundary_points=1): +def augment_sparse_warp(spectrogram, time_warping_para=80, interpolation_order=2, regularization_weight=0.0, num_boundary_points=1, num_control_points=1): """Reference: https://arxiv.org/pdf/1904.08779.pdf Args: spectrogram: `[batch, time, frequency]` float `Tensor` @@ -77,6 +77,7 @@ def augment_sparse_warp(spectrogram, time_warping_para=80, interpolation_order=2 regularization_weight: used to put into `sparse_image_warp` num_boundary_points: used to put into `sparse_image_warp`, default=1 means boundary points on 4 corners of the image + num_control_points: number of control points Returns: warped_spectrogram: `[batch, time, frequency]` float `Tensor` with same type as input image. @@ -92,25 +93,26 @@ def augment_sparse_warp(spectrogram, time_warping_para=80, interpolation_order=2 time_warping_para = tf.math.minimum( time_warping_para, tf.math.subtract(tf.math.floordiv(tau, 2), 1)) - mid_tau = tf.math.floordiv(tau, 2) mid_freq = tf.math.floordiv(freq_size, 2) left_mid_point = [0, mid_freq] right_mid_point = [tau, mid_freq] - # dest control point must between (W, tau-W), which means the first and last W interval of the spectrogram won't be warped - random_dest_time_point = tfv1.random_uniform( - [], time_warping_para, tau - time_warping_para, tf.int32) - source_control_point_locations = tf.cast([[ - left_mid_point, - [mid_tau, mid_freq], # source control point must start from the center of the image - right_mid_point - ]], tf.float32) + random_source_times = [tfv1.random_uniform( # generate source points `t` of time axis between (W, tau-W) + [], time_warping_para, tf.math.subtract(tau, time_warping_para), tf.int32) for i in range(num_control_points)] + random_dest_times = [tfv1.random_uniform( # generate dest points `t'` of time axis between (t-W, t+W) + [], tf.math.subtract(source_time, time_warping_para), tf.math.add(source_time, time_warping_para), tf.int32) for source_time in random_source_times] - dest_control_point_locations = tf.cast([[ - left_mid_point, - [random_dest_time_point, mid_freq], - right_mid_point - ]], tf.float32) + source_control_point_locations = tf.cast([ + [left_mid_point] + + [[source_time, mid_freq] for source_time in random_source_times] + + [right_mid_point] + ], tf.float32) + + dest_control_point_locations = tf.cast([ + [left_mid_point] + + [[dest_time, mid_freq] for dest_time in random_dest_times] + + [right_mid_point] + ], tf.float32) warped_spectrogram, _ = sparse_image_warp(spectrogram, source_control_point_locations=source_control_point_locations,