[FIX] use time_warping_para to constraint control width, add logic to disable cache [ADD] num_control_points

This commit is contained in:
Yi-Hua Chiu 2019-12-04 10:25:36 +08:00
parent ec0ee65eb0
commit 5fbc7e8596
4 changed files with 22 additions and 17 deletions

View File

@ -427,7 +427,8 @@ def train():
FLAGS.augmentation_spec_dropout_keeprate < 1 or
FLAGS.augmentation_freq_and_time_masking or
FLAGS.augmentation_pitch_and_tempo_scaling or
FLAGS.augmentation_speed_up_std > 0):
FLAGS.augmentation_speed_up_std > 0 or
FLAGS.augmentation_sparse_warp):
do_cache_dataset = False
# Create training and validation datasets

View File

@ -48,7 +48,8 @@ def samples_to_mfccs(samples, sample_rate, train_phase=False):
time_warping_para=FLAGS.augmentation_sparse_warp_time_warping_para,
interpolation_order=FLAGS.augmentation_sparse_warp_interpolation_order,
regularization_weight=FLAGS.augmentation_sparse_warp_regularization_weight,
num_boundary_points=FLAGS.augmentation_sparse_warp_num_boundary_points)
num_boundary_points=FLAGS.augmentation_sparse_warp_num_boundary_points,
num_control_points=FLAGS.augmentation_sparse_warp_num_control_points)
if FLAGS.augmentation_freq_and_time_masking:
spectrogram = augment_freq_time_mask(spectrogram,

View File

@ -30,6 +30,7 @@ def create_flags():
f.DEFINE_float('augmentation_spec_dropout_keeprate', 1, 'keep rate of dropout augmentation on spectrogram (if 1, no dropout will be performed on spectrogram)')
f.DEFINE_boolean('augmentation_sparse_warp', False, 'whether to use spectrogram sparse warp')
f.DEFINE_integer('augmentation_sparse_warp_num_control_points', 1, 'specify number of control points')
f.DEFINE_integer('augmentation_sparse_warp_time_warping_para', 80, 'time_warping_para')
f.DEFINE_integer('augmentation_sparse_warp_interpolation_order', 2, 'sparse_warp_interpolation_order')
f.DEFINE_float('augmentation_sparse_warp_regularization_weight', 0.0, 'sparse_warp_regularization_weight')

View File

@ -68,7 +68,7 @@ def augment_dropout(spectrogram,
return tf.nn.dropout(spectrogram, rate=1-keep_prob)
def augment_sparse_warp(spectrogram, time_warping_para=80, interpolation_order=2, regularization_weight=0.0, num_boundary_points=1):
def augment_sparse_warp(spectrogram, time_warping_para=80, interpolation_order=2, regularization_weight=0.0, num_boundary_points=1, num_control_points=1):
"""Reference: https://arxiv.org/pdf/1904.08779.pdf
Args:
spectrogram: `[batch, time, frequency]` float `Tensor`
@ -77,6 +77,7 @@ def augment_sparse_warp(spectrogram, time_warping_para=80, interpolation_order=2
regularization_weight: used to put into `sparse_image_warp`
num_boundary_points: used to put into `sparse_image_warp`,
default=1 means boundary points on 4 corners of the image
num_control_points: number of control points
Returns:
warped_spectrogram: `[batch, time, frequency]` float `Tensor` with same
type as input image.
@ -92,25 +93,26 @@ def augment_sparse_warp(spectrogram, time_warping_para=80, interpolation_order=2
time_warping_para = tf.math.minimum(
time_warping_para, tf.math.subtract(tf.math.floordiv(tau, 2), 1))
mid_tau = tf.math.floordiv(tau, 2)
mid_freq = tf.math.floordiv(freq_size, 2)
left_mid_point = [0, mid_freq]
right_mid_point = [tau, mid_freq]
# dest control point must between (W, tau-W), which means the first and last W interval of the spectrogram won't be warped
random_dest_time_point = tfv1.random_uniform(
[], time_warping_para, tau - time_warping_para, tf.int32)
source_control_point_locations = tf.cast([[
left_mid_point,
[mid_tau, mid_freq], # source control point must start from the center of the image
right_mid_point
]], tf.float32)
random_source_times = [tfv1.random_uniform( # generate source points `t` of time axis between (W, tau-W)
[], time_warping_para, tf.math.subtract(tau, time_warping_para), tf.int32) for i in range(num_control_points)]
random_dest_times = [tfv1.random_uniform( # generate dest points `t'` of time axis between (t-W, t+W)
[], tf.math.subtract(source_time, time_warping_para), tf.math.add(source_time, time_warping_para), tf.int32) for source_time in random_source_times]
dest_control_point_locations = tf.cast([[
left_mid_point,
[random_dest_time_point, mid_freq],
right_mid_point
]], tf.float32)
source_control_point_locations = tf.cast([
[left_mid_point] +
[[source_time, mid_freq] for source_time in random_source_times] +
[right_mid_point]
], tf.float32)
dest_control_point_locations = tf.cast([
[left_mid_point] +
[[dest_time, mid_freq] for dest_time in random_dest_times] +
[right_mid_point]
], tf.float32)
warped_spectrogram, _ = sparse_image_warp(spectrogram,
source_control_point_locations=source_control_point_locations,