[FIX] constraint time_warping_para to protect short audio augment
This commit is contained in:
parent
5fbc7e8596
commit
533d15645f
@ -84,7 +84,8 @@ def augment_sparse_warp(spectrogram, time_warping_para=80, interpolation_order=2
|
||||
"""
|
||||
|
||||
# resize to fit `sparse_image_warp`'s input shape
|
||||
spectrogram = tf.expand_dims(spectrogram, -1) # (1, time steps, freq, 1), batch_size must be 1
|
||||
# (1, time steps, freq, 1), batch_size must be 1
|
||||
spectrogram = tf.expand_dims(spectrogram, -1)
|
||||
|
||||
original_shape = tf.shape(spectrogram)
|
||||
tau, freq_size = original_shape[1], original_shape[2]
|
||||
@ -93,26 +94,41 @@ def augment_sparse_warp(spectrogram, time_warping_para=80, interpolation_order=2
|
||||
time_warping_para = tf.math.minimum(
|
||||
time_warping_para, tf.math.subtract(tf.math.floordiv(tau, 2), 1))
|
||||
|
||||
mid_freq = tf.math.floordiv(freq_size, 2)
|
||||
left_mid_point = [0, mid_freq]
|
||||
right_mid_point = [tau, mid_freq]
|
||||
choosen_freqs = tf.random.shuffle(tf.add(tf.range(freq_size), 1))[
|
||||
0: num_control_points]
|
||||
|
||||
random_source_times = [tfv1.random_uniform( # generate source points `t` of time axis between (W, tau-W)
|
||||
[], time_warping_para, tf.math.subtract(tau, time_warping_para), tf.int32) for i in range(num_control_points)]
|
||||
random_dest_times = [tfv1.random_uniform( # generate dest points `t'` of time axis between (t-W, t+W)
|
||||
[], tf.math.subtract(source_time, time_warping_para), tf.math.add(source_time, time_warping_para), tf.int32) for source_time in random_source_times]
|
||||
sources = []
|
||||
dests = []
|
||||
for i in range(num_control_points):
|
||||
source_max = tau - time_warping_para - 1
|
||||
# to protect short audio
|
||||
source_min = tf.math.minimum(source_max - 1, time_warping_para)
|
||||
rand_source_time = tfv1.random_uniform( # generate source points `t` of time axis between (W, tau-W)
|
||||
[], source_min, source_max, tf.int32)
|
||||
rand_dest_time = tfv1.random_uniform( # generate dest points `t'` of time axis between (t-W, t+W)
|
||||
[], tf.math.maximum(tf.math.subtract(rand_source_time, time_warping_para), 0), tf.math.add(rand_source_time, time_warping_para), tf.int32)
|
||||
|
||||
source_control_point_locations = tf.cast([
|
||||
[left_mid_point] +
|
||||
[[source_time, mid_freq] for source_time in random_source_times] +
|
||||
[right_mid_point]
|
||||
], tf.float32)
|
||||
# if choosen_freq == tau -1 => crash
|
||||
choosen_freq = tf.cond(tf.equal(choosen_freqs[i], tau-1),
|
||||
lambda: choosen_freqs[i] + # pylint: disable=cell-var-from-loop
|
||||
1, # pylint: disable=cell-var-from-loop
|
||||
lambda: choosen_freqs[i]) # pylint: disable=cell-var-from-loop
|
||||
sources.append([0, choosen_freq])
|
||||
sources.append([rand_source_time, choosen_freq])
|
||||
sources.append([tau, choosen_freq])
|
||||
|
||||
dest_control_point_locations = tf.cast([
|
||||
[left_mid_point] +
|
||||
[[dest_time, mid_freq] for dest_time in random_dest_times] +
|
||||
[right_mid_point]
|
||||
], tf.float32)
|
||||
dests.append([0, choosen_freq])
|
||||
dests.append([rand_dest_time, choosen_freq])
|
||||
dests.append([tau, choosen_freq])
|
||||
|
||||
source_control_point_locations = tf.cast([sources], tf.float32)
|
||||
|
||||
dest_control_point_locations = tf.cast([dests], tf.float32)
|
||||
|
||||
# debug
|
||||
# print('spectrogram', spectrogram)
|
||||
# spectrogram = tf.Print(spectrogram, sources, message='sources', first_n=1000)
|
||||
# spectrogram = tf.Print(spectrogram, dests, message='dests', first_n=1000)
|
||||
|
||||
warped_spectrogram, _ = sparse_image_warp(spectrogram,
|
||||
source_control_point_locations=source_control_point_locations,
|
||||
|
Loading…
Reference in New Issue
Block a user