Warp augmentation
This commit is contained in:
parent
6f2ba4b5b4
commit
eebf12134e
@ -20,6 +20,7 @@ python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||
--augment dropout \
|
||||
--augment pitch \
|
||||
--augment tempo \
|
||||
--augment warp \
|
||||
--augment time_mask \
|
||||
--augment frequency_mask \
|
||||
--augment add \
|
||||
|
@ -399,6 +399,20 @@ Spectrogram domain augmentations
|
||||
* **factor**: speed factor by which the time axis is stretched or shrunken (e.g. a value of 2.0 will double playback tempo)
|
||||
|
||||
|
||||
**Warp augmentation** ``--augment warp[p=<float>,nt=<int-range>,nf=<int-range>,wt=<float-range>,wf=<float-range>]``
|
||||
Applies a non-linear image warp to the spectrogram. This is achieved by randomly shifting a grid of equally distributed warp points along time and frequency axis.
|
||||
|
||||
* **p**: probability value between 0.0 (never) and 1.0 (always) if a given sample gets augmented by this method
|
||||
|
||||
* **nt**: number of equally distributed warp grid lines along time axis of the spectrogram (excluding the edges)
|
||||
|
||||
* **nf**: number of equally distributed warp grid lines along frequency axis of the spectrogram (excluding the edges)
|
||||
|
||||
* **wt**: standard deviation of the random shift applied to warp points along time axis (0.0 = no warp, 1.0 = half the distance to the neighbour point)
|
||||
|
||||
* **wf**: standard deviation of the random shift applied to warp points along frequency axis (0.0 = no warp, 1.0 = half the distance to the neighbour point)
|
||||
|
||||
|
||||
**Frequency mask augmentation** ``--augment frequency_mask[p=<float>,n=<int-range>,size=<int-range>]``
|
||||
Sets frequency-intervals within the augmented samples to zero (silence) at random frequencies. See the SpecAugment paper for more details - https://arxiv.org/abs/1904.08779
|
||||
|
||||
@ -469,6 +483,7 @@ Example training with all augmentations:
|
||||
--augment volume[p=0.1,dbfs=-10:-40] \
|
||||
--augment pitch[p=0.1,pitch=1~0.2] \
|
||||
--augment tempo[p=0.1,factor=1~0.5] \
|
||||
--augment warp[p=0.1,nt=4,nf=1,wt=0.5:1.0,wf=0.1:0.2] \
|
||||
--augment frequency_mask[p=0.1,n=1:3,size=1:5] \
|
||||
--augment time_mask[p=0.1,domain=signal,n=3:10~2,size=50:100~40] \
|
||||
--augment dropout[p=0.1,rate=0.05] \
|
||||
|
@ -412,6 +412,35 @@ class Tempo(GraphAugmentation):
|
||||
return spectrogram_aug[:, :, :, 0]
|
||||
|
||||
|
||||
class Warp(GraphAugmentation):
|
||||
"""See "Warp augmentation" in training documentation"""
|
||||
def __init__(self, p=1.0, nt=1, nf=1, wt=0.1, wf=0.0):
|
||||
super(Warp, self).__init__(p, domain='spectrogram')
|
||||
self.num_t = int_range(nt)
|
||||
self.num_f = int_range(nf)
|
||||
self.warp_t = float_range(wt)
|
||||
self.warp_f = float_range(wf)
|
||||
|
||||
def apply(self, tensor, transcript=None, clock=0.0):
|
||||
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
||||
original_shape = tf.shape(tensor)
|
||||
size_t, size_f = original_shape[1], original_shape[2]
|
||||
seed = (clock * tf.int32.min, clock * tf.int32.max)
|
||||
num_t = tf_pick_value_from_range(self.num_t, clock=clock)
|
||||
num_f = tf_pick_value_from_range(self.num_f, clock=clock)
|
||||
|
||||
def get_flows(n, size, warp):
|
||||
warp = tf_pick_value_from_range(warp, clock=clock)
|
||||
warp = warp * tf.cast(size, dtype=tf.float32) / tf.cast(2 * (n + 1), dtype=tf.float32)
|
||||
f = tf.random.stateless_normal([num_t, num_f], seed, mean=0.0, stddev=warp, dtype=tf.float32)
|
||||
return tf.pad(f, tf.constant([[1, 1], [1, 1]]), 'CONSTANT') # zero flow at all edges
|
||||
|
||||
flows = tf.stack([get_flows(num_t, size_t, self.warp_t), get_flows(num_f, size_f, self.warp_f)], axis=2)
|
||||
flows = tf.image.resize_bicubic(tf.expand_dims(flows, 0), [size_t, size_f])
|
||||
spectrogram_aug = tf.contrib.image.dense_image_warp(tf.expand_dims(tensor, -1), flows)
|
||||
return tf.reshape(spectrogram_aug, shape=(1, -1, size_f))
|
||||
|
||||
|
||||
class FrequencyMask(GraphAugmentation):
|
||||
"""See "Frequency mask augmentation" in training documentation"""
|
||||
def __init__(self, p=1.0, n=3, size=2):
|
||||
|
Loading…
Reference in New Issue
Block a user