Renamed prepare_samples to augment_samples
This commit is contained in:
parent
64e14886b8
commit
a5303ccca6
@ -10,7 +10,7 @@ import random
|
|||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from deepspeech_training.util.audio import LOADABLE_AUDIO_EXTENSIONS, AUDIO_TYPE_PCM, AUDIO_TYPE_WAV
|
from deepspeech_training.util.audio import LOADABLE_AUDIO_EXTENSIONS, AUDIO_TYPE_PCM, AUDIO_TYPE_WAV
|
||||||
from deepspeech_training.util.sample_collections import SampleList, LabeledSample, samples_from_source, prepare_samples
|
from deepspeech_training.util.sample_collections import SampleList, LabeledSample, samples_from_source, augment_samples
|
||||||
|
|
||||||
|
|
||||||
def get_samples_in_play_order():
|
def get_samples_in_play_order():
|
||||||
@ -39,7 +39,7 @@ def get_samples_in_play_order():
|
|||||||
|
|
||||||
def play_collection():
|
def play_collection():
|
||||||
samples = get_samples_in_play_order()
|
samples = get_samples_in_play_order()
|
||||||
samples = prepare_samples(samples,
|
samples = augment_samples(samples,
|
||||||
audio_type=AUDIO_TYPE_PCM,
|
audio_type=AUDIO_TYPE_PCM,
|
||||||
augmentation_specs=CLI_ARGS.augment,
|
augmentation_specs=CLI_ARGS.augment,
|
||||||
process_ahead=0,
|
process_ahead=0,
|
||||||
|
@ -13,7 +13,7 @@ from .text import text_to_char_array
|
|||||||
from .flags import FLAGS
|
from .flags import FLAGS
|
||||||
from .spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up, augment_sparse_warp
|
from .spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up, augment_sparse_warp
|
||||||
from .audio import read_frames_from_file, vad_split, pcm_to_np, DEFAULT_FORMAT
|
from .audio import read_frames_from_file, vad_split, pcm_to_np, DEFAULT_FORMAT
|
||||||
from .sample_collections import samples_from_sources, prepare_samples
|
from .sample_collections import samples_from_sources, augment_samples
|
||||||
from .helpers import remember_exception, MEGABYTE
|
from .helpers import remember_exception, MEGABYTE
|
||||||
|
|
||||||
|
|
||||||
@ -119,7 +119,7 @@ def create_dataset(sources,
|
|||||||
buffering=1 * MEGABYTE):
|
buffering=1 * MEGABYTE):
|
||||||
def generate_values():
|
def generate_values():
|
||||||
samples = samples_from_sources(sources, buffering=buffering, labeled=True)
|
samples = samples_from_sources(sources, buffering=buffering, labeled=True)
|
||||||
samples = prepare_samples(samples,
|
samples = augment_samples(samples,
|
||||||
repetitions=repetitions,
|
repetitions=repetitions,
|
||||||
augmentation_specs=augmentation_specs,
|
augmentation_specs=augmentation_specs,
|
||||||
buffering=buffering,
|
buffering=buffering,
|
||||||
|
@ -429,16 +429,16 @@ class PreparationContext:
|
|||||||
self.augmentations = augmentations
|
self.augmentations = augmentations
|
||||||
|
|
||||||
|
|
||||||
PREPARATION_CONTEXT = None
|
AUGMENTATION_CONTEXT = None
|
||||||
|
|
||||||
|
|
||||||
def _init_preparation_worker(preparation_context):
|
def _init_augmentation_worker(preparation_context):
|
||||||
global PREPARATION_CONTEXT # pylint: disable=global-statement
|
global AUGMENTATION_CONTEXT # pylint: disable=global-statement
|
||||||
PREPARATION_CONTEXT = preparation_context
|
AUGMENTATION_CONTEXT = preparation_context
|
||||||
|
|
||||||
|
|
||||||
def _prepare_sample(timed_sample, context=None):
|
def _augment_sample(timed_sample, context=None):
|
||||||
context = PREPARATION_CONTEXT if context is None else context
|
context = AUGMENTATION_CONTEXT if context is None else context
|
||||||
sample, clock = timed_sample
|
sample, clock = timed_sample
|
||||||
for augmentation in context.augmentations:
|
for augmentation in context.augmentations:
|
||||||
if random.random() < augmentation.probability:
|
if random.random() < augmentation.probability:
|
||||||
@ -447,7 +447,7 @@ def _prepare_sample(timed_sample, context=None):
|
|||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
|
||||||
def prepare_samples(samples,
|
def augment_samples(samples,
|
||||||
audio_type=AUDIO_TYPE_NP,
|
audio_type=AUDIO_TYPE_NP,
|
||||||
augmentation_specs=None,
|
augmentation_specs=None,
|
||||||
buffering=BUFFER_SIZE,
|
buffering=BUFFER_SIZE,
|
||||||
@ -497,12 +497,12 @@ def prepare_samples(samples,
|
|||||||
context = PreparationContext(audio_type, augmentations)
|
context = PreparationContext(audio_type, augmentations)
|
||||||
if process_ahead == 0:
|
if process_ahead == 0:
|
||||||
for timed_sample in timed_samples():
|
for timed_sample in timed_samples():
|
||||||
yield _prepare_sample(timed_sample, context=context)
|
yield _augment_sample(timed_sample, context=context)
|
||||||
else:
|
else:
|
||||||
with LimitingPool(process_ahead=process_ahead,
|
with LimitingPool(process_ahead=process_ahead,
|
||||||
initializer=_init_preparation_worker,
|
initializer=_init_augmentation_worker,
|
||||||
initargs=(context,)) as pool:
|
initargs=(context,)) as pool:
|
||||||
yield from pool.imap(_prepare_sample, timed_samples())
|
yield from pool.imap(_augment_sample, timed_samples())
|
||||||
finally:
|
finally:
|
||||||
for augmentation in augmentations:
|
for augmentation in augmentations:
|
||||||
call_if_exists(augmentation, 'stop')
|
call_if_exists(augmentation, 'stop')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user