Merge pull request #2560 from mychiux413/sparse_warp

[SpecAugment] Refactor sparse_image_warp for dynamic shape of spectrogram
This commit is contained in:
Reuben Morais 2020-01-03 13:25:31 +01:00 committed by GitHub
commit e1d14eb9a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 304 additions and 2 deletions

View File

@ -22,6 +22,7 @@ from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
from evaluate import evaluate
from six.moves import zip, range
from tensorflow.python.tools import freeze_graph, strip_unused_lib
from tensorflow.python.framework import errors_impl
from util.config import Config, initialize_globals
from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
from util.flags import create_flags, FLAGS
@ -428,7 +429,8 @@ def train():
FLAGS.augmentation_spec_dropout_keeprate < 1 or
FLAGS.augmentation_freq_and_time_masking or
FLAGS.augmentation_pitch_and_tempo_scaling or
FLAGS.augmentation_speed_up_std > 0):
FLAGS.augmentation_speed_up_std > 0 or
FLAGS.augmentation_sparse_warp):
do_cache_dataset = False
# Create training and validation datasets
@ -598,6 +600,10 @@ def train():
_, current_step, batch_loss, problem_files, step_summary = \
session.run([train_op, global_step, loss, non_finite_files, step_summaries_op],
feed_dict=feed_dict)
except tf.errors.InvalidArgumentError as err:
if FLAGS.augmentation_sparse_warp:
log_info("skip sparse warp error: {}".format(err))
continue
except tf.errors.OutOfRangeError:
break

View File

