sparse image warp to dynamic tensor

This commit is contained in:
Yi-Hua Chiu 2019-12-03 15:19:17 +08:00
parent 271a58e464
commit 368f0d413a
4 changed files with 94 additions and 77 deletions

View File

@ -42,6 +42,14 @@ def samples_to_mfccs(samples, sample_rate, train_phase=False):
spectrogram = augment_dropout(spectrogram,
keep_prob=FLAGS.augmentation_spec_dropout_keeprate)
# sparse warp must before freq/time masking
if FLAGS.augmentation_sparse_warp:
spectrogram = augment_sparse_warp(spectrogram,
time_warping_para=FLAGS.augmentation_sparse_warp_time_warping_para,
interpolation_order=FLAGS.augmentation_sparse_warp_interpolation_order,
regularization_weight=FLAGS.augmentation_sparse_warp_regularization_weight,
num_boundary_points=FLAGS.augmentation_sparse_warp_num_boundary_points)
if FLAGS.augmentation_freq_and_time_masking:
spectrogram = augment_freq_time_mask(spectrogram,
frequency_masking_para=FLAGS.augmentation_freq_and_time_masking_freq_mask_range,

View File

@ -29,6 +29,12 @@ def create_flags():
f.DEFINE_float('augmentation_spec_dropout_keeprate', 1, 'keep rate of dropout augmentation on spectrogram (if 1, no dropout will be performed on spectrogram)')
f.DEFINE_boolean('augmentation_sparse_warp', False, 'whether to use spectrogram sparse warp')
f.DEFINE_integer('augmentation_sparse_warp_time_warping_para', 80, 'time_warping_para')
f.DEFINE_integer('augmentation_sparse_warp_interpolation_order', 2, 'sparse_warp_interpolation_order')
f.DEFINE_float('augmentation_sparse_warp_regularization_weight', 0.0, 'sparse_warp_regularization_weight')
f.DEFINE_integer('augmentation_sparse_warp_num_boundary_points', 1, 'sparse_warp_num_boundary_points')
f.DEFINE_boolean('augmentation_freq_and_time_masking', False, 'whether to use frequency and time masking augmentation')
f.DEFINE_integer('augmentation_freq_and_time_masking_freq_mask_range', 5, 'max range of masks in the frequency domain when performing freqtime-mask augmentation')
f.DEFINE_integer('augmentation_freq_and_time_masking_number_freq_masks', 3, 'number of masks in the frequency domain when performing freqtime-mask augmentation')

View File

@ -17,10 +17,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.compat import v1 as tfv1
import tensorflow.compat.v1 as tfv1
from tensorflow.compat import dimension_value
from tensorflow.contrib.image.python.ops import dense_image_warp
from tensorflow.contrib.image.python.ops import interpolate_spline
@ -29,31 +28,43 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
def _to_float32(value):
return tf.cast(value, tf.float32)
def _to_int32(value):
return tf.cast(value, tf.int32)
def _get_grid_locations(image_height, image_width):
"""Wrapper for np.meshgrid."""
tfv1.assert_type(image_height, tf.int32)
tfv1.assert_type(image_width, tf.int32)
y_range = np.linspace(0, image_height - 1, image_height)
x_range = np.linspace(0, image_width - 1, image_width)
y_grid, x_grid = np.meshgrid(y_range, x_range, indexing='ij')
return np.stack((y_grid, x_grid), -1)
y_range = tf.range(image_height)
x_range = tf.range(image_width)
y_grid, x_grid = tf.meshgrid(y_range, x_range, indexing='ij')
return tf.stack((y_grid, x_grid), -1)
def _expand_to_minibatch(np_array, batch_size):
def _expand_to_minibatch(tensor, batch_size):
"""Tile arbitrarily-sized np_array to include new batch dimension."""
tiles = [batch_size] + [1] * np_array.ndim
return np.tile(np.expand_dims(np_array, 0), tiles)
ndim = tf.size(tf.shape(tensor))
ones = tf.ones((ndim,), tf.int32)
tiles = tf.concat(([batch_size], ones), 0)
return tf.tile(tf.expand_dims(tensor, 0), tiles)
def _get_boundary_locations(image_height, image_width, num_points_per_edge):
"""Compute evenly-spaced indices along edge of image."""
y_range = np.linspace(0, image_height - 1, num_points_per_edge + 2)
x_range = np.linspace(0, image_width - 1, num_points_per_edge + 2)
ys, xs = np.meshgrid(y_range, x_range, indexing='ij')
is_boundary = np.logical_or(
np.logical_or(xs == 0, xs == image_width - 1),
np.logical_or(ys == 0, ys == image_height - 1))
return np.stack([ys[is_boundary], xs[is_boundary]], axis=-1)
image_height_end = _to_float32(tf.math.subtract(image_height, 1))
image_width_end = _to_float32(tf.math.subtract(image_width, 1))
y_range = tf.linspace(0.0, image_height_end, num_points_per_edge + 2)
x_range = tf.linspace(0.0, image_height_end, num_points_per_edge + 2)
ys, xs = tf.meshgrid(y_range, x_range, indexing='ij')
is_boundary = tf.logical_or(
tf.logical_or(tf.equal(xs, 0.0), tf.equal(xs, image_width_end)),
tf.logical_or(tf.equal(ys, 0.0), tf.equal(ys, image_height_end)))
return tf.stack([tf.boolean_mask(ys, is_boundary), tf.boolean_mask(xs, is_boundary)], axis=-1)
def _add_zero_flow_controls_at_boundary(control_point_locations,
@ -79,25 +90,25 @@ def _add_zero_flow_controls_at_boundary(control_point_locations,
merged_control_point_flows: augmented set of control point flows
"""
batch_size = tensor_shape.dimension_value(control_point_locations.shape[0])
batch_size = dimension_value(tf.shape(control_point_locations)[0])
boundary_point_locations = _get_boundary_locations(image_height, image_width,
boundary_points_per_edge)
boundary_point_shape = tf.shape(boundary_point_locations)
boundary_point_flows = tf.zeros([boundary_point_shape[0], 2])
boundary_point_flows = np.zeros([boundary_point_locations.shape[0], 2])
minbatch_locations = _expand_to_minibatch(boundary_point_locations, batch_size)
type_to_use = control_point_locations.dtype
boundary_point_locations = constant_op.constant(
_expand_to_minibatch(boundary_point_locations, batch_size),
dtype=type_to_use)
boundary_point_locations = tf.cast(minbatch_locations, type_to_use)
boundary_point_flows = constant_op.constant(
_expand_to_minibatch(boundary_point_flows, batch_size), dtype=type_to_use)
minbatch_flows = _expand_to_minibatch(boundary_point_flows, batch_size)
merged_control_point_locations = array_ops.concat(
boundary_point_flows = tf.cast(minbatch_flows, type_to_use)
merged_control_point_locations = tf.concat(
[control_point_locations, boundary_point_locations], 1)
merged_control_point_flows = array_ops.concat(
merged_control_point_flows = tf.concat(
[control_point_flows, boundary_point_flows], 1)
return merged_control_point_locations, merged_control_point_flows
@ -173,17 +184,20 @@ def sparse_image_warp(image,
boundary_points_per_edge = num_boundary_points - 1
with ops.name_scope(name):
batch_size, image_height, image_width, _ = image.get_shape().as_list()
image_shape = tf.shape(image)
batch_size, image_height, image_width = image_shape[0], image_shape[1], image_shape[2]
# This generates the dense locations where the interpolant
# will be evaluated.
grid_locations = _get_grid_locations(image_height, image_width)
flattened_grid_locations = np.reshape(grid_locations,
[image_height * image_width, 2])
flattened_grid_locations = tf.reshape(grid_locations,
[tf.multiply(image_height, image_width), 2])
flattened_grid_locations = constant_op.constant(
_expand_to_minibatch(flattened_grid_locations, batch_size), image.dtype)
# flattened_grid_locations = constant_op.constant(
# _expand_to_minibatch(flattened_grid_locations, batch_size), image.dtype)
flattened_grid_locations = _expand_to_minibatch(flattened_grid_locations, batch_size)
flattened_grid_locations = tf.cast(flattened_grid_locations, dtype=image.dtype)
if clamp_boundaries:
(dest_control_point_locations,

View File

@ -1,4 +1,5 @@
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
from util.sparse_image_warp import sparse_image_warp
def augment_freq_time_mask(mel_spectrogram,
@ -83,54 +84,42 @@ def augment_dropout(spectrogram,
return tf.nn.dropout(spectrogram, rate=1-keep_prob)
def augment_sparse_warp(spectrogram: tf.Tensor, time_warping_para=80):
"""Spec augmentation Calculation Function.
'SpecAugment' have 3 steps for audio data augmentation.
first step is time warping using Tensorflow's image_sparse_warp function.
Second step is frequency masking, last step is time masking.
# Arguments:
mel_spectrogram(numpy array): audio file path of you want to warping and masking.
time_warping_para(float): Augmentation parameter, "time warp parameter W".
If none, default = 80 for LibriSpeech.
# Returns
mel_spectrogram(numpy array): warped and masked mel spectrogram.
"""
if spectrogram.get_shape().ndims == 2:
spectrogram = tf.reshape(spectrogram, shape=[1, -1, spectrogram.shape[1], 1])
elif spectrogram.get_shape().ndims == 3:
spectrogram = tf.reshape(spectrogram, shape=[spectrogram.shape[0], -1, spectrogram.shape[2], 1])
assert spectrogram.get_shape().ndims == 4
fbank_size = tf.shape(spectrogram)
n, v = fbank_size[1], fbank_size[2]
def augment_sparse_warp(spectrogram: tf.Tensor, time_warping_para=80, interpolation_order=2, regularization_weight=0.0, num_boundary_points=1):
# Step 1 : Time warping
# Image warping control point setting.
# Source
# radnom point along the time axis
pt = tf.random.uniform([], time_warping_para, n -
time_warping_para, tf.int32)
src_ctr_pt_freq = tf.range(tf.floordiv(v, 2)) # control points on freq-axis
# control points on time-axis
src_ctr_pt_time = tf.ones_like(src_ctr_pt_freq) * pt
src_ctr_pts = tf.stack((src_ctr_pt_time, src_ctr_pt_freq), -1)
src_ctr_pts = tf.cast(src_ctr_pts, dtype=tf.float32)
if spectrogram.get_shape().ndims == 3:
spectrogram = tf.expand_dims(spectrogram, -1)
elif spectrogram.get_shape().ndims == 2:
spectrogram = tf.expand_dims(tf.expand_dims(spectrogram, 0), -1)
# spectrogram shape: (1, time steps, freq, 1)
# Destination
w = tf.random.uniform([], -time_warping_para,
time_warping_para, tf.int32) # distance
dest_ctr_pt_freq = src_ctr_pt_freq
dest_ctr_pt_time = src_ctr_pt_time + w
dest_ctr_pts = tf.stack((dest_ctr_pt_time, dest_ctr_pt_freq), -1)
dest_ctr_pts = tf.cast(dest_ctr_pts, dtype=tf.float32)
spec_shape = tf.shape(spectrogram)
tau, freq_size = spec_shape[1], spec_shape[2] # batch_size must be 1
time_warping_para = tf.math.minimum(tau, tf.math.floordiv(tau, 2)) # to protect short audio
mid_tau = tf.math.floordiv(tau, 2)
mid_freq = tf.math.floordiv(freq_size, 2)
left_mid_point = [0, mid_freq]
right_mid_point = [tau, mid_freq]
time_warping_para = tf.math.minimum(time_warping_para, tf.math.floordiv(tau, 2))
random_dest_time_point = tfv1.random_uniform(
[], time_warping_para, tau - time_warping_para, tf.int32)
# warp
source_control_point_locations = tf.expand_dims(
src_ctr_pts, 0) # (1, v//2, 2)
dest_control_point_locations = tf.expand_dims(
dest_ctr_pts, 0) # (1, v//2, 2)
source_control_point_locations = tf.cast([[
left_mid_point,
[mid_tau, mid_freq],
right_mid_point
]], tf.float32)
dest_control_point_locations = tf.cast([[
left_mid_point,
[random_dest_time_point, mid_freq],
right_mid_point
]], tf.float32)
print(spectrogram.shape)
warped_image, _ = sparse_image_warp(spectrogram,
source_control_point_locations,
dest_control_point_locations)
return warped_image
source_control_point_locations=source_control_point_locations,
dest_control_point_locations=dest_control_point_locations,
interpolation_order=interpolation_order,
regularization_weight=regularization_weight,
num_boundary_points=num_boundary_points)
return tf.reshape(warped_image, shape=(1, -1, freq_size))