Print fully parsed augmentation config
This commit is contained in:
parent
5114362f6d
commit
d83630fef4
|
@ -309,6 +309,9 @@ class Overlay(SampleAugmentation):
|
||||||
self.queue = None
|
self.queue = None
|
||||||
self.enqueue_process = None
|
self.enqueue_process = None
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"Overlay(source={self.source!r}, p={self.probability!r}, snr={self.snr!r}, layers={self.layers!r})"
|
||||||
|
|
||||||
def start(self, buffering=BUFFER_SIZE):
|
def start(self, buffering=BUFFER_SIZE):
|
||||||
self.queue = Queue(
|
self.queue = Queue(
|
||||||
max(1, math.floor(self.probability * self.layers[1] * os.cpu_count()))
|
max(1, math.floor(self.probability * self.layers[1] * os.cpu_count()))
|
||||||
|
@ -369,6 +372,9 @@ class Codec(SampleAugmentation):
|
||||||
super(Codec, self).__init__(p)
|
super(Codec, self).__init__(p)
|
||||||
self.bitrate = int_range(bitrate)
|
self.bitrate = int_range(bitrate)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"Codec(p={self.probability!r}, bitrate={self.bitrate!r})"
|
||||||
|
|
||||||
def apply(self, sample, clock=0.0):
|
def apply(self, sample, clock=0.0):
|
||||||
bitrate = pick_value_from_range(self.bitrate, clock=clock)
|
bitrate = pick_value_from_range(self.bitrate, clock=clock)
|
||||||
sample.change_audio_type(
|
sample.change_audio_type(
|
||||||
|
@ -387,6 +393,9 @@ class Reverb(SampleAugmentation):
|
||||||
self.delay = float_range(delay)
|
self.delay = float_range(delay)
|
||||||
self.decay = float_range(decay)
|
self.decay = float_range(decay)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"Reverb(p={self.probability!r}, delay={self.delay!r}, decay={self.decay!r})"
|
||||||
|
|
||||||
def apply(self, sample, clock=0.0):
|
def apply(self, sample, clock=0.0):
|
||||||
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
|
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
|
||||||
audio = np.array(sample.audio, dtype=np.float64)
|
audio = np.array(sample.audio, dtype=np.float64)
|
||||||
|
@ -421,6 +430,9 @@ class Resample(SampleAugmentation):
|
||||||
super(Resample, self).__init__(p)
|
super(Resample, self).__init__(p)
|
||||||
self.rate = int_range(rate)
|
self.rate = int_range(rate)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"Resample(p={self.probability!r}, rate={self.rate!r})"
|
||||||
|
|
||||||
def apply(self, sample, clock=0.0):
|
def apply(self, sample, clock=0.0):
|
||||||
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
|
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
|
||||||
rate = pick_value_from_range(self.rate, clock=clock)
|
rate = pick_value_from_range(self.rate, clock=clock)
|
||||||
|
@ -438,6 +450,9 @@ class NormalizeSampleRate(SampleAugmentation):
|
||||||
super().__init__(p=1.0)
|
super().__init__(p=1.0)
|
||||||
self.rate = rate
|
self.rate = rate
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"NormalizeSampleRate(rate={self.rate!r})"
|
||||||
|
|
||||||
def apply(self, sample, clock=0.0):
|
def apply(self, sample, clock=0.0):
|
||||||
if sample.audio_format.rate == self.rate:
|
if sample.audio_format.rate == self.rate:
|
||||||
return
|
return
|
||||||
|
@ -460,6 +475,9 @@ class Volume(SampleAugmentation):
|
||||||
super(Volume, self).__init__(p)
|
super(Volume, self).__init__(p)
|
||||||
self.target_dbfs = float_range(dbfs)
|
self.target_dbfs = float_range(dbfs)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"Volume(p={self.probability!r}, dbfs={self.target_dbfs!r})"
|
||||||
|
|
||||||
def apply(self, sample, clock=0.0):
|
def apply(self, sample, clock=0.0):
|
||||||
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
|
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
|
||||||
target_dbfs = pick_value_from_range(self.target_dbfs, clock=clock)
|
target_dbfs = pick_value_from_range(self.target_dbfs, clock=clock)
|
||||||
|
@ -473,6 +491,9 @@ class Pitch(GraphAugmentation):
|
||||||
super(Pitch, self).__init__(p, domain="spectrogram")
|
super(Pitch, self).__init__(p, domain="spectrogram")
|
||||||
self.pitch = float_range(pitch)
|
self.pitch = float_range(pitch)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"Pitch(p={self.probability!r}, pitch={self.pitch!r})"
|
||||||
|
|
||||||
def apply(self, tensor, transcript=None, clock=0.0):
|
def apply(self, tensor, transcript=None, clock=0.0):
|
||||||
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
||||||
|
|
||||||
|
@ -513,6 +534,9 @@ class Tempo(GraphAugmentation):
|
||||||
self.factor = float_range(factor)
|
self.factor = float_range(factor)
|
||||||
self.max_time = float(max_time)
|
self.max_time = float(max_time)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"Tempo(p={self.probability!r}, factor={self.factor!r}, max_time={self.max_time!r})"
|
||||||
|
|
||||||
def apply(self, tensor, transcript=None, clock=0.0):
|
def apply(self, tensor, transcript=None, clock=0.0):
|
||||||
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
||||||
|
|
||||||
|
@ -536,12 +560,15 @@ class Tempo(GraphAugmentation):
|
||||||
class Warp(GraphAugmentation):
|
class Warp(GraphAugmentation):
|
||||||
"""See "Warp augmentation" in training documentation"""
|
"""See "Warp augmentation" in training documentation"""
|
||||||
|
|
||||||
def __init__(self, p=1.0, nt=1, nf=1, wt=0.1, wf=0.0):
|
def __init__(self, p=1.0, num_t=1, num_f=1, warp_t=0.1, warp_f=0.0):
|
||||||
super(Warp, self).__init__(p, domain="spectrogram")
|
super(Warp, self).__init__(p, domain="spectrogram")
|
||||||
self.num_t = int_range(nt)
|
self.num_t = int_range(num_t)
|
||||||
self.num_f = int_range(nf)
|
self.num_f = int_range(num_f)
|
||||||
self.warp_t = float_range(wt)
|
self.warp_t = float_range(warp_t)
|
||||||
self.warp_f = float_range(wf)
|
self.warp_f = float_range(warp_f)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"Warp(p={self.probability!r}, num_t={self.num_t!r}, num_f={self.num_f!r}, warp_t={self.warp_t!r}, warp_f={self.warp_f!r})"
|
||||||
|
|
||||||
def apply(self, tensor, transcript=None, clock=0.0):
|
def apply(self, tensor, transcript=None, clock=0.0):
|
||||||
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
||||||
|
@ -588,6 +615,11 @@ class FrequencyMask(GraphAugmentation):
|
||||||
self.n = int_range(n) # pylint: disable=invalid-name
|
self.n = int_range(n) # pylint: disable=invalid-name
|
||||||
self.size = int_range(size)
|
self.size = int_range(size)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (
|
||||||
|
f"FrequencyMask(p={self.probability!r}, n={self.n!r}, size={self.size!r})"
|
||||||
|
)
|
||||||
|
|
||||||
def apply(self, tensor, transcript=None, clock=0.0):
|
def apply(self, tensor, transcript=None, clock=0.0):
|
||||||
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
||||||
|
|
||||||
|
@ -627,6 +659,9 @@ class TimeMask(GraphAugmentation):
|
||||||
self.n = int_range(n) # pylint: disable=invalid-name
|
self.n = int_range(n) # pylint: disable=invalid-name
|
||||||
self.size = float_range(size)
|
self.size = float_range(size)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"TimeMask(p={self.probability!r}, domain={self.domain!r}, n={self.n!r}, size={self.size!r})"
|
||||||
|
|
||||||
def apply(self, tensor, transcript=None, clock=0.0):
|
def apply(self, tensor, transcript=None, clock=0.0):
|
||||||
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
||||||
|
|
||||||
|
@ -678,6 +713,9 @@ class Dropout(GraphAugmentation):
|
||||||
super(Dropout, self).__init__(p, domain=domain)
|
super(Dropout, self).__init__(p, domain=domain)
|
||||||
self.rate = float_range(rate)
|
self.rate = float_range(rate)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"Dropout(p={self.probability!r}, domain={self.domain!r}, rate={self.rate!r})"
|
||||||
|
|
||||||
def apply(self, tensor, transcript=None, clock=0.0):
|
def apply(self, tensor, transcript=None, clock=0.0):
|
||||||
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
||||||
|
|
||||||
|
@ -700,6 +738,9 @@ class Add(GraphAugmentation):
|
||||||
super(Add, self).__init__(p, domain=domain)
|
super(Add, self).__init__(p, domain=domain)
|
||||||
self.stddev = float_range(stddev)
|
self.stddev = float_range(stddev)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"Add(p={self.probability!r}, domain={self.domain!r}, stddev={self.stddev!r})"
|
||||||
|
|
||||||
def apply(self, tensor, transcript=None, clock=0.0):
|
def apply(self, tensor, transcript=None, clock=0.0):
|
||||||
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
||||||
|
|
||||||
|
@ -717,6 +758,9 @@ class Multiply(GraphAugmentation):
|
||||||
super(Multiply, self).__init__(p, domain=domain)
|
super(Multiply, self).__init__(p, domain=domain)
|
||||||
self.stddev = float_range(stddev)
|
self.stddev = float_range(stddev)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"Multiply(p={self.probability!r}, domain={self.domain!r}, stddev={self.stddev!r})"
|
||||||
|
|
||||||
def apply(self, tensor, transcript=None, clock=0.0):
|
def apply(self, tensor, transcript=None, clock=0.0):
|
||||||
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
||||||
|
|
||||||
|
|
|
@ -553,6 +553,7 @@ def initialize_globals():
|
||||||
|
|
||||||
# Augmentations
|
# Augmentations
|
||||||
c.augmentations = parse_augmentations(c.augment)
|
c.augmentations = parse_augmentations(c.augment)
|
||||||
|
print(f"Parsed augmentations from flags: {c.augmentations}")
|
||||||
if c.augmentations and c.feature_cache and c.cache_for_epochs == 0:
|
if c.augmentations and c.feature_cache and c.cache_for_epochs == 0:
|
||||||
log_warn(
|
log_warn(
|
||||||
"Due to current feature-cache settings the exact same sample augmentations of the first "
|
"Due to current feature-cache settings the exact same sample augmentations of the first "
|
||||||
|
|
Loading…
Reference in New Issue