diff --git a/DeepSpeech.py b/DeepSpeech.py index 3f620502..130f3b73 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -22,6 +22,7 @@ from ds_ctcdecoder import ctc_beam_search_decoder, Scorer from evaluate import evaluate from six.moves import zip, range from tensorflow.python.tools import freeze_graph, strip_unused_lib +from tensorflow.python.framework import errors_impl from util.config import Config, initialize_globals from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features from util.flags import create_flags, FLAGS @@ -428,7 +429,8 @@ def train(): FLAGS.augmentation_spec_dropout_keeprate < 1 or FLAGS.augmentation_freq_and_time_masking or FLAGS.augmentation_pitch_and_tempo_scaling or - FLAGS.augmentation_speed_up_std > 0): + FLAGS.augmentation_speed_up_std > 0 or + FLAGS.augmentation_sparse_warp): do_cache_dataset = False # Create training and validation datasets @@ -598,6 +600,10 @@ def train(): _, current_step, batch_loss, problem_files, step_summary = \ session.run([train_op, global_step, loss, non_finite_files, step_summaries_op], feed_dict=feed_dict) + except tf.errors.InvalidArgumentError as err: + if FLAGS.augmentation_sparse_warp: + log_info("skip sparse warp error: {}".format(err)) + continue except tf.errors.OutOfRangeError: break diff --git a/util/feeding.py b/util/feeding.py index 16d0e312..e772f0ed 100644 --- a/util/feeding.py +++ b/util/feeding.py @@ -14,7 +14,7 @@ from tensorflow.python.ops import gen_audio_ops as contrib_audio from util.config import Config from util.text import text_to_char_array from util.flags import FLAGS -from util.spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up +from util.spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up, augment_sparse_warp from util.audio import read_frames_from_file, vad_split, DEFAULT_FORMAT @@ -42,6 +42,15 @@ def samples_to_mfccs(samples, sample_rate, train_phase=False): spectrogram = augment_dropout(spectrogram, keep_prob=FLAGS.augmentation_spec_dropout_keeprate) + # sparse warp must before freq/time masking + if FLAGS.augmentation_sparse_warp: + spectrogram = augment_sparse_warp(spectrogram, + time_warping_para=FLAGS.augmentation_sparse_warp_time_warping_para, + interpolation_order=FLAGS.augmentation_sparse_warp_interpolation_order, + regularization_weight=FLAGS.augmentation_sparse_warp_regularization_weight, + num_boundary_points=FLAGS.augmentation_sparse_warp_num_boundary_points, + num_control_points=FLAGS.augmentation_sparse_warp_num_control_points) + if FLAGS.augmentation_freq_and_time_masking: spectrogram = augment_freq_time_mask(spectrogram, frequency_masking_para=FLAGS.augmentation_freq_and_time_masking_freq_mask_range, diff --git a/util/flags.py b/util/flags.py index d8a2656c..624217e0 100644 --- a/util/flags.py +++ b/util/flags.py @@ -29,6 +29,13 @@ def create_flags(): f.DEFINE_float('augmentation_spec_dropout_keeprate', 1, 'keep rate of dropout augmentation on spectrogram (if 1, no dropout will be performed on spectrogram)') + f.DEFINE_boolean('augmentation_sparse_warp', False, 'whether to use spectrogram sparse warp. USE OF THIS FLAG IS UNSUPPORTED, enable sparse warp will increase training time drastically, and the paper also mentioned that this is not a major factor to improve accuracy.') + f.DEFINE_integer('augmentation_sparse_warp_num_control_points', 1, 'specify number of control points') + f.DEFINE_integer('augmentation_sparse_warp_time_warping_para', 20, 'time_warping_para') + f.DEFINE_integer('augmentation_sparse_warp_interpolation_order', 2, 'sparse_warp_interpolation_order') + f.DEFINE_float('augmentation_sparse_warp_regularization_weight', 0.0, 'sparse_warp_regularization_weight') + f.DEFINE_integer('augmentation_sparse_warp_num_boundary_points', 1, 'sparse_warp_num_boundary_points') + f.DEFINE_boolean('augmentation_freq_and_time_masking', False, 'whether to use frequency and time masking augmentation') f.DEFINE_integer('augmentation_freq_and_time_masking_freq_mask_range', 5, 'max range of masks in the frequency domain when performing freqtime-mask augmentation') f.DEFINE_integer('augmentation_freq_and_time_masking_number_freq_masks', 3, 'number of masks in the frequency domain when performing freqtime-mask augmentation') diff --git a/util/sparse_image_warp.py b/util/sparse_image_warp.py new file mode 100644 index 00000000..0fcdba0a --- /dev/null +++ b/util/sparse_image_warp.py @@ -0,0 +1,220 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Image warping using sparse flow defined at control points.""" + +# The following code is from: https://github.com/tensorflow/tensorflow/blob/v1.14.0/tensorflow/contrib/image/python/ops/sparse_image_warp.py +# But refactored for dynamic tensor shape compatibility +# The core idea is to replace every numpy implementation with tensorflow implementation + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +import tensorflow.compat.v1 as tfv1 +from tensorflow.compat import dimension_value +from tensorflow.contrib.image.python.ops import dense_image_warp +from tensorflow.contrib.image.python.ops import interpolate_spline + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops + +def _to_float32(value): + return tf.cast(value, tf.float32) + +def _to_int32(value): + return tf.cast(value, tf.int32) + +def _get_grid_locations(image_height, image_width): + """Wrapper for np.meshgrid.""" + tfv1.assert_type(image_height, tf.int32) + tfv1.assert_type(image_width, tf.int32) + + y_range = tf.range(image_height) + x_range = tf.range(image_width) + y_grid, x_grid = tf.meshgrid(y_range, x_range, indexing='ij') + return tf.stack((y_grid, x_grid), -1) + + +def _expand_to_minibatch(tensor, batch_size): + """Tile arbitrarily-sized np_array to include new batch dimension.""" + ndim = tf.size(tf.shape(tensor)) + ones = tf.ones((ndim,), tf.int32) + + tiles = tf.concat(([batch_size], ones), 0) + return tf.tile(tf.expand_dims(tensor, 0), tiles) + + +def _get_boundary_locations(image_height, image_width, num_points_per_edge): + """Compute evenly-spaced indices along edge of image.""" + image_height_end = _to_float32(tf.math.subtract(image_height, 1)) + image_width_end = _to_float32(tf.math.subtract(image_width, 1)) + y_range = tf.linspace(0.0, image_height_end, num_points_per_edge + 2) + x_range = tf.linspace(0.0, image_height_end, num_points_per_edge + 2) + ys, xs = tf.meshgrid(y_range, x_range, indexing='ij') + is_boundary = tf.logical_or( + tf.logical_or(tf.equal(xs, 0.0), tf.equal(xs, image_width_end)), + tf.logical_or(tf.equal(ys, 0.0), tf.equal(ys, image_height_end))) + return tf.stack([tf.boolean_mask(ys, is_boundary), tf.boolean_mask(xs, is_boundary)], axis=-1) + + +def _add_zero_flow_controls_at_boundary(control_point_locations, + control_point_flows, image_height, + image_width, boundary_points_per_edge): + """Add control points for zero-flow boundary conditions. + + Augment the set of control points with extra points on the + boundary of the image that have zero flow. + + Args: + control_point_locations: input control points + control_point_flows: their flows + image_height: image height + image_width: image width + boundary_points_per_edge: number of points to add in the middle of each + edge (not including the corners). + The total number of points added is + 4 + 4*(boundary_points_per_edge). + + Returns: + merged_control_point_locations: augmented set of control point locations + merged_control_point_flows: augmented set of control point flows + """ + + batch_size = dimension_value(tf.shape(control_point_locations)[0]) + + boundary_point_locations = _get_boundary_locations(image_height, image_width, + boundary_points_per_edge) + boundary_point_shape = tf.shape(boundary_point_locations) + boundary_point_flows = tf.zeros([boundary_point_shape[0], 2]) + + minbatch_locations = _expand_to_minibatch(boundary_point_locations, batch_size) + type_to_use = control_point_locations.dtype + boundary_point_locations = tf.cast(minbatch_locations, type_to_use) + + minbatch_flows = _expand_to_minibatch(boundary_point_flows, batch_size) + + boundary_point_flows = tf.cast(minbatch_flows, type_to_use) + + merged_control_point_locations = tf.concat( + [control_point_locations, boundary_point_locations], 1) + + merged_control_point_flows = tf.concat( + [control_point_flows, boundary_point_flows], 1) + + return merged_control_point_locations, merged_control_point_flows + + +def sparse_image_warp(image, + source_control_point_locations, + dest_control_point_locations, + interpolation_order=2, + regularization_weight=0.0, + num_boundary_points=0, + name='sparse_image_warp'): + """Image warping using correspondences between sparse control points. + + Apply a non-linear warp to the image, where the warp is specified by + the source and destination locations of a (potentially small) number of + control points. First, we use a polyharmonic spline + (`tf.contrib.image.interpolate_spline`) to interpolate the displacements + between the corresponding control points to a dense flow field. + Then, we warp the image using this dense flow field + (`tf.contrib.image.dense_image_warp`). + + Let t index our control points. For regularization_weight=0, we have: + warped_image[b, dest_control_point_locations[b, t, 0], + dest_control_point_locations[b, t, 1], :] = + image[b, source_control_point_locations[b, t, 0], + source_control_point_locations[b, t, 1], :]. + + For regularization_weight > 0, this condition is met approximately, since + regularized interpolation trades off smoothness of the interpolant vs. + reconstruction of the interpolant at the control points. + See `tf.contrib.image.interpolate_spline` for further documentation of the + interpolation_order and regularization_weight arguments. + + + Args: + image: `[batch, height, width, channels]` float `Tensor` + source_control_point_locations: `[batch, num_control_points, 2]` float + `Tensor` + dest_control_point_locations: `[batch, num_control_points, 2]` float + `Tensor` + interpolation_order: polynomial order used by the spline interpolation + regularization_weight: weight on smoothness regularizer in interpolation + num_boundary_points: How many zero-flow boundary points to include at + each image edge.Usage: + num_boundary_points=0: don't add zero-flow points + num_boundary_points=1: 4 corners of the image + num_boundary_points=2: 4 corners and one in the middle of each edge + (8 points total) + num_boundary_points=n: 4 corners and n-1 along each edge + name: A name for the operation (optional). + + Note that image and offsets can be of type tf.half, tf.float32, or + tf.float64, and do not necessarily have to be the same type. + + Returns: + warped_image: `[batch, height, width, channels]` float `Tensor` with same + type as input image. + flow_field: `[batch, height, width, 2]` float `Tensor` containing the dense + flow field produced by the interpolation. + """ + + image = ops.convert_to_tensor(image) + source_control_point_locations = ops.convert_to_tensor( + source_control_point_locations) + dest_control_point_locations = ops.convert_to_tensor( + dest_control_point_locations) + + control_point_flows = ( + dest_control_point_locations - source_control_point_locations) + + clamp_boundaries = num_boundary_points > 0 + boundary_points_per_edge = num_boundary_points - 1 + + with ops.name_scope(name): + image_shape = tf.shape(image) + batch_size, image_height, image_width = image_shape[0], image_shape[1], image_shape[2] + + # This generates the dense locations where the interpolant + # will be evaluated. + grid_locations = _get_grid_locations(image_height, image_width) + + flattened_grid_locations = tf.reshape(grid_locations, + [tf.multiply(image_height, image_width), 2]) + + # flattened_grid_locations = constant_op.constant( + # _expand_to_minibatch(flattened_grid_locations, batch_size), image.dtype) + flattened_grid_locations = _expand_to_minibatch(flattened_grid_locations, batch_size) + flattened_grid_locations = tf.cast(flattened_grid_locations, dtype=image.dtype) + + if clamp_boundaries: + (dest_control_point_locations, + control_point_flows) = _add_zero_flow_controls_at_boundary( + dest_control_point_locations, control_point_flows, image_height, + image_width, boundary_points_per_edge) + + flattened_flows = interpolate_spline.interpolate_spline( + dest_control_point_locations, control_point_flows, + flattened_grid_locations, interpolation_order, regularization_weight) + + dense_flows = array_ops.reshape(flattened_flows, + [batch_size, image_height, image_width, 2]) + + warped_image = dense_image_warp.dense_image_warp(image, dense_flows) + + return warped_image, dense_flows diff --git a/util/spectrogram_augmentations.py b/util/spectrogram_augmentations.py index 9cf36a24..58980e8c 100644 --- a/util/spectrogram_augmentations.py +++ b/util/spectrogram_augmentations.py @@ -1,4 +1,6 @@ import tensorflow as tf +import tensorflow.compat.v1 as tfv1 +from util.sparse_image_warp import sparse_image_warp def augment_freq_time_mask(mel_spectrogram, frequency_masking_para=30, @@ -64,3 +66,61 @@ def augment_speed_up(spectrogram, def augment_dropout(spectrogram, keep_prob=0.95): return tf.nn.dropout(spectrogram, rate=1-keep_prob) + + +def augment_sparse_warp(spectrogram, time_warping_para=20, interpolation_order=2, regularization_weight=0.0, num_boundary_points=1, num_control_points=1): + """Reference: https://arxiv.org/pdf/1904.08779.pdf + Args: + spectrogram: `[batch, time, frequency]` float `Tensor` + time_warping_para: 'W' parameter in paper + interpolation_order: used to put into `sparse_image_warp` + regularization_weight: used to put into `sparse_image_warp` + num_boundary_points: used to put into `sparse_image_warp`, + default=1 means boundary points on 4 corners of the image + num_control_points: number of control points + Returns: + warped_spectrogram: `[batch, time, frequency]` float `Tensor` with same + type as input image. + """ + # reshape to fit `sparse_image_warp`'s input shape + # (1, time steps, freq, 1), batch_size must be 1 + spectrogram = tf.expand_dims(spectrogram, -1) + + original_shape = tf.shape(spectrogram) + tau, freq_size = original_shape[1], original_shape[2] + + # to protect short audio + time_warping_para = tf.math.minimum( + time_warping_para, tf.math.subtract(tf.math.floordiv(tau, 2), 1)) + + # don't choose boundary frequency + choosen_freqs = tf.random.shuffle( + tf.add(tf.range(freq_size - 3), 1))[0: num_control_points] + + source_max = tau - time_warping_para + source_min = tf.math.minimum(source_max - num_control_points, time_warping_para) + + choosen_times = tf.random.shuffle(tf.range(source_min, limit=source_max))[0: num_control_points] + dest_time_widths = tfv1.random_uniform([num_control_points], tf.negative(time_warping_para), time_warping_para, tf.int32) + + sources = [] + dests = [] + for i in range(num_control_points): + # generate source points `t` of time axis between (W, tau-W) + rand_source_time = choosen_times[i] + rand_dest_time = rand_source_time + dest_time_widths[i] + + choosen_freq = choosen_freqs[i] + sources.append([rand_source_time, choosen_freq]) + dests.append([rand_dest_time, choosen_freq]) + + source_control_point_locations = tf.cast([sources], tf.float32) + dest_control_point_locations = tf.cast([dests], tf.float32) + + warped_spectrogram, _ = sparse_image_warp(spectrogram, + source_control_point_locations=source_control_point_locations, + dest_control_point_locations=dest_control_point_locations, + interpolation_order=interpolation_order, + regularization_weight=regularization_weight, + num_boundary_points=num_boundary_points) + return tf.reshape(warped_spectrogram, shape=(1, -1, freq_size))