Perform data loading I/O within worker process rather than main process by wrapping Sample

This commit is contained in:
CatalinVoss 2020-11-12 21:46:39 -08:00
parent fc0b495643
commit be39d3354d
2 changed files with 35 additions and 9 deletions

View File

@ -150,6 +150,12 @@ def _init_augmentation_worker(preparation_context):
AUGMENTATION_CONTEXT = preparation_context
def _load_and_augment_sample(timed_sample, context=None):
sample, clock = timed_sample
realized_sample = sample.unpack()
return _augment_sample((realized_sample, clock), context)
def _augment_sample(timed_sample, context=None):
context = AUGMENTATION_CONTEXT if context is None else context
sample, clock = timed_sample
@ -213,12 +219,12 @@ def apply_sample_augmentations(samples,
context = AugmentationContext(audio_type, augmentations)
if process_ahead == 0:
for timed_sample in timed_samples():
yield _augment_sample(timed_sample, context=context)
yield _load_and_augment_sample(timed_sample, context=context)
else:
with LimitingPool(process_ahead=process_ahead,
initializer=_init_augmentation_worker,
initargs=(context,)) as pool:
yield from pool.imap(_augment_sample, timed_samples())
yield from pool.imap(_load_and_augment_sample, timed_samples())
finally:
for augmentation in augmentations:
augmentation.stop()

View File

@ -60,6 +60,27 @@ class LabeledSample(Sample):
self.transcript = transcript
class PackedSample:
"""
A wrapper that we can carry around in an iterator and pass to a child process in order to
have the child process do the loading/unpacking of the sample, allowing for parallel file
I/O.
"""
def __init__(self, filename, audio_type, label):
self.filename = filename
self.audio_type = audio_type
self.label = label
def unpack(self):
print("Unpacking sample: %s" % self.filename)
with open_remote(self.filename, 'rb') as audio_file:
data = audio_file.read()
if self.label is None:
s = Sample(self.audio_type, data, sample_id=self.filename)
s = LabeledSample(self.audio_type, data, self.label, sample_id=self.filename)
print("unpacked!")
return s
def load_sample(filename, label=None):
"""
Loads audio-file as a (labeled or unlabeled) sample
@ -70,21 +91,20 @@ def load_sample(filename, label=None):
Filename of the audio-file to load as sample
label : str
Label (transcript) of the sample.
If None: return util.audio.Sample instance
Otherwise: return util.sample_collections.LabeledSample instance
If None: returned result.unpack() will return util.audio.Sample instance
Otherwise: returned result.unpack() util.sample_collections.LabeledSample instance
Returns
-------
util.audio.Sample instance if label is None, else util.sample_collections.LabeledSample instance
util.sample_collections.PackedSample, a wrapper object, on which calling unpack() will return
util.audio.Sample instance if label is None, else util.sample_collections.LabeledSample instance
"""
print("loading sample!")
ext = os.path.splitext(filename)[1].lower()
audio_type = get_audio_type_from_extension(ext)
if audio_type is None:
raise ValueError('Unknown audio type extension "{}"'.format(ext))
with open_remote(filename, 'rb') as audio_file:
if label is None:
return Sample(audio_type, audio_file.read(), sample_id=filename)
return LabeledSample(audio_type, audio_file.read(), label, sample_id=filename)
return PackedSample(filename, audio_type, label)
class DirectSDBWriter: