Follow-up on PR comments; removed warp augmentation; split pitch_and_tempo augmentation

This commit is contained in:
Tilman Kamp 2020-06-16 11:07:57 +02:00
parent ea21c7d24e
commit 5b6de213d8
4 changed files with 61 additions and 355 deletions

View File

@ -311,7 +311,7 @@ Augmentations are applied in the following order:
4. **features** domain: The sample's mel spectrogram features are represented as a tensor.
Within a single domain, augmentations are applied in the same order as they appear in the command-line (the **warp** augmentation being the only exception, as it is always applied first when enabled).
Within a single domain, augmentations are applied in the same order as they appear in the command-line.
Sample domain augmentations
@ -365,17 +365,15 @@ Sample domain augmentations
Spectrogram domain augmentations
--------------------------------
**Pitch and tempo augmentation** ``--augment pitch_and_tempo[p=<float>,pitch=<float-range>,tempo=<float-range>]``
Scales spectrogram on time and frequency axis and thus changes pitch and playback tempo.
**Pitch augmentation** ``--augment pitch[p=<float>,pitch=<float-range>]``
Scales spectrogram on frequency axis and thus changes pitch.
* **p**: probability value between 0.0 (never) and 1.0 (always) if a given sample gets augmented by this method
* **pitch**: pitch factor by with the frequency axis is scaled (e.g. a value of 2.0 will raise audio frequency by one octave)
* **tempo**: tempo factor by which the time axis is stretched or shrunken (e.g. a value of 2.0 will double playback tempo)
**Speed augmentation** ``--augment speed[p=<float>,factor=<float-range>]``
**Tempo augmentation** ``--augment tempo[p=<float>,factor=<float-range>]``
Scales spectrogram on time axis and thus changes playback tempo.
* **p**: probability value between 0.0 (never) and 1.0 (always) if a given sample gets augmented by this method
@ -383,22 +381,6 @@ Spectrogram domain augmentations
* **factor**: speed factor by which the time axis is stretched or shrunken (e.g. a value of 2.0 will double playback tempo)
**Warp augmentation** ``--augment warp[p=<float>,shift=<float-range>,order=<int-range>,nbp=<int-range>,ncp=<int-range>,regularization_weight=<float>]``
Applies a non-linear image warp to the spectrogram, where the warp is specified by the source and destination locations of a (potentially small) number of control points. Of all specified spectrogram augmentations this one will always be applied first. See the SpecAugment paper for more details - https://arxiv.org/abs/1904.08779
* **p**: probability value between 0.0 (never) and 1.0 (always) if a given sample gets augmented by this method
* **shift**: maximum shift distance of control points on time axis in ms
* **order**: polynomial order used by the spline interpolation
* **nbp**: how many zero-flow boundary points to include at each spectrogram edge
* **ncp**: how many control points to warp inside the spectrogram
* **regularization_weight**: weight on smoothness regularizer in interpolation
**Frequency mask augmentation** ``--augment frequency_mask[p=<float>,n=<int-range>,size=<int-range>]``
Sets frequency-intervals within the augmented samples to zero (silence) at random frequencies. See the SpecAugment paper for more details - https://arxiv.org/abs/1904.08779
@ -467,9 +449,8 @@ Example training with all augmentations:
--augment resample[p=0.1,rate=12000:8000~4000] \
--augment codec[p=0.1,bitrate=48000:16000] \
--augment volume[p=0.1,dbfs=-10:-40] \
--augment pitch_and_tempo[p=0.1,pitch=1~0.2,tempo=1~0.2] \
--augment speed[p=0.1,factor=1~0.5] \
--augment warp[p=0.1,shift=30:60~20,ncp=4~3] \
--augment pitch[p=0.1,pitch=1~0.2] \
--augment tempo[p=0.1,factor=1~0.5] \
--augment frequency_mask[p=0.1,n=1:3,size=1:5] \
--augment time_mask[p=0.1,domain=signal,n=3:10~2,size=50:100~40] \
--augment dropout[p=0.1,rate=0.05] \

