From a5303ccca6973656b5fe464a22d463e0845f03b1 Mon Sep 17 00:00:00 2001 From: Tilman Kamp <5991088+tilmankamp@users.noreply.github.com> Date: Thu, 14 May 2020 16:50:18 +0200 Subject: [PATCH] Renamed prepare_samples to augment_samples --- bin/play.py | 4 ++-- training/deepspeech_training/util/feeding.py | 4 ++-- .../util/sample_collections.py | 20 +++++++++---------- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/bin/play.py b/bin/play.py index 8f985da9..7d19a790 100755 --- a/bin/play.py +++ b/bin/play.py @@ -10,7 +10,7 @@ import random import argparse from deepspeech_training.util.audio import LOADABLE_AUDIO_EXTENSIONS, AUDIO_TYPE_PCM, AUDIO_TYPE_WAV -from deepspeech_training.util.sample_collections import SampleList, LabeledSample, samples_from_source, prepare_samples +from deepspeech_training.util.sample_collections import SampleList, LabeledSample, samples_from_source, augment_samples def get_samples_in_play_order(): @@ -39,7 +39,7 @@ def get_samples_in_play_order(): def play_collection(): samples = get_samples_in_play_order() - samples = prepare_samples(samples, + samples = augment_samples(samples, audio_type=AUDIO_TYPE_PCM, augmentation_specs=CLI_ARGS.augment, process_ahead=0, diff --git a/training/deepspeech_training/util/feeding.py b/training/deepspeech_training/util/feeding.py index 1044c8f6..5cbd4833 100644 --- a/training/deepspeech_training/util/feeding.py +++ b/training/deepspeech_training/util/feeding.py @@ -13,7 +13,7 @@ from .text import text_to_char_array from .flags import FLAGS from .spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up, augment_sparse_warp from .audio import read_frames_from_file, vad_split, pcm_to_np, DEFAULT_FORMAT -from .sample_collections import samples_from_sources, prepare_samples +from .sample_collections import samples_from_sources, augment_samples from .helpers import remember_exception, MEGABYTE @@ -119,7 +119,7 @@ def create_dataset(sources, buffering=1 * MEGABYTE): def generate_values(): samples = samples_from_sources(sources, buffering=buffering, labeled=True) - samples = prepare_samples(samples, + samples = augment_samples(samples, repetitions=repetitions, augmentation_specs=augmentation_specs, buffering=buffering, diff --git a/training/deepspeech_training/util/sample_collections.py b/training/deepspeech_training/util/sample_collections.py index 97dd0b6f..da9bcdf9 100644 --- a/training/deepspeech_training/util/sample_collections.py +++ b/training/deepspeech_training/util/sample_collections.py @@ -429,16 +429,16 @@ class PreparationContext: self.augmentations = augmentations -PREPARATION_CONTEXT = None +AUGMENTATION_CONTEXT = None -def _init_preparation_worker(preparation_context): - global PREPARATION_CONTEXT # pylint: disable=global-statement - PREPARATION_CONTEXT = preparation_context +def _init_augmentation_worker(preparation_context): + global AUGMENTATION_CONTEXT # pylint: disable=global-statement + AUGMENTATION_CONTEXT = preparation_context -def _prepare_sample(timed_sample, context=None): - context = PREPARATION_CONTEXT if context is None else context +def _augment_sample(timed_sample, context=None): + context = AUGMENTATION_CONTEXT if context is None else context sample, clock = timed_sample for augmentation in context.augmentations: if random.random() < augmentation.probability: @@ -447,7 +447,7 @@ def _prepare_sample(timed_sample, context=None): return sample -def prepare_samples(samples, +def augment_samples(samples, audio_type=AUDIO_TYPE_NP, augmentation_specs=None, buffering=BUFFER_SIZE, @@ -497,12 +497,12 @@ def prepare_samples(samples, context = PreparationContext(audio_type, augmentations) if process_ahead == 0: for timed_sample in timed_samples(): - yield _prepare_sample(timed_sample, context=context) + yield _augment_sample(timed_sample, context=context) else: with LimitingPool(process_ahead=process_ahead, - initializer=_init_preparation_worker, + initializer=_init_augmentation_worker, initargs=(context,)) as pool: - yield from pool.imap(_prepare_sample, timed_samples()) + yield from pool.imap(_augment_sample, timed_samples()) finally: for augmentation in augmentations: call_if_exists(augmentation, 'stop')