Fix #2830 - Support for unlabeled samples
This commit is contained in:
parent
5740d64e6e
commit
41da7b2870
|
@ -26,8 +26,8 @@ AUDIO_TYPE_LOOKUP = {
|
|||
|
||||
def build_sdb():
|
||||
audio_type = AUDIO_TYPE_LOOKUP[CLI_ARGS.audio_type]
|
||||
with DirectSDBWriter(CLI_ARGS.target, audio_type=audio_type) as sdb_writer:
|
||||
samples = samples_from_files(CLI_ARGS.sources)
|
||||
with DirectSDBWriter(CLI_ARGS.target, audio_type=audio_type, labeled=not CLI_ARGS.unlabeled) as sdb_writer:
|
||||
samples = samples_from_files(CLI_ARGS.sources, labeled=not CLI_ARGS.unlabeled)
|
||||
bar = progressbar.ProgressBar(max_value=len(samples), widgets=SIMPLE_BAR)
|
||||
for sample in bar(change_audio_types(samples, audio_type=audio_type, processes=CLI_ARGS.workers)):
|
||||
sdb_writer.add(sample)
|
||||
|
@ -36,13 +36,18 @@ def build_sdb():
|
|||
def handle_args():
|
||||
parser = argparse.ArgumentParser(description='Tool for building Sample Databases (SDB files) '
|
||||
'from DeepSpeech CSV files and other SDB files')
|
||||
parser.add_argument('sources', nargs='+', help='Source CSV and/or SDB files - '
|
||||
'Note: For getting a correctly ordered target SDB, source SDBs have '
|
||||
'to have their samples already ordered from shortest to longest.')
|
||||
parser.add_argument('sources', nargs='+',
|
||||
help='Source CSV and/or SDB files - '
|
||||
'Note: For getting a correctly ordered target SDB, source SDBs have to have their samples '
|
||||
'already ordered from shortest to longest.')
|
||||
parser.add_argument('target', help='SDB file to create')
|
||||
parser.add_argument('--audio-type', default='opus', choices=AUDIO_TYPE_LOOKUP.keys(),
|
||||
help='Audio representation inside target SDB')
|
||||
parser.add_argument('--workers', type=int, default=None, help='Number of encoding SDB workers')
|
||||
parser.add_argument('--workers', type=int, default=None,
|
||||
help='Number of encoding SDB workers')
|
||||
parser.add_argument('--unlabeled', action='store_true',
|
||||
help='If to build an SDB with unlabeled (audio only) samples - '
|
||||
'typically used for building noise augmentation corpora')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
|||
import random
|
||||
import argparse
|
||||
|
||||
from util.sample_collections import samples_from_file
|
||||
from util.sample_collections import samples_from_file, LabeledSample
|
||||
from util.audio import AUDIO_TYPE_PCM
|
||||
|
||||
|
||||
|
@ -28,6 +28,7 @@ def play_sample(samples, index):
|
|||
sys.exit(1)
|
||||
sample = samples[index]
|
||||
print('Sample "{}"'.format(sample.sample_id))
|
||||
if isinstance(sample, LabeledSample):
|
||||
print(' "{}"'.format(sample.transcript))
|
||||
sample.change_audio_type(AUDIO_TYPE_PCM)
|
||||
rate, channels, width = sample.audio_format
|
||||
|
|
|
@ -26,31 +26,45 @@ OPUS_CHUNK_LEN_SIZE = 2
|
|||
|
||||
|
||||
class Sample:
|
||||
"""Represents in-memory audio data of a certain (convertible) representation.
|
||||
Attributes:
|
||||
audio_type (str): See `__init__`.
|
||||
audio_format (tuple:(int, int, int)): See `__init__`.
|
||||
audio (obj): Audio data represented as indicated by `audio_type`
|
||||
duration (float): Audio duration of the sample in seconds
|
||||
"""
|
||||
def __init__(self, audio_type, raw_data, audio_format=None):
|
||||
Represents in-memory audio data of a certain (convertible) representation.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
audio_type : str
|
||||
See `__init__`.
|
||||
audio_format : tuple:(int, int, int)
|
||||
See `__init__`.
|
||||
audio : binary
|
||||
Audio data represented as indicated by `audio_type`
|
||||
duration : float
|
||||
Audio duration of the sample in seconds
|
||||
"""
|
||||
Creates a Sample from a raw audio representation.
|
||||
:param audio_type: Audio data representation type
|
||||
def __init__(self, audio_type, raw_data, audio_format=None, sample_id=None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
audio_type : str
|
||||
Audio data representation type
|
||||
Supported types:
|
||||
- AUDIO_TYPE_OPUS: Memory file representation (BytesIO) of Opus encoded audio
|
||||
- util.audio.AUDIO_TYPE_OPUS: Memory file representation (BytesIO) of Opus encoded audio
|
||||
wrapped by a custom container format (used in SDBs)
|
||||
- AUDIO_TYPE_WAV: Memory file representation (BytesIO) of a Wave file
|
||||
- AUDIO_TYPE_PCM: Binary representation (bytearray) of PCM encoded audio data (Wave file without header)
|
||||
- AUDIO_TYPE_NP: NumPy representation of audio data (np.float32) - typically used for GPU feeding
|
||||
:param raw_data: Audio data in the form of the provided representation type (see audio_type).
|
||||
For types AUDIO_TYPE_OPUS or AUDIO_TYPE_WAV data can also be passed as a bytearray.
|
||||
:param audio_format: Tuple of sample-rate, number of channels and sample-width.
|
||||
Required in case of audio_type = AUDIO_TYPE_PCM or AUDIO_TYPE_NP,
|
||||
- util.audio.AUDIO_TYPE_WAV: Memory file representation (BytesIO) of a Wave file
|
||||
- util.audio.AUDIO_TYPE_PCM: Binary representation (bytearray) of PCM encoded audio data (Wave file without header)
|
||||
- util.audio.AUDIO_TYPE_NP: NumPy representation of audio data (np.float32) - typically used for GPU feeding
|
||||
raw_data : binary
|
||||
Audio data in the form of the provided representation type (see audio_type).
|
||||
For types util.audio.AUDIO_TYPE_OPUS or util.audio.AUDIO_TYPE_WAV data can also be passed as a bytearray.
|
||||
audio_format : tuple
|
||||
Tuple of sample-rate, number of channels and sample-width.
|
||||
Required in case of audio_type = util.audio.AUDIO_TYPE_PCM or util.audio.AUDIO_TYPE_NP,
|
||||
as this information cannot be derived from raw audio data.
|
||||
sample_id : str
|
||||
Tracking ID - should indicate sample's origin as precisely as possible
|
||||
"""
|
||||
self.audio_type = audio_type
|
||||
self.audio_format = audio_format
|
||||
self.sample_id = sample_id
|
||||
if audio_type in SERIALIZABLE_AUDIO_TYPES:
|
||||
self.audio = raw_data if isinstance(raw_data, io.BytesIO) else io.BytesIO(raw_data)
|
||||
self.duration = read_duration(audio_type, self.audio)
|
||||
|
@ -68,7 +82,11 @@ class Sample:
|
|||
def change_audio_type(self, new_audio_type):
|
||||
"""
|
||||
In-place conversion of audio data into a different representation.
|
||||
:param new_audio_type: New audio-type - see `__init__`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
new_audio_type : str
|
||||
New audio-type - see `__init__`.
|
||||
Not supported: Converting from AUDIO_TYPE_NP into any other type.
|
||||
"""
|
||||
if self.audio_type == new_audio_type:
|
||||
|
|
|
@ -116,7 +116,7 @@ def create_dataset(sources,
|
|||
process_ahead=None,
|
||||
buffering=1 * MEGABYTE):
|
||||
def generate_values():
|
||||
samples = samples_from_files(sources, buffering=buffering)
|
||||
samples = samples_from_files(sources, buffering=buffering, labeled=True)
|
||||
for sample in change_audio_types(samples,
|
||||
AUDIO_TYPE_NP,
|
||||
process_ahead=2 * batch_size if process_ahead is None else process_ahead):
|
||||
|
|
|
@ -29,23 +29,45 @@ class LabeledSample(Sample):
|
|||
Derived from util.audio.Sample and used by sample collection readers and writers."""
|
||||
def __init__(self, audio_type, raw_data, transcript, audio_format=DEFAULT_FORMAT, sample_id=None):
|
||||
"""
|
||||
Creates an in-memory speech sample together with a transcript of the utterance (label).
|
||||
:param audio_type: See util.audio.Sample.__init__ .
|
||||
:param raw_data: See util.audio.Sample.__init__ .
|
||||
:param transcript: Transcript of the sample's utterance
|
||||
:param audio_format: See util.audio.Sample.__init__ .
|
||||
:param sample_id: Tracking ID - typically assigned by collection readers
|
||||
Parameters
|
||||
----------
|
||||
audio_type : str
|
||||
See util.audio.Sample.__init__ .
|
||||
raw_data : binary
|
||||
See util.audio.Sample.__init__ .
|
||||
transcript : str
|
||||
Transcript of the sample's utterance
|
||||
audio_format : tuple
|
||||
See util.audio.Sample.__init__ .
|
||||
sample_id : str
|
||||
Tracking ID - should indicate sample's origin as precisely as possible.
|
||||
It is typically assigned by collection readers.
|
||||
"""
|
||||
super().__init__(audio_type, raw_data, audio_format=audio_format)
|
||||
self.sample_id = sample_id
|
||||
super().__init__(audio_type, raw_data, audio_format=audio_format, sample_id=sample_id)
|
||||
self.transcript = transcript
|
||||
|
||||
|
||||
class DirectSDBWriter:
|
||||
"""Sample collection writer for creating a Sample DB (SDB) file"""
|
||||
def __init__(self, sdb_filename, buffering=BUFFER_SIZE, audio_type=AUDIO_TYPE_OPUS, id_prefix=None):
|
||||
def __init__(self, sdb_filename, buffering=BUFFER_SIZE, audio_type=AUDIO_TYPE_OPUS, id_prefix=None, labeled=True):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
sdb_filename : str
|
||||
Path to the SDB file to write
|
||||
buffering : int
|
||||
Write-buffer size to use while writing the SDB file
|
||||
audio_type : str
|
||||
See util.audio.Sample.__init__ .
|
||||
id_prefix : str
|
||||
Prefix for IDs of written samples - defaults to sdb_filename
|
||||
labeled : bool or None
|
||||
If True: Writes labeled samples (util.sample_collections.LabeledSample) only.
|
||||
If False: Ignores transcripts (if available) and writes (unlabeled) util.audio.Sample instances.
|
||||
"""
|
||||
self.sdb_filename = sdb_filename
|
||||
self.id_prefix = sdb_filename if id_prefix is None else id_prefix
|
||||
self.labeled = labeled
|
||||
if audio_type not in SERIALIZABLE_AUDIO_TYPES:
|
||||
raise ValueError('Audio type "{}" not supported'.format(audio_type))
|
||||
self.audio_type = audio_type
|
||||
|
@ -55,12 +77,10 @@ class DirectSDBWriter:
|
|||
|
||||
self.sdb_file.write(MAGIC)
|
||||
|
||||
meta_data = {
|
||||
SCHEMA_KEY: [
|
||||
{CONTENT_KEY: CONTENT_TYPE_SPEECH, MIME_TYPE_KEY: audio_type},
|
||||
{CONTENT_KEY: CONTENT_TYPE_TRANSCRIPT, MIME_TYPE_KEY: MIME_TYPE_TEXT}
|
||||
]
|
||||
}
|
||||
schema_entries = [{CONTENT_KEY: CONTENT_TYPE_SPEECH, MIME_TYPE_KEY: audio_type}]
|
||||
if self.labeled:
|
||||
schema_entries.append({CONTENT_KEY: CONTENT_TYPE_TRANSCRIPT, MIME_TYPE_KEY: MIME_TYPE_TEXT})
|
||||
meta_data = {SCHEMA_KEY: schema_entries}
|
||||
meta_data = json.dumps(meta_data).encode()
|
||||
self.write_big_int(len(meta_data))
|
||||
self.sdb_file.write(meta_data)
|
||||
|
@ -83,10 +103,14 @@ class DirectSDBWriter:
|
|||
sample.change_audio_type(self.audio_type)
|
||||
opus = sample.audio.getbuffer()
|
||||
opus_len = to_bytes(len(opus))
|
||||
if self.labeled:
|
||||
transcript = sample.transcript.encode()
|
||||
transcript_len = to_bytes(len(transcript))
|
||||
entry_len = to_bytes(len(opus_len) + len(opus) + len(transcript_len) + len(transcript))
|
||||
buffer = b''.join([entry_len, opus_len, opus, transcript_len, transcript])
|
||||
else:
|
||||
entry_len = to_bytes(len(opus_len) + len(opus))
|
||||
buffer = b''.join([entry_len, opus_len, opus])
|
||||
self.offsets.append(self.sdb_file.tell())
|
||||
self.sdb_file.write(buffer)
|
||||
sample.sample_id = '{}:{}'.format(self.id_prefix, self.num_samples)
|
||||
|
@ -120,7 +144,22 @@ class DirectSDBWriter:
|
|||
|
||||
class SDB: # pylint: disable=too-many-instance-attributes
|
||||
"""Sample collection reader for reading a Sample DB (SDB) file"""
|
||||
def __init__(self, sdb_filename, buffering=BUFFER_SIZE, id_prefix=None):
|
||||
def __init__(self, sdb_filename, buffering=BUFFER_SIZE, id_prefix=None, labeled=True):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
sdb_filename : str
|
||||
Path to the SDB file to read samples from
|
||||
buffering : int
|
||||
Read-buffer size to use while reading the SDB file
|
||||
id_prefix : str
|
||||
Prefix for IDs of read samples - defaults to sdb_filename
|
||||
labeled : bool or None
|
||||
If True: Reads util.sample_collections.LabeledSample instances. Fails, if SDB file provides no transcripts.
|
||||
If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
|
||||
If None: Automatically determines if SDB schema has transcripts
|
||||
(reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances).
|
||||
"""
|
||||
self.sdb_filename = sdb_filename
|
||||
self.id_prefix = sdb_filename if id_prefix is None else id_prefix
|
||||
self.sdb_file = open(sdb_filename, 'rb', buffering=buffering)
|
||||
|
@ -139,10 +178,14 @@ class SDB: # pylint: disable=too-many-instance-attributes
|
|||
self.speech_index = speech_columns[0]
|
||||
self.audio_type = self.schema[self.speech_index][MIME_TYPE_KEY]
|
||||
|
||||
self.transcript_index = None
|
||||
if labeled is not False:
|
||||
transcript_columns = self.find_columns(content=CONTENT_TYPE_TRANSCRIPT, mime_type=MIME_TYPE_TEXT)
|
||||
if not transcript_columns:
|
||||
raise RuntimeError('No transcript data (missing in schema)')
|
||||
if transcript_columns:
|
||||
self.transcript_index = transcript_columns[0]
|
||||
else:
|
||||
if labeled is True:
|
||||
raise RuntimeError('No transcript data (missing in schema)')
|
||||
|
||||
sample_chunk_len = self.read_big_int()
|
||||
self.sdb_file.seek(sample_chunk_len + BIGINT_SIZE, 1)
|
||||
|
@ -194,9 +237,12 @@ class SDB: # pylint: disable=too-many-instance-attributes
|
|||
return tuple(column_data)
|
||||
|
||||
def __getitem__(self, i):
|
||||
sample_id = '{}:{}'.format(self.id_prefix, i)
|
||||
if self.transcript_index is None:
|
||||
[audio_data] = self.read_row(i, self.speech_index)
|
||||
return Sample(self.audio_type, audio_data, sample_id=sample_id)
|
||||
audio_data, transcript = self.read_row(i, self.speech_index, self.transcript_index)
|
||||
transcript = transcript.decode()
|
||||
sample_id = '{}:{}'.format(self.id_prefix, i)
|
||||
return LabeledSample(self.audio_type, audio_data, transcript, sample_id=sample_id)
|
||||
|
||||
def __iter__(self):
|
||||
|
@ -215,24 +261,50 @@ class SDB: # pylint: disable=too-many-instance-attributes
|
|||
|
||||
|
||||
class CSV:
|
||||
"""Sample collection reader for reading a DeepSpeech CSV file"""
|
||||
def __init__(self, csv_filename):
|
||||
"""Sample collection reader for reading a DeepSpeech CSV file
|
||||
Automatically orders samples by CSV column wav_filesize (if available)."""
|
||||
def __init__(self, csv_filename, labeled=None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
csv_filename : str
|
||||
Path to the CSV file containing sample audio paths and transcripts
|
||||
labeled : bool or None
|
||||
If True: Reads LabeledSample instances. Fails, if CSV file has no transcript column.
|
||||
If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
|
||||
If None: Automatically determines if CSV file has a transcript column
|
||||
(reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances).
|
||||
"""
|
||||
self.csv_filename = csv_filename
|
||||
self.labeled = labeled
|
||||
self.rows = []
|
||||
csv_dir = Path(csv_filename).parent
|
||||
with open(csv_filename, 'r', encoding='utf8') as csv_file:
|
||||
reader = csv.DictReader(csv_file)
|
||||
if 'transcript' in reader.fieldnames:
|
||||
if self.labeled is None:
|
||||
self.labeled = True
|
||||
elif self.labeled:
|
||||
raise RuntimeError('No transcript data (missing CSV column)')
|
||||
for row in reader:
|
||||
wav_filename = Path(row['wav_filename'])
|
||||
if not wav_filename.is_absolute():
|
||||
wav_filename = csv_dir / wav_filename
|
||||
self.rows.append((str(wav_filename), int(row['wav_filesize']), row['transcript']))
|
||||
wav_filename = str(wav_filename)
|
||||
wav_filesize = int(row['wav_filesize']) if 'wav_filesize' in row else 0
|
||||
if self.labeled:
|
||||
self.rows.append((wav_filename, wav_filesize, row['transcript']))
|
||||
else:
|
||||
self.rows.append((wav_filename, wav_filesize))
|
||||
self.rows.sort(key=lambda r: r[1])
|
||||
|
||||
def __getitem__(self, i):
|
||||
wav_filename, _, transcript = self.rows[i]
|
||||
row = self.rows[i]
|
||||
wav_filename = row[0]
|
||||
with open(wav_filename, 'rb') as wav_file:
|
||||
return LabeledSample(AUDIO_TYPE_WAV, wav_file.read(), transcript, sample_id=wav_filename)
|
||||
if self.labeled:
|
||||
return LabeledSample(AUDIO_TYPE_WAV, wav_file.read(), row[2], sample_id=wav_filename)
|
||||
return Sample(AUDIO_TYPE_WAV, wav_file.read(), sample_id=wav_filename)
|
||||
|
||||
def __iter__(self):
|
||||
for i in range(len(self.rows)):
|
||||
|
@ -242,21 +314,52 @@ class CSV:
|
|||
return len(self.rows)
|
||||
|
||||
|
||||
def samples_from_file(filename, buffering=BUFFER_SIZE):
|
||||
"""Returns an iterable of LabeledSample objects loaded from a file."""
|
||||
def samples_from_file(filename, buffering=BUFFER_SIZE, labeled=None):
|
||||
"""
|
||||
Returns an iterable of util.sample_collections.LabeledSample or util.audio.Sample instances
|
||||
loaded from a sample source file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filename : str
|
||||
Path to the sample source file (SDB or CSV)
|
||||
buffering : int
|
||||
Read-buffer size to use while reading files
|
||||
labeled : bool or None
|
||||
If True: Reads LabeledSample instances. Fails, if source provides no transcripts.
|
||||
If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
|
||||
If None: Automatically determines if source provides transcripts
|
||||
(reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances).
|
||||
"""
|
||||
ext = os.path.splitext(filename)[1].lower()
|
||||
if ext == '.sdb':
|
||||
return SDB(filename, buffering=buffering)
|
||||
return SDB(filename, buffering=buffering, labeled=labeled)
|
||||
if ext == '.csv':
|
||||
return CSV(filename)
|
||||
return CSV(filename, labeled=labeled)
|
||||
raise ValueError('Unknown file type: "{}"'.format(ext))
|
||||
|
||||
|
||||
def samples_from_files(filenames, buffering=BUFFER_SIZE):
|
||||
"""Returns an iterable of LabeledSample objects from a list of files."""
|
||||
def samples_from_files(filenames, buffering=BUFFER_SIZE, labeled=None):
|
||||
"""
|
||||
Returns an iterable of util.sample_collections.LabeledSample or util.audio.Sample instances
|
||||
loaded from a collection of sample source files.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filenames : list of str
|
||||
Paths to sample source files (SDBs or CSVs)
|
||||
buffering : int
|
||||
Read-buffer size to use while reading files
|
||||
labeled : bool or None
|
||||
If True: Reads LabeledSample instances. Fails, if not all sources provide transcripts.
|
||||
If False: Ignores transcripts (if available) and always reads (unlabeled) util.audio.Sample instances.
|
||||
If None: Reads util.sample_collections.LabeledSample instances from sources with transcripts and
|
||||
util.audio.Sample instances from sources with no transcripts.
|
||||
"""
|
||||
filenames = list(filenames)
|
||||
if len(filenames) == 0:
|
||||
raise ValueError('No files')
|
||||
if len(filenames) == 1:
|
||||
return samples_from_file(filenames[0], buffering=buffering)
|
||||
cols = list(map(partial(samples_from_file, buffering=buffering), filenames))
|
||||
return samples_from_file(filenames[0], buffering=buffering, labeled=labeled)
|
||||
cols = list(map(partial(samples_from_file, buffering=buffering, labeled=labeled), filenames))
|
||||
return Interleaved(*cols, key=lambda s: s.duration)
|
||||
|
|
Loading…
Reference in New Issue