STT/training/coqui_stt_training/util/sample_collections.py

677 lines
25 KiB
Python

# -*- coding: utf-8 -*-
import csv
import io
import json
import os
import tarfile
from functools import partial
from pathlib import Path
from .audio import (
AUDIO_TYPE_OPUS,
AUDIO_TYPE_PCM,
SERIALIZABLE_AUDIO_TYPES,
Sample,
get_loadable_audio_type_from_extension,
write_wav,
)
from .helpers import GIGABYTE, KILOBYTE, MEGABYTE, Interleaved, LenMap
from .io import is_remote_path, open_remote
BIG_ENDIAN = "big"
INT_SIZE = 4
BIGINT_SIZE = 2 * INT_SIZE
MAGIC = b"SAMPLEDB"
BUFFER_SIZE = 1 * MEGABYTE
REVERSE_BUFFER_SIZE = 16 * KILOBYTE
CACHE_SIZE = 1 * GIGABYTE
SCHEMA_KEY = "schema"
CONTENT_KEY = "content"
MIME_TYPE_KEY = "mime-type"
MIME_TYPE_TEXT = "text/plain"
CONTENT_TYPE_SPEECH = "speech"
CONTENT_TYPE_TRANSCRIPT = "transcript"
class LabeledSample(Sample):
"""In-memory labeled audio sample representing an utterance.
Derived from util.audio.Sample and used by sample collection readers and writers."""
def __init__(
self, audio_type, raw_data, transcript, audio_format=None, sample_id=None
):
"""
Parameters
----------
audio_type : str
See util.audio.Sample.__init__ .
raw_data : binary
See util.audio.Sample.__init__ .
transcript : str
Transcript of the sample's utterance
audio_format : tuple
See util.audio.Sample.__init__ .
sample_id : str
Tracking ID - should indicate sample's origin as precisely as possible.
It is typically assigned by collection readers.
"""
super().__init__(
audio_type, raw_data, audio_format=audio_format, sample_id=sample_id
)
self.transcript = transcript
class PackedSample:
"""
A wrapper that we can carry around in an iterator and pass to a child process in order to
have the child process do the loading/unpacking of the sample, allowing for parallel file
I/O.
"""
def __init__(self, filename, audio_type, label):
self.filename = filename
self.audio_type = audio_type
self.label = label
def unpack(self):
with open_remote(self.filename, "rb") as audio_file:
data = audio_file.read()
if self.label is None:
s = Sample(self.audio_type, data, sample_id=self.filename)
s = LabeledSample(self.audio_type, data, self.label, sample_id=self.filename)
return s
def unpack_maybe(sample):
"""
Loads the supplied sample from disk (or the network) if the audio isn't loaded in to memory already.
"""
if hasattr(sample, "unpack"):
realized_sample = sample.unpack()
else:
realized_sample = sample
return realized_sample
def load_sample(filename, label=None):
"""
Loads audio-file as a (labeled or unlabeled) sample
Parameters
----------
filename : str
Filename of the audio-file to load as sample
label : str
Label (transcript) of the sample.
If None: returned result.unpack() will return util.audio.Sample instance
Otherwise: returned result.unpack() util.sample_collections.LabeledSample instance
Returns
-------
util.sample_collections.PackedSample, a wrapper object, on which calling unpack() will return
util.audio.Sample instance if label is None, else util.sample_collections.LabeledSample instance
"""
ext = os.path.splitext(filename)[1].lower()
audio_type = get_loadable_audio_type_from_extension(ext)
if audio_type is None:
raise ValueError('Unknown audio type extension "{}"'.format(ext))
return PackedSample(filename, audio_type, label)
class DirectSDBWriter:
"""Sample collection writer for creating a Sample DB (SDB) file"""
def __init__(
self,
sdb_filename,
buffering=BUFFER_SIZE,
audio_type=AUDIO_TYPE_OPUS,
bitrate=None,
id_prefix=None,
labeled=True,
):
"""
Parameters
----------
sdb_filename : str
Path to the SDB file to write
buffering : int
Write-buffer size to use while writing the SDB file
audio_type : str
See util.audio.Sample.__init__ .
bitrate : int
Bitrate for sample-compression in case of lossy audio_type (e.g. AUDIO_TYPE_OPUS)
id_prefix : str
Prefix for IDs of written samples - defaults to sdb_filename
labeled : bool or None
If True: Writes labeled samples (util.sample_collections.LabeledSample) only.
If False: Ignores transcripts (if available) and writes (unlabeled) util.audio.Sample instances.
"""
self.sdb_filename = sdb_filename
self.id_prefix = sdb_filename if id_prefix is None else id_prefix
self.labeled = labeled
if audio_type not in SERIALIZABLE_AUDIO_TYPES:
raise ValueError('Audio type "{}" not supported'.format(audio_type))
self.audio_type = audio_type
self.bitrate = bitrate
self.sdb_file = open_remote(sdb_filename, "wb", buffering=buffering)
self.offsets = []
self.num_samples = 0
self.sdb_file.write(MAGIC)
schema_entries = [{CONTENT_KEY: CONTENT_TYPE_SPEECH, MIME_TYPE_KEY: audio_type}]
if self.labeled:
schema_entries.append(
{CONTENT_KEY: CONTENT_TYPE_TRANSCRIPT, MIME_TYPE_KEY: MIME_TYPE_TEXT}
)
meta_data = {SCHEMA_KEY: schema_entries}
meta_data = json.dumps(meta_data).encode()
self.write_big_int(len(meta_data))
self.sdb_file.write(meta_data)
self.offset_samples = self.sdb_file.tell()
self.sdb_file.seek(2 * BIGINT_SIZE, 1)
def write_int(self, n):
return self.sdb_file.write(n.to_bytes(INT_SIZE, BIG_ENDIAN))
def write_big_int(self, n):
return self.sdb_file.write(n.to_bytes(BIGINT_SIZE, BIG_ENDIAN))
def __enter__(self):
return self
def add(self, sample):
def to_bytes(n):
return n.to_bytes(INT_SIZE, BIG_ENDIAN)
sample.change_audio_type(self.audio_type, bitrate=self.bitrate)
opus = sample.audio.getbuffer()
opus_len = to_bytes(len(opus))
if self.labeled:
transcript = sample.transcript.encode()
transcript_len = to_bytes(len(transcript))
entry_len = to_bytes(
len(opus_len) + len(opus) + len(transcript_len) + len(transcript)
)
buffer = b"".join([entry_len, opus_len, opus, transcript_len, transcript])
else:
entry_len = to_bytes(len(opus_len) + len(opus))
buffer = b"".join([entry_len, opus_len, opus])
self.offsets.append(self.sdb_file.tell())
self.sdb_file.write(buffer)
sample.sample_id = "{}:{}".format(self.id_prefix, self.num_samples)
self.num_samples += 1
return sample.sample_id
def close(self):
if self.sdb_file is None:
return
offset_index = self.sdb_file.tell()
self.sdb_file.seek(self.offset_samples)
self.write_big_int(offset_index - self.offset_samples - BIGINT_SIZE)
self.write_big_int(self.num_samples)
self.sdb_file.seek(offset_index + BIGINT_SIZE)
self.write_big_int(self.num_samples)
for offset in self.offsets:
self.write_big_int(offset)
offset_end = self.sdb_file.tell()
self.sdb_file.seek(offset_index)
self.write_big_int(offset_end - offset_index - BIGINT_SIZE)
self.sdb_file.close()
self.sdb_file = None
def __len__(self):
return len(self.offsets)
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
class SDB: # pylint: disable=too-many-instance-attributes
"""Sample collection reader for reading a Sample DB (SDB) file"""
def __init__(
self,
sdb_filename,
buffering=BUFFER_SIZE,
id_prefix=None,
labeled=True,
reverse=False,
):
"""
Parameters
----------
sdb_filename : str
Path to the SDB file to read samples from
buffering : int
Read-ahead buffer size to use while reading the SDB file in normal order. Fixed to 16kB if in reverse-mode.
id_prefix : str
Prefix for IDs of read samples - defaults to sdb_filename
labeled : bool or None
If True: Reads util.sample_collections.LabeledSample instances. Fails, if SDB file provides no transcripts.
If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
If None: Automatically determines if SDB schema has transcripts
(reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances).
"""
self.sdb_filename = sdb_filename
self.id_prefix = sdb_filename if id_prefix is None else id_prefix
self.sdb_file = open_remote(
sdb_filename, "rb", buffering=REVERSE_BUFFER_SIZE if reverse else buffering
)
self.offsets = []
if self.sdb_file.read(len(MAGIC)) != MAGIC:
raise RuntimeError("No Sample Database")
meta_chunk_len = self.read_big_int()
self.meta = json.loads(self.sdb_file.read(meta_chunk_len).decode())
if SCHEMA_KEY not in self.meta:
raise RuntimeError("Missing schema")
self.schema = self.meta[SCHEMA_KEY]
speech_columns = self.find_columns(
content=CONTENT_TYPE_SPEECH, mime_type=SERIALIZABLE_AUDIO_TYPES
)
if not speech_columns:
raise RuntimeError("No speech data (missing in schema)")
self.speech_index = speech_columns[0]
self.audio_type = self.schema[self.speech_index][MIME_TYPE_KEY]
self.transcript_index = None
if labeled is not False:
transcript_columns = self.find_columns(
content=CONTENT_TYPE_TRANSCRIPT, mime_type=MIME_TYPE_TEXT
)
if transcript_columns:
self.transcript_index = transcript_columns[0]
else:
if labeled is True:
raise RuntimeError("No transcript data (missing in schema)")
sample_chunk_len = self.read_big_int()
self.sdb_file.seek(sample_chunk_len + BIGINT_SIZE, 1)
num_samples = self.read_big_int()
for _ in range(num_samples):
self.offsets.append(self.read_big_int())
if reverse:
self.offsets.reverse()
def read_int(self):
return int.from_bytes(self.sdb_file.read(INT_SIZE), BIG_ENDIAN)
def read_big_int(self):
return int.from_bytes(self.sdb_file.read(BIGINT_SIZE), BIG_ENDIAN)
def find_columns(self, content=None, mime_type=None):
criteria = []
if content is not None:
criteria.append((CONTENT_KEY, content))
if mime_type is not None:
criteria.append((MIME_TYPE_KEY, mime_type))
if len(criteria) == 0:
raise ValueError(
'At least one of "content" or "mime-type" has to be provided'
)
matches = []
for index, column in enumerate(self.schema):
matched = 0
for field, value in criteria:
if column[field] == value or (
isinstance(value, list) and column[field] in value
):
matched += 1
if matched == len(criteria):
matches.append(index)
return matches
def read_row(self, row_index, *columns):
columns = list(columns)
column_data = [None] * len(columns)
found = 0
if not 0 <= row_index < len(self.offsets):
raise ValueError(
"Wrong sample index: {} - has to be between 0 and {}".format(
row_index, len(self.offsets) - 1
)
)
self.sdb_file.seek(self.offsets[row_index] + INT_SIZE)
for index in range(len(self.schema)):
chunk_len = self.read_int()
if index in columns:
column_data[columns.index(index)] = self.sdb_file.read(chunk_len)
found += 1
if found == len(columns):
return tuple(column_data)
else:
self.sdb_file.seek(chunk_len, 1)
return tuple(column_data)
def __getitem__(self, i):
sample_id = "{}:{}".format(self.id_prefix, i)
if self.transcript_index is None:
[audio_data] = self.read_row(i, self.speech_index)
return Sample(self.audio_type, audio_data, sample_id=sample_id)
audio_data, transcript = self.read_row(
i, self.speech_index, self.transcript_index
)
transcript = transcript.decode()
return LabeledSample(
self.audio_type, audio_data, transcript, sample_id=sample_id
)
def __iter__(self):
for i in range(len(self.offsets)):
yield self[i]
def __len__(self):
return len(self.offsets)
def close(self):
if self.sdb_file is not None:
self.sdb_file.close()
def __del__(self):
self.close()
class CSVWriter: # pylint: disable=too-many-instance-attributes
"""Sample collection writer for writing a CSV data-set and all its referenced WAV samples"""
def __init__(self, csv_filename, absolute_paths=False, labeled=True):
"""
Parameters
----------
csv_filename : str
Path to the CSV file to write.
Will create a directory (CSV-filename without extension) next to it and fail if it already exists.
absolute_paths : bool
If paths in CSV file should be absolute instead of relative to the CSV file's parent directory.
labeled : bool or None
If True: Writes labeled samples (util.sample_collections.LabeledSample) only.
If False: Ignores transcripts (if available) and writes (unlabeled) util.audio.Sample instances.
Currently only works with local files (not gs:// or hdfs://...)
"""
self.csv_filename = Path(csv_filename)
self.csv_base_dir = self.csv_filename.parent.resolve().absolute()
self.set_name = self.csv_filename.stem
self.csv_dir = self.csv_base_dir / self.set_name
if self.csv_dir.exists():
raise RuntimeError('"{}" already existing'.format(self.csv_dir))
os.mkdir(str(self.csv_dir))
self.absolute_paths = absolute_paths
fieldnames = ["wav_filename", "wav_filesize"]
self.labeled = labeled
if labeled:
fieldnames.append("transcript")
self.csv_file = open_remote(csv_filename, "w", encoding="utf-8", newline="")
self.csv_writer = csv.DictWriter(self.csv_file, fieldnames=fieldnames)
self.csv_writer.writeheader()
self.counter = 0
def __enter__(self):
return self
def add(self, sample):
sample_filename = self.csv_dir / "sample{0:08d}.wav".format(self.counter)
self.counter += 1
sample.change_audio_type(AUDIO_TYPE_PCM)
write_wav(str(sample_filename), sample.audio, audio_format=sample.audio_format)
sample.sample_id = str(sample_filename.relative_to(self.csv_base_dir))
row = {
"wav_filename": str(sample_filename.absolute())
if self.absolute_paths
else sample.sample_id,
"wav_filesize": sample_filename.stat().st_size,
}
if self.labeled:
row["transcript"] = sample.transcript
self.csv_writer.writerow(row)
return sample.sample_id
def close(self):
if self.csv_file:
self.csv_file.close()
def __len__(self):
return self.counter
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
class TarWriter: # pylint: disable=too-many-instance-attributes
"""Sample collection writer for writing a CSV data-set and all its referenced WAV samples to a tar file."""
def __init__(self, tar_filename, gz=False, labeled=True, include=None):
"""
Parameters
----------
tar_filename : str
Path to the tar file to write.
gz : bool
If to compress tar file with gzip.
labeled : bool or None
If True: Writes labeled samples (util.sample_collections.LabeledSample) only.
If False: Ignores transcripts (if available) and writes (unlabeled) util.audio.Sample instances.
include : str[]
List of files to include into tar root.
Currently only works with local files (not gs:// or hdfs://...)
"""
self.tar = tarfile.open(tar_filename, "w:gz" if gz else "w")
samples_dir = tarfile.TarInfo("samples")
samples_dir.type = tarfile.DIRTYPE
self.tar.addfile(samples_dir)
if include:
for include_path in include:
self.tar.add(
include_path, recursive=False, arcname=Path(include_path).name
)
fieldnames = ["wav_filename", "wav_filesize"]
self.labeled = labeled
if labeled:
fieldnames.append("transcript")
self.csv_file = io.StringIO()
self.csv_writer = csv.DictWriter(self.csv_file, fieldnames=fieldnames)
self.csv_writer.writeheader()
self.counter = 0
def __enter__(self):
return self
def add(self, sample):
sample_filename = "samples/sample{0:08d}.wav".format(self.counter)
self.counter += 1
sample.change_audio_type(AUDIO_TYPE_PCM)
sample_file = io.BytesIO()
write_wav(sample_file, sample.audio, audio_format=sample.audio_format)
sample_size = sample_file.tell()
sample_file.seek(0)
sample_tar = tarfile.TarInfo(sample_filename)
sample_tar.size = sample_size
self.tar.addfile(sample_tar, sample_file)
row = {"wav_filename": sample_filename, "wav_filesize": sample_size}
if self.labeled:
row["transcript"] = sample.transcript
self.csv_writer.writerow(row)
return sample_filename
def close(self):
if self.csv_file and self.tar:
csv_tar = tarfile.TarInfo("samples.csv")
csv_tar.size = self.csv_file.tell()
self.csv_file.seek(0)
self.tar.addfile(csv_tar, io.BytesIO(self.csv_file.read().encode("utf8")))
if self.tar:
self.tar.close()
def __len__(self):
return self.counter
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
class SampleList:
"""Sample collection base class with samples loaded from a list of in-memory paths."""
def __init__(self, samples, labeled=True, reverse=False):
"""
Parameters
----------
samples : iterable of tuples of the form (sample_filename, filesize [, transcript])
File-size is used for ordering the samples; transcript has to be provided if labeled=True
labeled : bool or None
If True: Reads LabeledSample instances.
If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
reverse : bool
If the order of the samples should be reversed
"""
self.labeled = labeled
self.samples = list(samples)
self.samples.sort(key=lambda r: r[1], reverse=reverse)
def __getitem__(self, i):
sample_spec = self.samples[i]
return load_sample(
sample_spec[0], label=sample_spec[2] if self.labeled else None
)
def __len__(self):
return len(self.samples)
class CSV(SampleList):
"""Sample collection reader for reading a Coqui STT CSV file
Automatically orders samples by CSV column wav_filesize (if available)."""
def __init__(self, csv_filename, labeled=None, reverse=False):
"""
Parameters
----------
csv_filename : str
Path to the CSV file containing sample audio paths and transcripts
labeled : bool or None
If True: Reads LabeledSample instances. Fails, if CSV file has no transcript column.
If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
If None: Automatically determines if CSV file has a transcript column
(reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances).
reverse : bool
If the order of the samples should be reversed
"""
rows = []
with open_remote(csv_filename, "r", encoding="utf8") as csv_file:
reader = csv.DictReader(csv_file)
if "transcript" in reader.fieldnames:
if labeled is None:
labeled = True
elif labeled:
raise RuntimeError("No transcript data (missing CSV column)")
for row in reader:
wav_filename = Path(row["wav_filename"])
if not wav_filename.is_absolute() and not is_remote_path(
row["wav_filename"]
):
wav_filename = Path(csv_filename).parent / wav_filename
wav_filename = str(wav_filename)
else:
# Pathlib otherwise removes a / from filenames like hdfs://
wav_filename = row["wav_filename"]
wav_filesize = int(row["wav_filesize"]) if "wav_filesize" in row else 0
if labeled:
rows.append((wav_filename, wav_filesize, row["transcript"]))
else:
rows.append((wav_filename, wav_filesize))
super(CSV, self).__init__(rows, labeled=labeled, reverse=reverse)
def samples_from_source(
sample_source, buffering=BUFFER_SIZE, labeled=None, reverse=False
):
"""
Loads samples from a sample source file.
Parameters
----------
sample_source : str
Path to the sample source file (SDB or CSV)
buffering : int
Read-buffer size to use while reading files
labeled : bool or None
If True: Reads LabeledSample instances. Fails, if source provides no transcripts.
If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
If None: Automatically determines if source provides transcripts
(reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances).
reverse : bool
If the order of the samples should be reversed
Returns
-------
iterable of util.sample_collections.LabeledSample or util.audio.Sample instances supporting len.
"""
ext = os.path.splitext(sample_source)[1].lower()
if ext == ".sdb":
return SDB(sample_source, buffering=buffering, labeled=labeled, reverse=reverse)
if ext == ".csv":
return CSV(sample_source, labeled=labeled, reverse=reverse)
raise ValueError('Unknown file type: "{}"'.format(ext))
def samples_from_sources(
sample_sources, buffering=BUFFER_SIZE, labeled=None, reverse=False
):
"""
Loads and combines samples from a list of source files. Sources are combined in an interleaving way to
keep default sample order from shortest to longest.
Note that when using distributed training, it is much faster to call this function with single pre-
sorted sample source, because this allows for parallelization of the file I/O. (If this function is
called with multiple sources, the samples have to be unpacked on a single parent process to allow
for reading their durations.)
Parameters
----------
sample_sources : list of str
Paths to sample source files (SDBs or CSVs)
buffering : int
Read-buffer size to use while reading files
labeled : bool or None
If True: Reads LabeledSample instances. Fails, if not all sources provide transcripts.
If False: Ignores transcripts (if available) and always reads (unlabeled) util.audio.Sample instances.
If None: Reads util.sample_collections.LabeledSample instances from sources with transcripts and
util.audio.Sample instances from sources with no transcripts.
reverse : bool
If the order of the samples should be reversed
Returns
-------
iterable of util.sample_collections.PackedSample if a single collection is provided, wrapping
LabeledSample (labeled=True) or util.audio.Sample (labeled=False) supporting len
or LabeledSample / util.audio.Sample directly, if multiple collections are provided
"""
sample_sources = list(sample_sources)
if len(sample_sources) == 0:
raise ValueError("No files")
if len(sample_sources) == 1:
return samples_from_source(
sample_sources[0], buffering=buffering, labeled=labeled, reverse=reverse
)
# If we wish to interleave based on duration, we have to unpack the audio. Note that this unpacking should
# be done lazily onn the fly so that it respects the LimitingPool logic used in the feeding code.
cols = [
LenMap(
unpack_maybe,
samples_from_source(
source, buffering=buffering, labeled=labeled, reverse=reverse
),
)
for source in sample_sources
]
return Interleaved(*cols, key=lambda s: s.duration, reverse=reverse)