sparse image warp to dynamic tensor
This commit is contained in:
parent
271a58e464
commit
368f0d413a
@ -42,6 +42,14 @@ def samples_to_mfccs(samples, sample_rate, train_phase=False):
|
||||
spectrogram = augment_dropout(spectrogram,
|
||||
keep_prob=FLAGS.augmentation_spec_dropout_keeprate)
|
||||
|
||||
# sparse warp must before freq/time masking
|
||||
if FLAGS.augmentation_sparse_warp:
|
||||
spectrogram = augment_sparse_warp(spectrogram,
|
||||
time_warping_para=FLAGS.augmentation_sparse_warp_time_warping_para,
|
||||
interpolation_order=FLAGS.augmentation_sparse_warp_interpolation_order,
|
||||
regularization_weight=FLAGS.augmentation_sparse_warp_regularization_weight,
|
||||
num_boundary_points=FLAGS.augmentation_sparse_warp_num_boundary_points)
|
||||
|
||||
if FLAGS.augmentation_freq_and_time_masking:
|
||||
spectrogram = augment_freq_time_mask(spectrogram,
|
||||
frequency_masking_para=FLAGS.augmentation_freq_and_time_masking_freq_mask_range,
|
||||
|
@ -29,6 +29,12 @@ def create_flags():
|
||||
|
||||
f.DEFINE_float('augmentation_spec_dropout_keeprate', 1, 'keep rate of dropout augmentation on spectrogram (if 1, no dropout will be performed on spectrogram)')
|
||||
|
||||
f.DEFINE_boolean('augmentation_sparse_warp', False, 'whether to use spectrogram sparse warp')
|
||||
f.DEFINE_integer('augmentation_sparse_warp_time_warping_para', 80, 'time_warping_para')
|
||||
f.DEFINE_integer('augmentation_sparse_warp_interpolation_order', 2, 'sparse_warp_interpolation_order')
|
||||
f.DEFINE_float('augmentation_sparse_warp_regularization_weight', 0.0, 'sparse_warp_regularization_weight')
|
||||
f.DEFINE_integer('augmentation_sparse_warp_num_boundary_points', 1, 'sparse_warp_num_boundary_points')
|
||||
|
||||
f.DEFINE_boolean('augmentation_freq_and_time_masking', False, 'whether to use frequency and time masking augmentation')
|
||||
f.DEFINE_integer('augmentation_freq_and_time_masking_freq_mask_range', 5, 'max range of masks in the frequency domain when performing freqtime-mask augmentation')
|
||||
f.DEFINE_integer('augmentation_freq_and_time_masking_number_freq_masks', 3, 'number of masks in the frequency domain when performing freqtime-mask augmentation')
|
||||
|
@ -17,10 +17,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.compat import v1 as tfv1
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
from tensorflow.compat import dimension_value
|
||||
from tensorflow.contrib.image.python.ops import dense_image_warp
|
||||
from tensorflow.contrib.image.python.ops import interpolate_spline
|
||||
|
||||
@ -29,31 +28,43 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
|
||||
def _to_float32(value):
|
||||
return tf.cast(value, tf.float32)
|
||||
|
||||
def _to_int32(value):
|
||||
return tf.cast(value, tf.int32)
|
||||
|
||||
def _get_grid_locations(image_height, image_width):
|
||||
"""Wrapper for np.meshgrid."""
|
||||
tfv1.assert_type(image_height, tf.int32)
|
||||
tfv1.assert_type(image_width, tf.int32)
|
||||
|
||||
y_range = np.linspace(0, image_height - 1, image_height)
|
||||
x_range = np.linspace(0, image_width - 1, image_width)
|
||||
y_grid, x_grid = np.meshgrid(y_range, x_range, indexing='ij')
|
||||
return np.stack((y_grid, x_grid), -1)
|
||||
y_range = tf.range(image_height)
|
||||
x_range = tf.range(image_width)
|
||||
y_grid, x_grid = tf.meshgrid(y_range, x_range, indexing='ij')
|
||||
return tf.stack((y_grid, x_grid), -1)
|
||||
|
||||
|
||||
def _expand_to_minibatch(np_array, batch_size):
|
||||
def _expand_to_minibatch(tensor, batch_size):
|
||||
"""Tile arbitrarily-sized np_array to include new batch dimension."""
|
||||
tiles = [batch_size] + [1] * np_array.ndim
|
||||
return np.tile(np.expand_dims(np_array, 0), tiles)
|
||||
ndim = tf.size(tf.shape(tensor))
|
||||
ones = tf.ones((ndim,), tf.int32)
|
||||
|
||||
tiles = tf.concat(([batch_size], ones), 0)
|
||||
return tf.tile(tf.expand_dims(tensor, 0), tiles)
|
||||
|
||||
|
||||
def _get_boundary_locations(image_height, image_width, num_points_per_edge):
|
||||
"""Compute evenly-spaced indices along edge of image."""
|
||||
y_range = np.linspace(0, image_height - 1, num_points_per_edge + 2)
|
||||
x_range = np.linspace(0, image_width - 1, num_points_per_edge + 2)
|
||||
ys, xs = np.meshgrid(y_range, x_range, indexing='ij')
|
||||
is_boundary = np.logical_or(
|
||||
np.logical_or(xs == 0, xs == image_width - 1),
|
||||
np.logical_or(ys == 0, ys == image_height - 1))
|
||||
return np.stack([ys[is_boundary], xs[is_boundary]], axis=-1)
|
||||
image_height_end = _to_float32(tf.math.subtract(image_height, 1))
|
||||
image_width_end = _to_float32(tf.math.subtract(image_width, 1))
|
||||
y_range = tf.linspace(0.0, image_height_end, num_points_per_edge + 2)
|
||||
x_range = tf.linspace(0.0, image_height_end, num_points_per_edge + 2)
|
||||
ys, xs = tf.meshgrid(y_range, x_range, indexing='ij')
|
||||
is_boundary = tf.logical_or(
|
||||
tf.logical_or(tf.equal(xs, 0.0), tf.equal(xs, image_width_end)),
|
||||
tf.logical_or(tf.equal(ys, 0.0), tf.equal(ys, image_height_end)))
|
||||
return tf.stack([tf.boolean_mask(ys, is_boundary), tf.boolean_mask(xs, is_boundary)], axis=-1)
|
||||
|
||||
|
||||
def _add_zero_flow_controls_at_boundary(control_point_locations,
|
||||
@ -79,25 +90,25 @@ def _add_zero_flow_controls_at_boundary(control_point_locations,
|
||||
merged_control_point_flows: augmented set of control point flows
|
||||
"""
|
||||
|
||||
batch_size = tensor_shape.dimension_value(control_point_locations.shape[0])
|
||||
batch_size = dimension_value(tf.shape(control_point_locations)[0])
|
||||
|
||||
boundary_point_locations = _get_boundary_locations(image_height, image_width,
|
||||
boundary_points_per_edge)
|
||||
boundary_point_shape = tf.shape(boundary_point_locations)
|
||||
boundary_point_flows = tf.zeros([boundary_point_shape[0], 2])
|
||||
|
||||
boundary_point_flows = np.zeros([boundary_point_locations.shape[0], 2])
|
||||
|
||||
minbatch_locations = _expand_to_minibatch(boundary_point_locations, batch_size)
|
||||
type_to_use = control_point_locations.dtype
|
||||
boundary_point_locations = constant_op.constant(
|
||||
_expand_to_minibatch(boundary_point_locations, batch_size),
|
||||
dtype=type_to_use)
|
||||
boundary_point_locations = tf.cast(minbatch_locations, type_to_use)
|
||||
|
||||
boundary_point_flows = constant_op.constant(
|
||||
_expand_to_minibatch(boundary_point_flows, batch_size), dtype=type_to_use)
|
||||
minbatch_flows = _expand_to_minibatch(boundary_point_flows, batch_size)
|
||||
|
||||
merged_control_point_locations = array_ops.concat(
|
||||
boundary_point_flows = tf.cast(minbatch_flows, type_to_use)
|
||||
|
||||
merged_control_point_locations = tf.concat(
|
||||
[control_point_locations, boundary_point_locations], 1)
|
||||
|
||||
merged_control_point_flows = array_ops.concat(
|
||||
merged_control_point_flows = tf.concat(
|
||||
[control_point_flows, boundary_point_flows], 1)
|
||||
|
||||
return merged_control_point_locations, merged_control_point_flows
|
||||
@ -173,17 +184,20 @@ def sparse_image_warp(image,
|
||||
boundary_points_per_edge = num_boundary_points - 1
|
||||
|
||||
with ops.name_scope(name):
|
||||
batch_size, image_height, image_width, _ = image.get_shape().as_list()
|
||||
image_shape = tf.shape(image)
|
||||
batch_size, image_height, image_width = image_shape[0], image_shape[1], image_shape[2]
|
||||
|
||||
# This generates the dense locations where the interpolant
|
||||
# will be evaluated.
|
||||
grid_locations = _get_grid_locations(image_height, image_width)
|
||||
|
||||
flattened_grid_locations = np.reshape(grid_locations,
|
||||
[image_height * image_width, 2])
|
||||
flattened_grid_locations = tf.reshape(grid_locations,
|
||||
[tf.multiply(image_height, image_width), 2])
|
||||
|
||||
flattened_grid_locations = constant_op.constant(
|
||||
_expand_to_minibatch(flattened_grid_locations, batch_size), image.dtype)
|
||||
# flattened_grid_locations = constant_op.constant(
|
||||
# _expand_to_minibatch(flattened_grid_locations, batch_size), image.dtype)
|
||||
flattened_grid_locations = _expand_to_minibatch(flattened_grid_locations, batch_size)
|
||||
flattened_grid_locations = tf.cast(flattened_grid_locations, dtype=image.dtype)
|
||||
|
||||
if clamp_boundaries:
|
||||
(dest_control_point_locations,
|
||||
|
@ -1,4 +1,5 @@
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
from util.sparse_image_warp import sparse_image_warp
|
||||
|
||||
def augment_freq_time_mask(mel_spectrogram,
|
||||
@ -83,54 +84,42 @@ def augment_dropout(spectrogram,
|
||||
return tf.nn.dropout(spectrogram, rate=1-keep_prob)
|
||||
|
||||
|
||||
def augment_sparse_warp(spectrogram: tf.Tensor, time_warping_para=80):
|
||||
"""Spec augmentation Calculation Function.
|
||||
'SpecAugment' have 3 steps for audio data augmentation.
|
||||
first step is time warping using Tensorflow's image_sparse_warp function.
|
||||
Second step is frequency masking, last step is time masking.
|
||||
# Arguments:
|
||||
mel_spectrogram(numpy array): audio file path of you want to warping and masking.
|
||||
time_warping_para(float): Augmentation parameter, "time warp parameter W".
|
||||
If none, default = 80 for LibriSpeech.
|
||||
# Returns
|
||||
mel_spectrogram(numpy array): warped and masked mel spectrogram.
|
||||
"""
|
||||
if spectrogram.get_shape().ndims == 2:
|
||||
spectrogram = tf.reshape(spectrogram, shape=[1, -1, spectrogram.shape[1], 1])
|
||||
elif spectrogram.get_shape().ndims == 3:
|
||||
spectrogram = tf.reshape(spectrogram, shape=[spectrogram.shape[0], -1, spectrogram.shape[2], 1])
|
||||
assert spectrogram.get_shape().ndims == 4
|
||||
fbank_size = tf.shape(spectrogram)
|
||||
n, v = fbank_size[1], fbank_size[2]
|
||||
def augment_sparse_warp(spectrogram: tf.Tensor, time_warping_para=80, interpolation_order=2, regularization_weight=0.0, num_boundary_points=1):
|
||||
|
||||
# Step 1 : Time warping
|
||||
# Image warping control point setting.
|
||||
# Source
|
||||
# radnom point along the time axis
|
||||
pt = tf.random.uniform([], time_warping_para, n -
|
||||
time_warping_para, tf.int32)
|
||||
src_ctr_pt_freq = tf.range(tf.floordiv(v, 2)) # control points on freq-axis
|
||||
# control points on time-axis
|
||||
src_ctr_pt_time = tf.ones_like(src_ctr_pt_freq) * pt
|
||||
src_ctr_pts = tf.stack((src_ctr_pt_time, src_ctr_pt_freq), -1)
|
||||
src_ctr_pts = tf.cast(src_ctr_pts, dtype=tf.float32)
|
||||
if spectrogram.get_shape().ndims == 3:
|
||||
spectrogram = tf.expand_dims(spectrogram, -1)
|
||||
elif spectrogram.get_shape().ndims == 2:
|
||||
spectrogram = tf.expand_dims(tf.expand_dims(spectrogram, 0), -1)
|
||||
# spectrogram shape: (1, time steps, freq, 1)
|
||||
|
||||
# Destination
|
||||
w = tf.random.uniform([], -time_warping_para,
|
||||
time_warping_para, tf.int32) # distance
|
||||
dest_ctr_pt_freq = src_ctr_pt_freq
|
||||
dest_ctr_pt_time = src_ctr_pt_time + w
|
||||
dest_ctr_pts = tf.stack((dest_ctr_pt_time, dest_ctr_pt_freq), -1)
|
||||
dest_ctr_pts = tf.cast(dest_ctr_pts, dtype=tf.float32)
|
||||
spec_shape = tf.shape(spectrogram)
|
||||
tau, freq_size = spec_shape[1], spec_shape[2] # batch_size must be 1
|
||||
time_warping_para = tf.math.minimum(tau, tf.math.floordiv(tau, 2)) # to protect short audio
|
||||
|
||||
mid_tau = tf.math.floordiv(tau, 2)
|
||||
mid_freq = tf.math.floordiv(freq_size, 2)
|
||||
left_mid_point = [0, mid_freq]
|
||||
right_mid_point = [tau, mid_freq]
|
||||
|
||||
time_warping_para = tf.math.minimum(time_warping_para, tf.math.floordiv(tau, 2))
|
||||
random_dest_time_point = tfv1.random_uniform(
|
||||
[], time_warping_para, tau - time_warping_para, tf.int32)
|
||||
# warp
|
||||
source_control_point_locations = tf.expand_dims(
|
||||
src_ctr_pts, 0) # (1, v//2, 2)
|
||||
dest_control_point_locations = tf.expand_dims(
|
||||
dest_ctr_pts, 0) # (1, v//2, 2)
|
||||
source_control_point_locations = tf.cast([[
|
||||
left_mid_point,
|
||||
[mid_tau, mid_freq],
|
||||
right_mid_point
|
||||
]], tf.float32)
|
||||
dest_control_point_locations = tf.cast([[
|
||||
left_mid_point,
|
||||
[random_dest_time_point, mid_freq],
|
||||
right_mid_point
|
||||
]], tf.float32)
|
||||
|
||||
print(spectrogram.shape)
|
||||
warped_image, _ = sparse_image_warp(spectrogram,
|
||||
source_control_point_locations,
|
||||
dest_control_point_locations)
|
||||
return warped_image
|
||||
source_control_point_locations=source_control_point_locations,
|
||||
dest_control_point_locations=dest_control_point_locations,
|
||||
interpolation_order=interpolation_order,
|
||||
regularization_weight=regularization_weight,
|
||||
num_boundary_points=num_boundary_points)
|
||||
return tf.reshape(warped_image, shape=(1, -1, freq_size))
|
||||
|
Loading…
x
Reference in New Issue
Block a user