From 5b6de213d8dc8893647261f38dcfaf454854225c Mon Sep 17 00:00:00 2001 From: Tilman Kamp <5991088+tilmankamp@users.noreply.github.com> Date: Tue, 16 Jun 2020 11:07:57 +0200 Subject: [PATCH] Follow-up on PR comments; removed warp augmentation; split pitch_and_tempo augmentation --- doc/TRAINING.rst | 31 +-- .../deepspeech_training/util/augmentations.py | 154 ++++-------- training/deepspeech_training/util/feeding.py | 11 +- .../util/sparse_image_warp.py | 220 ------------------ 4 files changed, 61 insertions(+), 355 deletions(-) delete mode 100644 training/deepspeech_training/util/sparse_image_warp.py diff --git a/doc/TRAINING.rst b/doc/TRAINING.rst index 17230914..be904120 100644 --- a/doc/TRAINING.rst +++ b/doc/TRAINING.rst @@ -311,7 +311,7 @@ Augmentations are applied in the following order: 4. **features** domain: The sample's mel spectrogram features are represented as a tensor. -Within a single domain, augmentations are applied in the same order as they appear in the command-line (the **warp** augmentation being the only exception, as it is always applied first when enabled). +Within a single domain, augmentations are applied in the same order as they appear in the command-line. Sample domain augmentations @@ -365,17 +365,15 @@ Sample domain augmentations Spectrogram domain augmentations -------------------------------- -**Pitch and tempo augmentation** ``--augment pitch_and_tempo[p=,pitch=,tempo=]`` - Scales spectrogram on time and frequency axis and thus changes pitch and playback tempo. +**Pitch augmentation** ``--augment pitch[p=,pitch=]`` + Scales spectrogram on frequency axis and thus changes pitch. * **p**: probability value between 0.0 (never) and 1.0 (always) if a given sample gets augmented by this method * **pitch**: pitch factor by with the frequency axis is scaled (e.g. a value of 2.0 will raise audio frequency by one octave) - * **tempo**: tempo factor by which the time axis is stretched or shrunken (e.g. a value of 2.0 will double playback tempo) - -**Speed augmentation** ``--augment speed[p=,factor=]`` +**Tempo augmentation** ``--augment tempo[p=,factor=]`` Scales spectrogram on time axis and thus changes playback tempo. * **p**: probability value between 0.0 (never) and 1.0 (always) if a given sample gets augmented by this method @@ -383,22 +381,6 @@ Spectrogram domain augmentations * **factor**: speed factor by which the time axis is stretched or shrunken (e.g. a value of 2.0 will double playback tempo) -**Warp augmentation** ``--augment warp[p=,shift=,order=,nbp=,ncp=,regularization_weight=]`` - Applies a non-linear image warp to the spectrogram, where the warp is specified by the source and destination locations of a (potentially small) number of control points. Of all specified spectrogram augmentations this one will always be applied first. See the SpecAugment paper for more details - https://arxiv.org/abs/1904.08779 - - * **p**: probability value between 0.0 (never) and 1.0 (always) if a given sample gets augmented by this method - - * **shift**: maximum shift distance of control points on time axis in ms - - * **order**: polynomial order used by the spline interpolation - - * **nbp**: how many zero-flow boundary points to include at each spectrogram edge - - * **ncp**: how many control points to warp inside the spectrogram - - * **regularization_weight**: weight on smoothness regularizer in interpolation - - **Frequency mask augmentation** ``--augment frequency_mask[p=,n=,size=]`` Sets frequency-intervals within the augmented samples to zero (silence) at random frequencies. See the SpecAugment paper for more details - https://arxiv.org/abs/1904.08779 @@ -467,9 +449,8 @@ Example training with all augmentations: --augment resample[p=0.1,rate=12000:8000~4000] \ --augment codec[p=0.1,bitrate=48000:16000] \ --augment volume[p=0.1,dbfs=-10:-40] \ - --augment pitch_and_tempo[p=0.1,pitch=1~0.2,tempo=1~0.2] \ - --augment speed[p=0.1,factor=1~0.5] \ - --augment warp[p=0.1,shift=30:60~20,ncp=4~3] \ + --augment pitch[p=0.1,pitch=1~0.2] \ + --augment tempo[p=0.1,factor=1~0.5] \ --augment frequency_mask[p=0.1,n=1:3,size=1:5] \ --augment time_mask[p=0.1,domain=signal,n=3:10~2,size=50:100~40] \ --augment dropout[p=0.1,rate=0.05] \ diff --git a/training/deepspeech_training/util/augmentations.py b/training/deepspeech_training/util/augmentations.py index eff08fbc..5c17dedd 100644 --- a/training/deepspeech_training/util/augmentations.py +++ b/training/deepspeech_training/util/augmentations.py @@ -8,6 +8,7 @@ import numpy as np from multiprocessing import Queue, Process from .audio import gain_db_to_ratio, max_dbfs, normalize_audio, AUDIO_TYPE_NP, AUDIO_TYPE_PCM, AUDIO_TYPE_OPUS from .helpers import LimitingPool, int_range, float_range, pick_value_from_range, tf_pick_value_from_range, MEGABYTE +from .sample_collections import samples_from_source BUFFER_SIZE = 1 * MEGABYTE SPEC_PARSER = re.compile(r'^(?P[a-z_]+)(\[(?P.*)\])?$') @@ -36,21 +37,25 @@ class GraphAugmentation(Augmentation): raise ValueError('Unsupported augmentation domain: {}'.format(domain)) self.domain = domain - def apply(self, tensor, clock=0.0): + def apply(self, tensor, transcript=None, clock=0.0): raise NotImplementedError - def apply_with_probability(self, tensor, clock=0.0): + def apply_with_probability(self, tensor, transcript=None, clock=0.0): import tensorflow as tf # pylint: disable=import-outside-toplevel rv = tf.random.stateless_uniform([], seed=(clock * tf.int32.min, clock * tf.int32.max)) return tf.cond(tf.less(rv, self.probability), - lambda: self.apply(tensor, clock=clock), + lambda: self.apply(tensor, transcript=transcript, clock=clock), lambda: tensor) - def maybe_apply(self, domain, tensor, clock=0.0): + def maybe_apply(self, domain, tensor, transcript=None, clock=0.0): if domain == self.domain: - return self.apply_with_probability(tensor, clock=clock) + return self.apply_with_probability(tensor, transcript=transcript, clock=clock) return tensor + def units_per_ms(self): + from .flags import FLAGS # pylint: disable=import-outside-toplevel + return FLAGS.audio_sample_rate / 1000.0 if self.domain == 'signal' else 1.0 / FLAGS.feature_win_step + def parse_augmentation(augmentation_spec): """ @@ -103,7 +108,7 @@ def parse_augmentations(augmentation_specs): return [] if augmentation_specs is None else list(map(parse_augmentation, augmentation_specs)) -def apply_graph_augmentations(domain, tensor, augmentations, clock=0.0): +def apply_graph_augmentations(domain, tensor, augmentations, transcript=None, clock=0.0): """ Augments training sample tensor of a certain domain with matching augmentations of passed list. @@ -115,6 +120,7 @@ def apply_graph_augmentations(domain, tensor, augmentations, clock=0.0): Tensor to apply augmentations to. augmentations : list of augmentation class instances from util.augmentations.*. List of augmentations of which only the spectrogram ones will get applied to the samples. + transcript : SparseTensor clock : Tensor of type float32 Time indicator for augmentation value-ranges. Running from 0.0 (start of training) to 1.0 (end of training). @@ -124,13 +130,9 @@ def apply_graph_augmentations(domain, tensor, augmentations, clock=0.0): The augmented spectrogram """ if augmentations is not None: - # Warp has to come before any spectrogram masking for augmentation in augmentations: - if isinstance(augmentation, Warp): - tensor = augmentation.maybe_apply(domain, tensor, clock=clock) - for augmentation in augmentations: - if isinstance(augmentation, GraphAugmentation) and not isinstance(augmentation, Warp): - tensor = augmentation.maybe_apply(domain, tensor, clock=clock) + if isinstance(augmentation, GraphAugmentation): + tensor = augmentation.maybe_apply(domain, tensor, transcript=transcript, clock=clock) return tensor @@ -204,7 +206,7 @@ def apply_sample_augmentations(samples, if final_clock is not None: assert 0.0 <= final_clock <= 1.0 assert clock <= final_clock - augmentations = list(filter(lambda aug: isinstance(aug, SampleAugmentation), augmentations)) + augmentations = [aug for aug in augmentations if isinstance(aug, SampleAugmentation)] try: for augmentation in augmentations: augmentation.start(buffering=buffering) @@ -229,8 +231,6 @@ def _enqueue_overlay_samples(sample_source, queue, buffering=BUFFER_SIZE): It loads the (raw and still compressed) data and provides it to the actual augmentation workers. These are then doing decompression, potential conversion and overlaying in parallel. """ - # preventing cyclic import problems - from .sample_collections import samples_from_source # pylint: disable=import-outside-toplevel samples = samples_from_source(sample_source, buffering=buffering, labeled=False) while True: for sample in samples: @@ -238,7 +238,7 @@ def _enqueue_overlay_samples(sample_source, queue, buffering=BUFFER_SIZE): class Overlay(SampleAugmentation): - """See "Overlay augmentation" in TRAINING.rst""" + """See "Overlay augmentation" in training documentation""" def __init__(self, source, p=1.0, snr=3.0, layers=1): super(Overlay, self).__init__(p) self.source = source @@ -288,7 +288,7 @@ class Overlay(SampleAugmentation): class Codec(SampleAugmentation): - """See "Codec augmentation" in TRAINING.rst""" + """See "Codec augmentation" in training documentation""" def __init__(self, p=1.0, bitrate=3200): super(Codec, self).__init__(p) self.bitrate = int_range(bitrate) @@ -300,7 +300,7 @@ class Codec(SampleAugmentation): class Reverb(SampleAugmentation): - """See "Reverb augmentation" in TRAINING.rst""" + """See "Reverb augmentation" in training documentation""" def __init__(self, p=1.0, delay=20.0, decay=10.0): super(Reverb, self).__init__(p) self.delay = float_range(delay) @@ -330,7 +330,7 @@ class Reverb(SampleAugmentation): class Resample(SampleAugmentation): - """See "Resample augmentation" in TRAINING.rst""" + """See "Resample augmentation" in training documentation""" def __init__(self, p=1.0, rate=8000): super(Resample, self).__init__(p) self.rate = int_range(rate) @@ -350,7 +350,7 @@ class Resample(SampleAugmentation): class Volume(SampleAugmentation): - """See "Volume augmentation" in TRAINING.rst""" + """See "Volume augmentation" in training documentation""" def __init__(self, p=1.0, dbfs=3.0103): super(Volume, self).__init__(p) self.target_dbfs = float_range(dbfs) @@ -361,25 +361,22 @@ class Volume(SampleAugmentation): sample.audio = normalize_audio(sample.audio, dbfs=target_dbfs) -class PitchAndTempo(GraphAugmentation): - """See "Pitch and tempo augmentation" in TRAINING.rst""" - def __init__(self, p=1.0, tempo=1.2, pitch=(1.075, 1.075, 0.125)): - super(PitchAndTempo, self).__init__(p, domain='spectrogram') - self.tempo = float_range(tempo) +class Pitch(GraphAugmentation): + """See "Pitch augmentation" in training documentation""" + def __init__(self, p=1.0, pitch=(1.075, 1.075, 0.125)): + super(Pitch, self).__init__(p, domain='spectrogram') self.pitch = float_range(pitch) - def apply(self, tensor, clock=0.0): + def apply(self, tensor, transcript=None, clock=0.0): import tensorflow as tf # pylint: disable=import-outside-toplevel original_shape = tf.shape(tensor) pitch = tf_pick_value_from_range(self.pitch, clock=clock) - tempo = tf.math.maximum(1.0, tf_pick_value_from_range(self.tempo, clock=clock)) new_freq_size = tf.cast(tf.cast(original_shape[2], tf.float32) * pitch, tf.int32) - new_time_size = tf.cast(tf.cast(original_shape[1], tf.float32) / tempo, tf.int32) - spectrogram_aug = tf.image.resize_bilinear(tf.expand_dims(tensor, -1), [new_time_size, new_freq_size]) + spectrogram_aug = tf.image.resize_bilinear(tf.expand_dims(tensor, -1), [original_shape[1], new_freq_size]) spectrogram_aug = tf.image.crop_to_bounding_box(spectrogram_aug, offset_height=0, offset_width=0, - target_height=tf.shape(spectrogram_aug)[1], + target_height=original_shape[1], target_width=tf.math.minimum(original_shape[2], new_freq_size)) spectrogram_aug = tf.cond(pitch < 1, lambda: tf.image.pad_to_bounding_box(spectrogram_aug, @@ -391,82 +388,34 @@ class PitchAndTempo(GraphAugmentation): return spectrogram_aug[:, :, :, 0] -class Speed(GraphAugmentation): - """See "Speed augmentation" in TRAINING.rst""" - def __init__(self, p=1.0, factor=1.1): - super(Speed, self).__init__(p, domain='spectrogram') +class Tempo(GraphAugmentation): + """See "Tempo augmentation" in training documentation""" + def __init__(self, p=1.0, factor=1.1, max_time=-1): + super(Tempo, self).__init__(p, domain='spectrogram') self.factor = float_range(factor) + self.max_time = float(max_time) - def apply(self, tensor, clock=0.0): + def apply(self, tensor, transcript=None, clock=0.0): import tensorflow as tf # pylint: disable=import-outside-toplevel factor = tf_pick_value_from_range(self.factor, clock=clock) original_shape = tf.shape(tensor) new_time_size = tf.cast(tf.cast(original_shape[1], tf.float32) / factor, tf.int32) + if transcript is not None: + new_time_size = tf.math.maximum(new_time_size, tf.shape(transcript)[1]) + if self.max_time > 0: + new_time_size = tf.math.minimum(new_time_size, tf.cast(self.max_time * self.units_per_ms(), tf.int32)) spectrogram_aug = tf.image.resize_bilinear(tf.expand_dims(tensor, -1), [new_time_size, original_shape[2]]) return spectrogram_aug[:, :, :, 0] -class Warp(GraphAugmentation): - """See "Warp augmentation" in TRAINING.rst""" - def __init__(self, p=1.0, shift=100.0, order=3, nbp=1, ncp=1, regularization_weight=0.0): - super(Warp, self).__init__(p, domain='spectrogram') - self.shift = float_range(shift) - self.order = int_range(order) - self.nbp = int_range(nbp) - self.ncp = int_range(ncp) - # Making this a value-range is impossible, as it would get a tensor which would downstream be used as parameter - # of a comparison inside tensorflow.contrib.image.python.ops.interpolate_spline. This is not supported. - self.regularization_weight = float(regularization_weight) - - def apply(self, tensor, clock=0.0): - import tensorflow as tf # pylint: disable=import-outside-toplevel - from .flags import FLAGS # pylint: disable=import-outside-toplevel - from .sparse_image_warp import sparse_image_warp # pylint: disable=import-outside-toplevel - - # reshape to fit `sparse_image_warp`'s input shape (1, time steps, freq, 1), batch_size must be 1 - expanded_spectrogram = tf.expand_dims(tensor, -1) - original_shape = tf.shape(expanded_spectrogram) - tau, freq_size = original_shape[1], original_shape[2] - seed = (clock * tf.int32.min, clock * tf.int32.max) - - shift = tf_pick_value_from_range(self.shift, clock=clock) - shift *= FLAGS.audio_sample_rate / (FLAGS.feature_win_step * 1000.0) # number of windows - shift = tf.math.minimum(tf.cast(shift, dtype=tf.int32), tf.math.floordiv(tau, 2) - 1) # to protect short audio - nbp = tf_pick_value_from_range(self.nbp, clock=clock) - ncp = tf_pick_value_from_range(self.ncp, clock=clock) - # workaround for missing stateless shuffle support - frequencies = tf.random.stateless_uniform([2 * ncp], seed, minval=1, maxval=freq_size - 2, dtype=tf.int32) - frequencies = tf.unique(tf.concat([frequencies, tf.range(1, limit=freq_size - 3)], axis=0))[0][0:ncp] - source_max = tau - shift - source_min = tf.math.minimum(source_max - ncp, shift) - # workaround for missing stateless shuffle support - src_times = tf.random.stateless_uniform([2 * ncp], seed, minval=source_min, maxval=source_max, dtype=tf.int32) - src_times = tf.unique(tf.concat([src_times, tf.range(1, limit=source_max)], axis=0))[0][0:ncp] - dst_times = src_times + tf.random.stateless_uniform([ncp], seed, minval=-shift, maxval=shift, dtype=tf.int32) - scp_locations = tf.cast([tf.transpose(tf.stack([src_times, frequencies]))], dtype=tf.float32) - dcp_locations = tf.cast([tf.transpose(tf.stack([dst_times, frequencies]))], dtype=tf.float32) - - order = tf_pick_value_from_range(self.order, clock=clock) - order = tf.math.maximum(3, order) # prevents "Input matrix is not invertible." exception - order = tf.cast(order, tf.float32) - - spectrogram_aug, _ = sparse_image_warp(expanded_spectrogram, - source_control_point_locations=scp_locations, - dest_control_point_locations=dcp_locations, - interpolation_order=order, - regularization_weight=self.regularization_weight, - num_boundary_points=nbp) - return tf.reshape(spectrogram_aug, shape=(1, -1, freq_size)) - - class FrequencyMask(GraphAugmentation): - """See "Frequency mask augmentation" in TRAINING.rst""" + """See "Frequency mask augmentation" in training documentation""" def __init__(self, p=1.0, n=3, size=2): super(FrequencyMask, self).__init__(p, domain='spectrogram') self.n = int_range(n) # pylint: disable=invalid-name self.size = int_range(size) - def apply(self, tensor, clock=0.0): + def apply(self, tensor, transcript=None, clock=0.0): import tensorflow as tf # pylint: disable=import-outside-toplevel time_max = tf.shape(tensor)[1] freq_max = tf.shape(tensor)[2] @@ -486,25 +435,20 @@ class FrequencyMask(GraphAugmentation): class TimeMask(GraphAugmentation): - """See "Time mask augmentation" in TRAINING.rst""" + """See "Time mask augmentation" in training documentation""" def __init__(self, p=1.0, domain='spectrogram', n=3, size=10.0): super(TimeMask, self).__init__(p, domain=domain) self.n = int_range(n) # pylint: disable=invalid-name self.size = float_range(size) - def apply(self, tensor, clock=0.0): + def apply(self, tensor, transcript=None, clock=0.0): import tensorflow as tf # pylint: disable=import-outside-toplevel - from .flags import FLAGS # pylint: disable=import-outside-toplevel - time_factor = FLAGS.audio_sample_rate / 1000.0 # samples per ms - if self.domain != 'signal': - time_factor /= FLAGS.feature_win_step # windows per ms - time_max = tf.shape(tensor)[0] if self.domain == 'signal' else tf.shape(tensor)[1] + time_max = tf.shape(tensor)[0 if self.domain == 'signal' else 1] n = tf_pick_value_from_range(self.n, clock=clock) def body(i, augmented): - size = tf.cast(tf_pick_value_from_range(self.size, clock=clock) * time_factor, dtype=tf.int32) + size = tf.cast(tf_pick_value_from_range(self.size, clock=clock) * self.units_per_ms(), dtype=tf.int32) size = tf.math.maximum(1, tf.math.minimum(time_max - 1, size)) - tf.print(size) seed = tf.cast(clock * tf.int32.max, tf.int32) - i t0 = tf.random.stateless_uniform((), (-seed, seed), minval=0, maxval=time_max - size, dtype=tf.dtypes.int32) rest = time_max - t0 - size @@ -521,12 +465,12 @@ class TimeMask(GraphAugmentation): class Dropout(GraphAugmentation): - """See "Dropout augmentation" in TRAINING.rst""" + """See "Dropout augmentation" in training documentation""" def __init__(self, p=1.0, domain='spectrogram', rate=0.05): super(Dropout, self).__init__(p, domain=domain) self.rate = float_range(rate) - def apply(self, tensor, clock=0.0): + def apply(self, tensor, transcript=None, clock=0.0): import tensorflow as tf # pylint: disable=import-outside-toplevel rate = tf_pick_value_from_range(self.rate, clock=clock) rate = tf.math.maximum(0.0, rate) @@ -539,12 +483,12 @@ class Dropout(GraphAugmentation): class Add(GraphAugmentation): - """See "Add augmentation" in TRAINING.rst""" + """See "Add augmentation" in training documentation""" def __init__(self, p=1.0, domain='features', stddev=5): super(Add, self).__init__(p, domain=domain) self.stddev = float_range(stddev) - def apply(self, tensor, clock=0.0): + def apply(self, tensor, transcript=None, clock=0.0): import tensorflow as tf # pylint: disable=import-outside-toplevel stddev = tf_pick_value_from_range(self.stddev, clock=clock) seed = (clock * tf.int32.min, clock * tf.int32.max) @@ -552,12 +496,12 @@ class Add(GraphAugmentation): class Multiply(GraphAugmentation): - """See "Multiply augmentation" in TRAINING.rst""" + """See "Multiply augmentation" in training documentation""" def __init__(self, p=1.0, domain='features', stddev=5): super(Multiply, self).__init__(p, domain=domain) self.stddev = float_range(stddev) - def apply(self, tensor, clock=0.0): + def apply(self, tensor, transcript=None, clock=0.0): import tensorflow as tf # pylint: disable=import-outside-toplevel stddev = tf_pick_value_from_range(self.stddev, clock=clock) seed = (clock * tf.int32.min, clock * tf.int32.max) diff --git a/training/deepspeech_training/util/feeding.py b/training/deepspeech_training/util/feeding.py index 9dbdae2b..4c9b681d 100644 --- a/training/deepspeech_training/util/feeding.py +++ b/training/deepspeech_training/util/feeding.py @@ -18,7 +18,7 @@ from .sample_collections import samples_from_sources from .helpers import remember_exception, MEGABYTE -def audio_to_features(audio, sample_rate, clock=0.0, train_phase=False, augmentations=None, sample_id=None): +def audio_to_features(audio, sample_rate, transcript=None, clock=0.0, train_phase=False, augmentations=None, sample_id=None): if train_phase: # We need the lambdas to make TensorFlow happy. # pylint: disable=unnecessary-lambda @@ -29,7 +29,7 @@ def audio_to_features(audio, sample_rate, clock=0.0, train_phase=False, augmenta name='matching_sample_rate') if train_phase and augmentations is not None: - audio = apply_graph_augmentations('signal', audio, augmentations, clock=clock) + audio = apply_graph_augmentations('signal', audio, augmentations, transcript=transcript, clock=clock) spectrogram = contrib_audio.audio_spectrogram(audio, window_size=Config.audio_window_samples, @@ -37,7 +37,7 @@ def audio_to_features(audio, sample_rate, clock=0.0, train_phase=False, augmenta magnitude_squared=True) if train_phase and augmentations is not None: - spectrogram = apply_graph_augmentations('spectrogram', spectrogram, augmentations, clock=clock) + spectrogram = apply_graph_augmentations('spectrogram', spectrogram, augmentations, transcript=transcript, clock=clock) features = contrib_audio.mfcc(spectrogram=spectrogram, sample_rate=sample_rate, @@ -46,7 +46,7 @@ def audio_to_features(audio, sample_rate, clock=0.0, train_phase=False, augmenta features = tf.reshape(features, [-1, Config.n_input]) if train_phase and augmentations is not None: - features = apply_graph_augmentations('features', features, augmentations, clock=clock) + features = apply_graph_augmentations('features', features, augmentations, transcript=transcript, clock=clock) return features, tf.shape(input=features)[0] @@ -64,13 +64,14 @@ def audiofile_to_features(wav_filename, clock=0.0, train_phase=False, augmentati def entry_to_features(sample_id, audio, sample_rate, transcript, clock, train_phase=False, augmentations=None): # https://bugs.python.org/issue32117 + sparse_transcript = tf.SparseTensor(*transcript) features, features_len = audio_to_features(audio, sample_rate, + transcript=sparse_transcript, clock=clock, train_phase=train_phase, augmentations=augmentations, sample_id=sample_id) - sparse_transcript = tf.SparseTensor(*transcript) return sample_id, features, features_len, sparse_transcript diff --git a/training/deepspeech_training/util/sparse_image_warp.py b/training/deepspeech_training/util/sparse_image_warp.py deleted file mode 100644 index 0fcdba0a..00000000 --- a/training/deepspeech_training/util/sparse_image_warp.py +++ /dev/null @@ -1,220 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Image warping using sparse flow defined at control points.""" - -# The following code is from: https://github.com/tensorflow/tensorflow/blob/v1.14.0/tensorflow/contrib/image/python/ops/sparse_image_warp.py -# But refactored for dynamic tensor shape compatibility -# The core idea is to replace every numpy implementation with tensorflow implementation - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf -import tensorflow.compat.v1 as tfv1 -from tensorflow.compat import dimension_value -from tensorflow.contrib.image.python.ops import dense_image_warp -from tensorflow.contrib.image.python.ops import interpolate_spline - -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops - -def _to_float32(value): - return tf.cast(value, tf.float32) - -def _to_int32(value): - return tf.cast(value, tf.int32) - -def _get_grid_locations(image_height, image_width): - """Wrapper for np.meshgrid.""" - tfv1.assert_type(image_height, tf.int32) - tfv1.assert_type(image_width, tf.int32) - - y_range = tf.range(image_height) - x_range = tf.range(image_width) - y_grid, x_grid = tf.meshgrid(y_range, x_range, indexing='ij') - return tf.stack((y_grid, x_grid), -1) - - -def _expand_to_minibatch(tensor, batch_size): - """Tile arbitrarily-sized np_array to include new batch dimension.""" - ndim = tf.size(tf.shape(tensor)) - ones = tf.ones((ndim,), tf.int32) - - tiles = tf.concat(([batch_size], ones), 0) - return tf.tile(tf.expand_dims(tensor, 0), tiles) - - -def _get_boundary_locations(image_height, image_width, num_points_per_edge): - """Compute evenly-spaced indices along edge of image.""" - image_height_end = _to_float32(tf.math.subtract(image_height, 1)) - image_width_end = _to_float32(tf.math.subtract(image_width, 1)) - y_range = tf.linspace(0.0, image_height_end, num_points_per_edge + 2) - x_range = tf.linspace(0.0, image_height_end, num_points_per_edge + 2) - ys, xs = tf.meshgrid(y_range, x_range, indexing='ij') - is_boundary = tf.logical_or( - tf.logical_or(tf.equal(xs, 0.0), tf.equal(xs, image_width_end)), - tf.logical_or(tf.equal(ys, 0.0), tf.equal(ys, image_height_end))) - return tf.stack([tf.boolean_mask(ys, is_boundary), tf.boolean_mask(xs, is_boundary)], axis=-1) - - -def _add_zero_flow_controls_at_boundary(control_point_locations, - control_point_flows, image_height, - image_width, boundary_points_per_edge): - """Add control points for zero-flow boundary conditions. - - Augment the set of control points with extra points on the - boundary of the image that have zero flow. - - Args: - control_point_locations: input control points - control_point_flows: their flows - image_height: image height - image_width: image width - boundary_points_per_edge: number of points to add in the middle of each - edge (not including the corners). - The total number of points added is - 4 + 4*(boundary_points_per_edge). - - Returns: - merged_control_point_locations: augmented set of control point locations - merged_control_point_flows: augmented set of control point flows - """ - - batch_size = dimension_value(tf.shape(control_point_locations)[0]) - - boundary_point_locations = _get_boundary_locations(image_height, image_width, - boundary_points_per_edge) - boundary_point_shape = tf.shape(boundary_point_locations) - boundary_point_flows = tf.zeros([boundary_point_shape[0], 2]) - - minbatch_locations = _expand_to_minibatch(boundary_point_locations, batch_size) - type_to_use = control_point_locations.dtype - boundary_point_locations = tf.cast(minbatch_locations, type_to_use) - - minbatch_flows = _expand_to_minibatch(boundary_point_flows, batch_size) - - boundary_point_flows = tf.cast(minbatch_flows, type_to_use) - - merged_control_point_locations = tf.concat( - [control_point_locations, boundary_point_locations], 1) - - merged_control_point_flows = tf.concat( - [control_point_flows, boundary_point_flows], 1) - - return merged_control_point_locations, merged_control_point_flows - - -def sparse_image_warp(image, - source_control_point_locations, - dest_control_point_locations, - interpolation_order=2, - regularization_weight=0.0, - num_boundary_points=0, - name='sparse_image_warp'): - """Image warping using correspondences between sparse control points. - - Apply a non-linear warp to the image, where the warp is specified by - the source and destination locations of a (potentially small) number of - control points. First, we use a polyharmonic spline - (`tf.contrib.image.interpolate_spline`) to interpolate the displacements - between the corresponding control points to a dense flow field. - Then, we warp the image using this dense flow field - (`tf.contrib.image.dense_image_warp`). - - Let t index our control points. For regularization_weight=0, we have: - warped_image[b, dest_control_point_locations[b, t, 0], - dest_control_point_locations[b, t, 1], :] = - image[b, source_control_point_locations[b, t, 0], - source_control_point_locations[b, t, 1], :]. - - For regularization_weight > 0, this condition is met approximately, since - regularized interpolation trades off smoothness of the interpolant vs. - reconstruction of the interpolant at the control points. - See `tf.contrib.image.interpolate_spline` for further documentation of the - interpolation_order and regularization_weight arguments. - - - Args: - image: `[batch, height, width, channels]` float `Tensor` - source_control_point_locations: `[batch, num_control_points, 2]` float - `Tensor` - dest_control_point_locations: `[batch, num_control_points, 2]` float - `Tensor` - interpolation_order: polynomial order used by the spline interpolation - regularization_weight: weight on smoothness regularizer in interpolation - num_boundary_points: How many zero-flow boundary points to include at - each image edge.Usage: - num_boundary_points=0: don't add zero-flow points - num_boundary_points=1: 4 corners of the image - num_boundary_points=2: 4 corners and one in the middle of each edge - (8 points total) - num_boundary_points=n: 4 corners and n-1 along each edge - name: A name for the operation (optional). - - Note that image and offsets can be of type tf.half, tf.float32, or - tf.float64, and do not necessarily have to be the same type. - - Returns: - warped_image: `[batch, height, width, channels]` float `Tensor` with same - type as input image. - flow_field: `[batch, height, width, 2]` float `Tensor` containing the dense - flow field produced by the interpolation. - """ - - image = ops.convert_to_tensor(image) - source_control_point_locations = ops.convert_to_tensor( - source_control_point_locations) - dest_control_point_locations = ops.convert_to_tensor( - dest_control_point_locations) - - control_point_flows = ( - dest_control_point_locations - source_control_point_locations) - - clamp_boundaries = num_boundary_points > 0 - boundary_points_per_edge = num_boundary_points - 1 - - with ops.name_scope(name): - image_shape = tf.shape(image) - batch_size, image_height, image_width = image_shape[0], image_shape[1], image_shape[2] - - # This generates the dense locations where the interpolant - # will be evaluated. - grid_locations = _get_grid_locations(image_height, image_width) - - flattened_grid_locations = tf.reshape(grid_locations, - [tf.multiply(image_height, image_width), 2]) - - # flattened_grid_locations = constant_op.constant( - # _expand_to_minibatch(flattened_grid_locations, batch_size), image.dtype) - flattened_grid_locations = _expand_to_minibatch(flattened_grid_locations, batch_size) - flattened_grid_locations = tf.cast(flattened_grid_locations, dtype=image.dtype) - - if clamp_boundaries: - (dest_control_point_locations, - control_point_flows) = _add_zero_flow_controls_at_boundary( - dest_control_point_locations, control_point_flows, image_height, - image_width, boundary_points_per_edge) - - flattened_flows = interpolate_spline.interpolate_spline( - dest_control_point_locations, control_point_flows, - flattened_grid_locations, interpolation_order, regularization_weight) - - dense_flows = array_ops.reshape(flattened_flows, - [batch_size, image_height, image_width, 2]) - - warped_image = dense_image_warp.dense_image_warp(image, dense_flows) - - return warped_image, dense_flows