[MOD] change time_warping_para as 20 to make spec not sound too vague [FIX] make sure invertible error is not raised after many many epochs

This commit is contained in:
Yi-Hua Chiu 2019-12-31 16:31:46 +08:00
parent 3aedcc4222
commit fa41809a40
2 changed files with 15 additions and 12 deletions

View File

@ -31,7 +31,7 @@ def create_flags():
f.DEFINE_boolean('augmentation_sparse_warp', False, 'whether to use spectrogram sparse warp')
f.DEFINE_integer('augmentation_sparse_warp_num_control_points', 1, 'specify number of control points')
f.DEFINE_integer('augmentation_sparse_warp_time_warping_para', 80, 'time_warping_para')
f.DEFINE_integer('augmentation_sparse_warp_time_warping_para', 20, 'time_warping_para')
f.DEFINE_integer('augmentation_sparse_warp_interpolation_order', 2, 'sparse_warp_interpolation_order')
f.DEFINE_float('augmentation_sparse_warp_regularization_weight', 0.0, 'sparse_warp_regularization_weight')
f.DEFINE_integer('augmentation_sparse_warp_num_boundary_points', 1, 'sparse_warp_num_boundary_points')

View File

@ -68,7 +68,7 @@ def augment_dropout(spectrogram,
return tf.nn.dropout(spectrogram, rate=1-keep_prob)
def augment_sparse_warp(spectrogram, time_warping_para=80, interpolation_order=2, regularization_weight=0.0, num_boundary_points=1, num_control_points=1):
def augment_sparse_warp(spectrogram, time_warping_para=20, interpolation_order=2, regularization_weight=0.0, num_boundary_points=1, num_control_points=1):
"""Reference: https://arxiv.org/pdf/1904.08779.pdf
Args:
spectrogram: `[batch, time, frequency]` float `Tensor`
@ -82,8 +82,7 @@ def augment_sparse_warp(spectrogram, time_warping_para=80, interpolation_order=2
warped_spectrogram: `[batch, time, frequency]` float `Tensor` with same
type as input image.
"""
# resize to fit `sparse_image_warp`'s input shape
# reshape to fit `sparse_image_warp`'s input shape
# (1, time steps, freq, 1), batch_size must be 1
spectrogram = tf.expand_dims(spectrogram, -1)
@ -94,24 +93,28 @@ def augment_sparse_warp(spectrogram, time_warping_para=80, interpolation_order=2
time_warping_para = tf.math.minimum(
time_warping_para, tf.math.subtract(tf.math.floordiv(tau, 2), 1))
choosen_freqs = tf.random.shuffle(tf.add(tf.range(freq_size - 3), 1))[0: num_control_points]
# don't choose boundary frequency
choosen_freqs = tf.random.shuffle(
tf.add(tf.range(freq_size - 3), 1))[0: num_control_points]
source_max = tau - time_warping_para
source_min = tf.math.minimum(source_max - num_control_points, time_warping_para)
choosen_times = tf.random.shuffle(tf.range(source_min, limit=source_max))[0: num_control_points]
dest_time_widths = tfv1.random_uniform([num_control_points], tf.negative(time_warping_para), time_warping_para, tf.int32)
sources = []
dests = []
for i in range(num_control_points):
source_max = tau - time_warping_para - 1
source_min = tf.math.minimum(source_max - 1, time_warping_para)
rand_source_time = tfv1.random_uniform( # generate source points `t` of time axis between (W, tau-W)
[], source_min, source_max, tf.int32)
rand_dest_time = tfv1.random_uniform( # generate dest points `t'` of time axis between (t-W, t+W)
[], tf.math.maximum(tf.math.subtract(rand_source_time, time_warping_para), 0), tf.math.add(rand_source_time, time_warping_para), tf.int32)
# generate source points `t` of time axis between (W, tau-W)
rand_source_time = choosen_times[i]
rand_dest_time = rand_source_time + dest_time_widths[i]
choosen_freq = choosen_freqs[i]
sources.append([rand_source_time, choosen_freq])
dests.append([rand_dest_time, choosen_freq])
source_control_point_locations = tf.cast([sources], tf.float32)
dest_control_point_locations = tf.cast([dests], tf.float32)
# debug