diff --git a/training/deepspeech_training/util/augmentations.py b/training/deepspeech_training/util/augmentations.py index 941c17f2..0934fbd5 100644 --- a/training/deepspeech_training/util/augmentations.py +++ b/training/deepspeech_training/util/augmentations.py @@ -150,6 +150,12 @@ def _init_augmentation_worker(preparation_context): AUGMENTATION_CONTEXT = preparation_context +def _load_and_augment_sample(timed_sample, context=None): + sample, clock = timed_sample + realized_sample = sample.unpack() + return _augment_sample((realized_sample, clock), context) + + def _augment_sample(timed_sample, context=None): context = AUGMENTATION_CONTEXT if context is None else context sample, clock = timed_sample @@ -213,12 +219,12 @@ def apply_sample_augmentations(samples, context = AugmentationContext(audio_type, augmentations) if process_ahead == 0: for timed_sample in timed_samples(): - yield _augment_sample(timed_sample, context=context) + yield _load_and_augment_sample(timed_sample, context=context) else: with LimitingPool(process_ahead=process_ahead, initializer=_init_augmentation_worker, initargs=(context,)) as pool: - yield from pool.imap(_augment_sample, timed_samples()) + yield from pool.imap(_load_and_augment_sample, timed_samples()) finally: for augmentation in augmentations: augmentation.stop() diff --git a/training/deepspeech_training/util/sample_collections.py b/training/deepspeech_training/util/sample_collections.py index d9856484..23b0422b 100644 --- a/training/deepspeech_training/util/sample_collections.py +++ b/training/deepspeech_training/util/sample_collections.py @@ -60,6 +60,27 @@ class LabeledSample(Sample): self.transcript = transcript +class PackedSample: + """ + A wrapper that we can carry around in an iterator and pass to a child process in order to + have the child process do the loading/unpacking of the sample, allowing for parallel file + I/O. + """ + def __init__(self, filename, audio_type, label): + self.filename = filename + self.audio_type = audio_type + self.label = label + + def unpack(self): + print("Unpacking sample: %s" % self.filename) + with open_remote(self.filename, 'rb') as audio_file: + data = audio_file.read() + if self.label is None: + s = Sample(self.audio_type, data, sample_id=self.filename) + s = LabeledSample(self.audio_type, data, self.label, sample_id=self.filename) + print("unpacked!") + return s + def load_sample(filename, label=None): """ Loads audio-file as a (labeled or unlabeled) sample @@ -70,21 +91,20 @@ def load_sample(filename, label=None): Filename of the audio-file to load as sample label : str Label (transcript) of the sample. - If None: return util.audio.Sample instance - Otherwise: return util.sample_collections.LabeledSample instance + If None: returned result.unpack() will return util.audio.Sample instance + Otherwise: returned result.unpack() util.sample_collections.LabeledSample instance Returns ------- - util.audio.Sample instance if label is None, else util.sample_collections.LabeledSample instance + util.sample_collections.PackedSample, a wrapper object, on which calling unpack() will return + util.audio.Sample instance if label is None, else util.sample_collections.LabeledSample instance """ + print("loading sample!") ext = os.path.splitext(filename)[1].lower() audio_type = get_audio_type_from_extension(ext) if audio_type is None: raise ValueError('Unknown audio type extension "{}"'.format(ext)) - with open_remote(filename, 'rb') as audio_file: - if label is None: - return Sample(audio_type, audio_file.read(), sample_id=filename) - return LabeledSample(audio_type, audio_file.read(), label, sample_id=filename) + return PackedSample(filename, audio_type, label) class DirectSDBWriter: