From d83630fef4ba0d0da0e10fb1693e9a867b49b008 Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Tue, 18 May 2021 16:55:50 +0200 Subject: [PATCH] Print fully parsed augmentation config --- .../coqui_stt_training/util/augmentations.py | 54 +++++++++++++++++-- training/coqui_stt_training/util/config.py | 1 + 2 files changed, 50 insertions(+), 5 deletions(-) diff --git a/training/coqui_stt_training/util/augmentations.py b/training/coqui_stt_training/util/augmentations.py index 55e03cc2..0206b1aa 100644 --- a/training/coqui_stt_training/util/augmentations.py +++ b/training/coqui_stt_training/util/augmentations.py @@ -309,6 +309,9 @@ class Overlay(SampleAugmentation): self.queue = None self.enqueue_process = None + def __repr__(self): + return f"Overlay(source={self.source!r}, p={self.probability!r}, snr={self.snr!r}, layers={self.layers!r})" + def start(self, buffering=BUFFER_SIZE): self.queue = Queue( max(1, math.floor(self.probability * self.layers[1] * os.cpu_count())) @@ -369,6 +372,9 @@ class Codec(SampleAugmentation): super(Codec, self).__init__(p) self.bitrate = int_range(bitrate) + def __repr__(self): + return f"Codec(p={self.probability!r}, bitrate={self.bitrate!r})" + def apply(self, sample, clock=0.0): bitrate = pick_value_from_range(self.bitrate, clock=clock) sample.change_audio_type( @@ -387,6 +393,9 @@ class Reverb(SampleAugmentation): self.delay = float_range(delay) self.decay = float_range(decay) + def __repr__(self): + return f"Reverb(p={self.probability!r}, delay={self.delay!r}, decay={self.decay!r})" + def apply(self, sample, clock=0.0): sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP) audio = np.array(sample.audio, dtype=np.float64) @@ -421,6 +430,9 @@ class Resample(SampleAugmentation): super(Resample, self).__init__(p) self.rate = int_range(rate) + def __repr__(self): + return f"Resample(p={self.probability!r}, rate={self.rate!r})" + def apply(self, sample, clock=0.0): sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP) rate = pick_value_from_range(self.rate, clock=clock) @@ -438,6 +450,9 @@ class NormalizeSampleRate(SampleAugmentation): super().__init__(p=1.0) self.rate = rate + def __repr__(self): + return f"NormalizeSampleRate(rate={self.rate!r})" + def apply(self, sample, clock=0.0): if sample.audio_format.rate == self.rate: return @@ -460,6 +475,9 @@ class Volume(SampleAugmentation): super(Volume, self).__init__(p) self.target_dbfs = float_range(dbfs) + def __repr__(self): + return f"Volume(p={self.probability!r}, dbfs={self.target_dbfs!r})" + def apply(self, sample, clock=0.0): sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP) target_dbfs = pick_value_from_range(self.target_dbfs, clock=clock) @@ -473,6 +491,9 @@ class Pitch(GraphAugmentation): super(Pitch, self).__init__(p, domain="spectrogram") self.pitch = float_range(pitch) + def __repr__(self): + return f"Pitch(p={self.probability!r}, pitch={self.pitch!r})" + def apply(self, tensor, transcript=None, clock=0.0): import tensorflow as tf # pylint: disable=import-outside-toplevel @@ -513,6 +534,9 @@ class Tempo(GraphAugmentation): self.factor = float_range(factor) self.max_time = float(max_time) + def __repr__(self): + return f"Tempo(p={self.probability!r}, factor={self.factor!r}, max_time={self.max_time!r})" + def apply(self, tensor, transcript=None, clock=0.0): import tensorflow as tf # pylint: disable=import-outside-toplevel @@ -536,12 +560,15 @@ class Tempo(GraphAugmentation): class Warp(GraphAugmentation): """See "Warp augmentation" in training documentation""" - def __init__(self, p=1.0, nt=1, nf=1, wt=0.1, wf=0.0): + def __init__(self, p=1.0, num_t=1, num_f=1, warp_t=0.1, warp_f=0.0): super(Warp, self).__init__(p, domain="spectrogram") - self.num_t = int_range(nt) - self.num_f = int_range(nf) - self.warp_t = float_range(wt) - self.warp_f = float_range(wf) + self.num_t = int_range(num_t) + self.num_f = int_range(num_f) + self.warp_t = float_range(warp_t) + self.warp_f = float_range(warp_f) + + def __repr__(self): + return f"Warp(p={self.probability!r}, num_t={self.num_t!r}, num_f={self.num_f!r}, warp_t={self.warp_t!r}, warp_f={self.warp_f!r})" def apply(self, tensor, transcript=None, clock=0.0): import tensorflow as tf # pylint: disable=import-outside-toplevel @@ -588,6 +615,11 @@ class FrequencyMask(GraphAugmentation): self.n = int_range(n) # pylint: disable=invalid-name self.size = int_range(size) + def __repr__(self): + return ( + f"FrequencyMask(p={self.probability!r}, n={self.n!r}, size={self.size!r})" + ) + def apply(self, tensor, transcript=None, clock=0.0): import tensorflow as tf # pylint: disable=import-outside-toplevel @@ -627,6 +659,9 @@ class TimeMask(GraphAugmentation): self.n = int_range(n) # pylint: disable=invalid-name self.size = float_range(size) + def __repr__(self): + return f"TimeMask(p={self.probability!r}, domain={self.domain!r}, n={self.n!r}, size={self.size!r})" + def apply(self, tensor, transcript=None, clock=0.0): import tensorflow as tf # pylint: disable=import-outside-toplevel @@ -678,6 +713,9 @@ class Dropout(GraphAugmentation): super(Dropout, self).__init__(p, domain=domain) self.rate = float_range(rate) + def __repr__(self): + return f"Dropout(p={self.probability!r}, domain={self.domain!r}, rate={self.rate!r})" + def apply(self, tensor, transcript=None, clock=0.0): import tensorflow as tf # pylint: disable=import-outside-toplevel @@ -700,6 +738,9 @@ class Add(GraphAugmentation): super(Add, self).__init__(p, domain=domain) self.stddev = float_range(stddev) + def __repr__(self): + return f"Add(p={self.probability!r}, domain={self.domain!r}, stddev={self.stddev!r})" + def apply(self, tensor, transcript=None, clock=0.0): import tensorflow as tf # pylint: disable=import-outside-toplevel @@ -717,6 +758,9 @@ class Multiply(GraphAugmentation): super(Multiply, self).__init__(p, domain=domain) self.stddev = float_range(stddev) + def __repr__(self): + return f"Multiply(p={self.probability!r}, domain={self.domain!r}, stddev={self.stddev!r})" + def apply(self, tensor, transcript=None, clock=0.0): import tensorflow as tf # pylint: disable=import-outside-toplevel diff --git a/training/coqui_stt_training/util/config.py b/training/coqui_stt_training/util/config.py index b6b20291..a873f26f 100755 --- a/training/coqui_stt_training/util/config.py +++ b/training/coqui_stt_training/util/config.py @@ -553,6 +553,7 @@ def initialize_globals(): # Augmentations c.augmentations = parse_augmentations(c.augment) + print(f"Parsed augmentations from flags: {c.augmentations}") if c.augmentations and c.feature_cache and c.cache_for_epochs == 0: log_warn( "Due to current feature-cache settings the exact same sample augmentations of the first "