[MOD] change time_warping_para as 20 to make spec not sound too vague [FIX] make sure invertible error is not raised after many many epochs
This commit is contained in:
parent
3aedcc4222
commit
fa41809a40
@ -31,7 +31,7 @@ def create_flags():
|
||||
|
||||
f.DEFINE_boolean('augmentation_sparse_warp', False, 'whether to use spectrogram sparse warp')
|
||||
f.DEFINE_integer('augmentation_sparse_warp_num_control_points', 1, 'specify number of control points')
|
||||
f.DEFINE_integer('augmentation_sparse_warp_time_warping_para', 80, 'time_warping_para')
|
||||
f.DEFINE_integer('augmentation_sparse_warp_time_warping_para', 20, 'time_warping_para')
|
||||
f.DEFINE_integer('augmentation_sparse_warp_interpolation_order', 2, 'sparse_warp_interpolation_order')
|
||||
f.DEFINE_float('augmentation_sparse_warp_regularization_weight', 0.0, 'sparse_warp_regularization_weight')
|
||||
f.DEFINE_integer('augmentation_sparse_warp_num_boundary_points', 1, 'sparse_warp_num_boundary_points')
|
||||
|
@ -68,7 +68,7 @@ def augment_dropout(spectrogram,
|
||||
return tf.nn.dropout(spectrogram, rate=1-keep_prob)
|
||||
|
||||
|
||||
def augment_sparse_warp(spectrogram, time_warping_para=80, interpolation_order=2, regularization_weight=0.0, num_boundary_points=1, num_control_points=1):
|
||||
def augment_sparse_warp(spectrogram, time_warping_para=20, interpolation_order=2, regularization_weight=0.0, num_boundary_points=1, num_control_points=1):
|
||||
"""Reference: https://arxiv.org/pdf/1904.08779.pdf
|
||||
Args:
|
||||
spectrogram: `[batch, time, frequency]` float `Tensor`
|
||||
@ -82,8 +82,7 @@ def augment_sparse_warp(spectrogram, time_warping_para=80, interpolation_order=2
|
||||
warped_spectrogram: `[batch, time, frequency]` float `Tensor` with same
|
||||
type as input image.
|
||||
"""
|
||||
|
||||
# resize to fit `sparse_image_warp`'s input shape
|
||||
# reshape to fit `sparse_image_warp`'s input shape
|
||||
# (1, time steps, freq, 1), batch_size must be 1
|
||||
spectrogram = tf.expand_dims(spectrogram, -1)
|
||||
|
||||
@ -94,24 +93,28 @@ def augment_sparse_warp(spectrogram, time_warping_para=80, interpolation_order=2
|
||||
time_warping_para = tf.math.minimum(
|
||||
time_warping_para, tf.math.subtract(tf.math.floordiv(tau, 2), 1))
|
||||
|
||||
choosen_freqs = tf.random.shuffle(tf.add(tf.range(freq_size - 3), 1))[0: num_control_points]
|
||||
# don't choose boundary frequency
|
||||
choosen_freqs = tf.random.shuffle(
|
||||
tf.add(tf.range(freq_size - 3), 1))[0: num_control_points]
|
||||
|
||||
source_max = tau - time_warping_para
|
||||
source_min = tf.math.minimum(source_max - num_control_points, time_warping_para)
|
||||
|
||||
choosen_times = tf.random.shuffle(tf.range(source_min, limit=source_max))[0: num_control_points]
|
||||
dest_time_widths = tfv1.random_uniform([num_control_points], tf.negative(time_warping_para), time_warping_para, tf.int32)
|
||||
|
||||
sources = []
|
||||
dests = []
|
||||
for i in range(num_control_points):
|
||||
source_max = tau - time_warping_para - 1
|
||||
source_min = tf.math.minimum(source_max - 1, time_warping_para)
|
||||
rand_source_time = tfv1.random_uniform( # generate source points `t` of time axis between (W, tau-W)
|
||||
[], source_min, source_max, tf.int32)
|
||||
rand_dest_time = tfv1.random_uniform( # generate dest points `t'` of time axis between (t-W, t+W)
|
||||
[], tf.math.maximum(tf.math.subtract(rand_source_time, time_warping_para), 0), tf.math.add(rand_source_time, time_warping_para), tf.int32)
|
||||
# generate source points `t` of time axis between (W, tau-W)
|
||||
rand_source_time = choosen_times[i]
|
||||
rand_dest_time = rand_source_time + dest_time_widths[i]
|
||||
|
||||
choosen_freq = choosen_freqs[i]
|
||||
sources.append([rand_source_time, choosen_freq])
|
||||
dests.append([rand_dest_time, choosen_freq])
|
||||
|
||||
source_control_point_locations = tf.cast([sources], tf.float32)
|
||||
|
||||
dest_control_point_locations = tf.cast([dests], tf.float32)
|
||||
|
||||
# debug
|
||||
|
Loading…
Reference in New Issue
Block a user