Renamed prepare_samples to augment_samples
This commit is contained in:
		
							parent
							
								
									64e14886b8
								
							
						
					
					
						commit
						a5303ccca6
					
				@ -10,7 +10,7 @@ import random
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
from deepspeech_training.util.audio import LOADABLE_AUDIO_EXTENSIONS, AUDIO_TYPE_PCM, AUDIO_TYPE_WAV
 | 
			
		||||
from deepspeech_training.util.sample_collections import SampleList, LabeledSample, samples_from_source, prepare_samples
 | 
			
		||||
from deepspeech_training.util.sample_collections import SampleList, LabeledSample, samples_from_source, augment_samples
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_samples_in_play_order():
 | 
			
		||||
@ -39,7 +39,7 @@ def get_samples_in_play_order():
 | 
			
		||||
 | 
			
		||||
def play_collection():
 | 
			
		||||
    samples = get_samples_in_play_order()
 | 
			
		||||
    samples = prepare_samples(samples,
 | 
			
		||||
    samples = augment_samples(samples,
 | 
			
		||||
                              audio_type=AUDIO_TYPE_PCM,
 | 
			
		||||
                              augmentation_specs=CLI_ARGS.augment,
 | 
			
		||||
                              process_ahead=0,
 | 
			
		||||
 | 
			
		||||
@ -13,7 +13,7 @@ from .text import text_to_char_array
 | 
			
		||||
from .flags import FLAGS
 | 
			
		||||
from .spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up, augment_sparse_warp
 | 
			
		||||
from .audio import read_frames_from_file, vad_split, pcm_to_np, DEFAULT_FORMAT
 | 
			
		||||
from .sample_collections import samples_from_sources, prepare_samples
 | 
			
		||||
from .sample_collections import samples_from_sources, augment_samples
 | 
			
		||||
from .helpers import remember_exception, MEGABYTE
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -119,7 +119,7 @@ def create_dataset(sources,
 | 
			
		||||
                   buffering=1 * MEGABYTE):
 | 
			
		||||
    def generate_values():
 | 
			
		||||
        samples = samples_from_sources(sources, buffering=buffering, labeled=True)
 | 
			
		||||
        samples = prepare_samples(samples,
 | 
			
		||||
        samples = augment_samples(samples,
 | 
			
		||||
                                  repetitions=repetitions,
 | 
			
		||||
                                  augmentation_specs=augmentation_specs,
 | 
			
		||||
                                  buffering=buffering,
 | 
			
		||||
 | 
			
		||||
@ -429,16 +429,16 @@ class PreparationContext:
 | 
			
		||||
        self.augmentations = augmentations
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
PREPARATION_CONTEXT = None
 | 
			
		||||
AUGMENTATION_CONTEXT = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _init_preparation_worker(preparation_context):
 | 
			
		||||
    global PREPARATION_CONTEXT  # pylint: disable=global-statement
 | 
			
		||||
    PREPARATION_CONTEXT = preparation_context
 | 
			
		||||
def _init_augmentation_worker(preparation_context):
 | 
			
		||||
    global AUGMENTATION_CONTEXT  # pylint: disable=global-statement
 | 
			
		||||
    AUGMENTATION_CONTEXT = preparation_context
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _prepare_sample(timed_sample, context=None):
 | 
			
		||||
    context = PREPARATION_CONTEXT if context is None else context
 | 
			
		||||
def _augment_sample(timed_sample, context=None):
 | 
			
		||||
    context = AUGMENTATION_CONTEXT if context is None else context
 | 
			
		||||
    sample, clock = timed_sample
 | 
			
		||||
    for augmentation in context.augmentations:
 | 
			
		||||
        if random.random() < augmentation.probability:
 | 
			
		||||
@ -447,7 +447,7 @@ def _prepare_sample(timed_sample, context=None):
 | 
			
		||||
    return sample
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def prepare_samples(samples,
 | 
			
		||||
def augment_samples(samples,
 | 
			
		||||
                    audio_type=AUDIO_TYPE_NP,
 | 
			
		||||
                    augmentation_specs=None,
 | 
			
		||||
                    buffering=BUFFER_SIZE,
 | 
			
		||||
@ -497,12 +497,12 @@ def prepare_samples(samples,
 | 
			
		||||
        context = PreparationContext(audio_type, augmentations)
 | 
			
		||||
        if process_ahead == 0:
 | 
			
		||||
            for timed_sample in timed_samples():
 | 
			
		||||
                yield _prepare_sample(timed_sample, context=context)
 | 
			
		||||
                yield _augment_sample(timed_sample, context=context)
 | 
			
		||||
        else:
 | 
			
		||||
            with LimitingPool(process_ahead=process_ahead,
 | 
			
		||||
                              initializer=_init_preparation_worker,
 | 
			
		||||
                              initializer=_init_augmentation_worker,
 | 
			
		||||
                              initargs=(context,)) as pool:
 | 
			
		||||
                yield from pool.imap(_prepare_sample, timed_samples())
 | 
			
		||||
                yield from pool.imap(_augment_sample, timed_samples())
 | 
			
		||||
    finally:
 | 
			
		||||
        for augmentation in augmentations:
 | 
			
		||||
            call_if_exists(augmentation, 'stop')
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user