Break dependency cycle between augmentation and config

This commit is contained in:
Reuben Morais 2021-05-19 13:28:43 +02:00
parent d83630fef4
commit eab6d3f5d9
4 changed files with 6 additions and 10 deletions

View File

@ -33,6 +33,7 @@ AUDIO_TYPE_LOOKUP = {"wav": AUDIO_TYPE_WAV, "opus": AUDIO_TYPE_OPUS}
def build_data_set():
audio_type = AUDIO_TYPE_LOOKUP[CLI_ARGS.audio_type]
augmentations = parse_augmentations(CLI_ARGS.augment)
print(f"Parsed augmentations from flags: {augmentations}")
if any(not isinstance(a, SampleAugmentation) for a in augmentations):
print(
"Warning: Some of the specified augmentations will not get applied, as this tool only supports "

View File

@ -52,6 +52,7 @@ def get_samples_in_play_order():
def play_collection():
augmentations = parse_augmentations(CLI_ARGS.augment)
print(f"Parsed augmentations from flags: {augmentations}")
if any(not isinstance(a, SampleAugmentation) for a in augmentations):
print("Warning: Some of the augmentations cannot be simulated by this command.")
samples = get_samples_in_play_order()

View File

@ -15,7 +15,6 @@ from .audio import (
max_dbfs,
normalize_audio,
)
from .config import log_info
from .helpers import (
MEGABYTE,
LimitingPool,
@ -76,10 +75,10 @@ class GraphAugmentation(Augmentation):
return tensor
def units_per_ms(self):
from .flags import FLAGS # pylint: disable=import-outside-toplevel
from .config import Config # pylint: disable=import-outside-toplevel
return (
FLAGS.audio_sample_rate / 1000.0
Config.audio_sample_rate / 1000.0
if self.domain == "signal"
else 1.0 / Config.feature_win_step
)
@ -123,11 +122,6 @@ def parse_augmentation(augmentation_spec):
kwargs[pair[0]] = pair[1]
else:
raise ValueError("Unable to parse augmentation value assignment")
log_info(
"Processed augmentation type: [{}] with parameter settings: {}".format(
augmentation_cls.__name__, kwargs
)
)
return augmentation_cls(*args, **kwargs)

View File

@ -555,7 +555,7 @@ def initialize_globals():
c.augmentations = parse_augmentations(c.augment)
print(f"Parsed augmentations from flags: {c.augmentations}")
if c.augmentations and c.feature_cache and c.cache_for_epochs == 0:
log_warn(
print(
"Due to current feature-cache settings the exact same sample augmentations of the first "
"epoch will be repeated on all following epochs. This could lead to unintended over-fitting. "
"You could use --cache_for_epochs <n_epochs> to invalidate the cache after a given number of epochs."
@ -568,7 +568,7 @@ def initialize_globals():
# Caching
if c.cache_for_epochs == 1:
log_warn(
print(
"--cache_for_epochs == 1 is (re-)creating the feature cache on every epoch but will never use it."
)