[FIX] constraint time_warping_para to protect short audio augment
This commit is contained in:
parent
5fbc7e8596
commit
533d15645f
@ -84,7 +84,8 @@ def augment_sparse_warp(spectrogram, time_warping_para=80, interpolation_order=2
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# resize to fit `sparse_image_warp`'s input shape
|
# resize to fit `sparse_image_warp`'s input shape
|
||||||
spectrogram = tf.expand_dims(spectrogram, -1) # (1, time steps, freq, 1), batch_size must be 1
|
# (1, time steps, freq, 1), batch_size must be 1
|
||||||
|
spectrogram = tf.expand_dims(spectrogram, -1)
|
||||||
|
|
||||||
original_shape = tf.shape(spectrogram)
|
original_shape = tf.shape(spectrogram)
|
||||||
tau, freq_size = original_shape[1], original_shape[2]
|
tau, freq_size = original_shape[1], original_shape[2]
|
||||||
@ -93,26 +94,41 @@ def augment_sparse_warp(spectrogram, time_warping_para=80, interpolation_order=2
|
|||||||
time_warping_para = tf.math.minimum(
|
time_warping_para = tf.math.minimum(
|
||||||
time_warping_para, tf.math.subtract(tf.math.floordiv(tau, 2), 1))
|
time_warping_para, tf.math.subtract(tf.math.floordiv(tau, 2), 1))
|
||||||
|
|
||||||
mid_freq = tf.math.floordiv(freq_size, 2)
|
choosen_freqs = tf.random.shuffle(tf.add(tf.range(freq_size), 1))[
|
||||||
left_mid_point = [0, mid_freq]
|
0: num_control_points]
|
||||||
right_mid_point = [tau, mid_freq]
|
|
||||||
|
|
||||||
random_source_times = [tfv1.random_uniform( # generate source points `t` of time axis between (W, tau-W)
|
sources = []
|
||||||
[], time_warping_para, tf.math.subtract(tau, time_warping_para), tf.int32) for i in range(num_control_points)]
|
dests = []
|
||||||
random_dest_times = [tfv1.random_uniform( # generate dest points `t'` of time axis between (t-W, t+W)
|
for i in range(num_control_points):
|
||||||
[], tf.math.subtract(source_time, time_warping_para), tf.math.add(source_time, time_warping_para), tf.int32) for source_time in random_source_times]
|
source_max = tau - time_warping_para - 1
|
||||||
|
# to protect short audio
|
||||||
|
source_min = tf.math.minimum(source_max - 1, time_warping_para)
|
||||||
|
rand_source_time = tfv1.random_uniform( # generate source points `t` of time axis between (W, tau-W)
|
||||||
|
[], source_min, source_max, tf.int32)
|
||||||
|
rand_dest_time = tfv1.random_uniform( # generate dest points `t'` of time axis between (t-W, t+W)
|
||||||
|
[], tf.math.maximum(tf.math.subtract(rand_source_time, time_warping_para), 0), tf.math.add(rand_source_time, time_warping_para), tf.int32)
|
||||||
|
|
||||||
source_control_point_locations = tf.cast([
|
# if choosen_freq == tau -1 => crash
|
||||||
[left_mid_point] +
|
choosen_freq = tf.cond(tf.equal(choosen_freqs[i], tau-1),
|
||||||
[[source_time, mid_freq] for source_time in random_source_times] +
|
lambda: choosen_freqs[i] + # pylint: disable=cell-var-from-loop
|
||||||
[right_mid_point]
|
1, # pylint: disable=cell-var-from-loop
|
||||||
], tf.float32)
|
lambda: choosen_freqs[i]) # pylint: disable=cell-var-from-loop
|
||||||
|
sources.append([0, choosen_freq])
|
||||||
|
sources.append([rand_source_time, choosen_freq])
|
||||||
|
sources.append([tau, choosen_freq])
|
||||||
|
|
||||||
dest_control_point_locations = tf.cast([
|
dests.append([0, choosen_freq])
|
||||||
[left_mid_point] +
|
dests.append([rand_dest_time, choosen_freq])
|
||||||
[[dest_time, mid_freq] for dest_time in random_dest_times] +
|
dests.append([tau, choosen_freq])
|
||||||
[right_mid_point]
|
|
||||||
], tf.float32)
|
source_control_point_locations = tf.cast([sources], tf.float32)
|
||||||
|
|
||||||
|
dest_control_point_locations = tf.cast([dests], tf.float32)
|
||||||
|
|
||||||
|
# debug
|
||||||
|
# print('spectrogram', spectrogram)
|
||||||
|
# spectrogram = tf.Print(spectrogram, sources, message='sources', first_n=1000)
|
||||||
|
# spectrogram = tf.Print(spectrogram, dests, message='dests', first_n=1000)
|
||||||
|
|
||||||
warped_spectrogram, _ = sparse_image_warp(spectrogram,
|
warped_spectrogram, _ = sparse_image_warp(spectrogram,
|
||||||
source_control_point_locations=source_control_point_locations,
|
source_control_point_locations=source_control_point_locations,
|
||||||
|
Loading…
Reference in New Issue
Block a user