Use class name and review cleanup
This commit is contained in:
parent
f19ecbdd93
commit
9a708328e7
@ -90,7 +90,7 @@ def parse_augmentation(augmentation_spec):
|
|||||||
kwargs[pair[0]] = pair[1]
|
kwargs[pair[0]] = pair[1]
|
||||||
else:
|
else:
|
||||||
raise ValueError('Unable to parse augmentation value assignment')
|
raise ValueError('Unable to parse augmentation value assignment')
|
||||||
log_info('Processed augmentation type: [{}] with parameter settings: {}'.format(augmentation_cls().name, kwargs))
|
log_info('Processed augmentation type: [{}] with parameter settings: {}'.format(augmentation_cls.__name__, kwargs))
|
||||||
return augmentation_cls(*args, **kwargs)
|
return augmentation_cls(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@ -258,7 +258,6 @@ class Overlay(SampleAugmentation):
|
|||||||
self.current_sample = None
|
self.current_sample = None
|
||||||
self.queue = None
|
self.queue = None
|
||||||
self.enqueue_process = None
|
self.enqueue_process = None
|
||||||
self.name = "Overlay"
|
|
||||||
|
|
||||||
def start(self, buffering=BUFFER_SIZE):
|
def start(self, buffering=BUFFER_SIZE):
|
||||||
self.queue = Queue(max(1, math.floor(self.probability * self.layers[1] * os.cpu_count())))
|
self.queue = Queue(max(1, math.floor(self.probability * self.layers[1] * os.cpu_count())))
|
||||||
@ -310,7 +309,6 @@ class Codec(SampleAugmentation):
|
|||||||
def __init__(self, p=1.0, bitrate=3200):
|
def __init__(self, p=1.0, bitrate=3200):
|
||||||
super(Codec, self).__init__(p)
|
super(Codec, self).__init__(p)
|
||||||
self.bitrate = int_range(bitrate)
|
self.bitrate = int_range(bitrate)
|
||||||
self.name = "Codec"
|
|
||||||
|
|
||||||
def apply(self, sample, clock=0.0):
|
def apply(self, sample, clock=0.0):
|
||||||
bitrate = pick_value_from_range(self.bitrate, clock=clock)
|
bitrate = pick_value_from_range(self.bitrate, clock=clock)
|
||||||
@ -324,7 +322,6 @@ class Reverb(SampleAugmentation):
|
|||||||
super(Reverb, self).__init__(p)
|
super(Reverb, self).__init__(p)
|
||||||
self.delay = float_range(delay)
|
self.delay = float_range(delay)
|
||||||
self.decay = float_range(decay)
|
self.decay = float_range(decay)
|
||||||
self.name = "Reverb"
|
|
||||||
|
|
||||||
def apply(self, sample, clock=0.0):
|
def apply(self, sample, clock=0.0):
|
||||||
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
|
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
|
||||||
@ -354,7 +351,6 @@ class Resample(SampleAugmentation):
|
|||||||
def __init__(self, p=1.0, rate=8000):
|
def __init__(self, p=1.0, rate=8000):
|
||||||
super(Resample, self).__init__(p)
|
super(Resample, self).__init__(p)
|
||||||
self.rate = int_range(rate)
|
self.rate = int_range(rate)
|
||||||
self.name = "Resample"
|
|
||||||
|
|
||||||
def apply(self, sample, clock=0.0):
|
def apply(self, sample, clock=0.0):
|
||||||
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
|
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
|
||||||
@ -368,7 +364,6 @@ class NormalizeSampleRate(SampleAugmentation):
|
|||||||
def __init__(self, rate):
|
def __init__(self, rate):
|
||||||
super().__init__(p=1.0)
|
super().__init__(p=1.0)
|
||||||
self.rate = rate
|
self.rate = rate
|
||||||
self.name = "Normalize Sample Rate"
|
|
||||||
|
|
||||||
def apply(self, sample, clock=0.0):
|
def apply(self, sample, clock=0.0):
|
||||||
if sample.audio_format.rate == self.rate:
|
if sample.audio_format.rate == self.rate:
|
||||||
@ -384,7 +379,6 @@ class Volume(SampleAugmentation):
|
|||||||
def __init__(self, p=1.0, dbfs=3.0103):
|
def __init__(self, p=1.0, dbfs=3.0103):
|
||||||
super(Volume, self).__init__(p)
|
super(Volume, self).__init__(p)
|
||||||
self.target_dbfs = float_range(dbfs)
|
self.target_dbfs = float_range(dbfs)
|
||||||
self.name = "Volume"
|
|
||||||
|
|
||||||
def apply(self, sample, clock=0.0):
|
def apply(self, sample, clock=0.0):
|
||||||
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
|
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
|
||||||
@ -397,7 +391,6 @@ class Pitch(GraphAugmentation):
|
|||||||
def __init__(self, p=1.0, pitch=(1.075, 1.075, 0.125)):
|
def __init__(self, p=1.0, pitch=(1.075, 1.075, 0.125)):
|
||||||
super(Pitch, self).__init__(p, domain='spectrogram')
|
super(Pitch, self).__init__(p, domain='spectrogram')
|
||||||
self.pitch = float_range(pitch)
|
self.pitch = float_range(pitch)
|
||||||
self.name = "Pitch"
|
|
||||||
|
|
||||||
def apply(self, tensor, transcript=None, clock=0.0):
|
def apply(self, tensor, transcript=None, clock=0.0):
|
||||||
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
||||||
@ -426,7 +419,6 @@ class Tempo(GraphAugmentation):
|
|||||||
super(Tempo, self).__init__(p, domain='spectrogram')
|
super(Tempo, self).__init__(p, domain='spectrogram')
|
||||||
self.factor = float_range(factor)
|
self.factor = float_range(factor)
|
||||||
self.max_time = float(max_time)
|
self.max_time = float(max_time)
|
||||||
self.name = "Tempo"
|
|
||||||
|
|
||||||
def apply(self, tensor, transcript=None, clock=0.0):
|
def apply(self, tensor, transcript=None, clock=0.0):
|
||||||
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
||||||
@ -449,7 +441,6 @@ class Warp(GraphAugmentation):
|
|||||||
self.num_f = int_range(nf)
|
self.num_f = int_range(nf)
|
||||||
self.warp_t = float_range(wt)
|
self.warp_t = float_range(wt)
|
||||||
self.warp_f = float_range(wf)
|
self.warp_f = float_range(wf)
|
||||||
self.name = "Warp"
|
|
||||||
|
|
||||||
def apply(self, tensor, transcript=None, clock=0.0):
|
def apply(self, tensor, transcript=None, clock=0.0):
|
||||||
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
||||||
@ -477,7 +468,6 @@ class FrequencyMask(GraphAugmentation):
|
|||||||
super(FrequencyMask, self).__init__(p, domain='spectrogram')
|
super(FrequencyMask, self).__init__(p, domain='spectrogram')
|
||||||
self.n = int_range(n) # pylint: disable=invalid-name
|
self.n = int_range(n) # pylint: disable=invalid-name
|
||||||
self.size = int_range(size)
|
self.size = int_range(size)
|
||||||
self.name = "Frequency Mask"
|
|
||||||
|
|
||||||
def apply(self, tensor, transcript=None, clock=0.0):
|
def apply(self, tensor, transcript=None, clock=0.0):
|
||||||
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
||||||
@ -504,7 +494,6 @@ class TimeMask(GraphAugmentation):
|
|||||||
super(TimeMask, self).__init__(p, domain=domain)
|
super(TimeMask, self).__init__(p, domain=domain)
|
||||||
self.n = int_range(n) # pylint: disable=invalid-name
|
self.n = int_range(n) # pylint: disable=invalid-name
|
||||||
self.size = float_range(size)
|
self.size = float_range(size)
|
||||||
self.name = "Time Mask"
|
|
||||||
|
|
||||||
def apply(self, tensor, transcript=None, clock=0.0):
|
def apply(self, tensor, transcript=None, clock=0.0):
|
||||||
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
||||||
@ -534,7 +523,6 @@ class Dropout(GraphAugmentation):
|
|||||||
def __init__(self, p=1.0, domain='spectrogram', rate=0.05):
|
def __init__(self, p=1.0, domain='spectrogram', rate=0.05):
|
||||||
super(Dropout, self).__init__(p, domain=domain)
|
super(Dropout, self).__init__(p, domain=domain)
|
||||||
self.rate = float_range(rate)
|
self.rate = float_range(rate)
|
||||||
self.name = "Dropout"
|
|
||||||
|
|
||||||
def apply(self, tensor, transcript=None, clock=0.0):
|
def apply(self, tensor, transcript=None, clock=0.0):
|
||||||
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
||||||
@ -553,7 +541,6 @@ class Add(GraphAugmentation):
|
|||||||
def __init__(self, p=1.0, domain='features', stddev=5):
|
def __init__(self, p=1.0, domain='features', stddev=5):
|
||||||
super(Add, self).__init__(p, domain=domain)
|
super(Add, self).__init__(p, domain=domain)
|
||||||
self.stddev = float_range(stddev)
|
self.stddev = float_range(stddev)
|
||||||
self.name = "Add"
|
|
||||||
|
|
||||||
def apply(self, tensor, transcript=None, clock=0.0):
|
def apply(self, tensor, transcript=None, clock=0.0):
|
||||||
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
||||||
@ -567,7 +554,6 @@ class Multiply(GraphAugmentation):
|
|||||||
def __init__(self, p=1.0, domain='features', stddev=5):
|
def __init__(self, p=1.0, domain='features', stddev=5):
|
||||||
super(Multiply, self).__init__(p, domain=domain)
|
super(Multiply, self).__init__(p, domain=domain)
|
||||||
self.stddev = float_range(stddev)
|
self.stddev = float_range(stddev)
|
||||||
self.name = "Multiply"
|
|
||||||
|
|
||||||
def apply(self, tensor, transcript=None, clock=0.0):
|
def apply(self, tensor, transcript=None, clock=0.0):
|
||||||
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
||||||
|
@ -167,39 +167,37 @@ def get_value_range(value, target_type):
|
|||||||
This function converts all possible supplied values for augmentation
|
This function converts all possible supplied values for augmentation
|
||||||
into the [start,end,r] ValueRange type. The expected inputs are of the form:
|
into the [start,end,r] ValueRange type. The expected inputs are of the form:
|
||||||
|
|
||||||
<value>
|
<number>
|
||||||
<value>~<r>
|
<number>~<number>
|
||||||
<value>:<value>~<r>
|
<number>:<number>~<number>
|
||||||
|
|
||||||
Any "missing" values are filled so that ValueRange always includes [value,value,r].
|
Any "missing" values are filled so that ValueRange always includes [start,end,r].
|
||||||
"""
|
"""
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
if '~' in value:
|
if '~' in value:
|
||||||
parts = value.split('~')
|
parts = value.split('~')
|
||||||
if len(parts) == 2:
|
if len(parts) != 2:
|
||||||
value = parts[0]
|
|
||||||
r = target_type(parts[1])
|
|
||||||
elif len(parts) > 2:
|
|
||||||
raise ValueError('Cannot parse value range')
|
raise ValueError('Cannot parse value range')
|
||||||
|
value = parts[0]
|
||||||
|
r = parts[1]
|
||||||
else:
|
else:
|
||||||
# if no <r> supplied, use 0
|
r = 0 # if no <r> supplied, use 0
|
||||||
r = target_type(0)
|
|
||||||
parts = value.split(':')
|
parts = value.split(':')
|
||||||
if len(parts) > 2:
|
if len(parts) == 1:
|
||||||
|
parts.append(parts[0]) # only one <value> given, so double it
|
||||||
|
if len(parts) != 2:
|
||||||
raise ValueError('Cannot parse value range')
|
raise ValueError('Cannot parse value range')
|
||||||
elif len(parts) == 1:
|
return ValueRange(target_type(parts[0]), target_type(parts[1]), target_type(r))
|
||||||
# only one "<value>" supplied
|
|
||||||
parts.append(parts[0])
|
|
||||||
return ValueRange(target_type(parts[0]), target_type(parts[1]), r)
|
|
||||||
if isinstance(value, tuple):
|
if isinstance(value, tuple):
|
||||||
if len(value) == 2:
|
if len(value) == 2:
|
||||||
return ValueRange(target_type(value[0]), target_type(value[1]), 0)
|
return ValueRange(target_type(value[0]), target_type(value[1]), target_type(0))
|
||||||
elif len(value) == 3:
|
if len(value) == 3:
|
||||||
return ValueRange(target_type(value[0]), target_type(value[1]), target_type(value[2]))
|
return ValueRange(target_type(value[0]), target_type(value[1]), target_type(value[2]))
|
||||||
else:
|
else:
|
||||||
raise ValueError('Cannot convert to ValueRange: Wrong tuple size')
|
raise ValueError('Cannot convert to ValueRange: Wrong tuple size')
|
||||||
elif isinstance(value, int) or isinstance(value, float):
|
if isinstance(value, int) or isinstance(value, float):
|
||||||
return ValueRange(target_type(value), target_type(value), 0)
|
return ValueRange(target_type(value), target_type(value), target_type(0))
|
||||||
|
raise ValueError('Cannot convert to ValueRange: Wrong tuple size')
|
||||||
|
|
||||||
|
|
||||||
def int_range(value):
|
def int_range(value):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user