Merge pull request #2560 from mychiux413/sparse_warp
[SpecAugment] Refactor sparse_image_warp for dynamic shape of spectrogram
This commit is contained in:
commit
e1d14eb9a9
@ -22,6 +22,7 @@ from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
|
||||
from evaluate import evaluate
|
||||
from six.moves import zip, range
|
||||
from tensorflow.python.tools import freeze_graph, strip_unused_lib
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from util.config import Config, initialize_globals
|
||||
from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
|
||||
from util.flags import create_flags, FLAGS
|
||||
@ -428,7 +429,8 @@ def train():
|
||||
FLAGS.augmentation_spec_dropout_keeprate < 1 or
|
||||
FLAGS.augmentation_freq_and_time_masking or
|
||||
FLAGS.augmentation_pitch_and_tempo_scaling or
|
||||
FLAGS.augmentation_speed_up_std > 0):
|
||||
FLAGS.augmentation_speed_up_std > 0 or
|
||||
FLAGS.augmentation_sparse_warp):
|
||||
do_cache_dataset = False
|
||||
|
||||
# Create training and validation datasets
|
||||
@ -598,6 +600,10 @@ def train():
|
||||
_, current_step, batch_loss, problem_files, step_summary = \
|
||||
session.run([train_op, global_step, loss, non_finite_files, step_summaries_op],
|
||||
feed_dict=feed_dict)
|
||||
except tf.errors.InvalidArgumentError as err:
|
||||
if FLAGS.augmentation_sparse_warp:
|
||||
log_info("skip sparse warp error: {}".format(err))
|
||||
continue
|
||||
except tf.errors.OutOfRangeError:
|
||||
break
|
||||
|
||||
|
@ -14,7 +14,7 @@ from tensorflow.python.ops import gen_audio_ops as contrib_audio
|
||||
from util.config import Config
|
||||
from util.text import text_to_char_array
|
||||
from util.flags import FLAGS
|
||||
from util.spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up
|
||||
from util.spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up, augment_sparse_warp
|
||||
from util.audio import read_frames_from_file, vad_split, DEFAULT_FORMAT
|
||||
|
||||
|
||||
@ -42,6 +42,15 @@ def samples_to_mfccs(samples, sample_rate, train_phase=False):
|
||||
spectrogram = augment_dropout(spectrogram,
|
||||
keep_prob=FLAGS.augmentation_spec_dropout_keeprate)
|
||||
|
||||
# sparse warp must before freq/time masking
|
||||
if FLAGS.augmentation_sparse_warp:
|
||||
spectrogram = augment_sparse_warp(spectrogram,
|
||||
time_warping_para=FLAGS.augmentation_sparse_warp_time_warping_para,
|
||||
interpolation_order=FLAGS.augmentation_sparse_warp_interpolation_order,
|
||||
regularization_weight=FLAGS.augmentation_sparse_warp_regularization_weight,
|
||||
num_boundary_points=FLAGS.augmentation_sparse_warp_num_boundary_points,
|
||||
num_control_points=FLAGS.augmentation_sparse_warp_num_control_points)
|
||||
|
||||
if FLAGS.augmentation_freq_and_time_masking:
|
||||
spectrogram = augment_freq_time_mask(spectrogram,
|
||||
frequency_masking_para=FLAGS.augmentation_freq_and_time_masking_freq_mask_range,
|
||||
|
@ -29,6 +29,13 @@ def create_flags():
|
||||
|
||||
f.DEFINE_float('augmentation_spec_dropout_keeprate', 1, 'keep rate of dropout augmentation on spectrogram (if 1, no dropout will be performed on spectrogram)')
|
||||
|
||||
f.DEFINE_boolean('augmentation_sparse_warp', False, 'whether to use spectrogram sparse warp. USE OF THIS FLAG IS UNSUPPORTED, enable sparse warp will increase training time drastically, and the paper also mentioned that this is not a major factor to improve accuracy.')
|
||||
f.DEFINE_integer('augmentation_sparse_warp_num_control_points', 1, 'specify number of control points')
|
||||
f.DEFINE_integer('augmentation_sparse_warp_time_warping_para', 20, 'time_warping_para')
|
||||
f.DEFINE_integer('augmentation_sparse_warp_interpolation_order', 2, 'sparse_warp_interpolation_order')
|
||||
f.DEFINE_float('augmentation_sparse_warp_regularization_weight', 0.0, 'sparse_warp_regularization_weight')
|
||||
f.DEFINE_integer('augmentation_sparse_warp_num_boundary_points', 1, 'sparse_warp_num_boundary_points')
|
||||
|
||||
f.DEFINE_boolean('augmentation_freq_and_time_masking', False, 'whether to use frequency and time masking augmentation')
|
||||
f.DEFINE_integer('augmentation_freq_and_time_masking_freq_mask_range', 5, 'max range of masks in the frequency domain when performing freqtime-mask augmentation')
|
||||
f.DEFINE_integer('augmentation_freq_and_time_masking_number_freq_masks', 3, 'number of masks in the frequency domain when performing freqtime-mask augmentation')
|
||||
|
220
util/sparse_image_warp.py
Normal file
220
util/sparse_image_warp.py
Normal file
@ -0,0 +1,220 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Image warping using sparse flow defined at control points."""
|
||||
|
||||
# The following code is from: https://github.com/tensorflow/tensorflow/blob/v1.14.0/tensorflow/contrib/image/python/ops/sparse_image_warp.py
|
||||
# But refactored for dynamic tensor shape compatibility
|
||||
# The core idea is to replace every numpy implementation with tensorflow implementation
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
from tensorflow.compat import dimension_value
|
||||
from tensorflow.contrib.image.python.ops import dense_image_warp
|
||||
from tensorflow.contrib.image.python.ops import interpolate_spline
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
|
||||
def _to_float32(value):
|
||||
return tf.cast(value, tf.float32)
|
||||
|
||||
def _to_int32(value):
|
||||
return tf.cast(value, tf.int32)
|
||||
|
||||
def _get_grid_locations(image_height, image_width):
|
||||
"""Wrapper for np.meshgrid."""
|
||||
tfv1.assert_type(image_height, tf.int32)
|
||||
tfv1.assert_type(image_width, tf.int32)
|
||||
|
||||
y_range = tf.range(image_height)
|
||||
x_range = tf.range(image_width)
|
||||
y_grid, x_grid = tf.meshgrid(y_range, x_range, indexing='ij')
|
||||
return tf.stack((y_grid, x_grid), -1)
|
||||
|
||||
|
||||
def _expand_to_minibatch(tensor, batch_size):
|
||||
"""Tile arbitrarily-sized np_array to include new batch dimension."""
|
||||
ndim = tf.size(tf.shape(tensor))
|
||||
ones = tf.ones((ndim,), tf.int32)
|
||||
|
||||
tiles = tf.concat(([batch_size], ones), 0)
|
||||
return tf.tile(tf.expand_dims(tensor, 0), tiles)
|
||||
|
||||
|
||||
def _get_boundary_locations(image_height, image_width, num_points_per_edge):
|
||||
"""Compute evenly-spaced indices along edge of image."""
|
||||
image_height_end = _to_float32(tf.math.subtract(image_height, 1))
|
||||
image_width_end = _to_float32(tf.math.subtract(image_width, 1))
|
||||
y_range = tf.linspace(0.0, image_height_end, num_points_per_edge + 2)
|
||||
x_range = tf.linspace(0.0, image_height_end, num_points_per_edge + 2)
|
||||
ys, xs = tf.meshgrid(y_range, x_range, indexing='ij')
|
||||
is_boundary = tf.logical_or(
|
||||
tf.logical_or(tf.equal(xs, 0.0), tf.equal(xs, image_width_end)),
|
||||
tf.logical_or(tf.equal(ys, 0.0), tf.equal(ys, image_height_end)))
|
||||
return tf.stack([tf.boolean_mask(ys, is_boundary), tf.boolean_mask(xs, is_boundary)], axis=-1)
|
||||
|
||||
|
||||
def _add_zero_flow_controls_at_boundary(control_point_locations,
|
||||
control_point_flows, image_height,
|
||||
image_width, boundary_points_per_edge):
|
||||
"""Add control points for zero-flow boundary conditions.
|
||||
|
||||
Augment the set of control points with extra points on the
|
||||
boundary of the image that have zero flow.
|
||||
|
||||
Args:
|
||||
control_point_locations: input control points
|
||||
control_point_flows: their flows
|
||||
image_height: image height
|
||||
image_width: image width
|
||||
boundary_points_per_edge: number of points to add in the middle of each
|
||||
edge (not including the corners).
|
||||
The total number of points added is
|
||||
4 + 4*(boundary_points_per_edge).
|
||||
|
||||
Returns:
|
||||
merged_control_point_locations: augmented set of control point locations
|
||||
merged_control_point_flows: augmented set of control point flows
|
||||
"""
|
||||
|
||||
batch_size = dimension_value(tf.shape(control_point_locations)[0])
|
||||
|
||||
boundary_point_locations = _get_boundary_locations(image_height, image_width,
|
||||
boundary_points_per_edge)
|
||||
boundary_point_shape = tf.shape(boundary_point_locations)
|
||||
boundary_point_flows = tf.zeros([boundary_point_shape[0], 2])
|
||||
|
||||
minbatch_locations = _expand_to_minibatch(boundary_point_locations, batch_size)
|
||||
type_to_use = control_point_locations.dtype
|
||||
boundary_point_locations = tf.cast(minbatch_locations, type_to_use)
|
||||
|
||||
minbatch_flows = _expand_to_minibatch(boundary_point_flows, batch_size)
|
||||
|
||||
boundary_point_flows = tf.cast(minbatch_flows, type_to_use)
|
||||
|
||||
merged_control_point_locations = tf.concat(
|
||||
[control_point_locations, boundary_point_locations], 1)
|
||||
|
||||
merged_control_point_flows = tf.concat(
|
||||
[control_point_flows, boundary_point_flows], 1)
|
||||
|
||||
return merged_control_point_locations, merged_control_point_flows
|
||||
|
||||
|
||||
def sparse_image_warp(image,
|
||||
source_control_point_locations,
|
||||
dest_control_point_locations,
|
||||
interpolation_order=2,
|
||||
regularization_weight=0.0,
|
||||
num_boundary_points=0,
|
||||
name='sparse_image_warp'):
|
||||
"""Image warping using correspondences between sparse control points.
|
||||
|
||||
Apply a non-linear warp to the image, where the warp is specified by
|
||||
the source and destination locations of a (potentially small) number of
|
||||
control points. First, we use a polyharmonic spline
|
||||
(`tf.contrib.image.interpolate_spline`) to interpolate the displacements
|
||||
between the corresponding control points to a dense flow field.
|
||||
Then, we warp the image using this dense flow field
|
||||
(`tf.contrib.image.dense_image_warp`).
|
||||
|
||||
Let t index our control points. For regularization_weight=0, we have:
|
||||
warped_image[b, dest_control_point_locations[b, t, 0],
|
||||
dest_control_point_locations[b, t, 1], :] =
|
||||
image[b, source_control_point_locations[b, t, 0],
|
||||
source_control_point_locations[b, t, 1], :].
|
||||
|
||||
For regularization_weight > 0, this condition is met approximately, since
|
||||
regularized interpolation trades off smoothness of the interpolant vs.
|
||||
reconstruction of the interpolant at the control points.
|
||||
See `tf.contrib.image.interpolate_spline` for further documentation of the
|
||||
interpolation_order and regularization_weight arguments.
|
||||
|
||||
|
||||
Args:
|
||||
image: `[batch, height, width, channels]` float `Tensor`
|
||||
source_control_point_locations: `[batch, num_control_points, 2]` float
|
||||
`Tensor`
|
||||
dest_control_point_locations: `[batch, num_control_points, 2]` float
|
||||
`Tensor`
|
||||
interpolation_order: polynomial order used by the spline interpolation
|
||||
regularization_weight: weight on smoothness regularizer in interpolation
|
||||
num_boundary_points: How many zero-flow boundary points to include at
|
||||
each image edge.Usage:
|
||||
num_boundary_points=0: don't add zero-flow points
|
||||
num_boundary_points=1: 4 corners of the image
|
||||
num_boundary_points=2: 4 corners and one in the middle of each edge
|
||||
(8 points total)
|
||||
num_boundary_points=n: 4 corners and n-1 along each edge
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Note that image and offsets can be of type tf.half, tf.float32, or
|
||||
tf.float64, and do not necessarily have to be the same type.
|
||||
|
||||
Returns:
|
||||
warped_image: `[batch, height, width, channels]` float `Tensor` with same
|
||||
type as input image.
|
||||
flow_field: `[batch, height, width, 2]` float `Tensor` containing the dense
|
||||
flow field produced by the interpolation.
|
||||
"""
|
||||
|
||||
image = ops.convert_to_tensor(image)
|
||||
source_control_point_locations = ops.convert_to_tensor(
|
||||
source_control_point_locations)
|
||||
dest_control_point_locations = ops.convert_to_tensor(
|
||||
dest_control_point_locations)
|
||||
|
||||
control_point_flows = (
|
||||
dest_control_point_locations - source_control_point_locations)
|
||||
|
||||
clamp_boundaries = num_boundary_points > 0
|
||||
boundary_points_per_edge = num_boundary_points - 1
|
||||
|
||||
with ops.name_scope(name):
|
||||
image_shape = tf.shape(image)
|
||||
batch_size, image_height, image_width = image_shape[0], image_shape[1], image_shape[2]
|
||||
|
||||
# This generates the dense locations where the interpolant
|
||||
# will be evaluated.
|
||||
grid_locations = _get_grid_locations(image_height, image_width)
|
||||
|
||||
flattened_grid_locations = tf.reshape(grid_locations,
|
||||
[tf.multiply(image_height, image_width), 2])
|
||||
|
||||
# flattened_grid_locations = constant_op.constant(
|
||||
# _expand_to_minibatch(flattened_grid_locations, batch_size), image.dtype)
|
||||
flattened_grid_locations = _expand_to_minibatch(flattened_grid_locations, batch_size)
|
||||
flattened_grid_locations = tf.cast(flattened_grid_locations, dtype=image.dtype)
|
||||
|
||||
if clamp_boundaries:
|
||||
(dest_control_point_locations,
|
||||
control_point_flows) = _add_zero_flow_controls_at_boundary(
|
||||
dest_control_point_locations, control_point_flows, image_height,
|
||||
image_width, boundary_points_per_edge)
|
||||
|
||||
flattened_flows = interpolate_spline.interpolate_spline(
|
||||
dest_control_point_locations, control_point_flows,
|
||||
flattened_grid_locations, interpolation_order, regularization_weight)
|
||||
|
||||
dense_flows = array_ops.reshape(flattened_flows,
|
||||
[batch_size, image_height, image_width, 2])
|
||||
|
||||
warped_image = dense_image_warp.dense_image_warp(image, dense_flows)
|
||||
|
||||
return warped_image, dense_flows
|
@ -1,4 +1,6 @@
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
from util.sparse_image_warp import sparse_image_warp
|
||||
|
||||
def augment_freq_time_mask(mel_spectrogram,
|
||||
frequency_masking_para=30,
|
||||
@ -64,3 +66,61 @@ def augment_speed_up(spectrogram,
|
||||
def augment_dropout(spectrogram,
|
||||
keep_prob=0.95):
|
||||
return tf.nn.dropout(spectrogram, rate=1-keep_prob)
|
||||
|
||||
|
||||
def augment_sparse_warp(spectrogram, time_warping_para=20, interpolation_order=2, regularization_weight=0.0, num_boundary_points=1, num_control_points=1):
|
||||
"""Reference: https://arxiv.org/pdf/1904.08779.pdf
|
||||
Args:
|
||||
spectrogram: `[batch, time, frequency]` float `Tensor`
|
||||
time_warping_para: 'W' parameter in paper
|
||||
interpolation_order: used to put into `sparse_image_warp`
|
||||
regularization_weight: used to put into `sparse_image_warp`
|
||||
num_boundary_points: used to put into `sparse_image_warp`,
|
||||
default=1 means boundary points on 4 corners of the image
|
||||
num_control_points: number of control points
|
||||
Returns:
|
||||
warped_spectrogram: `[batch, time, frequency]` float `Tensor` with same
|
||||
type as input image.
|
||||
"""
|
||||
# reshape to fit `sparse_image_warp`'s input shape
|
||||
# (1, time steps, freq, 1), batch_size must be 1
|
||||
spectrogram = tf.expand_dims(spectrogram, -1)
|
||||
|
||||
original_shape = tf.shape(spectrogram)
|
||||
tau, freq_size = original_shape[1], original_shape[2]
|
||||
|
||||
# to protect short audio
|
||||
time_warping_para = tf.math.minimum(
|
||||
time_warping_para, tf.math.subtract(tf.math.floordiv(tau, 2), 1))
|
||||
|
||||
# don't choose boundary frequency
|
||||
choosen_freqs = tf.random.shuffle(
|
||||
tf.add(tf.range(freq_size - 3), 1))[0: num_control_points]
|
||||
|
||||
source_max = tau - time_warping_para
|
||||
source_min = tf.math.minimum(source_max - num_control_points, time_warping_para)
|
||||
|
||||
choosen_times = tf.random.shuffle(tf.range(source_min, limit=source_max))[0: num_control_points]
|
||||
dest_time_widths = tfv1.random_uniform([num_control_points], tf.negative(time_warping_para), time_warping_para, tf.int32)
|
||||
|
||||
sources = []
|
||||
dests = []
|
||||
for i in range(num_control_points):
|
||||
# generate source points `t` of time axis between (W, tau-W)
|
||||
rand_source_time = choosen_times[i]
|
||||
rand_dest_time = rand_source_time + dest_time_widths[i]
|
||||
|
||||
choosen_freq = choosen_freqs[i]
|
||||
sources.append([rand_source_time, choosen_freq])
|
||||
dests.append([rand_dest_time, choosen_freq])
|
||||
|
||||
source_control_point_locations = tf.cast([sources], tf.float32)
|
||||
dest_control_point_locations = tf.cast([dests], tf.float32)
|
||||
|
||||
warped_spectrogram, _ = sparse_image_warp(spectrogram,
|
||||
source_control_point_locations=source_control_point_locations,
|
||||
dest_control_point_locations=dest_control_point_locations,
|
||||
interpolation_order=interpolation_order,
|
||||
regularization_weight=regularization_weight,
|
||||
num_boundary_points=num_boundary_points)
|
||||
return tf.reshape(warped_spectrogram, shape=(1, -1, freq_size))
|
||||
|
Loading…
Reference in New Issue
Block a user