Renamed prepare_samples to augment_samples
This commit is contained in:
parent
64e14886b8
commit
a5303ccca6
@ -10,7 +10,7 @@ import random
|
||||
import argparse
|
||||
|
||||
from deepspeech_training.util.audio import LOADABLE_AUDIO_EXTENSIONS, AUDIO_TYPE_PCM, AUDIO_TYPE_WAV
|
||||
from deepspeech_training.util.sample_collections import SampleList, LabeledSample, samples_from_source, prepare_samples
|
||||
from deepspeech_training.util.sample_collections import SampleList, LabeledSample, samples_from_source, augment_samples
|
||||
|
||||
|
||||
def get_samples_in_play_order():
|
||||
@ -39,7 +39,7 @@ def get_samples_in_play_order():
|
||||
|
||||
def play_collection():
|
||||
samples = get_samples_in_play_order()
|
||||
samples = prepare_samples(samples,
|
||||
samples = augment_samples(samples,
|
||||
audio_type=AUDIO_TYPE_PCM,
|
||||
augmentation_specs=CLI_ARGS.augment,
|
||||
process_ahead=0,
|
||||
|
@ -13,7 +13,7 @@ from .text import text_to_char_array
|
||||
from .flags import FLAGS
|
||||
from .spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up, augment_sparse_warp
|
||||
from .audio import read_frames_from_file, vad_split, pcm_to_np, DEFAULT_FORMAT
|
||||
from .sample_collections import samples_from_sources, prepare_samples
|
||||
from .sample_collections import samples_from_sources, augment_samples
|
||||
from .helpers import remember_exception, MEGABYTE
|
||||
|
||||
|
||||
@ -119,7 +119,7 @@ def create_dataset(sources,
|
||||
buffering=1 * MEGABYTE):
|
||||
def generate_values():
|
||||
samples = samples_from_sources(sources, buffering=buffering, labeled=True)
|
||||
samples = prepare_samples(samples,
|
||||
samples = augment_samples(samples,
|
||||
repetitions=repetitions,
|
||||
augmentation_specs=augmentation_specs,
|
||||
buffering=buffering,
|
||||
|
@ -429,16 +429,16 @@ class PreparationContext:
|
||||
self.augmentations = augmentations
|
||||
|
||||
|
||||
PREPARATION_CONTEXT = None
|
||||
AUGMENTATION_CONTEXT = None
|
||||
|
||||
|
||||
def _init_preparation_worker(preparation_context):
|
||||
global PREPARATION_CONTEXT # pylint: disable=global-statement
|
||||
PREPARATION_CONTEXT = preparation_context
|
||||
def _init_augmentation_worker(preparation_context):
|
||||
global AUGMENTATION_CONTEXT # pylint: disable=global-statement
|
||||
AUGMENTATION_CONTEXT = preparation_context
|
||||
|
||||
|
||||
def _prepare_sample(timed_sample, context=None):
|
||||
context = PREPARATION_CONTEXT if context is None else context
|
||||
def _augment_sample(timed_sample, context=None):
|
||||
context = AUGMENTATION_CONTEXT if context is None else context
|
||||
sample, clock = timed_sample
|
||||
for augmentation in context.augmentations:
|
||||
if random.random() < augmentation.probability:
|
||||
@ -447,7 +447,7 @@ def _prepare_sample(timed_sample, context=None):
|
||||
return sample
|
||||
|
||||
|
||||
def prepare_samples(samples,
|
||||
def augment_samples(samples,
|
||||
audio_type=AUDIO_TYPE_NP,
|
||||
augmentation_specs=None,
|
||||
buffering=BUFFER_SIZE,
|
||||
@ -497,12 +497,12 @@ def prepare_samples(samples,
|
||||
context = PreparationContext(audio_type, augmentations)
|
||||
if process_ahead == 0:
|
||||
for timed_sample in timed_samples():
|
||||
yield _prepare_sample(timed_sample, context=context)
|
||||
yield _augment_sample(timed_sample, context=context)
|
||||
else:
|
||||
with LimitingPool(process_ahead=process_ahead,
|
||||
initializer=_init_preparation_worker,
|
||||
initializer=_init_augmentation_worker,
|
||||
initargs=(context,)) as pool:
|
||||
yield from pool.imap(_prepare_sample, timed_samples())
|
||||
yield from pool.imap(_augment_sample, timed_samples())
|
||||
finally:
|
||||
for augmentation in augmentations:
|
||||
call_if_exists(augmentation, 'stop')
|
||||
|
Loading…
x
Reference in New Issue
Block a user