From 41da7b287014f259e7c526471692bf25af70c9e0 Mon Sep 17 00:00:00 2001 From: Tilman Kamp <5991088+tilmankamp@users.noreply.github.com> Date: Mon, 23 Mar 2020 18:34:07 +0100 Subject: [PATCH] Fix #2830 - Support for unlabeled samples --- bin/build_sdb.py | 17 ++-- bin/play.py | 5 +- util/audio.py | 54 +++++++---- util/feeding.py | 2 +- util/sample_collections.py | 179 +++++++++++++++++++++++++++++-------- 5 files changed, 192 insertions(+), 65 deletions(-) diff --git a/bin/build_sdb.py b/bin/build_sdb.py index b4912972..b5fa8d35 100755 --- a/bin/build_sdb.py +++ b/bin/build_sdb.py @@ -26,8 +26,8 @@ AUDIO_TYPE_LOOKUP = { def build_sdb(): audio_type = AUDIO_TYPE_LOOKUP[CLI_ARGS.audio_type] - with DirectSDBWriter(CLI_ARGS.target, audio_type=audio_type) as sdb_writer: - samples = samples_from_files(CLI_ARGS.sources) + with DirectSDBWriter(CLI_ARGS.target, audio_type=audio_type, labeled=not CLI_ARGS.unlabeled) as sdb_writer: + samples = samples_from_files(CLI_ARGS.sources, labeled=not CLI_ARGS.unlabeled) bar = progressbar.ProgressBar(max_value=len(samples), widgets=SIMPLE_BAR) for sample in bar(change_audio_types(samples, audio_type=audio_type, processes=CLI_ARGS.workers)): sdb_writer.add(sample) @@ -36,13 +36,18 @@ def build_sdb(): def handle_args(): parser = argparse.ArgumentParser(description='Tool for building Sample Databases (SDB files) ' 'from DeepSpeech CSV files and other SDB files') - parser.add_argument('sources', nargs='+', help='Source CSV and/or SDB files - ' - 'Note: For getting a correctly ordered target SDB, source SDBs have ' - 'to have their samples already ordered from shortest to longest.') + parser.add_argument('sources', nargs='+', + help='Source CSV and/or SDB files - ' + 'Note: For getting a correctly ordered target SDB, source SDBs have to have their samples ' + 'already ordered from shortest to longest.') parser.add_argument('target', help='SDB file to create') parser.add_argument('--audio-type', default='opus', choices=AUDIO_TYPE_LOOKUP.keys(), help='Audio representation inside target SDB') - parser.add_argument('--workers', type=int, default=None, help='Number of encoding SDB workers') + parser.add_argument('--workers', type=int, default=None, + help='Number of encoding SDB workers') + parser.add_argument('--unlabeled', action='store_true', + help='If to build an SDB with unlabeled (audio only) samples - ' + 'typically used for building noise augmentation corpora') return parser.parse_args() diff --git a/bin/play.py b/bin/play.py index 55da4bc5..180d0b00 100755 --- a/bin/play.py +++ b/bin/play.py @@ -14,7 +14,7 @@ sys.path.insert(1, os.path.join(sys.path[0], '..')) import random import argparse -from util.sample_collections import samples_from_file +from util.sample_collections import samples_from_file, LabeledSample from util.audio import AUDIO_TYPE_PCM @@ -28,7 +28,8 @@ def play_sample(samples, index): sys.exit(1) sample = samples[index] print('Sample "{}"'.format(sample.sample_id)) - print(' "{}"'.format(sample.transcript)) + if isinstance(sample, LabeledSample): + print(' "{}"'.format(sample.transcript)) sample.change_audio_type(AUDIO_TYPE_PCM) rate, channels, width = sample.audio_format wave_obj = simpleaudio.WaveObject(sample.audio, channels, width, rate) diff --git a/util/audio.py b/util/audio.py index e2469e8a..9c6ed94e 100644 --- a/util/audio.py +++ b/util/audio.py @@ -26,31 +26,45 @@ OPUS_CHUNK_LEN_SIZE = 2 class Sample: - """Represents in-memory audio data of a certain (convertible) representation. - Attributes: - audio_type (str): See `__init__`. - audio_format (tuple:(int, int, int)): See `__init__`. - audio (obj): Audio data represented as indicated by `audio_type` - duration (float): Audio duration of the sample in seconds """ - def __init__(self, audio_type, raw_data, audio_format=None): + Represents in-memory audio data of a certain (convertible) representation. + + Attributes + ---------- + audio_type : str + See `__init__`. + audio_format : tuple:(int, int, int) + See `__init__`. + audio : binary + Audio data represented as indicated by `audio_type` + duration : float + Audio duration of the sample in seconds + """ + def __init__(self, audio_type, raw_data, audio_format=None, sample_id=None): """ - Creates a Sample from a raw audio representation. - :param audio_type: Audio data representation type + Parameters + ---------- + audio_type : str + Audio data representation type Supported types: - - AUDIO_TYPE_OPUS: Memory file representation (BytesIO) of Opus encoded audio + - util.audio.AUDIO_TYPE_OPUS: Memory file representation (BytesIO) of Opus encoded audio wrapped by a custom container format (used in SDBs) - - AUDIO_TYPE_WAV: Memory file representation (BytesIO) of a Wave file - - AUDIO_TYPE_PCM: Binary representation (bytearray) of PCM encoded audio data (Wave file without header) - - AUDIO_TYPE_NP: NumPy representation of audio data (np.float32) - typically used for GPU feeding - :param raw_data: Audio data in the form of the provided representation type (see audio_type). - For types AUDIO_TYPE_OPUS or AUDIO_TYPE_WAV data can also be passed as a bytearray. - :param audio_format: Tuple of sample-rate, number of channels and sample-width. - Required in case of audio_type = AUDIO_TYPE_PCM or AUDIO_TYPE_NP, + - util.audio.AUDIO_TYPE_WAV: Memory file representation (BytesIO) of a Wave file + - util.audio.AUDIO_TYPE_PCM: Binary representation (bytearray) of PCM encoded audio data (Wave file without header) + - util.audio.AUDIO_TYPE_NP: NumPy representation of audio data (np.float32) - typically used for GPU feeding + raw_data : binary + Audio data in the form of the provided representation type (see audio_type). + For types util.audio.AUDIO_TYPE_OPUS or util.audio.AUDIO_TYPE_WAV data can also be passed as a bytearray. + audio_format : tuple + Tuple of sample-rate, number of channels and sample-width. + Required in case of audio_type = util.audio.AUDIO_TYPE_PCM or util.audio.AUDIO_TYPE_NP, as this information cannot be derived from raw audio data. + sample_id : str + Tracking ID - should indicate sample's origin as precisely as possible """ self.audio_type = audio_type self.audio_format = audio_format + self.sample_id = sample_id if audio_type in SERIALIZABLE_AUDIO_TYPES: self.audio = raw_data if isinstance(raw_data, io.BytesIO) else io.BytesIO(raw_data) self.duration = read_duration(audio_type, self.audio) @@ -68,7 +82,11 @@ class Sample: def change_audio_type(self, new_audio_type): """ In-place conversion of audio data into a different representation. - :param new_audio_type: New audio-type - see `__init__`. + + Parameters + ---------- + new_audio_type : str + New audio-type - see `__init__`. Not supported: Converting from AUDIO_TYPE_NP into any other type. """ if self.audio_type == new_audio_type: diff --git a/util/feeding.py b/util/feeding.py index 93c5699b..09a0904c 100644 --- a/util/feeding.py +++ b/util/feeding.py @@ -116,7 +116,7 @@ def create_dataset(sources, process_ahead=None, buffering=1 * MEGABYTE): def generate_values(): - samples = samples_from_files(sources, buffering=buffering) + samples = samples_from_files(sources, buffering=buffering, labeled=True) for sample in change_audio_types(samples, AUDIO_TYPE_NP, process_ahead=2 * batch_size if process_ahead is None else process_ahead): diff --git a/util/sample_collections.py b/util/sample_collections.py index c1e99dc1..7009db18 100644 --- a/util/sample_collections.py +++ b/util/sample_collections.py @@ -29,23 +29,45 @@ class LabeledSample(Sample): Derived from util.audio.Sample and used by sample collection readers and writers.""" def __init__(self, audio_type, raw_data, transcript, audio_format=DEFAULT_FORMAT, sample_id=None): """ - Creates an in-memory speech sample together with a transcript of the utterance (label). - :param audio_type: See util.audio.Sample.__init__ . - :param raw_data: See util.audio.Sample.__init__ . - :param transcript: Transcript of the sample's utterance - :param audio_format: See util.audio.Sample.__init__ . - :param sample_id: Tracking ID - typically assigned by collection readers + Parameters + ---------- + audio_type : str + See util.audio.Sample.__init__ . + raw_data : binary + See util.audio.Sample.__init__ . + transcript : str + Transcript of the sample's utterance + audio_format : tuple + See util.audio.Sample.__init__ . + sample_id : str + Tracking ID - should indicate sample's origin as precisely as possible. + It is typically assigned by collection readers. """ - super().__init__(audio_type, raw_data, audio_format=audio_format) - self.sample_id = sample_id + super().__init__(audio_type, raw_data, audio_format=audio_format, sample_id=sample_id) self.transcript = transcript class DirectSDBWriter: """Sample collection writer for creating a Sample DB (SDB) file""" - def __init__(self, sdb_filename, buffering=BUFFER_SIZE, audio_type=AUDIO_TYPE_OPUS, id_prefix=None): + def __init__(self, sdb_filename, buffering=BUFFER_SIZE, audio_type=AUDIO_TYPE_OPUS, id_prefix=None, labeled=True): + """ + Parameters + ---------- + sdb_filename : str + Path to the SDB file to write + buffering : int + Write-buffer size to use while writing the SDB file + audio_type : str + See util.audio.Sample.__init__ . + id_prefix : str + Prefix for IDs of written samples - defaults to sdb_filename + labeled : bool or None + If True: Writes labeled samples (util.sample_collections.LabeledSample) only. + If False: Ignores transcripts (if available) and writes (unlabeled) util.audio.Sample instances. + """ self.sdb_filename = sdb_filename self.id_prefix = sdb_filename if id_prefix is None else id_prefix + self.labeled = labeled if audio_type not in SERIALIZABLE_AUDIO_TYPES: raise ValueError('Audio type "{}" not supported'.format(audio_type)) self.audio_type = audio_type @@ -55,12 +77,10 @@ class DirectSDBWriter: self.sdb_file.write(MAGIC) - meta_data = { - SCHEMA_KEY: [ - {CONTENT_KEY: CONTENT_TYPE_SPEECH, MIME_TYPE_KEY: audio_type}, - {CONTENT_KEY: CONTENT_TYPE_TRANSCRIPT, MIME_TYPE_KEY: MIME_TYPE_TEXT} - ] - } + schema_entries = [{CONTENT_KEY: CONTENT_TYPE_SPEECH, MIME_TYPE_KEY: audio_type}] + if self.labeled: + schema_entries.append({CONTENT_KEY: CONTENT_TYPE_TRANSCRIPT, MIME_TYPE_KEY: MIME_TYPE_TEXT}) + meta_data = {SCHEMA_KEY: schema_entries} meta_data = json.dumps(meta_data).encode() self.write_big_int(len(meta_data)) self.sdb_file.write(meta_data) @@ -83,10 +103,14 @@ class DirectSDBWriter: sample.change_audio_type(self.audio_type) opus = sample.audio.getbuffer() opus_len = to_bytes(len(opus)) - transcript = sample.transcript.encode() - transcript_len = to_bytes(len(transcript)) - entry_len = to_bytes(len(opus_len) + len(opus) + len(transcript_len) + len(transcript)) - buffer = b''.join([entry_len, opus_len, opus, transcript_len, transcript]) + if self.labeled: + transcript = sample.transcript.encode() + transcript_len = to_bytes(len(transcript)) + entry_len = to_bytes(len(opus_len) + len(opus) + len(transcript_len) + len(transcript)) + buffer = b''.join([entry_len, opus_len, opus, transcript_len, transcript]) + else: + entry_len = to_bytes(len(opus_len) + len(opus)) + buffer = b''.join([entry_len, opus_len, opus]) self.offsets.append(self.sdb_file.tell()) self.sdb_file.write(buffer) sample.sample_id = '{}:{}'.format(self.id_prefix, self.num_samples) @@ -120,7 +144,22 @@ class DirectSDBWriter: class SDB: # pylint: disable=too-many-instance-attributes """Sample collection reader for reading a Sample DB (SDB) file""" - def __init__(self, sdb_filename, buffering=BUFFER_SIZE, id_prefix=None): + def __init__(self, sdb_filename, buffering=BUFFER_SIZE, id_prefix=None, labeled=True): + """ + Parameters + ---------- + sdb_filename : str + Path to the SDB file to read samples from + buffering : int + Read-buffer size to use while reading the SDB file + id_prefix : str + Prefix for IDs of read samples - defaults to sdb_filename + labeled : bool or None + If True: Reads util.sample_collections.LabeledSample instances. Fails, if SDB file provides no transcripts. + If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances. + If None: Automatically determines if SDB schema has transcripts + (reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances). + """ self.sdb_filename = sdb_filename self.id_prefix = sdb_filename if id_prefix is None else id_prefix self.sdb_file = open(sdb_filename, 'rb', buffering=buffering) @@ -139,10 +178,14 @@ class SDB: # pylint: disable=too-many-instance-attributes self.speech_index = speech_columns[0] self.audio_type = self.schema[self.speech_index][MIME_TYPE_KEY] - transcript_columns = self.find_columns(content=CONTENT_TYPE_TRANSCRIPT, mime_type=MIME_TYPE_TEXT) - if not transcript_columns: - raise RuntimeError('No transcript data (missing in schema)') - self.transcript_index = transcript_columns[0] + self.transcript_index = None + if labeled is not False: + transcript_columns = self.find_columns(content=CONTENT_TYPE_TRANSCRIPT, mime_type=MIME_TYPE_TEXT) + if transcript_columns: + self.transcript_index = transcript_columns[0] + else: + if labeled is True: + raise RuntimeError('No transcript data (missing in schema)') sample_chunk_len = self.read_big_int() self.sdb_file.seek(sample_chunk_len + BIGINT_SIZE, 1) @@ -194,9 +237,12 @@ class SDB: # pylint: disable=too-many-instance-attributes return tuple(column_data) def __getitem__(self, i): + sample_id = '{}:{}'.format(self.id_prefix, i) + if self.transcript_index is None: + [audio_data] = self.read_row(i, self.speech_index) + return Sample(self.audio_type, audio_data, sample_id=sample_id) audio_data, transcript = self.read_row(i, self.speech_index, self.transcript_index) transcript = transcript.decode() - sample_id = '{}:{}'.format(self.id_prefix, i) return LabeledSample(self.audio_type, audio_data, transcript, sample_id=sample_id) def __iter__(self): @@ -215,24 +261,50 @@ class SDB: # pylint: disable=too-many-instance-attributes class CSV: - """Sample collection reader for reading a DeepSpeech CSV file""" - def __init__(self, csv_filename): + """Sample collection reader for reading a DeepSpeech CSV file + Automatically orders samples by CSV column wav_filesize (if available).""" + def __init__(self, csv_filename, labeled=None): + """ + Parameters + ---------- + csv_filename : str + Path to the CSV file containing sample audio paths and transcripts + labeled : bool or None + If True: Reads LabeledSample instances. Fails, if CSV file has no transcript column. + If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances. + If None: Automatically determines if CSV file has a transcript column + (reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances). + """ self.csv_filename = csv_filename + self.labeled = labeled self.rows = [] csv_dir = Path(csv_filename).parent with open(csv_filename, 'r', encoding='utf8') as csv_file: reader = csv.DictReader(csv_file) + if 'transcript' in reader.fieldnames: + if self.labeled is None: + self.labeled = True + elif self.labeled: + raise RuntimeError('No transcript data (missing CSV column)') for row in reader: wav_filename = Path(row['wav_filename']) if not wav_filename.is_absolute(): wav_filename = csv_dir / wav_filename - self.rows.append((str(wav_filename), int(row['wav_filesize']), row['transcript'])) + wav_filename = str(wav_filename) + wav_filesize = int(row['wav_filesize']) if 'wav_filesize' in row else 0 + if self.labeled: + self.rows.append((wav_filename, wav_filesize, row['transcript'])) + else: + self.rows.append((wav_filename, wav_filesize)) self.rows.sort(key=lambda r: r[1]) def __getitem__(self, i): - wav_filename, _, transcript = self.rows[i] + row = self.rows[i] + wav_filename = row[0] with open(wav_filename, 'rb') as wav_file: - return LabeledSample(AUDIO_TYPE_WAV, wav_file.read(), transcript, sample_id=wav_filename) + if self.labeled: + return LabeledSample(AUDIO_TYPE_WAV, wav_file.read(), row[2], sample_id=wav_filename) + return Sample(AUDIO_TYPE_WAV, wav_file.read(), sample_id=wav_filename) def __iter__(self): for i in range(len(self.rows)): @@ -242,21 +314,52 @@ class CSV: return len(self.rows) -def samples_from_file(filename, buffering=BUFFER_SIZE): - """Returns an iterable of LabeledSample objects loaded from a file.""" +def samples_from_file(filename, buffering=BUFFER_SIZE, labeled=None): + """ + Returns an iterable of util.sample_collections.LabeledSample or util.audio.Sample instances + loaded from a sample source file. + + Parameters + ---------- + filename : str + Path to the sample source file (SDB or CSV) + buffering : int + Read-buffer size to use while reading files + labeled : bool or None + If True: Reads LabeledSample instances. Fails, if source provides no transcripts. + If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances. + If None: Automatically determines if source provides transcripts + (reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances). + """ ext = os.path.splitext(filename)[1].lower() if ext == '.sdb': - return SDB(filename, buffering=buffering) + return SDB(filename, buffering=buffering, labeled=labeled) if ext == '.csv': - return CSV(filename) + return CSV(filename, labeled=labeled) raise ValueError('Unknown file type: "{}"'.format(ext)) -def samples_from_files(filenames, buffering=BUFFER_SIZE): - """Returns an iterable of LabeledSample objects from a list of files.""" +def samples_from_files(filenames, buffering=BUFFER_SIZE, labeled=None): + """ + Returns an iterable of util.sample_collections.LabeledSample or util.audio.Sample instances + loaded from a collection of sample source files. + + Parameters + ---------- + filenames : list of str + Paths to sample source files (SDBs or CSVs) + buffering : int + Read-buffer size to use while reading files + labeled : bool or None + If True: Reads LabeledSample instances. Fails, if not all sources provide transcripts. + If False: Ignores transcripts (if available) and always reads (unlabeled) util.audio.Sample instances. + If None: Reads util.sample_collections.LabeledSample instances from sources with transcripts and + util.audio.Sample instances from sources with no transcripts. + """ + filenames = list(filenames) if len(filenames) == 0: raise ValueError('No files') if len(filenames) == 1: - return samples_from_file(filenames[0], buffering=buffering) - cols = list(map(partial(samples_from_file, buffering=buffering), filenames)) + return samples_from_file(filenames[0], buffering=buffering, labeled=labeled) + cols = list(map(partial(samples_from_file, buffering=buffering, labeled=labeled), filenames)) return Interleaved(*cols, key=lambda s: s.duration)