Break dependency cycle between augmentation and config
This commit is contained in:
parent
d83630fef4
commit
eab6d3f5d9
|
@ -33,6 +33,7 @@ AUDIO_TYPE_LOOKUP = {"wav": AUDIO_TYPE_WAV, "opus": AUDIO_TYPE_OPUS}
|
||||||
def build_data_set():
|
def build_data_set():
|
||||||
audio_type = AUDIO_TYPE_LOOKUP[CLI_ARGS.audio_type]
|
audio_type = AUDIO_TYPE_LOOKUP[CLI_ARGS.audio_type]
|
||||||
augmentations = parse_augmentations(CLI_ARGS.augment)
|
augmentations = parse_augmentations(CLI_ARGS.augment)
|
||||||
|
print(f"Parsed augmentations from flags: {augmentations}")
|
||||||
if any(not isinstance(a, SampleAugmentation) for a in augmentations):
|
if any(not isinstance(a, SampleAugmentation) for a in augmentations):
|
||||||
print(
|
print(
|
||||||
"Warning: Some of the specified augmentations will not get applied, as this tool only supports "
|
"Warning: Some of the specified augmentations will not get applied, as this tool only supports "
|
||||||
|
|
|
@ -52,6 +52,7 @@ def get_samples_in_play_order():
|
||||||
|
|
||||||
def play_collection():
|
def play_collection():
|
||||||
augmentations = parse_augmentations(CLI_ARGS.augment)
|
augmentations = parse_augmentations(CLI_ARGS.augment)
|
||||||
|
print(f"Parsed augmentations from flags: {augmentations}")
|
||||||
if any(not isinstance(a, SampleAugmentation) for a in augmentations):
|
if any(not isinstance(a, SampleAugmentation) for a in augmentations):
|
||||||
print("Warning: Some of the augmentations cannot be simulated by this command.")
|
print("Warning: Some of the augmentations cannot be simulated by this command.")
|
||||||
samples = get_samples_in_play_order()
|
samples = get_samples_in_play_order()
|
||||||
|
|
|
@ -15,7 +15,6 @@ from .audio import (
|
||||||
max_dbfs,
|
max_dbfs,
|
||||||
normalize_audio,
|
normalize_audio,
|
||||||
)
|
)
|
||||||
from .config import log_info
|
|
||||||
from .helpers import (
|
from .helpers import (
|
||||||
MEGABYTE,
|
MEGABYTE,
|
||||||
LimitingPool,
|
LimitingPool,
|
||||||
|
@ -76,10 +75,10 @@ class GraphAugmentation(Augmentation):
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def units_per_ms(self):
|
def units_per_ms(self):
|
||||||
from .flags import FLAGS # pylint: disable=import-outside-toplevel
|
from .config import Config # pylint: disable=import-outside-toplevel
|
||||||
|
|
||||||
return (
|
return (
|
||||||
FLAGS.audio_sample_rate / 1000.0
|
Config.audio_sample_rate / 1000.0
|
||||||
if self.domain == "signal"
|
if self.domain == "signal"
|
||||||
else 1.0 / Config.feature_win_step
|
else 1.0 / Config.feature_win_step
|
||||||
)
|
)
|
||||||
|
@ -123,11 +122,6 @@ def parse_augmentation(augmentation_spec):
|
||||||
kwargs[pair[0]] = pair[1]
|
kwargs[pair[0]] = pair[1]
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unable to parse augmentation value assignment")
|
raise ValueError("Unable to parse augmentation value assignment")
|
||||||
log_info(
|
|
||||||
"Processed augmentation type: [{}] with parameter settings: {}".format(
|
|
||||||
augmentation_cls.__name__, kwargs
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return augmentation_cls(*args, **kwargs)
|
return augmentation_cls(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -555,7 +555,7 @@ def initialize_globals():
|
||||||
c.augmentations = parse_augmentations(c.augment)
|
c.augmentations = parse_augmentations(c.augment)
|
||||||
print(f"Parsed augmentations from flags: {c.augmentations}")
|
print(f"Parsed augmentations from flags: {c.augmentations}")
|
||||||
if c.augmentations and c.feature_cache and c.cache_for_epochs == 0:
|
if c.augmentations and c.feature_cache and c.cache_for_epochs == 0:
|
||||||
log_warn(
|
print(
|
||||||
"Due to current feature-cache settings the exact same sample augmentations of the first "
|
"Due to current feature-cache settings the exact same sample augmentations of the first "
|
||||||
"epoch will be repeated on all following epochs. This could lead to unintended over-fitting. "
|
"epoch will be repeated on all following epochs. This could lead to unintended over-fitting. "
|
||||||
"You could use --cache_for_epochs <n_epochs> to invalidate the cache after a given number of epochs."
|
"You could use --cache_for_epochs <n_epochs> to invalidate the cache after a given number of epochs."
|
||||||
|
@ -568,7 +568,7 @@ def initialize_globals():
|
||||||
|
|
||||||
# Caching
|
# Caching
|
||||||
if c.cache_for_epochs == 1:
|
if c.cache_for_epochs == 1:
|
||||||
log_warn(
|
print(
|
||||||
"--cache_for_epochs == 1 is (re-)creating the feature cache on every epoch but will never use it."
|
"--cache_for_epochs == 1 is (re-)creating the feature cache on every epoch but will never use it."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue