[FIX] use time_warping_para to constraint control width, add logic to disable cache [ADD] num_control_points
This commit is contained in:
parent
ec0ee65eb0
commit
5fbc7e8596
@ -427,7 +427,8 @@ def train():
|
|||||||
FLAGS.augmentation_spec_dropout_keeprate < 1 or
|
FLAGS.augmentation_spec_dropout_keeprate < 1 or
|
||||||
FLAGS.augmentation_freq_and_time_masking or
|
FLAGS.augmentation_freq_and_time_masking or
|
||||||
FLAGS.augmentation_pitch_and_tempo_scaling or
|
FLAGS.augmentation_pitch_and_tempo_scaling or
|
||||||
FLAGS.augmentation_speed_up_std > 0):
|
FLAGS.augmentation_speed_up_std > 0 or
|
||||||
|
FLAGS.augmentation_sparse_warp):
|
||||||
do_cache_dataset = False
|
do_cache_dataset = False
|
||||||
|
|
||||||
# Create training and validation datasets
|
# Create training and validation datasets
|
||||||
|
@ -48,7 +48,8 @@ def samples_to_mfccs(samples, sample_rate, train_phase=False):
|
|||||||
time_warping_para=FLAGS.augmentation_sparse_warp_time_warping_para,
|
time_warping_para=FLAGS.augmentation_sparse_warp_time_warping_para,
|
||||||
interpolation_order=FLAGS.augmentation_sparse_warp_interpolation_order,
|
interpolation_order=FLAGS.augmentation_sparse_warp_interpolation_order,
|
||||||
regularization_weight=FLAGS.augmentation_sparse_warp_regularization_weight,
|
regularization_weight=FLAGS.augmentation_sparse_warp_regularization_weight,
|
||||||
num_boundary_points=FLAGS.augmentation_sparse_warp_num_boundary_points)
|
num_boundary_points=FLAGS.augmentation_sparse_warp_num_boundary_points,
|
||||||
|
num_control_points=FLAGS.augmentation_sparse_warp_num_control_points)
|
||||||
|
|
||||||
if FLAGS.augmentation_freq_and_time_masking:
|
if FLAGS.augmentation_freq_and_time_masking:
|
||||||
spectrogram = augment_freq_time_mask(spectrogram,
|
spectrogram = augment_freq_time_mask(spectrogram,
|
||||||
|
@ -30,6 +30,7 @@ def create_flags():
|
|||||||
f.DEFINE_float('augmentation_spec_dropout_keeprate', 1, 'keep rate of dropout augmentation on spectrogram (if 1, no dropout will be performed on spectrogram)')
|
f.DEFINE_float('augmentation_spec_dropout_keeprate', 1, 'keep rate of dropout augmentation on spectrogram (if 1, no dropout will be performed on spectrogram)')
|
||||||
|
|
||||||
f.DEFINE_boolean('augmentation_sparse_warp', False, 'whether to use spectrogram sparse warp')
|
f.DEFINE_boolean('augmentation_sparse_warp', False, 'whether to use spectrogram sparse warp')
|
||||||
|
f.DEFINE_integer('augmentation_sparse_warp_num_control_points', 1, 'specify number of control points')
|
||||||
f.DEFINE_integer('augmentation_sparse_warp_time_warping_para', 80, 'time_warping_para')
|
f.DEFINE_integer('augmentation_sparse_warp_time_warping_para', 80, 'time_warping_para')
|
||||||
f.DEFINE_integer('augmentation_sparse_warp_interpolation_order', 2, 'sparse_warp_interpolation_order')
|
f.DEFINE_integer('augmentation_sparse_warp_interpolation_order', 2, 'sparse_warp_interpolation_order')
|
||||||
f.DEFINE_float('augmentation_sparse_warp_regularization_weight', 0.0, 'sparse_warp_regularization_weight')
|
f.DEFINE_float('augmentation_sparse_warp_regularization_weight', 0.0, 'sparse_warp_regularization_weight')
|
||||||
|
@ -68,7 +68,7 @@ def augment_dropout(spectrogram,
|
|||||||
return tf.nn.dropout(spectrogram, rate=1-keep_prob)
|
return tf.nn.dropout(spectrogram, rate=1-keep_prob)
|
||||||
|
|
||||||
|
|
||||||
def augment_sparse_warp(spectrogram, time_warping_para=80, interpolation_order=2, regularization_weight=0.0, num_boundary_points=1):
|
def augment_sparse_warp(spectrogram, time_warping_para=80, interpolation_order=2, regularization_weight=0.0, num_boundary_points=1, num_control_points=1):
|
||||||
"""Reference: https://arxiv.org/pdf/1904.08779.pdf
|
"""Reference: https://arxiv.org/pdf/1904.08779.pdf
|
||||||
Args:
|
Args:
|
||||||
spectrogram: `[batch, time, frequency]` float `Tensor`
|
spectrogram: `[batch, time, frequency]` float `Tensor`
|
||||||
@ -77,6 +77,7 @@ def augment_sparse_warp(spectrogram, time_warping_para=80, interpolation_order=2
|
|||||||
regularization_weight: used to put into `sparse_image_warp`
|
regularization_weight: used to put into `sparse_image_warp`
|
||||||
num_boundary_points: used to put into `sparse_image_warp`,
|
num_boundary_points: used to put into `sparse_image_warp`,
|
||||||
default=1 means boundary points on 4 corners of the image
|
default=1 means boundary points on 4 corners of the image
|
||||||
|
num_control_points: number of control points
|
||||||
Returns:
|
Returns:
|
||||||
warped_spectrogram: `[batch, time, frequency]` float `Tensor` with same
|
warped_spectrogram: `[batch, time, frequency]` float `Tensor` with same
|
||||||
type as input image.
|
type as input image.
|
||||||
@ -92,25 +93,26 @@ def augment_sparse_warp(spectrogram, time_warping_para=80, interpolation_order=2
|
|||||||
time_warping_para = tf.math.minimum(
|
time_warping_para = tf.math.minimum(
|
||||||
time_warping_para, tf.math.subtract(tf.math.floordiv(tau, 2), 1))
|
time_warping_para, tf.math.subtract(tf.math.floordiv(tau, 2), 1))
|
||||||
|
|
||||||
mid_tau = tf.math.floordiv(tau, 2)
|
|
||||||
mid_freq = tf.math.floordiv(freq_size, 2)
|
mid_freq = tf.math.floordiv(freq_size, 2)
|
||||||
left_mid_point = [0, mid_freq]
|
left_mid_point = [0, mid_freq]
|
||||||
right_mid_point = [tau, mid_freq]
|
right_mid_point = [tau, mid_freq]
|
||||||
|
|
||||||
# dest control point must between (W, tau-W), which means the first and last W interval of the spectrogram won't be warped
|
random_source_times = [tfv1.random_uniform( # generate source points `t` of time axis between (W, tau-W)
|
||||||
random_dest_time_point = tfv1.random_uniform(
|
[], time_warping_para, tf.math.subtract(tau, time_warping_para), tf.int32) for i in range(num_control_points)]
|
||||||
[], time_warping_para, tau - time_warping_para, tf.int32)
|
random_dest_times = [tfv1.random_uniform( # generate dest points `t'` of time axis between (t-W, t+W)
|
||||||
source_control_point_locations = tf.cast([[
|
[], tf.math.subtract(source_time, time_warping_para), tf.math.add(source_time, time_warping_para), tf.int32) for source_time in random_source_times]
|
||||||
left_mid_point,
|
|
||||||
[mid_tau, mid_freq], # source control point must start from the center of the image
|
|
||||||
right_mid_point
|
|
||||||
]], tf.float32)
|
|
||||||
|
|
||||||
dest_control_point_locations = tf.cast([[
|
source_control_point_locations = tf.cast([
|
||||||
left_mid_point,
|
[left_mid_point] +
|
||||||
[random_dest_time_point, mid_freq],
|
[[source_time, mid_freq] for source_time in random_source_times] +
|
||||||
right_mid_point
|
[right_mid_point]
|
||||||
]], tf.float32)
|
], tf.float32)
|
||||||
|
|
||||||
|
dest_control_point_locations = tf.cast([
|
||||||
|
[left_mid_point] +
|
||||||
|
[[dest_time, mid_freq] for dest_time in random_dest_times] +
|
||||||
|
[right_mid_point]
|
||||||
|
], tf.float32)
|
||||||
|
|
||||||
warped_spectrogram, _ = sparse_image_warp(spectrogram,
|
warped_spectrogram, _ = sparse_image_warp(spectrogram,
|
||||||
source_control_point_locations=source_control_point_locations,
|
source_control_point_locations=source_control_point_locations,
|
||||||
|
Loading…
Reference in New Issue
Block a user