clean code
This commit is contained in:
parent
368f0d413a
commit
450483c30b
@ -66,8 +66,6 @@ def samples_to_mfccs(samples, sample_rate, train_phase=False):
|
||||
if FLAGS.augmentation_speed_up_std > 0:
|
||||
spectrogram = augment_speed_up(spectrogram, speed_std=FLAGS.augmentation_speed_up_std)
|
||||
|
||||
# spectrogram = augment_sparse_warp(spectrogram)
|
||||
|
||||
mfccs = contrib_audio.mfcc(spectrogram, sample_rate, dct_coefficient_count=Config.n_input)
|
||||
mfccs = tf.reshape(mfccs, [-1, Config.n_input])
|
||||
|
||||
|
@ -13,6 +13,11 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Image warping using sparse flow defined at control points."""
|
||||
|
||||
# The following code is from: https://github.com/tensorflow/tensorflow/blob/v1.14.0/tensorflow/contrib/image/python/ops/sparse_image_warp.py
|
||||
# But refactored for dynamic tensor shape compatibility
|
||||
# The core idea is to replace every numpy implementation with tensorflow implementation
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
@ -23,9 +28,7 @@ from tensorflow.compat import dimension_value
|
||||
from tensorflow.contrib.image.python.ops import dense_image_warp
|
||||
from tensorflow.contrib.image.python.ops import interpolate_spline
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
|
||||
def _to_float32(value):
|
||||
|
@ -84,42 +84,54 @@ def augment_dropout(spectrogram,
|
||||
return tf.nn.dropout(spectrogram, rate=1-keep_prob)
|
||||
|
||||
|
||||
def augment_sparse_warp(spectrogram: tf.Tensor, time_warping_para=80, interpolation_order=2, regularization_weight=0.0, num_boundary_points=1):
|
||||
def augment_sparse_warp(spectrogram, time_warping_para=80, interpolation_order=2, regularization_weight=0.0, num_boundary_points=1):
|
||||
"""Reference: https://arxiv.org/pdf/1904.08779.pdf
|
||||
Args:
|
||||
spectrogram: `[batch, time, frequency]` float `Tensor`
|
||||
time_warping_para: 'W' parameter in paper
|
||||
interpolation_order: used to put into `sparse_image_warp`
|
||||
regularization_weight: used to put into `sparse_image_warp`
|
||||
num_boundary_points: used to put into `sparse_image_warp`,
|
||||
default=1 means boundary points on 4 corners of the image
|
||||
Returns:
|
||||
warped_spectrogram: `[batch, time, frequency]` float `Tensor` with same
|
||||
type as input image.
|
||||
"""
|
||||
|
||||
if spectrogram.get_shape().ndims == 3:
|
||||
spectrogram = tf.expand_dims(spectrogram, -1)
|
||||
elif spectrogram.get_shape().ndims == 2:
|
||||
spectrogram = tf.expand_dims(tf.expand_dims(spectrogram, 0), -1)
|
||||
# spectrogram shape: (1, time steps, freq, 1)
|
||||
# resize to fit `sparse_image_warp`'s input shape
|
||||
spectrogram = tf.expand_dims(spectrogram, -1) # (1, time steps, freq, 1), batch_size must be 1
|
||||
|
||||
spec_shape = tf.shape(spectrogram)
|
||||
tau, freq_size = spec_shape[1], spec_shape[2] # batch_size must be 1
|
||||
time_warping_para = tf.math.minimum(tau, tf.math.floordiv(tau, 2)) # to protect short audio
|
||||
original_shape = tf.shape(spectrogram)
|
||||
tau, freq_size = original_shape[1], original_shape[2]
|
||||
|
||||
# to protect short audio
|
||||
time_warping_para = tf.math.minimum(
|
||||
time_warping_para, tf.math.floordiv(tau, 2))
|
||||
|
||||
mid_tau = tf.math.floordiv(tau, 2)
|
||||
mid_freq = tf.math.floordiv(freq_size, 2)
|
||||
left_mid_point = [0, mid_freq]
|
||||
right_mid_point = [tau, mid_freq]
|
||||
|
||||
time_warping_para = tf.math.minimum(time_warping_para, tf.math.floordiv(tau, 2))
|
||||
# dest control point must between (W, tau-W), which means the first and last W interval of the spectrogram won't be warped
|
||||
random_dest_time_point = tfv1.random_uniform(
|
||||
[], time_warping_para, tau - time_warping_para, tf.int32)
|
||||
# warp
|
||||
source_control_point_locations = tf.cast([[
|
||||
left_mid_point,
|
||||
[mid_tau, mid_freq],
|
||||
[mid_tau, mid_freq], # source control point must start from the center of the image
|
||||
right_mid_point
|
||||
]], tf.float32)
|
||||
|
||||
dest_control_point_locations = tf.cast([[
|
||||
left_mid_point,
|
||||
[random_dest_time_point, mid_freq],
|
||||
right_mid_point
|
||||
]], tf.float32)
|
||||
|
||||
warped_image, _ = sparse_image_warp(spectrogram,
|
||||
source_control_point_locations=source_control_point_locations,
|
||||
dest_control_point_locations=dest_control_point_locations,
|
||||
interpolation_order=interpolation_order,
|
||||
regularization_weight=regularization_weight,
|
||||
num_boundary_points=num_boundary_points)
|
||||
return tf.reshape(warped_image, shape=(1, -1, freq_size))
|
||||
warped_spectrogram, _ = sparse_image_warp(spectrogram,
|
||||
source_control_point_locations=source_control_point_locations,
|
||||
dest_control_point_locations=dest_control_point_locations,
|
||||
interpolation_order=interpolation_order,
|
||||
regularization_weight=regularization_weight,
|
||||
num_boundary_points=num_boundary_points)
|
||||
return tf.reshape(warped_spectrogram, shape=(1, -1, freq_size))
|
||||
|
Loading…
x
Reference in New Issue
Block a user