clean code

This commit is contained in:
Yi-Hua Chiu 2019-12-03 16:19:32 +08:00
parent 368f0d413a
commit 450483c30b
3 changed files with 36 additions and 23 deletions

View File

@ -66,8 +66,6 @@ def samples_to_mfccs(samples, sample_rate, train_phase=False):
if FLAGS.augmentation_speed_up_std > 0:
spectrogram = augment_speed_up(spectrogram, speed_std=FLAGS.augmentation_speed_up_std)
# spectrogram = augment_sparse_warp(spectrogram)
mfccs = contrib_audio.mfcc(spectrogram, sample_rate, dct_coefficient_count=Config.n_input)
mfccs = tf.reshape(mfccs, [-1, Config.n_input])

View File

@ -13,6 +13,11 @@
# limitations under the License.
# ==============================================================================
"""Image warping using sparse flow defined at control points."""
# The following code is from: https://github.com/tensorflow/tensorflow/blob/v1.14.0/tensorflow/contrib/image/python/ops/sparse_image_warp.py
# But refactored for dynamic tensor shape compatibility
# The core idea is to replace every numpy implementation with tensorflow implementation
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@ -23,9 +28,7 @@ from tensorflow.compat import dimension_value
from tensorflow.contrib.image.python.ops import dense_image_warp
from tensorflow.contrib.image.python.ops import interpolate_spline
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
def _to_float32(value):

View File

@ -84,42 +84,54 @@ def augment_dropout(spectrogram,
return tf.nn.dropout(spectrogram, rate=1-keep_prob)
def augment_sparse_warp(spectrogram: tf.Tensor, time_warping_para=80, interpolation_order=2, regularization_weight=0.0, num_boundary_points=1):
def augment_sparse_warp(spectrogram, time_warping_para=80, interpolation_order=2, regularization_weight=0.0, num_boundary_points=1):
"""Reference: https://arxiv.org/pdf/1904.08779.pdf
Args:
spectrogram: `[batch, time, frequency]` float `Tensor`
time_warping_para: 'W' parameter in paper
interpolation_order: used to put into `sparse_image_warp`
regularization_weight: used to put into `sparse_image_warp`
num_boundary_points: used to put into `sparse_image_warp`,
default=1 means boundary points on 4 corners of the image
Returns:
warped_spectrogram: `[batch, time, frequency]` float `Tensor` with same
type as input image.
"""
if spectrogram.get_shape().ndims == 3:
spectrogram = tf.expand_dims(spectrogram, -1)
elif spectrogram.get_shape().ndims == 2:
spectrogram = tf.expand_dims(tf.expand_dims(spectrogram, 0), -1)
# spectrogram shape: (1, time steps, freq, 1)
# resize to fit `sparse_image_warp`'s input shape
spectrogram = tf.expand_dims(spectrogram, -1) # (1, time steps, freq, 1), batch_size must be 1
spec_shape = tf.shape(spectrogram)
tau, freq_size = spec_shape[1], spec_shape[2] # batch_size must be 1
time_warping_para = tf.math.minimum(tau, tf.math.floordiv(tau, 2)) # to protect short audio
original_shape = tf.shape(spectrogram)
tau, freq_size = original_shape[1], original_shape[2]
# to protect short audio
time_warping_para = tf.math.minimum(
time_warping_para, tf.math.floordiv(tau, 2))
mid_tau = tf.math.floordiv(tau, 2)
mid_freq = tf.math.floordiv(freq_size, 2)
left_mid_point = [0, mid_freq]
right_mid_point = [tau, mid_freq]
time_warping_para = tf.math.minimum(time_warping_para, tf.math.floordiv(tau, 2))
# dest control point must between (W, tau-W), which means the first and last W interval of the spectrogram won't be warped
random_dest_time_point = tfv1.random_uniform(
[], time_warping_para, tau - time_warping_para, tf.int32)
# warp
source_control_point_locations = tf.cast([[
left_mid_point,
[mid_tau, mid_freq],
[mid_tau, mid_freq], # source control point must start from the center of the image
right_mid_point
]], tf.float32)
dest_control_point_locations = tf.cast([[
left_mid_point,
[random_dest_time_point, mid_freq],
right_mid_point
]], tf.float32)
warped_image, _ = sparse_image_warp(spectrogram,
source_control_point_locations=source_control_point_locations,
dest_control_point_locations=dest_control_point_locations,
interpolation_order=interpolation_order,
regularization_weight=regularization_weight,
num_boundary_points=num_boundary_points)
return tf.reshape(warped_image, shape=(1, -1, freq_size))
warped_spectrogram, _ = sparse_image_warp(spectrogram,
source_control_point_locations=source_control_point_locations,
dest_control_point_locations=dest_control_point_locations,
interpolation_order=interpolation_order,
regularization_weight=regularization_weight,
num_boundary_points=num_boundary_points)
return tf.reshape(warped_spectrogram, shape=(1, -1, freq_size))