Follow-up on PR comments; removed warp augmentation; split pitch_and_tempo augmentation
This commit is contained in:
parent
ea21c7d24e
commit
5b6de213d8
@ -311,7 +311,7 @@ Augmentations are applied in the following order:
|
||||
|
||||
4. **features** domain: The sample's mel spectrogram features are represented as a tensor.
|
||||
|
||||
Within a single domain, augmentations are applied in the same order as they appear in the command-line (the **warp** augmentation being the only exception, as it is always applied first when enabled).
|
||||
Within a single domain, augmentations are applied in the same order as they appear in the command-line.
|
||||
|
||||
|
||||
Sample domain augmentations
|
||||
@ -365,17 +365,15 @@ Sample domain augmentations
|
||||
Spectrogram domain augmentations
|
||||
--------------------------------
|
||||
|
||||
**Pitch and tempo augmentation** ``--augment pitch_and_tempo[p=<float>,pitch=<float-range>,tempo=<float-range>]``
|
||||
Scales spectrogram on time and frequency axis and thus changes pitch and playback tempo.
|
||||
**Pitch augmentation** ``--augment pitch[p=<float>,pitch=<float-range>]``
|
||||
Scales spectrogram on frequency axis and thus changes pitch.
|
||||
|
||||
* **p**: probability value between 0.0 (never) and 1.0 (always) if a given sample gets augmented by this method
|
||||
|
||||
* **pitch**: pitch factor by with the frequency axis is scaled (e.g. a value of 2.0 will raise audio frequency by one octave)
|
||||
|
||||
* **tempo**: tempo factor by which the time axis is stretched or shrunken (e.g. a value of 2.0 will double playback tempo)
|
||||
|
||||
|
||||
**Speed augmentation** ``--augment speed[p=<float>,factor=<float-range>]``
|
||||
**Tempo augmentation** ``--augment tempo[p=<float>,factor=<float-range>]``
|
||||
Scales spectrogram on time axis and thus changes playback tempo.
|
||||
|
||||
* **p**: probability value between 0.0 (never) and 1.0 (always) if a given sample gets augmented by this method
|
||||
@ -383,22 +381,6 @@ Spectrogram domain augmentations
|
||||
* **factor**: speed factor by which the time axis is stretched or shrunken (e.g. a value of 2.0 will double playback tempo)
|
||||
|
||||
|
||||
**Warp augmentation** ``--augment warp[p=<float>,shift=<float-range>,order=<int-range>,nbp=<int-range>,ncp=<int-range>,regularization_weight=<float>]``
|
||||
Applies a non-linear image warp to the spectrogram, where the warp is specified by the source and destination locations of a (potentially small) number of control points. Of all specified spectrogram augmentations this one will always be applied first. See the SpecAugment paper for more details - https://arxiv.org/abs/1904.08779
|
||||
|
||||
* **p**: probability value between 0.0 (never) and 1.0 (always) if a given sample gets augmented by this method
|
||||
|
||||
* **shift**: maximum shift distance of control points on time axis in ms
|
||||
|
||||
* **order**: polynomial order used by the spline interpolation
|
||||
|
||||
* **nbp**: how many zero-flow boundary points to include at each spectrogram edge
|
||||
|
||||
* **ncp**: how many control points to warp inside the spectrogram
|
||||
|
||||
* **regularization_weight**: weight on smoothness regularizer in interpolation
|
||||
|
||||
|
||||
**Frequency mask augmentation** ``--augment frequency_mask[p=<float>,n=<int-range>,size=<int-range>]``
|
||||
Sets frequency-intervals within the augmented samples to zero (silence) at random frequencies. See the SpecAugment paper for more details - https://arxiv.org/abs/1904.08779
|
||||
|
||||
@ -467,9 +449,8 @@ Example training with all augmentations:
|
||||
--augment resample[p=0.1,rate=12000:8000~4000] \
|
||||
--augment codec[p=0.1,bitrate=48000:16000] \
|
||||
--augment volume[p=0.1,dbfs=-10:-40] \
|
||||
--augment pitch_and_tempo[p=0.1,pitch=1~0.2,tempo=1~0.2] \
|
||||
--augment speed[p=0.1,factor=1~0.5] \
|
||||
--augment warp[p=0.1,shift=30:60~20,ncp=4~3] \
|
||||
--augment pitch[p=0.1,pitch=1~0.2] \
|
||||
--augment tempo[p=0.1,factor=1~0.5] \
|
||||
--augment frequency_mask[p=0.1,n=1:3,size=1:5] \
|
||||
--augment time_mask[p=0.1,domain=signal,n=3:10~2,size=50:100~40] \
|
||||
--augment dropout[p=0.1,rate=0.05] \
|
||||
|
@ -8,6 +8,7 @@ import numpy as np
|
||||
from multiprocessing import Queue, Process
|
||||
from .audio import gain_db_to_ratio, max_dbfs, normalize_audio, AUDIO_TYPE_NP, AUDIO_TYPE_PCM, AUDIO_TYPE_OPUS
|
||||
from .helpers import LimitingPool, int_range, float_range, pick_value_from_range, tf_pick_value_from_range, MEGABYTE
|
||||
from .sample_collections import samples_from_source
|
||||
|
||||
BUFFER_SIZE = 1 * MEGABYTE
|
||||
SPEC_PARSER = re.compile(r'^(?P<cls>[a-z_]+)(\[(?P<params>.*)\])?$')
|
||||
@ -36,21 +37,25 @@ class GraphAugmentation(Augmentation):
|
||||
raise ValueError('Unsupported augmentation domain: {}'.format(domain))
|
||||
self.domain = domain
|
||||
|
||||
def apply(self, tensor, clock=0.0):
|
||||
def apply(self, tensor, transcript=None, clock=0.0):
|
||||
raise NotImplementedError
|
||||
|
||||
def apply_with_probability(self, tensor, clock=0.0):
|
||||
def apply_with_probability(self, tensor, transcript=None, clock=0.0):
|
||||
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
||||
rv = tf.random.stateless_uniform([], seed=(clock * tf.int32.min, clock * tf.int32.max))
|
||||
return tf.cond(tf.less(rv, self.probability),
|
||||
lambda: self.apply(tensor, clock=clock),
|
||||
lambda: self.apply(tensor, transcript=transcript, clock=clock),
|
||||
lambda: tensor)
|
||||
|
||||
def maybe_apply(self, domain, tensor, clock=0.0):
|
||||
def maybe_apply(self, domain, tensor, transcript=None, clock=0.0):
|
||||
if domain == self.domain:
|
||||
return self.apply_with_probability(tensor, clock=clock)
|
||||
return self.apply_with_probability(tensor, transcript=transcript, clock=clock)
|
||||
return tensor
|
||||
|
||||
def units_per_ms(self):
|
||||
from .flags import FLAGS # pylint: disable=import-outside-toplevel
|
||||
return FLAGS.audio_sample_rate / 1000.0 if self.domain == 'signal' else 1.0 / FLAGS.feature_win_step
|
||||
|
||||
|
||||
def parse_augmentation(augmentation_spec):
|
||||
"""
|
||||
@ -103,7 +108,7 @@ def parse_augmentations(augmentation_specs):
|
||||
return [] if augmentation_specs is None else list(map(parse_augmentation, augmentation_specs))
|
||||
|
||||
|
||||
def apply_graph_augmentations(domain, tensor, augmentations, clock=0.0):
|
||||
def apply_graph_augmentations(domain, tensor, augmentations, transcript=None, clock=0.0):
|
||||
"""
|
||||
Augments training sample tensor of a certain domain with matching augmentations of passed list.
|
||||
|
||||
@ -115,6 +120,7 @@ def apply_graph_augmentations(domain, tensor, augmentations, clock=0.0):
|
||||
Tensor to apply augmentations to.
|
||||
augmentations : list of augmentation class instances from util.augmentations.*.
|
||||
List of augmentations of which only the spectrogram ones will get applied to the samples.
|
||||
transcript : SparseTensor
|
||||
clock : Tensor of type float32
|
||||
Time indicator for augmentation value-ranges. Running from 0.0 (start of training) to 1.0 (end of training).
|
||||
|
||||
@ -124,13 +130,9 @@ def apply_graph_augmentations(domain, tensor, augmentations, clock=0.0):
|
||||
The augmented spectrogram
|
||||
"""
|
||||
if augmentations is not None:
|
||||
# Warp has to come before any spectrogram masking
|
||||
for augmentation in augmentations:
|
||||
if isinstance(augmentation, Warp):
|
||||
tensor = augmentation.maybe_apply(domain, tensor, clock=clock)
|
||||
for augmentation in augmentations:
|
||||
if isinstance(augmentation, GraphAugmentation) and not isinstance(augmentation, Warp):
|
||||
tensor = augmentation.maybe_apply(domain, tensor, clock=clock)
|
||||
if isinstance(augmentation, GraphAugmentation):
|
||||
tensor = augmentation.maybe_apply(domain, tensor, transcript=transcript, clock=clock)
|
||||
return tensor
|
||||
|
||||
|
||||
@ -204,7 +206,7 @@ def apply_sample_augmentations(samples,
|
||||
if final_clock is not None:
|
||||
assert 0.0 <= final_clock <= 1.0
|
||||
assert clock <= final_clock
|
||||
augmentations = list(filter(lambda aug: isinstance(aug, SampleAugmentation), augmentations))
|
||||
augmentations = [aug for aug in augmentations if isinstance(aug, SampleAugmentation)]
|
||||
try:
|
||||
for augmentation in augmentations:
|
||||
augmentation.start(buffering=buffering)
|
||||
@ -229,8 +231,6 @@ def _enqueue_overlay_samples(sample_source, queue, buffering=BUFFER_SIZE):
|
||||
It loads the (raw and still compressed) data and provides it to the actual augmentation workers.
|
||||
These are then doing decompression, potential conversion and overlaying in parallel.
|
||||
"""
|
||||
# preventing cyclic import problems
|
||||
from .sample_collections import samples_from_source # pylint: disable=import-outside-toplevel
|
||||
samples = samples_from_source(sample_source, buffering=buffering, labeled=False)
|
||||
while True:
|
||||
for sample in samples:
|
||||
@ -238,7 +238,7 @@ def _enqueue_overlay_samples(sample_source, queue, buffering=BUFFER_SIZE):
|
||||
|
||||
|
||||
class Overlay(SampleAugmentation):
|
||||
"""See "Overlay augmentation" in TRAINING.rst"""
|
||||
"""See "Overlay augmentation" in training documentation"""
|
||||
def __init__(self, source, p=1.0, snr=3.0, layers=1):
|
||||
super(Overlay, self).__init__(p)
|
||||
self.source = source
|
||||
@ -288,7 +288,7 @@ class Overlay(SampleAugmentation):
|
||||
|
||||
|
||||
class Codec(SampleAugmentation):
|
||||
"""See "Codec augmentation" in TRAINING.rst"""
|
||||
"""See "Codec augmentation" in training documentation"""
|
||||
def __init__(self, p=1.0, bitrate=3200):
|
||||
super(Codec, self).__init__(p)
|
||||
self.bitrate = int_range(bitrate)
|
||||
@ -300,7 +300,7 @@ class Codec(SampleAugmentation):
|
||||
|
||||
|
||||
class Reverb(SampleAugmentation):
|
||||
"""See "Reverb augmentation" in TRAINING.rst"""
|
||||
"""See "Reverb augmentation" in training documentation"""
|
||||
def __init__(self, p=1.0, delay=20.0, decay=10.0):
|
||||
super(Reverb, self).__init__(p)
|
||||
self.delay = float_range(delay)
|
||||
@ -330,7 +330,7 @@ class Reverb(SampleAugmentation):
|
||||
|
||||
|
||||
class Resample(SampleAugmentation):
|
||||
"""See "Resample augmentation" in TRAINING.rst"""
|
||||
"""See "Resample augmentation" in training documentation"""
|
||||
def __init__(self, p=1.0, rate=8000):
|
||||
super(Resample, self).__init__(p)
|
||||
self.rate = int_range(rate)
|
||||
@ -350,7 +350,7 @@ class Resample(SampleAugmentation):
|
||||
|
||||
|
||||
class Volume(SampleAugmentation):
|
||||
"""See "Volume augmentation" in TRAINING.rst"""
|
||||
"""See "Volume augmentation" in training documentation"""
|
||||
def __init__(self, p=1.0, dbfs=3.0103):
|
||||
super(Volume, self).__init__(p)
|
||||
self.target_dbfs = float_range(dbfs)
|
||||
@ -361,25 +361,22 @@ class Volume(SampleAugmentation):
|
||||
sample.audio = normalize_audio(sample.audio, dbfs=target_dbfs)
|
||||
|
||||
|
||||
class PitchAndTempo(GraphAugmentation):
|
||||
"""See "Pitch and tempo augmentation" in TRAINING.rst"""
|
||||
def __init__(self, p=1.0, tempo=1.2, pitch=(1.075, 1.075, 0.125)):
|
||||
super(PitchAndTempo, self).__init__(p, domain='spectrogram')
|
||||
self.tempo = float_range(tempo)
|
||||
class Pitch(GraphAugmentation):
|
||||
"""See "Pitch augmentation" in training documentation"""
|
||||
def __init__(self, p=1.0, pitch=(1.075, 1.075, 0.125)):
|
||||
super(Pitch, self).__init__(p, domain='spectrogram')
|
||||
self.pitch = float_range(pitch)
|
||||
|
||||
def apply(self, tensor, clock=0.0):
|
||||
def apply(self, tensor, transcript=None, clock=0.0):
|
||||
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
||||
original_shape = tf.shape(tensor)
|
||||
pitch = tf_pick_value_from_range(self.pitch, clock=clock)
|
||||
tempo = tf.math.maximum(1.0, tf_pick_value_from_range(self.tempo, clock=clock))
|
||||
new_freq_size = tf.cast(tf.cast(original_shape[2], tf.float32) * pitch, tf.int32)
|
||||
new_time_size = tf.cast(tf.cast(original_shape[1], tf.float32) / tempo, tf.int32)
|
||||
spectrogram_aug = tf.image.resize_bilinear(tf.expand_dims(tensor, -1), [new_time_size, new_freq_size])
|
||||
spectrogram_aug = tf.image.resize_bilinear(tf.expand_dims(tensor, -1), [original_shape[1], new_freq_size])
|
||||
spectrogram_aug = tf.image.crop_to_bounding_box(spectrogram_aug,
|
||||
offset_height=0,
|
||||
offset_width=0,
|
||||
target_height=tf.shape(spectrogram_aug)[1],
|
||||
target_height=original_shape[1],
|
||||
target_width=tf.math.minimum(original_shape[2], new_freq_size))
|
||||
spectrogram_aug = tf.cond(pitch < 1,
|
||||
lambda: tf.image.pad_to_bounding_box(spectrogram_aug,
|
||||
@ -391,82 +388,34 @@ class PitchAndTempo(GraphAugmentation):
|
||||
return spectrogram_aug[:, :, :, 0]
|
||||
|
||||
|
||||
class Speed(GraphAugmentation):
|
||||
"""See "Speed augmentation" in TRAINING.rst"""
|
||||
def __init__(self, p=1.0, factor=1.1):
|
||||
super(Speed, self).__init__(p, domain='spectrogram')
|
||||
class Tempo(GraphAugmentation):
|
||||
"""See "Tempo augmentation" in training documentation"""
|
||||
def __init__(self, p=1.0, factor=1.1, max_time=-1):
|
||||
super(Tempo, self).__init__(p, domain='spectrogram')
|
||||
self.factor = float_range(factor)
|
||||
self.max_time = float(max_time)
|
||||
|
||||
def apply(self, tensor, clock=0.0):
|
||||
def apply(self, tensor, transcript=None, clock=0.0):
|
||||
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
||||
factor = tf_pick_value_from_range(self.factor, clock=clock)
|
||||
original_shape = tf.shape(tensor)
|
||||
new_time_size = tf.cast(tf.cast(original_shape[1], tf.float32) / factor, tf.int32)
|
||||
if transcript is not None:
|
||||
new_time_size = tf.math.maximum(new_time_size, tf.shape(transcript)[1])
|
||||
if self.max_time > 0:
|
||||
new_time_size = tf.math.minimum(new_time_size, tf.cast(self.max_time * self.units_per_ms(), tf.int32))
|
||||
spectrogram_aug = tf.image.resize_bilinear(tf.expand_dims(tensor, -1), [new_time_size, original_shape[2]])
|
||||
return spectrogram_aug[:, :, :, 0]
|
||||
|
||||
|
||||
class Warp(GraphAugmentation):
|
||||
"""See "Warp augmentation" in TRAINING.rst"""
|
||||
def __init__(self, p=1.0, shift=100.0, order=3, nbp=1, ncp=1, regularization_weight=0.0):
|
||||
super(Warp, self).__init__(p, domain='spectrogram')
|
||||
self.shift = float_range(shift)
|
||||
self.order = int_range(order)
|
||||
self.nbp = int_range(nbp)
|
||||
self.ncp = int_range(ncp)
|
||||
# Making this a value-range is impossible, as it would get a tensor which would downstream be used as parameter
|
||||
# of a comparison inside tensorflow.contrib.image.python.ops.interpolate_spline. This is not supported.
|
||||
self.regularization_weight = float(regularization_weight)
|
||||
|
||||
def apply(self, tensor, clock=0.0):
|
||||
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
||||
from .flags import FLAGS # pylint: disable=import-outside-toplevel
|
||||
from .sparse_image_warp import sparse_image_warp # pylint: disable=import-outside-toplevel
|
||||
|
||||
# reshape to fit `sparse_image_warp`'s input shape (1, time steps, freq, 1), batch_size must be 1
|
||||
expanded_spectrogram = tf.expand_dims(tensor, -1)
|
||||
original_shape = tf.shape(expanded_spectrogram)
|
||||
tau, freq_size = original_shape[1], original_shape[2]
|
||||
seed = (clock * tf.int32.min, clock * tf.int32.max)
|
||||
|
||||
shift = tf_pick_value_from_range(self.shift, clock=clock)
|
||||
shift *= FLAGS.audio_sample_rate / (FLAGS.feature_win_step * 1000.0) # number of windows
|
||||
shift = tf.math.minimum(tf.cast(shift, dtype=tf.int32), tf.math.floordiv(tau, 2) - 1) # to protect short audio
|
||||
nbp = tf_pick_value_from_range(self.nbp, clock=clock)
|
||||
ncp = tf_pick_value_from_range(self.ncp, clock=clock)
|
||||
# workaround for missing stateless shuffle support
|
||||
frequencies = tf.random.stateless_uniform([2 * ncp], seed, minval=1, maxval=freq_size - 2, dtype=tf.int32)
|
||||
frequencies = tf.unique(tf.concat([frequencies, tf.range(1, limit=freq_size - 3)], axis=0))[0][0:ncp]
|
||||
source_max = tau - shift
|
||||
source_min = tf.math.minimum(source_max - ncp, shift)
|
||||
# workaround for missing stateless shuffle support
|
||||
src_times = tf.random.stateless_uniform([2 * ncp], seed, minval=source_min, maxval=source_max, dtype=tf.int32)
|
||||
src_times = tf.unique(tf.concat([src_times, tf.range(1, limit=source_max)], axis=0))[0][0:ncp]
|
||||
dst_times = src_times + tf.random.stateless_uniform([ncp], seed, minval=-shift, maxval=shift, dtype=tf.int32)
|
||||
scp_locations = tf.cast([tf.transpose(tf.stack([src_times, frequencies]))], dtype=tf.float32)
|
||||
dcp_locations = tf.cast([tf.transpose(tf.stack([dst_times, frequencies]))], dtype=tf.float32)
|
||||
|
||||
order = tf_pick_value_from_range(self.order, clock=clock)
|
||||
order = tf.math.maximum(3, order) # prevents "Input matrix is not invertible." exception
|
||||
order = tf.cast(order, tf.float32)
|
||||
|
||||
spectrogram_aug, _ = sparse_image_warp(expanded_spectrogram,
|
||||
source_control_point_locations=scp_locations,
|
||||
dest_control_point_locations=dcp_locations,
|
||||
interpolation_order=order,
|
||||
regularization_weight=self.regularization_weight,
|
||||
num_boundary_points=nbp)
|
||||
return tf.reshape(spectrogram_aug, shape=(1, -1, freq_size))
|
||||
|
||||
|
||||
class FrequencyMask(GraphAugmentation):
|
||||
"""See "Frequency mask augmentation" in TRAINING.rst"""
|
||||
"""See "Frequency mask augmentation" in training documentation"""
|
||||
def __init__(self, p=1.0, n=3, size=2):
|
||||
super(FrequencyMask, self).__init__(p, domain='spectrogram')
|
||||
self.n = int_range(n) # pylint: disable=invalid-name
|
||||
self.size = int_range(size)
|
||||
|
||||
def apply(self, tensor, clock=0.0):
|
||||
def apply(self, tensor, transcript=None, clock=0.0):
|
||||
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
||||
time_max = tf.shape(tensor)[1]
|
||||
freq_max = tf.shape(tensor)[2]
|
||||
@ -486,25 +435,20 @@ class FrequencyMask(GraphAugmentation):
|
||||
|
||||
|
||||
class TimeMask(GraphAugmentation):
|
||||
"""See "Time mask augmentation" in TRAINING.rst"""
|
||||
"""See "Time mask augmentation" in training documentation"""
|
||||
def __init__(self, p=1.0, domain='spectrogram', n=3, size=10.0):
|
||||
super(TimeMask, self).__init__(p, domain=domain)
|
||||
self.n = int_range(n) # pylint: disable=invalid-name
|
||||
self.size = float_range(size)
|
||||
|
||||
def apply(self, tensor, clock=0.0):
|
||||
def apply(self, tensor, transcript=None, clock=0.0):
|
||||
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
||||
from .flags import FLAGS # pylint: disable=import-outside-toplevel
|
||||
time_factor = FLAGS.audio_sample_rate / 1000.0 # samples per ms
|
||||
if self.domain != 'signal':
|
||||
time_factor /= FLAGS.feature_win_step # windows per ms
|
||||
time_max = tf.shape(tensor)[0] if self.domain == 'signal' else tf.shape(tensor)[1]
|
||||
time_max = tf.shape(tensor)[0 if self.domain == 'signal' else 1]
|
||||
n = tf_pick_value_from_range(self.n, clock=clock)
|
||||
|
||||
def body(i, augmented):
|
||||
size = tf.cast(tf_pick_value_from_range(self.size, clock=clock) * time_factor, dtype=tf.int32)
|
||||
size = tf.cast(tf_pick_value_from_range(self.size, clock=clock) * self.units_per_ms(), dtype=tf.int32)
|
||||
size = tf.math.maximum(1, tf.math.minimum(time_max - 1, size))
|
||||
tf.print(size)
|
||||
seed = tf.cast(clock * tf.int32.max, tf.int32) - i
|
||||
t0 = tf.random.stateless_uniform((), (-seed, seed), minval=0, maxval=time_max - size, dtype=tf.dtypes.int32)
|
||||
rest = time_max - t0 - size
|
||||
@ -521,12 +465,12 @@ class TimeMask(GraphAugmentation):
|
||||
|
||||
|
||||
class Dropout(GraphAugmentation):
|
||||
"""See "Dropout augmentation" in TRAINING.rst"""
|
||||
"""See "Dropout augmentation" in training documentation"""
|
||||
def __init__(self, p=1.0, domain='spectrogram', rate=0.05):
|
||||
super(Dropout, self).__init__(p, domain=domain)
|
||||
self.rate = float_range(rate)
|
||||
|
||||
def apply(self, tensor, clock=0.0):
|
||||
def apply(self, tensor, transcript=None, clock=0.0):
|
||||
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
||||
rate = tf_pick_value_from_range(self.rate, clock=clock)
|
||||
rate = tf.math.maximum(0.0, rate)
|
||||
@ -539,12 +483,12 @@ class Dropout(GraphAugmentation):
|
||||
|
||||
|
||||
class Add(GraphAugmentation):
|
||||
"""See "Add augmentation" in TRAINING.rst"""
|
||||
"""See "Add augmentation" in training documentation"""
|
||||
def __init__(self, p=1.0, domain='features', stddev=5):
|
||||
super(Add, self).__init__(p, domain=domain)
|
||||
self.stddev = float_range(stddev)
|
||||
|
||||
def apply(self, tensor, clock=0.0):
|
||||
def apply(self, tensor, transcript=None, clock=0.0):
|
||||
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
||||
stddev = tf_pick_value_from_range(self.stddev, clock=clock)
|
||||
seed = (clock * tf.int32.min, clock * tf.int32.max)
|
||||
@ -552,12 +496,12 @@ class Add(GraphAugmentation):
|
||||
|
||||
|
||||
class Multiply(GraphAugmentation):
|
||||
"""See "Multiply augmentation" in TRAINING.rst"""
|
||||
"""See "Multiply augmentation" in training documentation"""
|
||||
def __init__(self, p=1.0, domain='features', stddev=5):
|
||||
super(Multiply, self).__init__(p, domain=domain)
|
||||
self.stddev = float_range(stddev)
|
||||
|
||||
def apply(self, tensor, clock=0.0):
|
||||
def apply(self, tensor, transcript=None, clock=0.0):
|
||||
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
||||
stddev = tf_pick_value_from_range(self.stddev, clock=clock)
|
||||
seed = (clock * tf.int32.min, clock * tf.int32.max)
|
||||
|
@ -18,7 +18,7 @@ from .sample_collections import samples_from_sources
|
||||
from .helpers import remember_exception, MEGABYTE
|
||||
|
||||
|
||||
def audio_to_features(audio, sample_rate, clock=0.0, train_phase=False, augmentations=None, sample_id=None):
|
||||
def audio_to_features(audio, sample_rate, transcript=None, clock=0.0, train_phase=False, augmentations=None, sample_id=None):
|
||||
if train_phase:
|
||||
# We need the lambdas to make TensorFlow happy.
|
||||
# pylint: disable=unnecessary-lambda
|
||||
@ -29,7 +29,7 @@ def audio_to_features(audio, sample_rate, clock=0.0, train_phase=False, augmenta
|
||||
name='matching_sample_rate')
|
||||
|
||||
if train_phase and augmentations is not None:
|
||||
audio = apply_graph_augmentations('signal', audio, augmentations, clock=clock)
|
||||
audio = apply_graph_augmentations('signal', audio, augmentations, transcript=transcript, clock=clock)
|
||||
|
||||
spectrogram = contrib_audio.audio_spectrogram(audio,
|
||||
window_size=Config.audio_window_samples,
|
||||
@ -37,7 +37,7 @@ def audio_to_features(audio, sample_rate, clock=0.0, train_phase=False, augmenta
|
||||
magnitude_squared=True)
|
||||
|
||||
if train_phase and augmentations is not None:
|
||||
spectrogram = apply_graph_augmentations('spectrogram', spectrogram, augmentations, clock=clock)
|
||||
spectrogram = apply_graph_augmentations('spectrogram', spectrogram, augmentations, transcript=transcript, clock=clock)
|
||||
|
||||
features = contrib_audio.mfcc(spectrogram=spectrogram,
|
||||
sample_rate=sample_rate,
|
||||
@ -46,7 +46,7 @@ def audio_to_features(audio, sample_rate, clock=0.0, train_phase=False, augmenta
|
||||
features = tf.reshape(features, [-1, Config.n_input])
|
||||
|
||||
if train_phase and augmentations is not None:
|
||||
features = apply_graph_augmentations('features', features, augmentations, clock=clock)
|
||||
features = apply_graph_augmentations('features', features, augmentations, transcript=transcript, clock=clock)
|
||||
|
||||
return features, tf.shape(input=features)[0]
|
||||
|
||||
@ -64,13 +64,14 @@ def audiofile_to_features(wav_filename, clock=0.0, train_phase=False, augmentati
|
||||
|
||||
def entry_to_features(sample_id, audio, sample_rate, transcript, clock, train_phase=False, augmentations=None):
|
||||
# https://bugs.python.org/issue32117
|
||||
sparse_transcript = tf.SparseTensor(*transcript)
|
||||
features, features_len = audio_to_features(audio,
|
||||
sample_rate,
|
||||
transcript=sparse_transcript,
|
||||
clock=clock,
|
||||
train_phase=train_phase,
|
||||
augmentations=augmentations,
|
||||
sample_id=sample_id)
|
||||
sparse_transcript = tf.SparseTensor(*transcript)
|
||||
return sample_id, features, features_len, sparse_transcript
|
||||
|
||||
|
||||
|
@ -1,220 +0,0 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Image warping using sparse flow defined at control points."""
|
||||
|
||||
# The following code is from: https://github.com/tensorflow/tensorflow/blob/v1.14.0/tensorflow/contrib/image/python/ops/sparse_image_warp.py
|
||||
# But refactored for dynamic tensor shape compatibility
|
||||
# The core idea is to replace every numpy implementation with tensorflow implementation
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
from tensorflow.compat import dimension_value
|
||||
from tensorflow.contrib.image.python.ops import dense_image_warp
|
||||
from tensorflow.contrib.image.python.ops import interpolate_spline
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
|
||||
def _to_float32(value):
|
||||
return tf.cast(value, tf.float32)
|
||||
|
||||
def _to_int32(value):
|
||||
return tf.cast(value, tf.int32)
|
||||
|
||||
def _get_grid_locations(image_height, image_width):
|
||||
"""Wrapper for np.meshgrid."""
|
||||
tfv1.assert_type(image_height, tf.int32)
|
||||
tfv1.assert_type(image_width, tf.int32)
|
||||
|
||||
y_range = tf.range(image_height)
|
||||
x_range = tf.range(image_width)
|
||||
y_grid, x_grid = tf.meshgrid(y_range, x_range, indexing='ij')
|
||||
return tf.stack((y_grid, x_grid), -1)
|
||||
|
||||
|
||||
def _expand_to_minibatch(tensor, batch_size):
|
||||
"""Tile arbitrarily-sized np_array to include new batch dimension."""
|
||||
ndim = tf.size(tf.shape(tensor))
|
||||
ones = tf.ones((ndim,), tf.int32)
|
||||
|
||||
tiles = tf.concat(([batch_size], ones), 0)
|
||||
return tf.tile(tf.expand_dims(tensor, 0), tiles)
|
||||
|
||||
|
||||
def _get_boundary_locations(image_height, image_width, num_points_per_edge):
|
||||
"""Compute evenly-spaced indices along edge of image."""
|
||||
image_height_end = _to_float32(tf.math.subtract(image_height, 1))
|
||||
image_width_end = _to_float32(tf.math.subtract(image_width, 1))
|
||||
y_range = tf.linspace(0.0, image_height_end, num_points_per_edge + 2)
|
||||
x_range = tf.linspace(0.0, image_height_end, num_points_per_edge + 2)
|
||||
ys, xs = tf.meshgrid(y_range, x_range, indexing='ij')
|
||||
is_boundary = tf.logical_or(
|
||||
tf.logical_or(tf.equal(xs, 0.0), tf.equal(xs, image_width_end)),
|
||||
tf.logical_or(tf.equal(ys, 0.0), tf.equal(ys, image_height_end)))
|
||||
return tf.stack([tf.boolean_mask(ys, is_boundary), tf.boolean_mask(xs, is_boundary)], axis=-1)
|
||||
|
||||
|
||||
def _add_zero_flow_controls_at_boundary(control_point_locations,
|
||||
control_point_flows, image_height,
|
||||
image_width, boundary_points_per_edge):
|
||||
"""Add control points for zero-flow boundary conditions.
|
||||
|
||||
Augment the set of control points with extra points on the
|
||||
boundary of the image that have zero flow.
|
||||
|
||||
Args:
|
||||
control_point_locations: input control points
|
||||
control_point_flows: their flows
|
||||
image_height: image height
|
||||
image_width: image width
|
||||
boundary_points_per_edge: number of points to add in the middle of each
|
||||
edge (not including the corners).
|
||||
The total number of points added is
|
||||
4 + 4*(boundary_points_per_edge).
|
||||
|
||||
Returns:
|
||||
merged_control_point_locations: augmented set of control point locations
|
||||
merged_control_point_flows: augmented set of control point flows
|
||||
"""
|
||||
|
||||
batch_size = dimension_value(tf.shape(control_point_locations)[0])
|
||||
|
||||
boundary_point_locations = _get_boundary_locations(image_height, image_width,
|
||||
boundary_points_per_edge)
|
||||
boundary_point_shape = tf.shape(boundary_point_locations)
|
||||
boundary_point_flows = tf.zeros([boundary_point_shape[0], 2])
|
||||
|
||||
minbatch_locations = _expand_to_minibatch(boundary_point_locations, batch_size)
|
||||
type_to_use = control_point_locations.dtype
|
||||
boundary_point_locations = tf.cast(minbatch_locations, type_to_use)
|
||||
|
||||
minbatch_flows = _expand_to_minibatch(boundary_point_flows, batch_size)
|
||||
|
||||
boundary_point_flows = tf.cast(minbatch_flows, type_to_use)
|
||||
|
||||
merged_control_point_locations = tf.concat(
|
||||
[control_point_locations, boundary_point_locations], 1)
|
||||
|
||||
merged_control_point_flows = tf.concat(
|
||||
[control_point_flows, boundary_point_flows], 1)
|
||||
|
||||
return merged_control_point_locations, merged_control_point_flows
|
||||
|
||||
|
||||
def sparse_image_warp(image,
|
||||
source_control_point_locations,
|
||||
dest_control_point_locations,
|
||||
interpolation_order=2,
|
||||
regularization_weight=0.0,
|
||||
num_boundary_points=0,
|
||||
name='sparse_image_warp'):
|
||||
"""Image warping using correspondences between sparse control points.
|
||||
|
||||
Apply a non-linear warp to the image, where the warp is specified by
|
||||
the source and destination locations of a (potentially small) number of
|
||||
control points. First, we use a polyharmonic spline
|
||||
(`tf.contrib.image.interpolate_spline`) to interpolate the displacements
|
||||
between the corresponding control points to a dense flow field.
|
||||
Then, we warp the image using this dense flow field
|
||||
(`tf.contrib.image.dense_image_warp`).
|
||||
|
||||
Let t index our control points. For regularization_weight=0, we have:
|
||||
warped_image[b, dest_control_point_locations[b, t, 0],
|
||||
dest_control_point_locations[b, t, 1], :] =
|
||||
image[b, source_control_point_locations[b, t, 0],
|
||||
source_control_point_locations[b, t, 1], :].
|
||||
|
||||
For regularization_weight > 0, this condition is met approximately, since
|
||||
regularized interpolation trades off smoothness of the interpolant vs.
|
||||
reconstruction of the interpolant at the control points.
|
||||
See `tf.contrib.image.interpolate_spline` for further documentation of the
|
||||
interpolation_order and regularization_weight arguments.
|
||||
|
||||
|
||||
Args:
|
||||
image: `[batch, height, width, channels]` float `Tensor`
|
||||
source_control_point_locations: `[batch, num_control_points, 2]` float
|
||||
`Tensor`
|
||||
dest_control_point_locations: `[batch, num_control_points, 2]` float
|
||||
`Tensor`
|
||||
interpolation_order: polynomial order used by the spline interpolation
|
||||
regularization_weight: weight on smoothness regularizer in interpolation
|
||||
num_boundary_points: How many zero-flow boundary points to include at
|
||||
each image edge.Usage:
|
||||
num_boundary_points=0: don't add zero-flow points
|
||||
num_boundary_points=1: 4 corners of the image
|
||||
num_boundary_points=2: 4 corners and one in the middle of each edge
|
||||
(8 points total)
|
||||
num_boundary_points=n: 4 corners and n-1 along each edge
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Note that image and offsets can be of type tf.half, tf.float32, or
|
||||
tf.float64, and do not necessarily have to be the same type.
|
||||
|
||||
Returns:
|
||||
warped_image: `[batch, height, width, channels]` float `Tensor` with same
|
||||
type as input image.
|
||||
flow_field: `[batch, height, width, 2]` float `Tensor` containing the dense
|
||||
flow field produced by the interpolation.
|
||||
"""
|
||||
|
||||
image = ops.convert_to_tensor(image)
|
||||
source_control_point_locations = ops.convert_to_tensor(
|
||||
source_control_point_locations)
|
||||
dest_control_point_locations = ops.convert_to_tensor(
|
||||
dest_control_point_locations)
|
||||
|
||||
control_point_flows = (
|
||||
dest_control_point_locations - source_control_point_locations)
|
||||
|
||||
clamp_boundaries = num_boundary_points > 0
|
||||
boundary_points_per_edge = num_boundary_points - 1
|
||||
|
||||
with ops.name_scope(name):
|
||||
image_shape = tf.shape(image)
|
||||
batch_size, image_height, image_width = image_shape[0], image_shape[1], image_shape[2]
|
||||
|
||||
# This generates the dense locations where the interpolant
|
||||
# will be evaluated.
|
||||
grid_locations = _get_grid_locations(image_height, image_width)
|
||||
|
||||
flattened_grid_locations = tf.reshape(grid_locations,
|
||||
[tf.multiply(image_height, image_width), 2])
|
||||
|
||||
# flattened_grid_locations = constant_op.constant(
|
||||
# _expand_to_minibatch(flattened_grid_locations, batch_size), image.dtype)
|
||||
flattened_grid_locations = _expand_to_minibatch(flattened_grid_locations, batch_size)
|
||||
flattened_grid_locations = tf.cast(flattened_grid_locations, dtype=image.dtype)
|
||||
|
||||
if clamp_boundaries:
|
||||
(dest_control_point_locations,
|
||||
control_point_flows) = _add_zero_flow_controls_at_boundary(
|
||||
dest_control_point_locations, control_point_flows, image_height,
|
||||
image_width, boundary_points_per_edge)
|
||||
|
||||
flattened_flows = interpolate_spline.interpolate_spline(
|
||||
dest_control_point_locations, control_point_flows,
|
||||
flattened_grid_locations, interpolation_order, regularization_weight)
|
||||
|
||||
dense_flows = array_ops.reshape(flattened_flows,
|
||||
[batch_size, image_height, image_width, 2])
|
||||
|
||||
warped_image = dense_image_warp.dense_image_warp(image, dense_flows)
|
||||
|
||||
return warped_image, dense_flows
|
Loading…
x
Reference in New Issue
Block a user