Perform data loading I/O within worker process rather than main process by wrapping Sample
This commit is contained in:
parent
fc0b495643
commit
be39d3354d
@ -150,6 +150,12 @@ def _init_augmentation_worker(preparation_context):
|
||||
AUGMENTATION_CONTEXT = preparation_context
|
||||
|
||||
|
||||
def _load_and_augment_sample(timed_sample, context=None):
|
||||
sample, clock = timed_sample
|
||||
realized_sample = sample.unpack()
|
||||
return _augment_sample((realized_sample, clock), context)
|
||||
|
||||
|
||||
def _augment_sample(timed_sample, context=None):
|
||||
context = AUGMENTATION_CONTEXT if context is None else context
|
||||
sample, clock = timed_sample
|
||||
@ -213,12 +219,12 @@ def apply_sample_augmentations(samples,
|
||||
context = AugmentationContext(audio_type, augmentations)
|
||||
if process_ahead == 0:
|
||||
for timed_sample in timed_samples():
|
||||
yield _augment_sample(timed_sample, context=context)
|
||||
yield _load_and_augment_sample(timed_sample, context=context)
|
||||
else:
|
||||
with LimitingPool(process_ahead=process_ahead,
|
||||
initializer=_init_augmentation_worker,
|
||||
initargs=(context,)) as pool:
|
||||
yield from pool.imap(_augment_sample, timed_samples())
|
||||
yield from pool.imap(_load_and_augment_sample, timed_samples())
|
||||
finally:
|
||||
for augmentation in augmentations:
|
||||
augmentation.stop()
|
||||
|
@ -60,6 +60,27 @@ class LabeledSample(Sample):
|
||||
self.transcript = transcript
|
||||
|
||||
|
||||
class PackedSample:
|
||||
"""
|
||||
A wrapper that we can carry around in an iterator and pass to a child process in order to
|
||||
have the child process do the loading/unpacking of the sample, allowing for parallel file
|
||||
I/O.
|
||||
"""
|
||||
def __init__(self, filename, audio_type, label):
|
||||
self.filename = filename
|
||||
self.audio_type = audio_type
|
||||
self.label = label
|
||||
|
||||
def unpack(self):
|
||||
print("Unpacking sample: %s" % self.filename)
|
||||
with open_remote(self.filename, 'rb') as audio_file:
|
||||
data = audio_file.read()
|
||||
if self.label is None:
|
||||
s = Sample(self.audio_type, data, sample_id=self.filename)
|
||||
s = LabeledSample(self.audio_type, data, self.label, sample_id=self.filename)
|
||||
print("unpacked!")
|
||||
return s
|
||||
|
||||
def load_sample(filename, label=None):
|
||||
"""
|
||||
Loads audio-file as a (labeled or unlabeled) sample
|
||||
@ -70,21 +91,20 @@ def load_sample(filename, label=None):
|
||||
Filename of the audio-file to load as sample
|
||||
label : str
|
||||
Label (transcript) of the sample.
|
||||
If None: return util.audio.Sample instance
|
||||
Otherwise: return util.sample_collections.LabeledSample instance
|
||||
If None: returned result.unpack() will return util.audio.Sample instance
|
||||
Otherwise: returned result.unpack() util.sample_collections.LabeledSample instance
|
||||
|
||||
Returns
|
||||
-------
|
||||
util.audio.Sample instance if label is None, else util.sample_collections.LabeledSample instance
|
||||
util.sample_collections.PackedSample, a wrapper object, on which calling unpack() will return
|
||||
util.audio.Sample instance if label is None, else util.sample_collections.LabeledSample instance
|
||||
"""
|
||||
print("loading sample!")
|
||||
ext = os.path.splitext(filename)[1].lower()
|
||||
audio_type = get_audio_type_from_extension(ext)
|
||||
if audio_type is None:
|
||||
raise ValueError('Unknown audio type extension "{}"'.format(ext))
|
||||
with open_remote(filename, 'rb') as audio_file:
|
||||
if label is None:
|
||||
return Sample(audio_type, audio_file.read(), sample_id=filename)
|
||||
return LabeledSample(audio_type, audio_file.read(), label, sample_id=filename)
|
||||
return PackedSample(filename, audio_type, label)
|
||||
|
||||
|
||||
class DirectSDBWriter:
|
||||
|
Loading…
x
Reference in New Issue
Block a user