[FIX] constraint time_warping_para to protect short audio augment

This commit is contained in:
Yi-Hua Chiu 2019-12-18 14:26:39 +08:00
parent 5fbc7e8596
commit 533d15645f

View File

@ -84,7 +84,8 @@ def augment_sparse_warp(spectrogram, time_warping_para=80, interpolation_order=2
""" """
# resize to fit `sparse_image_warp`'s input shape # resize to fit `sparse_image_warp`'s input shape
spectrogram = tf.expand_dims(spectrogram, -1) # (1, time steps, freq, 1), batch_size must be 1 # (1, time steps, freq, 1), batch_size must be 1
spectrogram = tf.expand_dims(spectrogram, -1)
original_shape = tf.shape(spectrogram) original_shape = tf.shape(spectrogram)
tau, freq_size = original_shape[1], original_shape[2] tau, freq_size = original_shape[1], original_shape[2]
@ -93,26 +94,41 @@ def augment_sparse_warp(spectrogram, time_warping_para=80, interpolation_order=2
time_warping_para = tf.math.minimum( time_warping_para = tf.math.minimum(
time_warping_para, tf.math.subtract(tf.math.floordiv(tau, 2), 1)) time_warping_para, tf.math.subtract(tf.math.floordiv(tau, 2), 1))
mid_freq = tf.math.floordiv(freq_size, 2) choosen_freqs = tf.random.shuffle(tf.add(tf.range(freq_size), 1))[
left_mid_point = [0, mid_freq] 0: num_control_points]
right_mid_point = [tau, mid_freq]
random_source_times = [tfv1.random_uniform( # generate source points `t` of time axis between (W, tau-W) sources = []
[], time_warping_para, tf.math.subtract(tau, time_warping_para), tf.int32) for i in range(num_control_points)] dests = []
random_dest_times = [tfv1.random_uniform( # generate dest points `t'` of time axis between (t-W, t+W) for i in range(num_control_points):
[], tf.math.subtract(source_time, time_warping_para), tf.math.add(source_time, time_warping_para), tf.int32) for source_time in random_source_times] source_max = tau - time_warping_para - 1
# to protect short audio
source_min = tf.math.minimum(source_max - 1, time_warping_para)
rand_source_time = tfv1.random_uniform( # generate source points `t` of time axis between (W, tau-W)
[], source_min, source_max, tf.int32)
rand_dest_time = tfv1.random_uniform( # generate dest points `t'` of time axis between (t-W, t+W)
[], tf.math.maximum(tf.math.subtract(rand_source_time, time_warping_para), 0), tf.math.add(rand_source_time, time_warping_para), tf.int32)
source_control_point_locations = tf.cast([ # if choosen_freq == tau -1 => crash
[left_mid_point] + choosen_freq = tf.cond(tf.equal(choosen_freqs[i], tau-1),
[[source_time, mid_freq] for source_time in random_source_times] + lambda: choosen_freqs[i] + # pylint: disable=cell-var-from-loop
[right_mid_point] 1, # pylint: disable=cell-var-from-loop
], tf.float32) lambda: choosen_freqs[i]) # pylint: disable=cell-var-from-loop
sources.append([0, choosen_freq])
sources.append([rand_source_time, choosen_freq])
sources.append([tau, choosen_freq])
dest_control_point_locations = tf.cast([ dests.append([0, choosen_freq])
[left_mid_point] + dests.append([rand_dest_time, choosen_freq])
[[dest_time, mid_freq] for dest_time in random_dest_times] + dests.append([tau, choosen_freq])
[right_mid_point]
], tf.float32) source_control_point_locations = tf.cast([sources], tf.float32)
dest_control_point_locations = tf.cast([dests], tf.float32)
# debug
# print('spectrogram', spectrogram)
# spectrogram = tf.Print(spectrogram, sources, message='sources', first_n=1000)
# spectrogram = tf.Print(spectrogram, dests, message='dests', first_n=1000)
warped_spectrogram, _ = sparse_image_warp(spectrogram, warped_spectrogram, _ = sparse_image_warp(spectrogram,
source_control_point_locations=source_control_point_locations, source_control_point_locations=source_control_point_locations,