Perform data loading I/O within worker process rather than main process by wrapping Sample
This commit is contained in:
parent
fc0b495643
commit
be39d3354d
@ -150,6 +150,12 @@ def _init_augmentation_worker(preparation_context):
|
|||||||
AUGMENTATION_CONTEXT = preparation_context
|
AUGMENTATION_CONTEXT = preparation_context
|
||||||
|
|
||||||
|
|
||||||
|
def _load_and_augment_sample(timed_sample, context=None):
|
||||||
|
sample, clock = timed_sample
|
||||||
|
realized_sample = sample.unpack()
|
||||||
|
return _augment_sample((realized_sample, clock), context)
|
||||||
|
|
||||||
|
|
||||||
def _augment_sample(timed_sample, context=None):
|
def _augment_sample(timed_sample, context=None):
|
||||||
context = AUGMENTATION_CONTEXT if context is None else context
|
context = AUGMENTATION_CONTEXT if context is None else context
|
||||||
sample, clock = timed_sample
|
sample, clock = timed_sample
|
||||||
@ -213,12 +219,12 @@ def apply_sample_augmentations(samples,
|
|||||||
context = AugmentationContext(audio_type, augmentations)
|
context = AugmentationContext(audio_type, augmentations)
|
||||||
if process_ahead == 0:
|
if process_ahead == 0:
|
||||||
for timed_sample in timed_samples():
|
for timed_sample in timed_samples():
|
||||||
yield _augment_sample(timed_sample, context=context)
|
yield _load_and_augment_sample(timed_sample, context=context)
|
||||||
else:
|
else:
|
||||||
with LimitingPool(process_ahead=process_ahead,
|
with LimitingPool(process_ahead=process_ahead,
|
||||||
initializer=_init_augmentation_worker,
|
initializer=_init_augmentation_worker,
|
||||||
initargs=(context,)) as pool:
|
initargs=(context,)) as pool:
|
||||||
yield from pool.imap(_augment_sample, timed_samples())
|
yield from pool.imap(_load_and_augment_sample, timed_samples())
|
||||||
finally:
|
finally:
|
||||||
for augmentation in augmentations:
|
for augmentation in augmentations:
|
||||||
augmentation.stop()
|
augmentation.stop()
|
||||||
|
@ -60,6 +60,27 @@ class LabeledSample(Sample):
|
|||||||
self.transcript = transcript
|
self.transcript = transcript
|
||||||
|
|
||||||
|
|
||||||
|
class PackedSample:
|
||||||
|
"""
|
||||||
|
A wrapper that we can carry around in an iterator and pass to a child process in order to
|
||||||
|
have the child process do the loading/unpacking of the sample, allowing for parallel file
|
||||||
|
I/O.
|
||||||
|
"""
|
||||||
|
def __init__(self, filename, audio_type, label):
|
||||||
|
self.filename = filename
|
||||||
|
self.audio_type = audio_type
|
||||||
|
self.label = label
|
||||||
|
|
||||||
|
def unpack(self):
|
||||||
|
print("Unpacking sample: %s" % self.filename)
|
||||||
|
with open_remote(self.filename, 'rb') as audio_file:
|
||||||
|
data = audio_file.read()
|
||||||
|
if self.label is None:
|
||||||
|
s = Sample(self.audio_type, data, sample_id=self.filename)
|
||||||
|
s = LabeledSample(self.audio_type, data, self.label, sample_id=self.filename)
|
||||||
|
print("unpacked!")
|
||||||
|
return s
|
||||||
|
|
||||||
def load_sample(filename, label=None):
|
def load_sample(filename, label=None):
|
||||||
"""
|
"""
|
||||||
Loads audio-file as a (labeled or unlabeled) sample
|
Loads audio-file as a (labeled or unlabeled) sample
|
||||||
@ -70,21 +91,20 @@ def load_sample(filename, label=None):
|
|||||||
Filename of the audio-file to load as sample
|
Filename of the audio-file to load as sample
|
||||||
label : str
|
label : str
|
||||||
Label (transcript) of the sample.
|
Label (transcript) of the sample.
|
||||||
If None: return util.audio.Sample instance
|
If None: returned result.unpack() will return util.audio.Sample instance
|
||||||
Otherwise: return util.sample_collections.LabeledSample instance
|
Otherwise: returned result.unpack() util.sample_collections.LabeledSample instance
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
|
util.sample_collections.PackedSample, a wrapper object, on which calling unpack() will return
|
||||||
util.audio.Sample instance if label is None, else util.sample_collections.LabeledSample instance
|
util.audio.Sample instance if label is None, else util.sample_collections.LabeledSample instance
|
||||||
"""
|
"""
|
||||||
|
print("loading sample!")
|
||||||
ext = os.path.splitext(filename)[1].lower()
|
ext = os.path.splitext(filename)[1].lower()
|
||||||
audio_type = get_audio_type_from_extension(ext)
|
audio_type = get_audio_type_from_extension(ext)
|
||||||
if audio_type is None:
|
if audio_type is None:
|
||||||
raise ValueError('Unknown audio type extension "{}"'.format(ext))
|
raise ValueError('Unknown audio type extension "{}"'.format(ext))
|
||||||
with open_remote(filename, 'rb') as audio_file:
|
return PackedSample(filename, audio_type, label)
|
||||||
if label is None:
|
|
||||||
return Sample(audio_type, audio_file.read(), sample_id=filename)
|
|
||||||
return LabeledSample(audio_type, audio_file.read(), label, sample_id=filename)
|
|
||||||
|
|
||||||
|
|
||||||
class DirectSDBWriter:
|
class DirectSDBWriter:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user