Fix #2830 - Support for unlabeled samples
This commit is contained in:
parent
5740d64e6e
commit
41da7b2870
|
@ -26,8 +26,8 @@ AUDIO_TYPE_LOOKUP = {
|
||||||
|
|
||||||
def build_sdb():
|
def build_sdb():
|
||||||
audio_type = AUDIO_TYPE_LOOKUP[CLI_ARGS.audio_type]
|
audio_type = AUDIO_TYPE_LOOKUP[CLI_ARGS.audio_type]
|
||||||
with DirectSDBWriter(CLI_ARGS.target, audio_type=audio_type) as sdb_writer:
|
with DirectSDBWriter(CLI_ARGS.target, audio_type=audio_type, labeled=not CLI_ARGS.unlabeled) as sdb_writer:
|
||||||
samples = samples_from_files(CLI_ARGS.sources)
|
samples = samples_from_files(CLI_ARGS.sources, labeled=not CLI_ARGS.unlabeled)
|
||||||
bar = progressbar.ProgressBar(max_value=len(samples), widgets=SIMPLE_BAR)
|
bar = progressbar.ProgressBar(max_value=len(samples), widgets=SIMPLE_BAR)
|
||||||
for sample in bar(change_audio_types(samples, audio_type=audio_type, processes=CLI_ARGS.workers)):
|
for sample in bar(change_audio_types(samples, audio_type=audio_type, processes=CLI_ARGS.workers)):
|
||||||
sdb_writer.add(sample)
|
sdb_writer.add(sample)
|
||||||
|
@ -36,13 +36,18 @@ def build_sdb():
|
||||||
def handle_args():
|
def handle_args():
|
||||||
parser = argparse.ArgumentParser(description='Tool for building Sample Databases (SDB files) '
|
parser = argparse.ArgumentParser(description='Tool for building Sample Databases (SDB files) '
|
||||||
'from DeepSpeech CSV files and other SDB files')
|
'from DeepSpeech CSV files and other SDB files')
|
||||||
parser.add_argument('sources', nargs='+', help='Source CSV and/or SDB files - '
|
parser.add_argument('sources', nargs='+',
|
||||||
'Note: For getting a correctly ordered target SDB, source SDBs have '
|
help='Source CSV and/or SDB files - '
|
||||||
'to have their samples already ordered from shortest to longest.')
|
'Note: For getting a correctly ordered target SDB, source SDBs have to have their samples '
|
||||||
|
'already ordered from shortest to longest.')
|
||||||
parser.add_argument('target', help='SDB file to create')
|
parser.add_argument('target', help='SDB file to create')
|
||||||
parser.add_argument('--audio-type', default='opus', choices=AUDIO_TYPE_LOOKUP.keys(),
|
parser.add_argument('--audio-type', default='opus', choices=AUDIO_TYPE_LOOKUP.keys(),
|
||||||
help='Audio representation inside target SDB')
|
help='Audio representation inside target SDB')
|
||||||
parser.add_argument('--workers', type=int, default=None, help='Number of encoding SDB workers')
|
parser.add_argument('--workers', type=int, default=None,
|
||||||
|
help='Number of encoding SDB workers')
|
||||||
|
parser.add_argument('--unlabeled', action='store_true',
|
||||||
|
help='If to build an SDB with unlabeled (audio only) samples - '
|
||||||
|
'typically used for building noise augmentation corpora')
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@ sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||||
import random
|
import random
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from util.sample_collections import samples_from_file
|
from util.sample_collections import samples_from_file, LabeledSample
|
||||||
from util.audio import AUDIO_TYPE_PCM
|
from util.audio import AUDIO_TYPE_PCM
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,6 +28,7 @@ def play_sample(samples, index):
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
sample = samples[index]
|
sample = samples[index]
|
||||||
print('Sample "{}"'.format(sample.sample_id))
|
print('Sample "{}"'.format(sample.sample_id))
|
||||||
|
if isinstance(sample, LabeledSample):
|
||||||
print(' "{}"'.format(sample.transcript))
|
print(' "{}"'.format(sample.transcript))
|
||||||
sample.change_audio_type(AUDIO_TYPE_PCM)
|
sample.change_audio_type(AUDIO_TYPE_PCM)
|
||||||
rate, channels, width = sample.audio_format
|
rate, channels, width = sample.audio_format
|
||||||
|
|
|
@ -26,31 +26,45 @@ OPUS_CHUNK_LEN_SIZE = 2
|
||||||
|
|
||||||
|
|
||||||
class Sample:
|
class Sample:
|
||||||
"""Represents in-memory audio data of a certain (convertible) representation.
|
|
||||||
Attributes:
|
|
||||||
audio_type (str): See `__init__`.
|
|
||||||
audio_format (tuple:(int, int, int)): See `__init__`.
|
|
||||||
audio (obj): Audio data represented as indicated by `audio_type`
|
|
||||||
duration (float): Audio duration of the sample in seconds
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, audio_type, raw_data, audio_format=None):
|
Represents in-memory audio data of a certain (convertible) representation.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
audio_type : str
|
||||||
|
See `__init__`.
|
||||||
|
audio_format : tuple:(int, int, int)
|
||||||
|
See `__init__`.
|
||||||
|
audio : binary
|
||||||
|
Audio data represented as indicated by `audio_type`
|
||||||
|
duration : float
|
||||||
|
Audio duration of the sample in seconds
|
||||||
"""
|
"""
|
||||||
Creates a Sample from a raw audio representation.
|
def __init__(self, audio_type, raw_data, audio_format=None, sample_id=None):
|
||||||
:param audio_type: Audio data representation type
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
audio_type : str
|
||||||
|
Audio data representation type
|
||||||
Supported types:
|
Supported types:
|
||||||
- AUDIO_TYPE_OPUS: Memory file representation (BytesIO) of Opus encoded audio
|
- util.audio.AUDIO_TYPE_OPUS: Memory file representation (BytesIO) of Opus encoded audio
|
||||||
wrapped by a custom container format (used in SDBs)
|
wrapped by a custom container format (used in SDBs)
|
||||||
- AUDIO_TYPE_WAV: Memory file representation (BytesIO) of a Wave file
|
- util.audio.AUDIO_TYPE_WAV: Memory file representation (BytesIO) of a Wave file
|
||||||
- AUDIO_TYPE_PCM: Binary representation (bytearray) of PCM encoded audio data (Wave file without header)
|
- util.audio.AUDIO_TYPE_PCM: Binary representation (bytearray) of PCM encoded audio data (Wave file without header)
|
||||||
- AUDIO_TYPE_NP: NumPy representation of audio data (np.float32) - typically used for GPU feeding
|
- util.audio.AUDIO_TYPE_NP: NumPy representation of audio data (np.float32) - typically used for GPU feeding
|
||||||
:param raw_data: Audio data in the form of the provided representation type (see audio_type).
|
raw_data : binary
|
||||||
For types AUDIO_TYPE_OPUS or AUDIO_TYPE_WAV data can also be passed as a bytearray.
|
Audio data in the form of the provided representation type (see audio_type).
|
||||||
:param audio_format: Tuple of sample-rate, number of channels and sample-width.
|
For types util.audio.AUDIO_TYPE_OPUS or util.audio.AUDIO_TYPE_WAV data can also be passed as a bytearray.
|
||||||
Required in case of audio_type = AUDIO_TYPE_PCM or AUDIO_TYPE_NP,
|
audio_format : tuple
|
||||||
|
Tuple of sample-rate, number of channels and sample-width.
|
||||||
|
Required in case of audio_type = util.audio.AUDIO_TYPE_PCM or util.audio.AUDIO_TYPE_NP,
|
||||||
as this information cannot be derived from raw audio data.
|
as this information cannot be derived from raw audio data.
|
||||||
|
sample_id : str
|
||||||
|
Tracking ID - should indicate sample's origin as precisely as possible
|
||||||
"""
|
"""
|
||||||
self.audio_type = audio_type
|
self.audio_type = audio_type
|
||||||
self.audio_format = audio_format
|
self.audio_format = audio_format
|
||||||
|
self.sample_id = sample_id
|
||||||
if audio_type in SERIALIZABLE_AUDIO_TYPES:
|
if audio_type in SERIALIZABLE_AUDIO_TYPES:
|
||||||
self.audio = raw_data if isinstance(raw_data, io.BytesIO) else io.BytesIO(raw_data)
|
self.audio = raw_data if isinstance(raw_data, io.BytesIO) else io.BytesIO(raw_data)
|
||||||
self.duration = read_duration(audio_type, self.audio)
|
self.duration = read_duration(audio_type, self.audio)
|
||||||
|
@ -68,7 +82,11 @@ class Sample:
|
||||||
def change_audio_type(self, new_audio_type):
|
def change_audio_type(self, new_audio_type):
|
||||||
"""
|
"""
|
||||||
In-place conversion of audio data into a different representation.
|
In-place conversion of audio data into a different representation.
|
||||||
:param new_audio_type: New audio-type - see `__init__`.
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
new_audio_type : str
|
||||||
|
New audio-type - see `__init__`.
|
||||||
Not supported: Converting from AUDIO_TYPE_NP into any other type.
|
Not supported: Converting from AUDIO_TYPE_NP into any other type.
|
||||||
"""
|
"""
|
||||||
if self.audio_type == new_audio_type:
|
if self.audio_type == new_audio_type:
|
||||||
|
|
|
@ -116,7 +116,7 @@ def create_dataset(sources,
|
||||||
process_ahead=None,
|
process_ahead=None,
|
||||||
buffering=1 * MEGABYTE):
|
buffering=1 * MEGABYTE):
|
||||||
def generate_values():
|
def generate_values():
|
||||||
samples = samples_from_files(sources, buffering=buffering)
|
samples = samples_from_files(sources, buffering=buffering, labeled=True)
|
||||||
for sample in change_audio_types(samples,
|
for sample in change_audio_types(samples,
|
||||||
AUDIO_TYPE_NP,
|
AUDIO_TYPE_NP,
|
||||||
process_ahead=2 * batch_size if process_ahead is None else process_ahead):
|
process_ahead=2 * batch_size if process_ahead is None else process_ahead):
|
||||||
|
|
|
@ -29,23 +29,45 @@ class LabeledSample(Sample):
|
||||||
Derived from util.audio.Sample and used by sample collection readers and writers."""
|
Derived from util.audio.Sample and used by sample collection readers and writers."""
|
||||||
def __init__(self, audio_type, raw_data, transcript, audio_format=DEFAULT_FORMAT, sample_id=None):
|
def __init__(self, audio_type, raw_data, transcript, audio_format=DEFAULT_FORMAT, sample_id=None):
|
||||||
"""
|
"""
|
||||||
Creates an in-memory speech sample together with a transcript of the utterance (label).
|
Parameters
|
||||||
:param audio_type: See util.audio.Sample.__init__ .
|
----------
|
||||||
:param raw_data: See util.audio.Sample.__init__ .
|
audio_type : str
|
||||||
:param transcript: Transcript of the sample's utterance
|
See util.audio.Sample.__init__ .
|
||||||
:param audio_format: See util.audio.Sample.__init__ .
|
raw_data : binary
|
||||||
:param sample_id: Tracking ID - typically assigned by collection readers
|
See util.audio.Sample.__init__ .
|
||||||
|
transcript : str
|
||||||
|
Transcript of the sample's utterance
|
||||||
|
audio_format : tuple
|
||||||
|
See util.audio.Sample.__init__ .
|
||||||
|
sample_id : str
|
||||||
|
Tracking ID - should indicate sample's origin as precisely as possible.
|
||||||
|
It is typically assigned by collection readers.
|
||||||
"""
|
"""
|
||||||
super().__init__(audio_type, raw_data, audio_format=audio_format)
|
super().__init__(audio_type, raw_data, audio_format=audio_format, sample_id=sample_id)
|
||||||
self.sample_id = sample_id
|
|
||||||
self.transcript = transcript
|
self.transcript = transcript
|
||||||
|
|
||||||
|
|
||||||
class DirectSDBWriter:
|
class DirectSDBWriter:
|
||||||
"""Sample collection writer for creating a Sample DB (SDB) file"""
|
"""Sample collection writer for creating a Sample DB (SDB) file"""
|
||||||
def __init__(self, sdb_filename, buffering=BUFFER_SIZE, audio_type=AUDIO_TYPE_OPUS, id_prefix=None):
|
def __init__(self, sdb_filename, buffering=BUFFER_SIZE, audio_type=AUDIO_TYPE_OPUS, id_prefix=None, labeled=True):
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sdb_filename : str
|
||||||
|
Path to the SDB file to write
|
||||||
|
buffering : int
|
||||||
|
Write-buffer size to use while writing the SDB file
|
||||||
|
audio_type : str
|
||||||
|
See util.audio.Sample.__init__ .
|
||||||
|
id_prefix : str
|
||||||
|
Prefix for IDs of written samples - defaults to sdb_filename
|
||||||
|
labeled : bool or None
|
||||||
|
If True: Writes labeled samples (util.sample_collections.LabeledSample) only.
|
||||||
|
If False: Ignores transcripts (if available) and writes (unlabeled) util.audio.Sample instances.
|
||||||
|
"""
|
||||||
self.sdb_filename = sdb_filename
|
self.sdb_filename = sdb_filename
|
||||||
self.id_prefix = sdb_filename if id_prefix is None else id_prefix
|
self.id_prefix = sdb_filename if id_prefix is None else id_prefix
|
||||||
|
self.labeled = labeled
|
||||||
if audio_type not in SERIALIZABLE_AUDIO_TYPES:
|
if audio_type not in SERIALIZABLE_AUDIO_TYPES:
|
||||||
raise ValueError('Audio type "{}" not supported'.format(audio_type))
|
raise ValueError('Audio type "{}" not supported'.format(audio_type))
|
||||||
self.audio_type = audio_type
|
self.audio_type = audio_type
|
||||||
|
@ -55,12 +77,10 @@ class DirectSDBWriter:
|
||||||
|
|
||||||
self.sdb_file.write(MAGIC)
|
self.sdb_file.write(MAGIC)
|
||||||
|
|
||||||
meta_data = {
|
schema_entries = [{CONTENT_KEY: CONTENT_TYPE_SPEECH, MIME_TYPE_KEY: audio_type}]
|
||||||
SCHEMA_KEY: [
|
if self.labeled:
|
||||||
{CONTENT_KEY: CONTENT_TYPE_SPEECH, MIME_TYPE_KEY: audio_type},
|
schema_entries.append({CONTENT_KEY: CONTENT_TYPE_TRANSCRIPT, MIME_TYPE_KEY: MIME_TYPE_TEXT})
|
||||||
{CONTENT_KEY: CONTENT_TYPE_TRANSCRIPT, MIME_TYPE_KEY: MIME_TYPE_TEXT}
|
meta_data = {SCHEMA_KEY: schema_entries}
|
||||||
]
|
|
||||||
}
|
|
||||||
meta_data = json.dumps(meta_data).encode()
|
meta_data = json.dumps(meta_data).encode()
|
||||||
self.write_big_int(len(meta_data))
|
self.write_big_int(len(meta_data))
|
||||||
self.sdb_file.write(meta_data)
|
self.sdb_file.write(meta_data)
|
||||||
|
@ -83,10 +103,14 @@ class DirectSDBWriter:
|
||||||
sample.change_audio_type(self.audio_type)
|
sample.change_audio_type(self.audio_type)
|
||||||
opus = sample.audio.getbuffer()
|
opus = sample.audio.getbuffer()
|
||||||
opus_len = to_bytes(len(opus))
|
opus_len = to_bytes(len(opus))
|
||||||
|
if self.labeled:
|
||||||
transcript = sample.transcript.encode()
|
transcript = sample.transcript.encode()
|
||||||
transcript_len = to_bytes(len(transcript))
|
transcript_len = to_bytes(len(transcript))
|
||||||
entry_len = to_bytes(len(opus_len) + len(opus) + len(transcript_len) + len(transcript))
|
entry_len = to_bytes(len(opus_len) + len(opus) + len(transcript_len) + len(transcript))
|
||||||
buffer = b''.join([entry_len, opus_len, opus, transcript_len, transcript])
|
buffer = b''.join([entry_len, opus_len, opus, transcript_len, transcript])
|
||||||
|
else:
|
||||||
|
entry_len = to_bytes(len(opus_len) + len(opus))
|
||||||
|
buffer = b''.join([entry_len, opus_len, opus])
|
||||||
self.offsets.append(self.sdb_file.tell())
|
self.offsets.append(self.sdb_file.tell())
|
||||||
self.sdb_file.write(buffer)
|
self.sdb_file.write(buffer)
|
||||||
sample.sample_id = '{}:{}'.format(self.id_prefix, self.num_samples)
|
sample.sample_id = '{}:{}'.format(self.id_prefix, self.num_samples)
|
||||||
|
@ -120,7 +144,22 @@ class DirectSDBWriter:
|
||||||
|
|
||||||
class SDB: # pylint: disable=too-many-instance-attributes
|
class SDB: # pylint: disable=too-many-instance-attributes
|
||||||
"""Sample collection reader for reading a Sample DB (SDB) file"""
|
"""Sample collection reader for reading a Sample DB (SDB) file"""
|
||||||
def __init__(self, sdb_filename, buffering=BUFFER_SIZE, id_prefix=None):
|
def __init__(self, sdb_filename, buffering=BUFFER_SIZE, id_prefix=None, labeled=True):
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sdb_filename : str
|
||||||
|
Path to the SDB file to read samples from
|
||||||
|
buffering : int
|
||||||
|
Read-buffer size to use while reading the SDB file
|
||||||
|
id_prefix : str
|
||||||
|
Prefix for IDs of read samples - defaults to sdb_filename
|
||||||
|
labeled : bool or None
|
||||||
|
If True: Reads util.sample_collections.LabeledSample instances. Fails, if SDB file provides no transcripts.
|
||||||
|
If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
|
||||||
|
If None: Automatically determines if SDB schema has transcripts
|
||||||
|
(reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances).
|
||||||
|
"""
|
||||||
self.sdb_filename = sdb_filename
|
self.sdb_filename = sdb_filename
|
||||||
self.id_prefix = sdb_filename if id_prefix is None else id_prefix
|
self.id_prefix = sdb_filename if id_prefix is None else id_prefix
|
||||||
self.sdb_file = open(sdb_filename, 'rb', buffering=buffering)
|
self.sdb_file = open(sdb_filename, 'rb', buffering=buffering)
|
||||||
|
@ -139,10 +178,14 @@ class SDB: # pylint: disable=too-many-instance-attributes
|
||||||
self.speech_index = speech_columns[0]
|
self.speech_index = speech_columns[0]
|
||||||
self.audio_type = self.schema[self.speech_index][MIME_TYPE_KEY]
|
self.audio_type = self.schema[self.speech_index][MIME_TYPE_KEY]
|
||||||
|
|
||||||
|
self.transcript_index = None
|
||||||
|
if labeled is not False:
|
||||||
transcript_columns = self.find_columns(content=CONTENT_TYPE_TRANSCRIPT, mime_type=MIME_TYPE_TEXT)
|
transcript_columns = self.find_columns(content=CONTENT_TYPE_TRANSCRIPT, mime_type=MIME_TYPE_TEXT)
|
||||||
if not transcript_columns:
|
if transcript_columns:
|
||||||
raise RuntimeError('No transcript data (missing in schema)')
|
|
||||||
self.transcript_index = transcript_columns[0]
|
self.transcript_index = transcript_columns[0]
|
||||||
|
else:
|
||||||
|
if labeled is True:
|
||||||
|
raise RuntimeError('No transcript data (missing in schema)')
|
||||||
|
|
||||||
sample_chunk_len = self.read_big_int()
|
sample_chunk_len = self.read_big_int()
|
||||||
self.sdb_file.seek(sample_chunk_len + BIGINT_SIZE, 1)
|
self.sdb_file.seek(sample_chunk_len + BIGINT_SIZE, 1)
|
||||||
|
@ -194,9 +237,12 @@ class SDB: # pylint: disable=too-many-instance-attributes
|
||||||
return tuple(column_data)
|
return tuple(column_data)
|
||||||
|
|
||||||
def __getitem__(self, i):
|
def __getitem__(self, i):
|
||||||
|
sample_id = '{}:{}'.format(self.id_prefix, i)
|
||||||
|
if self.transcript_index is None:
|
||||||
|
[audio_data] = self.read_row(i, self.speech_index)
|
||||||
|
return Sample(self.audio_type, audio_data, sample_id=sample_id)
|
||||||
audio_data, transcript = self.read_row(i, self.speech_index, self.transcript_index)
|
audio_data, transcript = self.read_row(i, self.speech_index, self.transcript_index)
|
||||||
transcript = transcript.decode()
|
transcript = transcript.decode()
|
||||||
sample_id = '{}:{}'.format(self.id_prefix, i)
|
|
||||||
return LabeledSample(self.audio_type, audio_data, transcript, sample_id=sample_id)
|
return LabeledSample(self.audio_type, audio_data, transcript, sample_id=sample_id)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
@ -215,24 +261,50 @@ class SDB: # pylint: disable=too-many-instance-attributes
|
||||||
|
|
||||||
|
|
||||||
class CSV:
|
class CSV:
|
||||||
"""Sample collection reader for reading a DeepSpeech CSV file"""
|
"""Sample collection reader for reading a DeepSpeech CSV file
|
||||||
def __init__(self, csv_filename):
|
Automatically orders samples by CSV column wav_filesize (if available)."""
|
||||||
|
def __init__(self, csv_filename, labeled=None):
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
csv_filename : str
|
||||||
|
Path to the CSV file containing sample audio paths and transcripts
|
||||||
|
labeled : bool or None
|
||||||
|
If True: Reads LabeledSample instances. Fails, if CSV file has no transcript column.
|
||||||
|
If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
|
||||||
|
If None: Automatically determines if CSV file has a transcript column
|
||||||
|
(reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances).
|
||||||
|
"""
|
||||||
self.csv_filename = csv_filename
|
self.csv_filename = csv_filename
|
||||||
|
self.labeled = labeled
|
||||||
self.rows = []
|
self.rows = []
|
||||||
csv_dir = Path(csv_filename).parent
|
csv_dir = Path(csv_filename).parent
|
||||||
with open(csv_filename, 'r', encoding='utf8') as csv_file:
|
with open(csv_filename, 'r', encoding='utf8') as csv_file:
|
||||||
reader = csv.DictReader(csv_file)
|
reader = csv.DictReader(csv_file)
|
||||||
|
if 'transcript' in reader.fieldnames:
|
||||||
|
if self.labeled is None:
|
||||||
|
self.labeled = True
|
||||||
|
elif self.labeled:
|
||||||
|
raise RuntimeError('No transcript data (missing CSV column)')
|
||||||
for row in reader:
|
for row in reader:
|
||||||
wav_filename = Path(row['wav_filename'])
|
wav_filename = Path(row['wav_filename'])
|
||||||
if not wav_filename.is_absolute():
|
if not wav_filename.is_absolute():
|
||||||
wav_filename = csv_dir / wav_filename
|
wav_filename = csv_dir / wav_filename
|
||||||
self.rows.append((str(wav_filename), int(row['wav_filesize']), row['transcript']))
|
wav_filename = str(wav_filename)
|
||||||
|
wav_filesize = int(row['wav_filesize']) if 'wav_filesize' in row else 0
|
||||||
|
if self.labeled:
|
||||||
|
self.rows.append((wav_filename, wav_filesize, row['transcript']))
|
||||||
|
else:
|
||||||
|
self.rows.append((wav_filename, wav_filesize))
|
||||||
self.rows.sort(key=lambda r: r[1])
|
self.rows.sort(key=lambda r: r[1])
|
||||||
|
|
||||||
def __getitem__(self, i):
|
def __getitem__(self, i):
|
||||||
wav_filename, _, transcript = self.rows[i]
|
row = self.rows[i]
|
||||||
|
wav_filename = row[0]
|
||||||
with open(wav_filename, 'rb') as wav_file:
|
with open(wav_filename, 'rb') as wav_file:
|
||||||
return LabeledSample(AUDIO_TYPE_WAV, wav_file.read(), transcript, sample_id=wav_filename)
|
if self.labeled:
|
||||||
|
return LabeledSample(AUDIO_TYPE_WAV, wav_file.read(), row[2], sample_id=wav_filename)
|
||||||
|
return Sample(AUDIO_TYPE_WAV, wav_file.read(), sample_id=wav_filename)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
for i in range(len(self.rows)):
|
for i in range(len(self.rows)):
|
||||||
|
@ -242,21 +314,52 @@ class CSV:
|
||||||
return len(self.rows)
|
return len(self.rows)
|
||||||
|
|
||||||
|
|
||||||
def samples_from_file(filename, buffering=BUFFER_SIZE):
|
def samples_from_file(filename, buffering=BUFFER_SIZE, labeled=None):
|
||||||
"""Returns an iterable of LabeledSample objects loaded from a file."""
|
"""
|
||||||
|
Returns an iterable of util.sample_collections.LabeledSample or util.audio.Sample instances
|
||||||
|
loaded from a sample source file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
filename : str
|
||||||
|
Path to the sample source file (SDB or CSV)
|
||||||
|
buffering : int
|
||||||
|
Read-buffer size to use while reading files
|
||||||
|
labeled : bool or None
|
||||||
|
If True: Reads LabeledSample instances. Fails, if source provides no transcripts.
|
||||||
|
If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
|
||||||
|
If None: Automatically determines if source provides transcripts
|
||||||
|
(reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances).
|
||||||
|
"""
|
||||||
ext = os.path.splitext(filename)[1].lower()
|
ext = os.path.splitext(filename)[1].lower()
|
||||||
if ext == '.sdb':
|
if ext == '.sdb':
|
||||||
return SDB(filename, buffering=buffering)
|
return SDB(filename, buffering=buffering, labeled=labeled)
|
||||||
if ext == '.csv':
|
if ext == '.csv':
|
||||||
return CSV(filename)
|
return CSV(filename, labeled=labeled)
|
||||||
raise ValueError('Unknown file type: "{}"'.format(ext))
|
raise ValueError('Unknown file type: "{}"'.format(ext))
|
||||||
|
|
||||||
|
|
||||||
def samples_from_files(filenames, buffering=BUFFER_SIZE):
|
def samples_from_files(filenames, buffering=BUFFER_SIZE, labeled=None):
|
||||||
"""Returns an iterable of LabeledSample objects from a list of files."""
|
"""
|
||||||
|
Returns an iterable of util.sample_collections.LabeledSample or util.audio.Sample instances
|
||||||
|
loaded from a collection of sample source files.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
filenames : list of str
|
||||||
|
Paths to sample source files (SDBs or CSVs)
|
||||||
|
buffering : int
|
||||||
|
Read-buffer size to use while reading files
|
||||||
|
labeled : bool or None
|
||||||
|
If True: Reads LabeledSample instances. Fails, if not all sources provide transcripts.
|
||||||
|
If False: Ignores transcripts (if available) and always reads (unlabeled) util.audio.Sample instances.
|
||||||
|
If None: Reads util.sample_collections.LabeledSample instances from sources with transcripts and
|
||||||
|
util.audio.Sample instances from sources with no transcripts.
|
||||||
|
"""
|
||||||
|
filenames = list(filenames)
|
||||||
if len(filenames) == 0:
|
if len(filenames) == 0:
|
||||||
raise ValueError('No files')
|
raise ValueError('No files')
|
||||||
if len(filenames) == 1:
|
if len(filenames) == 1:
|
||||||
return samples_from_file(filenames[0], buffering=buffering)
|
return samples_from_file(filenames[0], buffering=buffering, labeled=labeled)
|
||||||
cols = list(map(partial(samples_from_file, buffering=buffering), filenames))
|
cols = list(map(partial(samples_from_file, buffering=buffering, labeled=labeled), filenames))
|
||||||
return Interleaved(*cols, key=lambda s: s.duration)
|
return Interleaved(*cols, key=lambda s: s.duration)
|
||||||
|
|
Loading…
Reference in New Issue