View File

@ -8,6 +8,7 @@ import numpy as np
from multiprocessing import Queue, Process
from .audio import gain_db_to_ratio, max_dbfs, normalize_audio, AUDIO_TYPE_NP, AUDIO_TYPE_PCM, AUDIO_TYPE_OPUS
from .helpers import LimitingPool, int_range, float_range, pick_value_from_range, tf_pick_value_from_range, MEGABYTE
from .sample_collections import samples_from_source
BUFFER_SIZE = 1 * MEGABYTE
SPEC_PARSER = re.compile(r'^(?P<cls>[a-z_]+)(\[(?P<params>.*)\])?$')
@ -36,21 +37,25 @@ class GraphAugmentation(Augmentation):
raise ValueError('Unsupported augmentation domain: {}'.format(domain))
self.domain = domain
def apply(self, tensor, clock=0.0):
def apply(self, tensor, transcript=None, clock=0.0):
raise NotImplementedError
def apply_with_probability(self, tensor, clock=0.0):
def apply_with_probability(self, tensor, transcript=None, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel
rv = tf.random.stateless_uniform([], seed=(clock * tf.int32.min, clock * tf.int32.max))
return tf.cond(tf.less(rv, self.probability),
lambda: self.apply(tensor, clock=clock),
lambda: self.apply(tensor, transcript=transcript, clock=clock),
lambda: tensor)
def maybe_apply(self, domain, tensor, clock=0.0):
def maybe_apply(self, domain, tensor, transcript=None, clock=0.0):
if domain == self.domain:
return self.apply_with_probability(tensor, clock=clock)
return self.apply_with_probability(tensor, transcript=transcript, clock=clock)
return tensor
def units_per_ms(self):
from .flags import FLAGS # pylint: disable=import-outside-toplevel
return FLAGS.audio_sample_rate / 1000.0 if self.domain == 'signal' else 1.0 / FLAGS.feature_win_step
def parse_augmentation(augmentation_spec):
"""
@ -103,7 +108,7 @@ def parse_augmentations(augmentation_specs):
return [] if augmentation_specs is None else list(map(parse_augmentation, augmentation_specs))
def apply_graph_augmentations(domain, tensor, augmentations, clock=0.0):
def apply_graph_augmentations(domain, tensor, augmentations, transcript=None, clock=0.0):
"""
Augments training sample tensor of a certain domain with matching augmentations of passed list.
@ -115,6 +120,7 @@ def apply_graph_augmentations(domain, tensor, augmentations, clock=0.0):
Tensor to apply augmentations to.
augmentations : list of augmentation class instances from util.augmentations.*.
List of augmentations of which only the spectrogram ones will get applied to the samples.
transcript : SparseTensor
clock : Tensor of type float32
Time indicator for augmentation value-ranges. Running from 0.0 (start of training) to 1.0 (end of training).
@ -124,13 +130,9 @@ def apply_graph_augmentations(domain, tensor, augmentations, clock=0.0):
The augmented spectrogram
"""
if augmentations is not None:
# Warp has to come before any spectrogram masking
for augmentation in augmentations:
if isinstance(augmentation, Warp):
tensor = augmentation.maybe_apply(domain, tensor, clock=clock)
for augmentation in augmentations:
if isinstance(augmentation, GraphAugmentation) and not isinstance(augmentation, Warp):
tensor = augmentation.maybe_apply(domain, tensor, clock=clock)
if isinstance(augmentation, GraphAugmentation):
tensor = augmentation.maybe_apply(domain, tensor, transcript=transcript, clock=clock)
return tensor
@ -204,7 +206,7 @@ def apply_sample_augmentations(samples,
if final_clock is not None:
assert 0.0 <= final_clock <= 1.0
assert clock <= final_clock
augmentations = list(filter(lambda aug: isinstance(aug, SampleAugmentation), augmentations))
augmentations = [aug for aug in augmentations if isinstance(aug, SampleAugmentation)]
try:
for augmentation in augmentations:
augmentation.start(buffering=buffering)
@ -229,8 +231,6 @@ def _enqueue_overlay_samples(sample_source, queue, buffering=BUFFER_SIZE):
It loads the (raw and still compressed) data and provides it to the actual augmentation workers.
These are then doing decompression, potential conversion and overlaying in parallel.
"""
# preventing cyclic import problems
from .sample_collections import samples_from_source # pylint: disable=import-outside-toplevel
samples = samples_from_source(sample_source, buffering=buffering, labeled=False)
while True:
for sample in samples:
@ -238,7 +238,7 @@ def _enqueue_overlay_samples(sample_source, queue, buffering=BUFFER_SIZE):
class Overlay(SampleAugmentation):
"""See "Overlay augmentation" in TRAINING.rst"""
"""See "Overlay augmentation" in training documentation"""
def __init__(self, source, p=1.0, snr=3.0, layers=1):
super(Overlay, self).__init__(p)
self.source = source
@ -288,7 +288,7 @@ class Overlay(SampleAugmentation):
class Codec(SampleAugmentation):
"""See "Codec augmentation" in TRAINING.rst"""
"""See "Codec augmentation" in training documentation"""
def __init__(self, p=1.0, bitrate=3200):
super(Codec, self).__init__(p)
self.bitrate = int_range(bitrate)
@ -300,7 +300,7 @@ class Codec(SampleAugmentation):
class Reverb(SampleAugmentation):
"""See "Reverb augmentation" in TRAINING.rst"""
"""See "Reverb augmentation" in training documentation"""
def __init__(self, p=1.0, delay=20.0, decay=10.0):
super(Reverb, self).__init__(p)
self.delay = float_range(delay)
@ -330,7 +330,7 @@ class Reverb(SampleAugmentation):
class Resample(SampleAugmentation):
"""See "Resample augmentation" in TRAINING.rst"""
"""See "Resample augmentation" in training documentation"""
def __init__(self, p=1.0, rate=8000):
super(Resample, self).__init__(p)
self.rate = int_range(rate)
@ -350,7 +350,7 @@ class Resample(SampleAugmentation):
class Volume(SampleAugmentation):
"""See "Volume augmentation" in TRAINING.rst"""
"""See "Volume augmentation" in training documentation"""
def __init__(self, p=1.0, dbfs=3.0103):
super(Volume, self).__init__(p)
self.target_dbfs = float_range(dbfs)
@ -361,25 +361,22 @@ class Volume(SampleAugmentation):
sample.audio = normalize_audio(sample.audio, dbfs=target_dbfs)
class PitchAndTempo(GraphAugmentation):
"""See "Pitch and tempo augmentation" in TRAINING.rst"""
def __init__(self, p=1.0, tempo=1.2, pitch=(1.075, 1.075, 0.125)):
super(PitchAndTempo, self).__init__(p, domain='spectrogram')
self.tempo = float_range(tempo)
class Pitch(GraphAugmentation):
"""See "Pitch augmentation" in training documentation"""
def __init__(self, p=1.0, pitch=(1.075, 1.075, 0.125)):
super(Pitch, self).__init__(p, domain='spectrogram')
self.pitch = float_range(pitch)
def apply(self, tensor, clock=0.0):
def apply(self, tensor, transcript=None, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel
original_shape = tf.shape(tensor)
pitch = tf_pick_value_from_range(self.pitch, clock=clock)
tempo = tf.math.maximum(1.0, tf_pick_value_from_range(self.tempo, clock=clock))
new_freq_size = tf.cast(tf.cast(original_shape[2], tf.float32) * pitch, tf.int32)
new_time_size = tf.cast(tf.cast(original_shape[1], tf.float32) / tempo, tf.int32)
spectrogram_aug = tf.image.resize_bilinear(tf.expand_dims(tensor, -1), [new_time_size, new_freq_size])
spectrogram_aug = tf.image.resize_bilinear(tf.expand_dims(tensor, -1), [original_shape[1], new_freq_size])
spectrogram_aug = tf.image.crop_to_bounding_box(spectrogram_aug,
offset_height=0,
offset_width=0,
target_height=tf.shape(spectrogram_aug)[1],
target_height=original_shape[1],
target_width=tf.math.minimum(original_shape[2], new_freq_size))
spectrogram_aug = tf.cond(pitch < 1,
lambda: tf.image.pad_to_bounding_box(spectrogram_aug,
@ -391,82 +388,34 @@ class PitchAndTempo(GraphAugmentation):
return spectrogram_aug[:, :, :, 0]
class Speed(GraphAugmentation):
"""See "Speed augmentation" in TRAINING.rst"""
def __init__(self, p=1.0, factor=1.1):
super(Speed, self).__init__(p, domain='spectrogram')
class Tempo(GraphAugmentation):
"""See "Tempo augmentation" in training documentation"""
def __init__(self, p=1.0, factor=1.1, max_time=-1):
super(Tempo, self).__init__(p, domain='spectrogram')
self.factor = float_range(factor)
self.max_time = float(max_time)
def apply(self, tensor, clock=0.0):
def apply(self, tensor, transcript=None, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel
factor = tf_pick_value_from_range(self.factor, clock=clock)
original_shape = tf.shape(tensor)
new_time_size = tf.cast(tf.cast(original_shape[1], tf.float32) / factor, tf.int32)
if transcript is not None:
new_time_size = tf.math.maximum(new_time_size, tf.shape(transcript)[1])
if self.max_time > 0:
new_time_size = tf.math.minimum(new_time_size, tf.cast(self.max_time * self.units_per_ms(), tf.int32))
spectrogram_aug = tf.image.resize_bilinear(tf.expand_dims(tensor, -1), [new_time_size, original_shape[2]])
return spectrogram_aug[:, :, :, 0]
class Warp(GraphAugmentation):
"""See "Warp augmentation" in TRAINING.rst"""
def __init__(self, p=1.0, shift=100.0, order=3, nbp=1, ncp=1, regularization_weight=0.0):
super(Warp, self).__init__(p, domain='spectrogram')
self.shift = float_range(shift)
self.order = int_range(order)
self.nbp = int_range(nbp)
self.ncp = int_range(ncp)
# Making this a value-range is impossible, as it would get a tensor which would downstream be used as parameter
# of a comparison inside tensorflow.contrib.image.python.ops.interpolate_spline. This is not supported.
self.regularization_weight = float(regularization_weight)
def apply(self, tensor, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel
from .flags import FLAGS # pylint: disable=import-outside-toplevel
from .sparse_image_warp import sparse_image_warp # pylint: disable=import-outside-toplevel
# reshape to fit `sparse_image_warp`'s input shape (1, time steps, freq, 1), batch_size must be 1
expanded_spectrogram = tf.expand_dims(tensor, -1)
original_shape = tf.shape(expanded_spectrogram)
tau, freq_size = original_shape[1], original_shape[2]
seed = (clock * tf.int32.min, clock * tf.int32.max)
shift = tf_pick_value_from_range(self.shift, clock=clock)
shift *= FLAGS.audio_sample_rate / (FLAGS.feature_win_step * 1000.0) # number of windows
shift = tf.math.minimum(tf.cast(shift, dtype=tf.int32), tf.math.floordiv(tau, 2) - 1) # to protect short audio
nbp = tf_pick_value_from_range(self.nbp, clock=clock)
ncp = tf_pick_value_from_range(self.ncp, clock=clock)
# workaround for missing stateless shuffle support
frequencies = tf.random.stateless_uniform([2 * ncp], seed, minval=1, maxval=freq_size - 2, dtype=tf.int32)
frequencies = tf.unique(tf.concat([frequencies, tf.range(1, limit=freq_size - 3)], axis=0))[0][0:ncp]
source_max = tau - shift
source_min = tf.math.minimum(source_max - ncp, shift)
# workaround for missing stateless shuffle support
src_times = tf.random.stateless_uniform([2 * ncp], seed, minval=source_min, maxval=source_max, dtype=tf.int32)
src_times = tf.unique(tf.concat([src_times, tf.range(1, limit=source_max)], axis=0))[0][0:ncp]
dst_times = src_times + tf.random.stateless_uniform([ncp], seed, minval=-shift, maxval=shift, dtype=tf.int32)
scp_locations = tf.cast([tf.transpose(tf.stack([src_times, frequencies]))], dtype=tf.float32)
dcp_locations = tf.cast([tf.transpose(tf.stack([dst_times, frequencies]))], dtype=tf.float32)
order = tf_pick_value_from_range(self.order, clock=clock)
order = tf.math.maximum(3, order) # prevents "Input matrix is not invertible." exception
order = tf.cast(order, tf.float32)
spectrogram_aug, _ = sparse_image_warp(expanded_spectrogram,
source_control_point_locations=scp_locations,
dest_control_point_locations=dcp_locations,
interpolation_order=order,
regularization_weight=self.regularization_weight,
num_boundary_points=nbp)
return tf.reshape(spectrogram_aug, shape=(1, -1, freq_size))
class FrequencyMask(GraphAugmentation):
"""See "Frequency mask augmentation" in TRAINING.rst"""
"""See "Frequency mask augmentation" in training documentation"""
def __init__(self, p=1.0, n=3, size=2):
super(FrequencyMask, self).__init__(p, domain='spectrogram')
self.n = int_range(n) # pylint: disable=invalid-name
self.size = int_range(size)
def apply(self, tensor, clock=0.0):
def apply(self, tensor, transcript=None, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel
time_max = tf.shape(tensor)[1]
freq_max = tf.shape(tensor)[2]
@ -486,25 +435,20 @@ class FrequencyMask(GraphAugmentation):
class TimeMask(GraphAugmentation):
"""See "Time mask augmentation" in TRAINING.rst"""
"""See "Time mask augmentation" in training documentation"""
def __init__(self, p=1.0, domain='spectrogram', n=3, size=10.0):
super(TimeMask, self).__init__(p, domain=domain)
self.n = int_range(n) # pylint: disable=invalid-name
self.size = float_range(size)
def apply(self, tensor, clock=0.0):
def apply(self, tensor, transcript=None, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel
from .flags import FLAGS # pylint: disable=import-outside-toplevel
time_factor = FLAGS.audio_sample_rate / 1000.0 # samples per ms
if self.domain != 'signal':
time_factor /= FLAGS.feature_win_step # windows per ms
time_max = tf.shape(tensor)[0] if self.domain == 'signal' else tf.shape(tensor)[1]
time_max = tf.shape(tensor)[0 if self.domain == 'signal' else 1]
n = tf_pick_value_from_range(self.n, clock=clock)
def body(i, augmented):
size = tf.cast(tf_pick_value_from_range(self.size, clock=clock) * time_factor, dtype=tf.int32)
size = tf.cast(tf_pick_value_from_range(self.size, clock=clock) * self.units_per_ms(), dtype=tf.int32)
size = tf.math.maximum(1, tf.math.minimum(time_max - 1, size))
tf.print(size)
seed = tf.cast(clock * tf.int32.max, tf.int32) - i
t0 = tf.random.stateless_uniform((), (-seed, seed), minval=0, maxval=time_max - size, dtype=tf.dtypes.int32)
rest = time_max - t0 - size
@ -521,12 +465,12 @@ class TimeMask(GraphAugmentation):
class Dropout(GraphAugmentation):
"""See "Dropout augmentation" in TRAINING.rst"""
"""See "Dropout augmentation" in training documentation"""
def __init__(self, p=1.0, domain='spectrogram', rate=0.05):
super(Dropout, self).__init__(p, domain=domain)
self.rate = float_range(rate)
def apply(self, tensor, clock=0.0):
def apply(self, tensor, transcript=None, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel
rate = tf_pick_value_from_range(self.rate, clock=clock)
rate = tf.math.maximum(0.0, rate)
@ -539,12 +483,12 @@ class Dropout(GraphAugmentation):
class Add(GraphAugmentation):
"""See "Add augmentation" in TRAINING.rst"""
"""See "Add augmentation" in training documentation"""
def __init__(self, p=1.0, domain='features', stddev=5):
super(Add, self).__init__(p, domain=domain)
self.stddev = float_range(stddev)
def apply(self, tensor, clock=0.0):
def apply(self, tensor, transcript=None, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel
stddev = tf_pick_value_from_range(self.stddev, clock=clock)
seed = (clock * tf.int32.min, clock * tf.int32.max)
@ -552,12 +496,12 @@ class Add(GraphAugmentation):
class Multiply(GraphAugmentation):
"""See "Multiply augmentation" in TRAINING.rst"""
"""See "Multiply augmentation" in training documentation"""
def __init__(self, p=1.0, domain='features', stddev=5):
super(Multiply, self).__init__(p, domain=domain)
self.stddev = float_range(stddev)
def apply(self, tensor, clock=0.0):
def apply(self, tensor, transcript=None, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel
stddev = tf_pick_value_from_range(self.stddev, clock=clock)
seed = (clock * tf.int32.min, clock * tf.int32.max)

View File

@ -18,7 +18,7 @@ from .sample_collections import samples_from_sources
from .helpers import remember_exception, MEGABYTE
def audio_to_features(audio, sample_rate, clock=0.0, train_phase=False, augmentations=None, sample_id=None):
def audio_to_features(audio, sample_rate, transcript=None, clock=0.0, train_phase=False, augmentations=None, sample_id=None):
if train_phase:
# We need the lambdas to make TensorFlow happy.
# pylint: disable=unnecessary-lambda
@ -29,7 +29,7 @@ def audio_to_features(audio, sample_rate, clock=0.0, train_phase=False, augmenta
name='matching_sample_rate')
if train_phase and augmentations is not None:
audio = apply_graph_augmentations('signal', audio, augmentations, clock=clock)
audio = apply_graph_augmentations('signal', audio, augmentations, transcript=transcript, clock=clock)
spectrogram = contrib_audio.audio_spectrogram(audio,
window_size=Config.audio_window_samples,
@ -37,7 +37,7 @@ def audio_to_features(audio, sample_rate, clock=0.0, train_phase=False, augmenta
magnitude_squared=True)
if train_phase and augmentations is not None:
spectrogram = apply_graph_augmentations('spectrogram', spectrogram, augmentations, clock=clock)
spectrogram = apply_graph_augmentations('spectrogram', spectrogram, augmentations, transcript=transcript, clock=clock)
features = contrib_audio.mfcc(spectrogram=spectrogram,
sample_rate=sample_rate,
@ -46,7 +46,7 @@ def audio_to_features(audio, sample_rate, clock=0.0, train_phase=False, augmenta
features = tf.reshape(features, [-1, Config.n_input])
if train_phase and augmentations is not None:
features = apply_graph_augmentations('features', features, augmentations, clock=clock)
features = apply_graph_augmentations('features', features, augmentations, transcript=transcript, clock=clock)
return features, tf.shape(input=features)[0]
@ -64,13 +64,14 @@ def audiofile_to_features(wav_filename, clock=0.0, train_phase=False, augmentati
def entry_to_features(sample_id, audio, sample_rate, transcript, clock, train_phase=False, augmentations=None):
# https://bugs.python.org/issue32117
sparse_transcript = tf.SparseTensor(*transcript)
features, features_len = audio_to_features(audio,
sample_rate,
transcript=sparse_transcript,
clock=clock,
train_phase=train_phase,
augmentations=augmentations,
sample_id=sample_id)
sparse_transcript = tf.SparseTensor(*transcript)
return sample_id, features, features_len, sparse_transcript

View File

@ -1,220 +0,0 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Image warping using sparse flow defined at control points."""
# The following code is from: https://github.com/tensorflow/tensorflow/blob/v1.14.0/tensorflow/contrib/image/python/ops/sparse_image_warp.py
# But refactored for dynamic tensor shape compatibility
# The core idea is to replace every numpy implementation with tensorflow implementation
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
from tensorflow.compat import dimension_value
from tensorflow.contrib.image.python.ops import dense_image_warp
from tensorflow.contrib.image.python.ops import interpolate_spline
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
def _to_float32(value):
return tf.cast(value, tf.float32)
def _to_int32(value):
return tf.cast(value, tf.int32)
def _get_grid_locations(image_height, image_width):
"""Wrapper for np.meshgrid."""
tfv1.assert_type(image_height, tf.int32)
tfv1.assert_type(image_width, tf.int32)
y_range = tf.range(image_height)
x_range = tf.range(image_width)
y_grid, x_grid = tf.meshgrid(y_range, x_range, indexing='ij')
return tf.stack((y_grid, x_grid), -1)
def _expand_to_minibatch(tensor, batch_size):
"""Tile arbitrarily-sized np_array to include new batch dimension."""
ndim = tf.size(tf.shape(tensor))
ones = tf.ones((ndim,), tf.int32)
tiles = tf.concat(([batch_size], ones), 0)
return tf.tile(tf.expand_dims(tensor, 0), tiles)
def _get_boundary_locations(image_height, image_width, num_points_per_edge):
"""Compute evenly-spaced indices along edge of image."""
image_height_end = _to_float32(tf.math.subtract(image_height, 1))
image_width_end = _to_float32(tf.math.subtract(image_width, 1))
y_range = tf.linspace(0.0, image_height_end, num_points_per_edge + 2)
x_range = tf.linspace(0.0, image_height_end, num_points_per_edge + 2)
ys, xs = tf.meshgrid(y_range, x_range, indexing='ij')
is_boundary = tf.logical_or(
tf.logical_or(tf.equal(xs, 0.0), tf.equal(xs, image_width_end)),
tf.logical_or(tf.equal(ys, 0.0), tf.equal(ys, image_height_end)))
return tf.stack([tf.boolean_mask(ys, is_boundary), tf.boolean_mask(xs, is_boundary)], axis=-1)
def _add_zero_flow_controls_at_boundary(control_point_locations,
control_point_flows, image_height,
image_width, boundary_points_per_edge):
"""Add control points for zero-flow boundary conditions.
Augment the set of control points with extra points on the
boundary of the image that have zero flow.
Args:
control_point_locations: input control points
control_point_flows: their flows
image_height: image height
image_width: image width
boundary_points_per_edge: number of points to add in the middle of each
edge (not including the corners).
The total number of points added is
4 + 4*(boundary_points_per_edge).
Returns:
merged_control_point_locations: augmented set of control point locations
merged_control_point_flows: augmented set of control point flows
"""
batch_size = dimension_value(tf.shape(control_point_locations)[0])
boundary_point_locations = _get_boundary_locations(image_height, image_width,
boundary_points_per_edge)
boundary_point_shape = tf.shape(boundary_point_locations)
boundary_point_flows = tf.zeros([boundary_point_shape[0], 2])
minbatch_locations = _expand_to_minibatch(boundary_point_locations, batch_size)
type_to_use = control_point_locations.dtype
boundary_point_locations = tf.cast(minbatch_locations, type_to_use)
minbatch_flows = _expand_to_minibatch(boundary_point_flows, batch_size)
boundary_point_flows = tf.cast(minbatch_flows, type_to_use)
merged_control_point_locations = tf.concat(
[control_point_locations, boundary_point_locations], 1)
merged_control_point_flows = tf.concat(
[control_point_flows, boundary_point_flows], 1)
return merged_control_point_locations, merged_control_point_flows
def sparse_image_warp(image,
source_control_point_locations,
dest_control_point_locations,
interpolation_order=2,
regularization_weight=0.0,
num_boundary_points=0,
name='sparse_image_warp'):
"""Image warping using correspondences between sparse control points.
Apply a non-linear warp to the image, where the warp is specified by
the source and destination locations of a (potentially small) number of
control points. First, we use a polyharmonic spline
(`tf.contrib.image.interpolate_spline`) to interpolate the displacements
between the corresponding control points to a dense flow field.
Then, we warp the image using this dense flow field
(`tf.contrib.image.dense_image_warp`).
Let t index our control points. For regularization_weight=0, we have:
warped_image[b, dest_control_point_locations[b, t, 0],
dest_control_point_locations[b, t, 1], :] =
image[b, source_control_point_locations[b, t, 0],
source_control_point_locations[b, t, 1], :].
For regularization_weight > 0, this condition is met approximately, since
regularized interpolation trades off smoothness of the interpolant vs.
reconstruction of the interpolant at the control points.
See `tf.contrib.image.interpolate_spline` for further documentation of the
interpolation_order and regularization_weight arguments.
Args:
image: `[batch, height, width, channels]` float `Tensor`
source_control_point_locations: `[batch, num_control_points, 2]` float
`Tensor`
dest_control_point_locations: `[batch, num_control_points, 2]` float
`Tensor`
interpolation_order: polynomial order used by the spline interpolation
regularization_weight: weight on smoothness regularizer in interpolation
num_boundary_points: How many zero-flow boundary points to include at
each image edge.Usage:
num_boundary_points=0: don't add zero-flow points
num_boundary_points=1: 4 corners of the image
num_boundary_points=2: 4 corners and one in the middle of each edge
(8 points total)
num_boundary_points=n: 4 corners and n-1 along each edge
name: A name for the operation (optional).
Note that image and offsets can be of type tf.half, tf.float32, or
tf.float64, and do not necessarily have to be the same type.
Returns:
warped_image: `[batch, height, width, channels]` float `Tensor` with same
type as input image.
flow_field: `[batch, height, width, 2]` float `Tensor` containing the dense
flow field produced by the interpolation.
"""
image = ops.convert_to_tensor(image)
source_control_point_locations = ops.convert_to_tensor(
source_control_point_locations)
dest_control_point_locations = ops.convert_to_tensor(
dest_control_point_locations)
control_point_flows = (
dest_control_point_locations - source_control_point_locations)
clamp_boundaries = num_boundary_points > 0
boundary_points_per_edge = num_boundary_points - 1
with ops.name_scope(name):
image_shape = tf.shape(image)
batch_size, image_height, image_width = image_shape[0], image_shape[1], image_shape[2]
# This generates the dense locations where the interpolant
# will be evaluated.
grid_locations = _get_grid_locations(image_height, image_width)
flattened_grid_locations = tf.reshape(grid_locations,
[tf.multiply(image_height, image_width), 2])
# flattened_grid_locations = constant_op.constant(
# _expand_to_minibatch(flattened_grid_locations, batch_size), image.dtype)
flattened_grid_locations = _expand_to_minibatch(flattened_grid_locations, batch_size)
flattened_grid_locations = tf.cast(flattened_grid_locations, dtype=image.dtype)
if clamp_boundaries:
(dest_control_point_locations,
control_point_flows) = _add_zero_flow_controls_at_boundary(
dest_control_point_locations, control_point_flows, image_height,
image_width, boundary_points_per_edge)
flattened_flows = interpolate_spline.interpolate_spline(
dest_control_point_locations, control_point_flows,
flattened_grid_locations, interpolation_order, regularization_weight)
dense_flows = array_ops.reshape(flattened_flows,
[batch_size, image_height, image_width, 2])
warped_image = dense_image_warp.dense_image_warp(image, dense_flows)
return warped_image, dense_flows