@ -14,7 +14,7 @@ from tensorflow.python.ops import gen_audio_ops as contrib_audio
from util.config import Config
from util.text import text_to_char_array
from util.flags import FLAGS
from util.spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up
from util.spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up, augment_sparse_warp
from util.audio import read_frames_from_file, vad_split, DEFAULT_FORMAT
@ -42,6 +42,15 @@ def samples_to_mfccs(samples, sample_rate, train_phase=False):
spectrogram = augment_dropout(spectrogram,
keep_prob=FLAGS.augmentation_spec_dropout_keeprate)
# sparse warp must before freq/time masking
if FLAGS.augmentation_sparse_warp:
spectrogram = augment_sparse_warp(spectrogram,
time_warping_para=FLAGS.augmentation_sparse_warp_time_warping_para,
interpolation_order=FLAGS.augmentation_sparse_warp_interpolation_order,
regularization_weight=FLAGS.augmentation_sparse_warp_regularization_weight,
num_boundary_points=FLAGS.augmentation_sparse_warp_num_boundary_points,
num_control_points=FLAGS.augmentation_sparse_warp_num_control_points)
if FLAGS.augmentation_freq_and_time_masking:
spectrogram = augment_freq_time_mask(spectrogram,
frequency_masking_para=FLAGS.augmentation_freq_and_time_masking_freq_mask_range,

View File

@ -29,6 +29,13 @@ def create_flags():
f.DEFINE_float('augmentation_spec_dropout_keeprate', 1, 'keep rate of dropout augmentation on spectrogram (if 1, no dropout will be performed on spectrogram)')
f.DEFINE_boolean('augmentation_sparse_warp', False, 'whether to use spectrogram sparse warp. USE OF THIS FLAG IS UNSUPPORTED, enable sparse warp will increase training time drastically, and the paper also mentioned that this is not a major factor to improve accuracy.')
f.DEFINE_integer('augmentation_sparse_warp_num_control_points', 1, 'specify number of control points')
f.DEFINE_integer('augmentation_sparse_warp_time_warping_para', 20, 'time_warping_para')
f.DEFINE_integer('augmentation_sparse_warp_interpolation_order', 2, 'sparse_warp_interpolation_order')
f.DEFINE_float('augmentation_sparse_warp_regularization_weight', 0.0, 'sparse_warp_regularization_weight')
f.DEFINE_integer('augmentation_sparse_warp_num_boundary_points', 1, 'sparse_warp_num_boundary_points')
f.DEFINE_boolean('augmentation_freq_and_time_masking', False, 'whether to use frequency and time masking augmentation')
f.DEFINE_integer('augmentation_freq_and_time_masking_freq_mask_range', 5, 'max range of masks in the frequency domain when performing freqtime-mask augmentation')
f.DEFINE_integer('augmentation_freq_and_time_masking_number_freq_masks', 3, 'number of masks in the frequency domain when performing freqtime-mask augmentation')

220
util/sparse_image_warp.py Normal file
View File

@ -0,0 +1,220 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Image warping using sparse flow defined at control points."""
# The following code is from: https://github.com/tensorflow/tensorflow/blob/v1.14.0/tensorflow/contrib/image/python/ops/sparse_image_warp.py
# But refactored for dynamic tensor shape compatibility
# The core idea is to replace every numpy implementation with tensorflow implementation
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
from tensorflow.compat import dimension_value
from tensorflow.contrib.image.python.ops import dense_image_warp
from tensorflow.contrib.image.python.ops import interpolate_spline
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
def _to_float32(value):
return tf.cast(value, tf.float32)
def _to_int32(value):
return tf.cast(value, tf.int32)
def _get_grid_locations(image_height, image_width):
"""Wrapper for np.meshgrid."""
tfv1.assert_type(image_height, tf.int32)
tfv1.assert_type(image_width, tf.int32)
y_range = tf.range(image_height)
x_range = tf.range(image_width)
y_grid, x_grid = tf.meshgrid(y_range, x_range, indexing='ij')
return tf.stack((y_grid, x_grid), -1)
def _expand_to_minibatch(tensor, batch_size):
"""Tile arbitrarily-sized np_array to include new batch dimension."""
ndim = tf.size(tf.shape(tensor))
ones = tf.ones((ndim,), tf.int32)
tiles = tf.concat(([batch_size], ones), 0)
return tf.tile(tf.expand_dims(tensor, 0), tiles)
def _get_boundary_locations(image_height, image_width, num_points_per_edge):
"""Compute evenly-spaced indices along edge of image."""
image_height_end = _to_float32(tf.math.subtract(image_height, 1))
image_width_end = _to_float32(tf.math.subtract(image_width, 1))
y_range = tf.linspace(0.0, image_height_end, num_points_per_edge + 2)
x_range = tf.linspace(0.0, image_height_end, num_points_per_edge + 2)
ys, xs = tf.meshgrid(y_range, x_range, indexing='ij')
is_boundary = tf.logical_or(
tf.logical_or(tf.equal(xs, 0.0), tf.equal(xs, image_width_end)),
tf.logical_or(tf.equal(ys, 0.0), tf.equal(ys, image_height_end)))
return tf.stack([tf.boolean_mask(ys, is_boundary), tf.boolean_mask(xs, is_boundary)], axis=-1)
def _add_zero_flow_controls_at_boundary(control_point_locations,
control_point_flows, image_height,
image_width, boundary_points_per_edge):
"""Add control points for zero-flow boundary conditions.
Augment the set of control points with extra points on the
boundary of the image that have zero flow.
Args:
control_point_locations: input control points
control_point_flows: their flows
image_height: image height
image_width: image width
boundary_points_per_edge: number of points to add in the middle of each
edge (not including the corners).
The total number of points added is
4 + 4*(boundary_points_per_edge).
Returns:
merged_control_point_locations: augmented set of control point locations
merged_control_point_flows: augmented set of control point flows
"""
batch_size = dimension_value(tf.shape(control_point_locations)[0])
boundary_point_locations = _get_boundary_locations(image_height, image_width,
boundary_points_per_edge)
boundary_point_shape = tf.shape(boundary_point_locations)
boundary_point_flows = tf.zeros([boundary_point_shape[0], 2])
minbatch_locations = _expand_to_minibatch(boundary_point_locations, batch_size)
type_to_use = control_point_locations.dtype
boundary_point_locations = tf.cast(minbatch_locations, type_to_use)
minbatch_flows = _expand_to_minibatch(boundary_point_flows, batch_size)
boundary_point_flows = tf.cast(minbatch_flows, type_to_use)
merged_control_point_locations = tf.concat(
[control_point_locations, boundary_point_locations], 1)
merged_control_point_flows = tf.concat(
[control_point_flows, boundary_point_flows], 1)
return merged_control_point_locations, merged_control_point_flows
def sparse_image_warp(image,
source_control_point_locations,
dest_control_point_locations,
interpolation_order=2,
regularization_weight=0.0,
num_boundary_points=0,
name='sparse_image_warp'):
"""Image warping using correspondences between sparse control points.
Apply a non-linear warp to the image, where the warp is specified by
the source and destination locations of a (potentially small) number of
control points. First, we use a polyharmonic spline
(`tf.contrib.image.interpolate_spline`) to interpolate the displacements
between the corresponding control points to a dense flow field.
Then, we warp the image using this dense flow field
(`tf.contrib.image.dense_image_warp`).
Let t index our control points. For regularization_weight=0, we have:
warped_image[b, dest_control_point_locations[b, t, 0],
dest_control_point_locations[b, t, 1], :] =
image[b, source_control_point_locations[b, t, 0],
source_control_point_locations[b, t, 1], :].
For regularization_weight > 0, this condition is met approximately, since
regularized interpolation trades off smoothness of the interpolant vs.
reconstruction of the interpolant at the control points.
See `tf.contrib.image.interpolate_spline` for further documentation of the
interpolation_order and regularization_weight arguments.
Args:
image: `[batch, height, width, channels]` float `Tensor`
source_control_point_locations: `[batch, num_control_points, 2]` float
`Tensor`
dest_control_point_locations: `[batch, num_control_points, 2]` float
`Tensor`
interpolation_order: polynomial order used by the spline interpolation
regularization_weight: weight on smoothness regularizer in interpolation
num_boundary_points: How many zero-flow boundary points to include at
each image edge.Usage:
num_boundary_points=0: don't add zero-flow points
num_boundary_points=1: 4 corners of the image
num_boundary_points=2: 4 corners and one in the middle of each edge
(8 points total)
num_boundary_points=n: 4 corners and n-1 along each edge
name: A name for the operation (optional).
Note that image and offsets can be of type tf.half, tf.float32, or
tf.float64, and do not necessarily have to be the same type.
Returns:
warped_image: `[batch, height, width, channels]` float `Tensor` with same
type as input image.
flow_field: `[batch, height, width, 2]` float `Tensor` containing the dense
flow field produced by the interpolation.
"""
image = ops.convert_to_tensor(image)
source_control_point_locations = ops.convert_to_tensor(
source_control_point_locations)
dest_control_point_locations = ops.convert_to_tensor(
dest_control_point_locations)
control_point_flows = (
dest_control_point_locations - source_control_point_locations)
clamp_boundaries = num_boundary_points > 0
boundary_points_per_edge = num_boundary_points - 1
with ops.name_scope(name):
image_shape = tf.shape(image)
batch_size, image_height, image_width = image_shape[0], image_shape[1], image_shape[2]
# This generates the dense locations where the interpolant
# will be evaluated.
grid_locations = _get_grid_locations(image_height, image_width)
flattened_grid_locations = tf.reshape(grid_locations,
[tf.multiply(image_height, image_width), 2])
# flattened_grid_locations = constant_op.constant(
# _expand_to_minibatch(flattened_grid_locations, batch_size), image.dtype)
flattened_grid_locations = _expand_to_minibatch(flattened_grid_locations, batch_size)
flattened_grid_locations = tf.cast(flattened_grid_locations, dtype=image.dtype)
if clamp_boundaries:
(dest_control_point_locations,
control_point_flows) = _add_zero_flow_controls_at_boundary(
dest_control_point_locations, control_point_flows, image_height,
image_width, boundary_points_per_edge)
flattened_flows = interpolate_spline.interpolate_spline(
dest_control_point_locations, control_point_flows,
flattened_grid_locations, interpolation_order, regularization_weight)
dense_flows = array_ops.reshape(flattened_flows,
[batch_size, image_height, image_width, 2])
warped_image = dense_image_warp.dense_image_warp(image, dense_flows)
return warped_image, dense_flows

View File

@ -1,4 +1,6 @@
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
from util.sparse_image_warp import sparse_image_warp
def augment_freq_time_mask(mel_spectrogram,
frequency_masking_para=30,
@ -64,3 +66,61 @@ def augment_speed_up(spectrogram,
def augment_dropout(spectrogram,
keep_prob=0.95):
return tf.nn.dropout(spectrogram, rate=1-keep_prob)
def augment_sparse_warp(spectrogram, time_warping_para=20, interpolation_order=2, regularization_weight=0.0, num_boundary_points=1, num_control_points=1):
"""Reference: https://arxiv.org/pdf/1904.08779.pdf
Args:
spectrogram: `[batch, time, frequency]` float `Tensor`
time_warping_para: 'W' parameter in paper
interpolation_order: used to put into `sparse_image_warp`
regularization_weight: used to put into `sparse_image_warp`
num_boundary_points: used to put into `sparse_image_warp`,
default=1 means boundary points on 4 corners of the image
num_control_points: number of control points
Returns:
warped_spectrogram: `[batch, time, frequency]` float `Tensor` with same
type as input image.
"""
# reshape to fit `sparse_image_warp`'s input shape
# (1, time steps, freq, 1), batch_size must be 1
spectrogram = tf.expand_dims(spectrogram, -1)
original_shape = tf.shape(spectrogram)
tau, freq_size = original_shape[1], original_shape[2]
# to protect short audio
time_warping_para = tf.math.minimum(
time_warping_para, tf.math.subtract(tf.math.floordiv(tau, 2), 1))
# don't choose boundary frequency
choosen_freqs = tf.random.shuffle(
tf.add(tf.range(freq_size - 3), 1))[0: num_control_points]
source_max = tau - time_warping_para
source_min = tf.math.minimum(source_max - num_control_points, time_warping_para)
choosen_times = tf.random.shuffle(tf.range(source_min, limit=source_max))[0: num_control_points]
dest_time_widths = tfv1.random_uniform([num_control_points], tf.negative(time_warping_para), time_warping_para, tf.int32)
sources = []
dests = []
for i in range(num_control_points):
# generate source points `t` of time axis between (W, tau-W)
rand_source_time = choosen_times[i]
rand_dest_time = rand_source_time + dest_time_widths[i]
choosen_freq = choosen_freqs[i]
sources.append([rand_source_time, choosen_freq])
dests.append([rand_dest_time, choosen_freq])
source_control_point_locations = tf.cast([sources], tf.float32)
dest_control_point_locations = tf.cast([dests], tf.float32)
warped_spectrogram, _ = sparse_image_warp(spectrogram,
source_control_point_locations=source_control_point_locations,
dest_control_point_locations=dest_control_point_locations,
interpolation_order=interpolation_order,
regularization_weight=regularization_weight,
num_boundary_points=num_boundary_points)
return tf.reshape(warped_spectrogram, shape=(1, -1, freq_size))