Update augmentation code to TF2

This commit is contained in:
Reuben Morais 2021-01-02 15:05:07 +00:00
parent 85b9f0fd3d
commit 159697738c
2 changed files with 9 additions and 5 deletions

View File

@ -67,6 +67,7 @@ def main():
'llvmlite == 0.31.0', # for numba==0.47.0
'librosa',
'soundfile',
'tensorflow_addons >= 0.12.0',
]
decoder_pypi_dep = [

View File

@ -1,4 +1,3 @@
import os
import re
import math
@ -388,10 +387,11 @@ class Pitch(GraphAugmentation):
def apply(self, tensor, transcript=None, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel
import tensorflow.compat.v1 as tfv1 # pylint: disable=import-outside-toplevel
original_shape = tf.shape(tensor)
pitch = tf_pick_value_from_range(self.pitch, clock=clock)
new_freq_size = tf.cast(tf.cast(original_shape[2], tf.float32) * pitch, tf.int32)
spectrogram_aug = tf.image.resize_bilinear(tf.expand_dims(tensor, -1), [original_shape[1], new_freq_size])
spectrogram_aug = tfv1.image.resize_bilinear(tf.expand_dims(tensor, -1), [original_shape[1], new_freq_size])
spectrogram_aug = tf.image.crop_to_bounding_box(spectrogram_aug,
offset_height=0,
offset_width=0,
@ -416,6 +416,7 @@ class Tempo(GraphAugmentation):
def apply(self, tensor, transcript=None, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel
import tensorflow.compat.v1 as tfv1 # pylint: disable=import-outside-toplevel
factor = tf_pick_value_from_range(self.factor, clock=clock)
original_shape = tf.shape(tensor)
new_time_size = tf.cast(tf.cast(original_shape[1], tf.float32) / factor, tf.int32)
@ -423,7 +424,7 @@ class Tempo(GraphAugmentation):
new_time_size = tf.math.maximum(new_time_size, tf.shape(transcript)[1])
if self.max_time > 0:
new_time_size = tf.math.minimum(new_time_size, tf.cast(self.max_time * self.units_per_ms(), tf.int32))
spectrogram_aug = tf.image.resize_bilinear(tf.expand_dims(tensor, -1), [new_time_size, original_shape[2]])
spectrogram_aug = tfv1.image.resize_bilinear(tf.expand_dims(tensor, -1), [new_time_size, original_shape[2]])
return spectrogram_aug[:, :, :, 0]
@ -438,6 +439,8 @@ class Warp(GraphAugmentation):
def apply(self, tensor, transcript=None, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel
import tensorflow.compat.v1 as tfv1 # pylint: disable=import-outside-toplevel
import tensorflow_addons as tfa # pylint: disable=import-outside-toplevel
original_shape = tf.shape(tensor)
size_t, size_f = original_shape[1], original_shape[2]
seed = (clock * tf.int32.min, clock * tf.int32.max)
@ -451,8 +454,8 @@ class Warp(GraphAugmentation):
return tf.pad(f, tf.constant([[1, 1], [1, 1]]), 'CONSTANT') # zero flow at all edges
flows = tf.stack([get_flows(num_t, size_t, self.warp_t), get_flows(num_f, size_f, self.warp_f)], axis=2)
flows = tf.image.resize_bicubic(tf.expand_dims(flows, 0), [size_t, size_f])
spectrogram_aug = tf.contrib.image.dense_image_warp(tf.expand_dims(tensor, -1), flows)
flows = tfv1.image.resize_bicubic(tf.expand_dims(flows, 0), [size_t, size_f])
spectrogram_aug = tfa.image.dense_image_warp(tf.expand_dims(tensor, -1), flows)
return tf.reshape(spectrogram_aug, shape=(1, -1, size_f))