Merge pull request #2849 from tilmankamp/unlabeled-samples

Fix #2830 - Support for unlabeled samples
This commit is contained in:
Tilman Kamp 2020-03-24 18:20:22 +01:00 committed by GitHub
commit 8088b574fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 192 additions and 65 deletions

View File

@ -26,8 +26,8 @@ AUDIO_TYPE_LOOKUP = {
def build_sdb():
audio_type = AUDIO_TYPE_LOOKUP[CLI_ARGS.audio_type]
with DirectSDBWriter(CLI_ARGS.target, audio_type=audio_type) as sdb_writer:
samples = samples_from_files(CLI_ARGS.sources)
with DirectSDBWriter(CLI_ARGS.target, audio_type=audio_type, labeled=not CLI_ARGS.unlabeled) as sdb_writer:
samples = samples_from_files(CLI_ARGS.sources, labeled=not CLI_ARGS.unlabeled)
bar = progressbar.ProgressBar(max_value=len(samples), widgets=SIMPLE_BAR)
for sample in bar(change_audio_types(samples, audio_type=audio_type, processes=CLI_ARGS.workers)):
sdb_writer.add(sample)
@ -36,13 +36,18 @@ def build_sdb():
def handle_args():
parser = argparse.ArgumentParser(description='Tool for building Sample Databases (SDB files) '
'from DeepSpeech CSV files and other SDB files')
parser.add_argument('sources', nargs='+', help='Source CSV and/or SDB files - '
'Note: For getting a correctly ordered target SDB, source SDBs have '
'to have their samples already ordered from shortest to longest.')
parser.add_argument('sources', nargs='+',
help='Source CSV and/or SDB files - '
'Note: For getting a correctly ordered target SDB, source SDBs have to have their samples '
'already ordered from shortest to longest.')
parser.add_argument('target', help='SDB file to create')
parser.add_argument('--audio-type', default='opus', choices=AUDIO_TYPE_LOOKUP.keys(),
help='Audio representation inside target SDB')
parser.add_argument('--workers', type=int, default=None, help='Number of encoding SDB workers')
parser.add_argument('--workers', type=int, default=None,
help='Number of encoding SDB workers')
parser.add_argument('--unlabeled', action='store_true',
help='If to build an SDB with unlabeled (audio only) samples - '
'typically used for building noise augmentation corpora')
return parser.parse_args()

View File

@ -14,7 +14,7 @@ sys.path.insert(1, os.path.join(sys.path[0], '..'))
import random
import argparse
from util.sample_collections import samples_from_file
from util.sample_collections import samples_from_file, LabeledSample
from util.audio import AUDIO_TYPE_PCM
@ -28,7 +28,8 @@ def play_sample(samples, index):
sys.exit(1)
sample = samples[index]
print('Sample "{}"'.format(sample.sample_id))
print(' "{}"'.format(sample.transcript))
if isinstance(sample, LabeledSample):
print(' "{}"'.format(sample.transcript))
sample.change_audio_type(AUDIO_TYPE_PCM)
rate, channels, width = sample.audio_format
wave_obj = simpleaudio.WaveObject(sample.audio, channels, width, rate)

View File

@ -26,31 +26,45 @@ OPUS_CHUNK_LEN_SIZE = 2
class Sample:
"""Represents in-memory audio data of a certain (convertible) representation.
Attributes:
audio_type (str): See `__init__`.
audio_format (tuple:(int, int, int)): See `__init__`.
audio (obj): Audio data represented as indicated by `audio_type`
duration (float): Audio duration of the sample in seconds
"""
def __init__(self, audio_type, raw_data, audio_format=None):
Represents in-memory audio data of a certain (convertible) representation.
Attributes
----------
audio_type : str
See `__init__`.
audio_format : tuple:(int, int, int)
See `__init__`.
audio : binary
Audio data represented as indicated by `audio_type`
duration : float
Audio duration of the sample in seconds
"""
def __init__(self, audio_type, raw_data, audio_format=None, sample_id=None):
"""
Creates a Sample from a raw audio representation.
:param audio_type: Audio data representation type
Parameters
----------
audio_type : str
Audio data representation type
Supported types:
- AUDIO_TYPE_OPUS: Memory file representation (BytesIO) of Opus encoded audio
- util.audio.AUDIO_TYPE_OPUS: Memory file representation (BytesIO) of Opus encoded audio
wrapped by a custom container format (used in SDBs)
- AUDIO_TYPE_WAV: Memory file representation (BytesIO) of a Wave file
- AUDIO_TYPE_PCM: Binary representation (bytearray) of PCM encoded audio data (Wave file without header)
- AUDIO_TYPE_NP: NumPy representation of audio data (np.float32) - typically used for GPU feeding
:param raw_data: Audio data in the form of the provided representation type (see audio_type).
For types AUDIO_TYPE_OPUS or AUDIO_TYPE_WAV data can also be passed as a bytearray.
:param audio_format: Tuple of sample-rate, number of channels and sample-width.
Required in case of audio_type = AUDIO_TYPE_PCM or AUDIO_TYPE_NP,
- util.audio.AUDIO_TYPE_WAV: Memory file representation (BytesIO) of a Wave file
- util.audio.AUDIO_TYPE_PCM: Binary representation (bytearray) of PCM encoded audio data (Wave file without header)
- util.audio.AUDIO_TYPE_NP: NumPy representation of audio data (np.float32) - typically used for GPU feeding
raw_data : binary
Audio data in the form of the provided representation type (see audio_type).
For types util.audio.AUDIO_TYPE_OPUS or util.audio.AUDIO_TYPE_WAV data can also be passed as a bytearray.
audio_format : tuple
Tuple of sample-rate, number of channels and sample-width.
Required in case of audio_type = util.audio.AUDIO_TYPE_PCM or util.audio.AUDIO_TYPE_NP,
as this information cannot be derived from raw audio data.
sample_id : str
Tracking ID - should indicate sample's origin as precisely as possible
"""
self.audio_type = audio_type
self.audio_format = audio_format
self.sample_id = sample_id
if audio_type in SERIALIZABLE_AUDIO_TYPES:
self.audio = raw_data if isinstance(raw_data, io.BytesIO) else io.BytesIO(raw_data)
self.duration = read_duration(audio_type, self.audio)
@ -68,7 +82,11 @@ class Sample:
def change_audio_type(self, new_audio_type):
"""
In-place conversion of audio data into a different representation.
:param new_audio_type: New audio-type - see `__init__`.
Parameters
----------
new_audio_type : str
New audio-type - see `__init__`.
Not supported: Converting from AUDIO_TYPE_NP into any other type.
"""
if self.audio_type == new_audio_type:

View File

@ -116,7 +116,7 @@ def create_dataset(sources,
process_ahead=None,
buffering=1 * MEGABYTE):
def generate_values():
samples = samples_from_files(sources, buffering=buffering)
samples = samples_from_files(sources, buffering=buffering, labeled=True)
for sample in change_audio_types(samples,
AUDIO_TYPE_NP,
process_ahead=2 * batch_size if process_ahead is None else process_ahead):

View File

@ -29,23 +29,45 @@ class LabeledSample(Sample):
Derived from util.audio.Sample and used by sample collection readers and writers."""
def __init__(self, audio_type, raw_data, transcript, audio_format=DEFAULT_FORMAT, sample_id=None):
"""
Creates an in-memory speech sample together with a transcript of the utterance (label).
:param audio_type: See util.audio.Sample.__init__ .
:param raw_data: See util.audio.Sample.__init__ .
:param transcript: Transcript of the sample's utterance
:param audio_format: See util.audio.Sample.__init__ .
:param sample_id: Tracking ID - typically assigned by collection readers
Parameters
----------
audio_type : str
See util.audio.Sample.__init__ .
raw_data : binary
See util.audio.Sample.__init__ .
transcript : str
Transcript of the sample's utterance
audio_format : tuple
See util.audio.Sample.__init__ .
sample_id : str
Tracking ID - should indicate sample's origin as precisely as possible.
It is typically assigned by collection readers.
"""
super().__init__(audio_type, raw_data, audio_format=audio_format)
self.sample_id = sample_id
super().__init__(audio_type, raw_data, audio_format=audio_format, sample_id=sample_id)
self.transcript = transcript
class DirectSDBWriter:
"""Sample collection writer for creating a Sample DB (SDB) file"""
def __init__(self, sdb_filename, buffering=BUFFER_SIZE, audio_type=AUDIO_TYPE_OPUS, id_prefix=None):
def __init__(self, sdb_filename, buffering=BUFFER_SIZE, audio_type=AUDIO_TYPE_OPUS, id_prefix=None, labeled=True):
"""
Parameters
----------
sdb_filename : str
Path to the SDB file to write
buffering : int
Write-buffer size to use while writing the SDB file
audio_type : str
See util.audio.Sample.__init__ .
id_prefix : str
Prefix for IDs of written samples - defaults to sdb_filename
labeled : bool or None
If True: Writes labeled samples (util.sample_collections.LabeledSample) only.
If False: Ignores transcripts (if available) and writes (unlabeled) util.audio.Sample instances.
"""
self.sdb_filename = sdb_filename
self.id_prefix = sdb_filename if id_prefix is None else id_prefix
self.labeled = labeled
if audio_type not in SERIALIZABLE_AUDIO_TYPES:
raise ValueError('Audio type "{}" not supported'.format(audio_type))
self.audio_type = audio_type
@ -55,12 +77,10 @@ class DirectSDBWriter:
self.sdb_file.write(MAGIC)
meta_data = {
SCHEMA_KEY: [
{CONTENT_KEY: CONTENT_TYPE_SPEECH, MIME_TYPE_KEY: audio_type},
{CONTENT_KEY: CONTENT_TYPE_TRANSCRIPT, MIME_TYPE_KEY: MIME_TYPE_TEXT}
]
}
schema_entries = [{CONTENT_KEY: CONTENT_TYPE_SPEECH, MIME_TYPE_KEY: audio_type}]
if self.labeled:
schema_entries.append({CONTENT_KEY: CONTENT_TYPE_TRANSCRIPT, MIME_TYPE_KEY: MIME_TYPE_TEXT})
meta_data = {SCHEMA_KEY: schema_entries}
meta_data = json.dumps(meta_data).encode()
self.write_big_int(len(meta_data))
self.sdb_file.write(meta_data)
@ -83,10 +103,14 @@ class DirectSDBWriter:
sample.change_audio_type(self.audio_type)
opus = sample.audio.getbuffer()
opus_len = to_bytes(len(opus))
transcript = sample.transcript.encode()
transcript_len = to_bytes(len(transcript))
entry_len = to_bytes(len(opus_len) + len(opus) + len(transcript_len) + len(transcript))
buffer = b''.join([entry_len, opus_len, opus, transcript_len, transcript])
if self.labeled:
transcript = sample.transcript.encode()
transcript_len = to_bytes(len(transcript))
entry_len = to_bytes(len(opus_len) + len(opus) + len(transcript_len) + len(transcript))
buffer = b''.join([entry_len, opus_len, opus, transcript_len, transcript])
else:
entry_len = to_bytes(len(opus_len) + len(opus))
buffer = b''.join([entry_len, opus_len, opus])
self.offsets.append(self.sdb_file.tell())
self.sdb_file.write(buffer)
sample.sample_id = '{}:{}'.format(self.id_prefix, self.num_samples)
@ -120,7 +144,22 @@ class DirectSDBWriter:
class SDB: # pylint: disable=too-many-instance-attributes
"""Sample collection reader for reading a Sample DB (SDB) file"""
def __init__(self, sdb_filename, buffering=BUFFER_SIZE, id_prefix=None):
def __init__(self, sdb_filename, buffering=BUFFER_SIZE, id_prefix=None, labeled=True):
"""
Parameters
----------
sdb_filename : str
Path to the SDB file to read samples from
buffering : int
Read-buffer size to use while reading the SDB file
id_prefix : str
Prefix for IDs of read samples - defaults to sdb_filename
labeled : bool or None
If True: Reads util.sample_collections.LabeledSample instances. Fails, if SDB file provides no transcripts.
If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
If None: Automatically determines if SDB schema has transcripts
(reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances).
"""
self.sdb_filename = sdb_filename
self.id_prefix = sdb_filename if id_prefix is None else id_prefix
self.sdb_file = open(sdb_filename, 'rb', buffering=buffering)
@ -139,10 +178,14 @@ class SDB: # pylint: disable=too-many-instance-attributes
self.speech_index = speech_columns[0]
self.audio_type = self.schema[self.speech_index][MIME_TYPE_KEY]
transcript_columns = self.find_columns(content=CONTENT_TYPE_TRANSCRIPT, mime_type=MIME_TYPE_TEXT)
if not transcript_columns:
raise RuntimeError('No transcript data (missing in schema)')
self.transcript_index = transcript_columns[0]
self.transcript_index = None
if labeled is not False:
transcript_columns = self.find_columns(content=CONTENT_TYPE_TRANSCRIPT, mime_type=MIME_TYPE_TEXT)
if transcript_columns:
self.transcript_index = transcript_columns[0]
else:
if labeled is True:
raise RuntimeError('No transcript data (missing in schema)')
sample_chunk_len = self.read_big_int()
self.sdb_file.seek(sample_chunk_len + BIGINT_SIZE, 1)
@ -194,9 +237,12 @@ class SDB: # pylint: disable=too-many-instance-attributes
return tuple(column_data)
def __getitem__(self, i):
sample_id = '{}:{}'.format(self.id_prefix, i)
if self.transcript_index is None:
[audio_data] = self.read_row(i, self.speech_index)
return Sample(self.audio_type, audio_data, sample_id=sample_id)
audio_data, transcript = self.read_row(i, self.speech_index, self.transcript_index)
transcript = transcript.decode()
sample_id = '{}:{}'.format(self.id_prefix, i)
return LabeledSample(self.audio_type, audio_data, transcript, sample_id=sample_id)
def __iter__(self):
@ -215,24 +261,50 @@ class SDB: # pylint: disable=too-many-instance-attributes
class CSV:
"""Sample collection reader for reading a DeepSpeech CSV file"""
def __init__(self, csv_filename):
"""Sample collection reader for reading a DeepSpeech CSV file
Automatically orders samples by CSV column wav_filesize (if available)."""
def __init__(self, csv_filename, labeled=None):
"""
Parameters
----------
csv_filename : str
Path to the CSV file containing sample audio paths and transcripts
labeled : bool or None
If True: Reads LabeledSample instances. Fails, if CSV file has no transcript column.
If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
If None: Automatically determines if CSV file has a transcript column
(reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances).
"""
self.csv_filename = csv_filename
self.labeled = labeled
self.rows = []
csv_dir = Path(csv_filename).parent
with open(csv_filename, 'r', encoding='utf8') as csv_file:
reader = csv.DictReader(csv_file)
if 'transcript' in reader.fieldnames:
if self.labeled is None:
self.labeled = True
elif self.labeled:
raise RuntimeError('No transcript data (missing CSV column)')
for row in reader:
wav_filename = Path(row['wav_filename'])
if not wav_filename.is_absolute():
wav_filename = csv_dir / wav_filename
self.rows.append((str(wav_filename), int(row['wav_filesize']), row['transcript']))
wav_filename = str(wav_filename)
wav_filesize = int(row['wav_filesize']) if 'wav_filesize' in row else 0
if self.labeled:
self.rows.append((wav_filename, wav_filesize, row['transcript']))
else:
self.rows.append((wav_filename, wav_filesize))
self.rows.sort(key=lambda r: r[1])
def __getitem__(self, i):
wav_filename, _, transcript = self.rows[i]
row = self.rows[i]
wav_filename = row[0]
with open(wav_filename, 'rb') as wav_file:
return LabeledSample(AUDIO_TYPE_WAV, wav_file.read(), transcript, sample_id=wav_filename)
if self.labeled:
return LabeledSample(AUDIO_TYPE_WAV, wav_file.read(), row[2], sample_id=wav_filename)
return Sample(AUDIO_TYPE_WAV, wav_file.read(), sample_id=wav_filename)
def __iter__(self):
for i in range(len(self.rows)):
@ -242,21 +314,52 @@ class CSV:
return len(self.rows)
def samples_from_file(filename, buffering=BUFFER_SIZE):
"""Returns an iterable of LabeledSample objects loaded from a file."""
def samples_from_file(filename, buffering=BUFFER_SIZE, labeled=None):
"""
Returns an iterable of util.sample_collections.LabeledSample or util.audio.Sample instances
loaded from a sample source file.
Parameters
----------
filename : str
Path to the sample source file (SDB or CSV)
buffering : int
Read-buffer size to use while reading files
labeled : bool or None
If True: Reads LabeledSample instances. Fails, if source provides no transcripts.
If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
If None: Automatically determines if source provides transcripts
(reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances).
"""
ext = os.path.splitext(filename)[1].lower()
if ext == '.sdb':
return SDB(filename, buffering=buffering)
return SDB(filename, buffering=buffering, labeled=labeled)
if ext == '.csv':
return CSV(filename)
return CSV(filename, labeled=labeled)
raise ValueError('Unknown file type: "{}"'.format(ext))
def samples_from_files(filenames, buffering=BUFFER_SIZE):
"""Returns an iterable of LabeledSample objects from a list of files."""
def samples_from_files(filenames, buffering=BUFFER_SIZE, labeled=None):
"""
Returns an iterable of util.sample_collections.LabeledSample or util.audio.Sample instances
loaded from a collection of sample source files.
Parameters
----------
filenames : list of str
Paths to sample source files (SDBs or CSVs)
buffering : int
Read-buffer size to use while reading files
labeled : bool or None
If True: Reads LabeledSample instances. Fails, if not all sources provide transcripts.
If False: Ignores transcripts (if available) and always reads (unlabeled) util.audio.Sample instances.
If None: Reads util.sample_collections.LabeledSample instances from sources with transcripts and
util.audio.Sample instances from sources with no transcripts.
"""
filenames = list(filenames)
if len(filenames) == 0:
raise ValueError('No files')
if len(filenames) == 1:
return samples_from_file(filenames[0], buffering=buffering)
cols = list(map(partial(samples_from_file, buffering=buffering), filenames))
return samples_from_file(filenames[0], buffering=buffering, labeled=labeled)
cols = list(map(partial(samples_from_file, buffering=buffering, labeled=labeled), filenames))
return Interleaved(*cols, key=lambda s: s.duration)