Renamed prepare_samples to augment_samples

This commit is contained in:
Tilman Kamp 2020-05-14 16:50:18 +02:00
parent 64e14886b8
commit a5303ccca6
3 changed files with 14 additions and 14 deletions

View File

@ -10,7 +10,7 @@ import random
import argparse
from deepspeech_training.util.audio import LOADABLE_AUDIO_EXTENSIONS, AUDIO_TYPE_PCM, AUDIO_TYPE_WAV
from deepspeech_training.util.sample_collections import SampleList, LabeledSample, samples_from_source, prepare_samples
from deepspeech_training.util.sample_collections import SampleList, LabeledSample, samples_from_source, augment_samples
def get_samples_in_play_order():
@ -39,7 +39,7 @@ def get_samples_in_play_order():
def play_collection():
samples = get_samples_in_play_order()
samples = prepare_samples(samples,
samples = augment_samples(samples,
audio_type=AUDIO_TYPE_PCM,
augmentation_specs=CLI_ARGS.augment,
process_ahead=0,

View File

@ -13,7 +13,7 @@ from .text import text_to_char_array
from .flags import FLAGS
from .spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up, augment_sparse_warp
from .audio import read_frames_from_file, vad_split, pcm_to_np, DEFAULT_FORMAT
from .sample_collections import samples_from_sources, prepare_samples
from .sample_collections import samples_from_sources, augment_samples
from .helpers import remember_exception, MEGABYTE
@ -119,7 +119,7 @@ def create_dataset(sources,
buffering=1 * MEGABYTE):
def generate_values():
samples = samples_from_sources(sources, buffering=buffering, labeled=True)
samples = prepare_samples(samples,
samples = augment_samples(samples,
repetitions=repetitions,
augmentation_specs=augmentation_specs,
buffering=buffering,

View File

@ -429,16 +429,16 @@ class PreparationContext:
self.augmentations = augmentations
PREPARATION_CONTEXT = None
AUGMENTATION_CONTEXT = None
def _init_preparation_worker(preparation_context):
global PREPARATION_CONTEXT # pylint: disable=global-statement
PREPARATION_CONTEXT = preparation_context
def _init_augmentation_worker(preparation_context):
global AUGMENTATION_CONTEXT # pylint: disable=global-statement
AUGMENTATION_CONTEXT = preparation_context
def _prepare_sample(timed_sample, context=None):
context = PREPARATION_CONTEXT if context is None else context
def _augment_sample(timed_sample, context=None):
context = AUGMENTATION_CONTEXT if context is None else context
sample, clock = timed_sample
for augmentation in context.augmentations:
if random.random() < augmentation.probability:
@ -447,7 +447,7 @@ def _prepare_sample(timed_sample, context=None):
return sample
def prepare_samples(samples,
def augment_samples(samples,
audio_type=AUDIO_TYPE_NP,
augmentation_specs=None,
buffering=BUFFER_SIZE,
@ -497,12 +497,12 @@ def prepare_samples(samples,
context = PreparationContext(audio_type, augmentations)
if process_ahead == 0:
for timed_sample in timed_samples():
yield _prepare_sample(timed_sample, context=context)
yield _augment_sample(timed_sample, context=context)
else:
with LimitingPool(process_ahead=process_ahead,
initializer=_init_preparation_worker,
initializer=_init_augmentation_worker,
initargs=(context,)) as pool:
yield from pool.imap(_prepare_sample, timed_samples())
yield from pool.imap(_augment_sample, timed_samples())
finally:
for augmentation in augmentations:
call_if_exists(augmentation, 'stop')