* Redo remote I/O changes once more; this time without messing with taskcluster * Add bin changes * Fix merge-induced issue? * For the interleaved case with multiple collections, unpack audio on the fly To reproduce the previous failure rm data/smoke_test/ldc93s1.csv rm data/smoke_test/ldc93s1.sdb rm -rf /tmp/ldc93s1_cache_sdb_csv rm -rf /tmp/ckpt_sdb_csv rm -rf /tmp/train_sdb_csv ./bin/run-tc-ldc93s1_new_sdb_csv.sh 109 16000 python -u DeepSpeech.py --noshow_progressbar --noearly_stop --train_files ./data/smoke_test/ldc93s1.sdb,./data/smoke_test/ldc93s1.csv --train_batch_size 1 --feature_cache /tmp/ldc93s1_cache_sdb_csv --dev_files ./data/smoke_test/ldc93s1.sdb,./data/smoke_test/ldc93s1.csv --dev_batch_size 1 --test_files ./data/smoke_test/ldc93s1.sdb,./data/smoke_test/ldc93s1.csv --test_batch_size 1 --n_hidden 100 --epochs 109 --max_to_keep 1 --checkpoint_dir /tmp/ckpt_sdb_csv --learning_rate 0.001 --dropout_rate 0.05 --export_dir /tmp/train_sdb_csv --scorer_path data/smoke_test/pruned_lm.scorer --audio_sample_rate 16000 * Attempt to preserve length information with a wrapper around `map()`… this gets pretty python-y * Call the right `__next__()` * Properly implement the rest of the map wrappers here…… * Fix trailing whitespace situation and other linter complaints * Remove data accidentally checked in * Fix overlay augmentations * Wavs must be open in rb mode if we're passing in an external file pointer -- this confused me * Lint whitespace * Revert "Fix trailing whitespace situation and other linter complaints" This reverts commit c3c45397a2f98e9b00d00c18c4ced4fc52475032. * Fix linter issue but without such an aggressive diff * Move unpack_maybe into sample_collections * Use unpack_maybe in place of duplicate lambda * Fix confusing comment * Add clarifying comment for on-the-fly unpacking
631 lines
25 KiB
Python
631 lines
25 KiB
Python
# -*- coding: utf-8 -*-
|
|
import os
|
|
import io
|
|
import csv
|
|
import json
|
|
import tarfile
|
|
|
|
from pathlib import Path
|
|
from functools import partial
|
|
|
|
from .helpers import KILOBYTE, MEGABYTE, GIGABYTE, Interleaved, LenMap
|
|
from .audio import (
|
|
Sample,
|
|
DEFAULT_FORMAT,
|
|
AUDIO_TYPE_PCM,
|
|
AUDIO_TYPE_OPUS,
|
|
SERIALIZABLE_AUDIO_TYPES,
|
|
get_audio_type_from_extension,
|
|
write_wav
|
|
)
|
|
from .io import open_remote, is_remote_path
|
|
|
|
BIG_ENDIAN = 'big'
|
|
INT_SIZE = 4
|
|
BIGINT_SIZE = 2 * INT_SIZE
|
|
MAGIC = b'SAMPLEDB'
|
|
|
|
BUFFER_SIZE = 1 * MEGABYTE
|
|
REVERSE_BUFFER_SIZE = 16 * KILOBYTE
|
|
CACHE_SIZE = 1 * GIGABYTE
|
|
|
|
SCHEMA_KEY = 'schema'
|
|
CONTENT_KEY = 'content'
|
|
MIME_TYPE_KEY = 'mime-type'
|
|
MIME_TYPE_TEXT = 'text/plain'
|
|
CONTENT_TYPE_SPEECH = 'speech'
|
|
CONTENT_TYPE_TRANSCRIPT = 'transcript'
|
|
|
|
|
|
class LabeledSample(Sample):
|
|
"""In-memory labeled audio sample representing an utterance.
|
|
Derived from util.audio.Sample and used by sample collection readers and writers."""
|
|
def __init__(self, audio_type, raw_data, transcript, audio_format=DEFAULT_FORMAT, sample_id=None):
|
|
"""
|
|
Parameters
|
|
----------
|
|
audio_type : str
|
|
See util.audio.Sample.__init__ .
|
|
raw_data : binary
|
|
See util.audio.Sample.__init__ .
|
|
transcript : str
|
|
Transcript of the sample's utterance
|
|
audio_format : tuple
|
|
See util.audio.Sample.__init__ .
|
|
sample_id : str
|
|
Tracking ID - should indicate sample's origin as precisely as possible.
|
|
It is typically assigned by collection readers.
|
|
"""
|
|
super().__init__(audio_type, raw_data, audio_format=audio_format, sample_id=sample_id)
|
|
self.transcript = transcript
|
|
|
|
|
|
class PackedSample:
|
|
"""
|
|
A wrapper that we can carry around in an iterator and pass to a child process in order to
|
|
have the child process do the loading/unpacking of the sample, allowing for parallel file
|
|
I/O.
|
|
"""
|
|
def __init__(self, filename, audio_type, label):
|
|
self.filename = filename
|
|
self.audio_type = audio_type
|
|
self.label = label
|
|
|
|
def unpack(self):
|
|
with open_remote(self.filename, 'rb') as audio_file:
|
|
data = audio_file.read()
|
|
if self.label is None:
|
|
s = Sample(self.audio_type, data, sample_id=self.filename)
|
|
s = LabeledSample(self.audio_type, data, self.label, sample_id=self.filename)
|
|
return s
|
|
|
|
|
|
def unpack_maybe(sample):
|
|
"""
|
|
Loads the supplied sample from disk (or the network) if the audio isn't loaded in to memory already.
|
|
"""
|
|
if hasattr(sample, 'unpack'):
|
|
realized_sample = sample.unpack()
|
|
else:
|
|
realized_sample = sample
|
|
return realized_sample
|
|
|
|
|
|
def load_sample(filename, label=None):
|
|
"""
|
|
Loads audio-file as a (labeled or unlabeled) sample
|
|
|
|
Parameters
|
|
----------
|
|
filename : str
|
|
Filename of the audio-file to load as sample
|
|
label : str
|
|
Label (transcript) of the sample.
|
|
If None: returned result.unpack() will return util.audio.Sample instance
|
|
Otherwise: returned result.unpack() util.sample_collections.LabeledSample instance
|
|
|
|
Returns
|
|
-------
|
|
util.sample_collections.PackedSample, a wrapper object, on which calling unpack() will return
|
|
util.audio.Sample instance if label is None, else util.sample_collections.LabeledSample instance
|
|
"""
|
|
ext = os.path.splitext(filename)[1].lower()
|
|
audio_type = get_audio_type_from_extension(ext)
|
|
if audio_type is None:
|
|
raise ValueError('Unknown audio type extension "{}"'.format(ext))
|
|
return PackedSample(filename, audio_type, label)
|
|
|
|
|
|
class DirectSDBWriter:
|
|
"""Sample collection writer for creating a Sample DB (SDB) file"""
|
|
def __init__(self,
|
|
sdb_filename,
|
|
buffering=BUFFER_SIZE,
|
|
audio_type=AUDIO_TYPE_OPUS,
|
|
bitrate=None,
|
|
id_prefix=None,
|
|
labeled=True):
|
|
"""
|
|
Parameters
|
|
----------
|
|
sdb_filename : str
|
|
Path to the SDB file to write
|
|
buffering : int
|
|
Write-buffer size to use while writing the SDB file
|
|
audio_type : str
|
|
See util.audio.Sample.__init__ .
|
|
bitrate : int
|
|
Bitrate for sample-compression in case of lossy audio_type (e.g. AUDIO_TYPE_OPUS)
|
|
id_prefix : str
|
|
Prefix for IDs of written samples - defaults to sdb_filename
|
|
labeled : bool or None
|
|
If True: Writes labeled samples (util.sample_collections.LabeledSample) only.
|
|
If False: Ignores transcripts (if available) and writes (unlabeled) util.audio.Sample instances.
|
|
"""
|
|
self.sdb_filename = sdb_filename
|
|
self.id_prefix = sdb_filename if id_prefix is None else id_prefix
|
|
self.labeled = labeled
|
|
if audio_type not in SERIALIZABLE_AUDIO_TYPES:
|
|
raise ValueError('Audio type "{}" not supported'.format(audio_type))
|
|
self.audio_type = audio_type
|
|
self.bitrate = bitrate
|
|
self.sdb_file = open_remote(sdb_filename, 'wb', buffering=buffering)
|
|
self.offsets = []
|
|
self.num_samples = 0
|
|
|
|
self.sdb_file.write(MAGIC)
|
|
|
|
schema_entries = [{CONTENT_KEY: CONTENT_TYPE_SPEECH, MIME_TYPE_KEY: audio_type}]
|
|
if self.labeled:
|
|
schema_entries.append({CONTENT_KEY: CONTENT_TYPE_TRANSCRIPT, MIME_TYPE_KEY: MIME_TYPE_TEXT})
|
|
meta_data = {SCHEMA_KEY: schema_entries}
|
|
meta_data = json.dumps(meta_data).encode()
|
|
self.write_big_int(len(meta_data))
|
|
self.sdb_file.write(meta_data)
|
|
|
|
self.offset_samples = self.sdb_file.tell()
|
|
self.sdb_file.seek(2 * BIGINT_SIZE, 1)
|
|
|
|
def write_int(self, n):
|
|
return self.sdb_file.write(n.to_bytes(INT_SIZE, BIG_ENDIAN))
|
|
|
|
def write_big_int(self, n):
|
|
return self.sdb_file.write(n.to_bytes(BIGINT_SIZE, BIG_ENDIAN))
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def add(self, sample):
|
|
def to_bytes(n):
|
|
return n.to_bytes(INT_SIZE, BIG_ENDIAN)
|
|
sample.change_audio_type(self.audio_type, bitrate=self.bitrate)
|
|
opus = sample.audio.getbuffer()
|
|
opus_len = to_bytes(len(opus))
|
|
if self.labeled:
|
|
transcript = sample.transcript.encode()
|
|
transcript_len = to_bytes(len(transcript))
|
|
entry_len = to_bytes(len(opus_len) + len(opus) + len(transcript_len) + len(transcript))
|
|
buffer = b''.join([entry_len, opus_len, opus, transcript_len, transcript])
|
|
else:
|
|
entry_len = to_bytes(len(opus_len) + len(opus))
|
|
buffer = b''.join([entry_len, opus_len, opus])
|
|
self.offsets.append(self.sdb_file.tell())
|
|
self.sdb_file.write(buffer)
|
|
sample.sample_id = '{}:{}'.format(self.id_prefix, self.num_samples)
|
|
self.num_samples += 1
|
|
return sample.sample_id
|
|
|
|
def close(self):
|
|
if self.sdb_file is None:
|
|
return
|
|
offset_index = self.sdb_file.tell()
|
|
self.sdb_file.seek(self.offset_samples)
|
|
self.write_big_int(offset_index - self.offset_samples - BIGINT_SIZE)
|
|
self.write_big_int(self.num_samples)
|
|
|
|
self.sdb_file.seek(offset_index + BIGINT_SIZE)
|
|
self.write_big_int(self.num_samples)
|
|
for offset in self.offsets:
|
|
self.write_big_int(offset)
|
|
offset_end = self.sdb_file.tell()
|
|
self.sdb_file.seek(offset_index)
|
|
self.write_big_int(offset_end - offset_index - BIGINT_SIZE)
|
|
self.sdb_file.close()
|
|
self.sdb_file = None
|
|
|
|
def __len__(self):
|
|
return len(self.offsets)
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
self.close()
|
|
|
|
|
|
class SDB: # pylint: disable=too-many-instance-attributes
|
|
"""Sample collection reader for reading a Sample DB (SDB) file"""
|
|
def __init__(self,
|
|
sdb_filename,
|
|
buffering=BUFFER_SIZE,
|
|
id_prefix=None,
|
|
labeled=True,
|
|
reverse=False):
|
|
"""
|
|
Parameters
|
|
----------
|
|
sdb_filename : str
|
|
Path to the SDB file to read samples from
|
|
buffering : int
|
|
Read-ahead buffer size to use while reading the SDB file in normal order. Fixed to 16kB if in reverse-mode.
|
|
id_prefix : str
|
|
Prefix for IDs of read samples - defaults to sdb_filename
|
|
labeled : bool or None
|
|
If True: Reads util.sample_collections.LabeledSample instances. Fails, if SDB file provides no transcripts.
|
|
If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
|
|
If None: Automatically determines if SDB schema has transcripts
|
|
(reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances).
|
|
"""
|
|
self.sdb_filename = sdb_filename
|
|
self.id_prefix = sdb_filename if id_prefix is None else id_prefix
|
|
self.sdb_file = open_remote(sdb_filename, 'rb', buffering=REVERSE_BUFFER_SIZE if reverse else buffering)
|
|
self.offsets = []
|
|
if self.sdb_file.read(len(MAGIC)) != MAGIC:
|
|
raise RuntimeError('No Sample Database')
|
|
meta_chunk_len = self.read_big_int()
|
|
self.meta = json.loads(self.sdb_file.read(meta_chunk_len).decode())
|
|
if SCHEMA_KEY not in self.meta:
|
|
raise RuntimeError('Missing schema')
|
|
self.schema = self.meta[SCHEMA_KEY]
|
|
|
|
speech_columns = self.find_columns(content=CONTENT_TYPE_SPEECH, mime_type=SERIALIZABLE_AUDIO_TYPES)
|
|
if not speech_columns:
|
|
raise RuntimeError('No speech data (missing in schema)')
|
|
self.speech_index = speech_columns[0]
|
|
self.audio_type = self.schema[self.speech_index][MIME_TYPE_KEY]
|
|
|
|
self.transcript_index = None
|
|
if labeled is not False:
|
|
transcript_columns = self.find_columns(content=CONTENT_TYPE_TRANSCRIPT, mime_type=MIME_TYPE_TEXT)
|
|
if transcript_columns:
|
|
self.transcript_index = transcript_columns[0]
|
|
else:
|
|
if labeled is True:
|
|
raise RuntimeError('No transcript data (missing in schema)')
|
|
|
|
sample_chunk_len = self.read_big_int()
|
|
self.sdb_file.seek(sample_chunk_len + BIGINT_SIZE, 1)
|
|
num_samples = self.read_big_int()
|
|
for _ in range(num_samples):
|
|
self.offsets.append(self.read_big_int())
|
|
if reverse:
|
|
self.offsets.reverse()
|
|
|
|
def read_int(self):
|
|
return int.from_bytes(self.sdb_file.read(INT_SIZE), BIG_ENDIAN)
|
|
|
|
def read_big_int(self):
|
|
return int.from_bytes(self.sdb_file.read(BIGINT_SIZE), BIG_ENDIAN)
|
|
|
|
def find_columns(self, content=None, mime_type=None):
|
|
criteria = []
|
|
if content is not None:
|
|
criteria.append((CONTENT_KEY, content))
|
|
if mime_type is not None:
|
|
criteria.append((MIME_TYPE_KEY, mime_type))
|
|
if len(criteria) == 0:
|
|
raise ValueError('At least one of "content" or "mime-type" has to be provided')
|
|
matches = []
|
|
for index, column in enumerate(self.schema):
|
|
matched = 0
|
|
for field, value in criteria:
|
|
if column[field] == value or (isinstance(value, list) and column[field] in value):
|
|
matched += 1
|
|
if matched == len(criteria):
|
|
matches.append(index)
|
|
return matches
|
|
|
|
def read_row(self, row_index, *columns):
|
|
columns = list(columns)
|
|
column_data = [None] * len(columns)
|
|
found = 0
|
|
if not 0 <= row_index < len(self.offsets):
|
|
raise ValueError('Wrong sample index: {} - has to be between 0 and {}'
|
|
.format(row_index, len(self.offsets) - 1))
|
|
self.sdb_file.seek(self.offsets[row_index] + INT_SIZE)
|
|
for index in range(len(self.schema)):
|
|
chunk_len = self.read_int()
|
|
if index in columns:
|
|
column_data[columns.index(index)] = self.sdb_file.read(chunk_len)
|
|
found += 1
|
|
if found == len(columns):
|
|
return tuple(column_data)
|
|
else:
|
|
self.sdb_file.seek(chunk_len, 1)
|
|
return tuple(column_data)
|
|
|
|
def __getitem__(self, i):
|
|
sample_id = '{}:{}'.format(self.id_prefix, i)
|
|
if self.transcript_index is None:
|
|
[audio_data] = self.read_row(i, self.speech_index)
|
|
return Sample(self.audio_type, audio_data, sample_id=sample_id)
|
|
audio_data, transcript = self.read_row(i, self.speech_index, self.transcript_index)
|
|
transcript = transcript.decode()
|
|
return LabeledSample(self.audio_type, audio_data, transcript, sample_id=sample_id)
|
|
|
|
def __iter__(self):
|
|
for i in range(len(self.offsets)):
|
|
yield self[i]
|
|
|
|
def __len__(self):
|
|
return len(self.offsets)
|
|
|
|
def close(self):
|
|
if self.sdb_file is not None:
|
|
self.sdb_file.close()
|
|
|
|
def __del__(self):
|
|
self.close()
|
|
|
|
|
|
class CSVWriter: # pylint: disable=too-many-instance-attributes
|
|
"""Sample collection writer for writing a CSV data-set and all its referenced WAV samples"""
|
|
def __init__(self,
|
|
csv_filename,
|
|
absolute_paths=False,
|
|
labeled=True):
|
|
"""
|
|
Parameters
|
|
----------
|
|
csv_filename : str
|
|
Path to the CSV file to write.
|
|
Will create a directory (CSV-filename without extension) next to it and fail if it already exists.
|
|
absolute_paths : bool
|
|
If paths in CSV file should be absolute instead of relative to the CSV file's parent directory.
|
|
labeled : bool or None
|
|
If True: Writes labeled samples (util.sample_collections.LabeledSample) only.
|
|
If False: Ignores transcripts (if available) and writes (unlabeled) util.audio.Sample instances.
|
|
|
|
Currently only works with local files (not gs:// or hdfs://...)
|
|
"""
|
|
self.csv_filename = Path(csv_filename)
|
|
self.csv_base_dir = self.csv_filename.parent.resolve().absolute()
|
|
self.set_name = self.csv_filename.stem
|
|
self.csv_dir = self.csv_base_dir / self.set_name
|
|
if self.csv_dir.exists():
|
|
raise RuntimeError('"{}" already existing'.format(self.csv_dir))
|
|
os.mkdir(str(self.csv_dir))
|
|
self.absolute_paths = absolute_paths
|
|
fieldnames = ['wav_filename', 'wav_filesize']
|
|
self.labeled = labeled
|
|
if labeled:
|
|
fieldnames.append('transcript')
|
|
self.csv_file = open_remote(csv_filename, 'w', encoding='utf-8', newline='')
|
|
self.csv_writer = csv.DictWriter(self.csv_file, fieldnames=fieldnames)
|
|
self.csv_writer.writeheader()
|
|
self.counter = 0
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def add(self, sample):
|
|
sample_filename = self.csv_dir / 'sample{0:08d}.wav'.format(self.counter)
|
|
self.counter += 1
|
|
sample.change_audio_type(AUDIO_TYPE_PCM)
|
|
write_wav(str(sample_filename), sample.audio, audio_format=sample.audio_format)
|
|
sample.sample_id = str(sample_filename.relative_to(self.csv_base_dir))
|
|
row = {
|
|
'wav_filename': str(sample_filename.absolute()) if self.absolute_paths else sample.sample_id,
|
|
'wav_filesize': sample_filename.stat().st_size
|
|
}
|
|
if self.labeled:
|
|
row['transcript'] = sample.transcript
|
|
self.csv_writer.writerow(row)
|
|
return sample.sample_id
|
|
|
|
def close(self):
|
|
if self.csv_file:
|
|
self.csv_file.close()
|
|
|
|
def __len__(self):
|
|
return self.counter
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
self.close()
|
|
|
|
|
|
class TarWriter: # pylint: disable=too-many-instance-attributes
|
|
"""Sample collection writer for writing a CSV data-set and all its referenced WAV samples to a tar file."""
|
|
def __init__(self,
|
|
tar_filename,
|
|
gz=False,
|
|
labeled=True,
|
|
include=None):
|
|
"""
|
|
Parameters
|
|
----------
|
|
tar_filename : str
|
|
Path to the tar file to write.
|
|
gz : bool
|
|
If to compress tar file with gzip.
|
|
labeled : bool or None
|
|
If True: Writes labeled samples (util.sample_collections.LabeledSample) only.
|
|
If False: Ignores transcripts (if available) and writes (unlabeled) util.audio.Sample instances.
|
|
include : str[]
|
|
List of files to include into tar root.
|
|
|
|
Currently only works with local files (not gs:// or hdfs://...)
|
|
"""
|
|
self.tar = tarfile.open(tar_filename, 'w:gz' if gz else 'w')
|
|
samples_dir = tarfile.TarInfo('samples')
|
|
samples_dir.type = tarfile.DIRTYPE
|
|
self.tar.addfile(samples_dir)
|
|
if include:
|
|
for include_path in include:
|
|
self.tar.add(include_path, recursive=False, arcname=Path(include_path).name)
|
|
fieldnames = ['wav_filename', 'wav_filesize']
|
|
self.labeled = labeled
|
|
if labeled:
|
|
fieldnames.append('transcript')
|
|
self.csv_file = io.StringIO()
|
|
self.csv_writer = csv.DictWriter(self.csv_file, fieldnames=fieldnames)
|
|
self.csv_writer.writeheader()
|
|
self.counter = 0
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def add(self, sample):
|
|
sample_filename = 'samples/sample{0:08d}.wav'.format(self.counter)
|
|
self.counter += 1
|
|
sample.change_audio_type(AUDIO_TYPE_PCM)
|
|
sample_file = io.BytesIO()
|
|
write_wav(sample_file, sample.audio, audio_format=sample.audio_format)
|
|
sample_size = sample_file.tell()
|
|
sample_file.seek(0)
|
|
sample_tar = tarfile.TarInfo(sample_filename)
|
|
sample_tar.size = sample_size
|
|
self.tar.addfile(sample_tar, sample_file)
|
|
row = {
|
|
'wav_filename': sample_filename,
|
|
'wav_filesize': sample_size
|
|
}
|
|
if self.labeled:
|
|
row['transcript'] = sample.transcript
|
|
self.csv_writer.writerow(row)
|
|
return sample_filename
|
|
|
|
def close(self):
|
|
if self.csv_file and self.tar:
|
|
csv_tar = tarfile.TarInfo('samples.csv')
|
|
csv_tar.size = self.csv_file.tell()
|
|
self.csv_file.seek(0)
|
|
self.tar.addfile(csv_tar, io.BytesIO(self.csv_file.read().encode('utf8')))
|
|
if self.tar:
|
|
self.tar.close()
|
|
|
|
def __len__(self):
|
|
return self.counter
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
self.close()
|
|
|
|
|
|
class SampleList:
|
|
"""Sample collection base class with samples loaded from a list of in-memory paths."""
|
|
def __init__(self, samples, labeled=True, reverse=False):
|
|
"""
|
|
Parameters
|
|
----------
|
|
samples : iterable of tuples of the form (sample_filename, filesize [, transcript])
|
|
File-size is used for ordering the samples; transcript has to be provided if labeled=True
|
|
labeled : bool or None
|
|
If True: Reads LabeledSample instances.
|
|
If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
|
|
reverse : bool
|
|
If the order of the samples should be reversed
|
|
"""
|
|
self.labeled = labeled
|
|
self.samples = list(samples)
|
|
self.samples.sort(key=lambda r: r[1], reverse=reverse)
|
|
|
|
def __getitem__(self, i):
|
|
sample_spec = self.samples[i]
|
|
return load_sample(sample_spec[0], label=sample_spec[2] if self.labeled else None)
|
|
|
|
def __len__(self):
|
|
return len(self.samples)
|
|
|
|
|
|
class CSV(SampleList):
|
|
"""Sample collection reader for reading a DeepSpeech CSV file
|
|
Automatically orders samples by CSV column wav_filesize (if available)."""
|
|
def __init__(self, csv_filename, labeled=None, reverse=False):
|
|
"""
|
|
Parameters
|
|
----------
|
|
csv_filename : str
|
|
Path to the CSV file containing sample audio paths and transcripts
|
|
labeled : bool or None
|
|
If True: Reads LabeledSample instances. Fails, if CSV file has no transcript column.
|
|
If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
|
|
If None: Automatically determines if CSV file has a transcript column
|
|
(reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances).
|
|
reverse : bool
|
|
If the order of the samples should be reversed
|
|
"""
|
|
rows = []
|
|
with open_remote(csv_filename, 'r', encoding='utf8') as csv_file:
|
|
reader = csv.DictReader(csv_file)
|
|
if 'transcript' in reader.fieldnames:
|
|
if labeled is None:
|
|
labeled = True
|
|
elif labeled:
|
|
raise RuntimeError('No transcript data (missing CSV column)')
|
|
for row in reader:
|
|
wav_filename = Path(row['wav_filename'])
|
|
if not wav_filename.is_absolute() and not is_remote_path(row['wav_filename']):
|
|
wav_filename = Path(csv_filename).parent / wav_filename
|
|
wav_filename = str(wav_filename)
|
|
else:
|
|
# Pathlib otherwise removes a / from filenames like hdfs://
|
|
wav_filename = row['wav_filename']
|
|
wav_filesize = int(row['wav_filesize']) if 'wav_filesize' in row else 0
|
|
if labeled:
|
|
rows.append((wav_filename, wav_filesize, row['transcript']))
|
|
else:
|
|
rows.append((wav_filename, wav_filesize))
|
|
super(CSV, self).__init__(rows, labeled=labeled, reverse=reverse)
|
|
|
|
|
|
def samples_from_source(sample_source, buffering=BUFFER_SIZE, labeled=None, reverse=False):
|
|
"""
|
|
Loads samples from a sample source file.
|
|
|
|
Parameters
|
|
----------
|
|
sample_source : str
|
|
Path to the sample source file (SDB or CSV)
|
|
buffering : int
|
|
Read-buffer size to use while reading files
|
|
labeled : bool or None
|
|
If True: Reads LabeledSample instances. Fails, if source provides no transcripts.
|
|
If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
|
|
If None: Automatically determines if source provides transcripts
|
|
(reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances).
|
|
reverse : bool
|
|
If the order of the samples should be reversed
|
|
|
|
Returns
|
|
-------
|
|
iterable of util.sample_collections.LabeledSample or util.audio.Sample instances supporting len.
|
|
"""
|
|
ext = os.path.splitext(sample_source)[1].lower()
|
|
if ext == '.sdb':
|
|
return SDB(sample_source, buffering=buffering, labeled=labeled, reverse=reverse)
|
|
if ext == '.csv':
|
|
return CSV(sample_source, labeled=labeled, reverse=reverse)
|
|
raise ValueError('Unknown file type: "{}"'.format(ext))
|
|
|
|
|
|
def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None, reverse=False):
|
|
"""
|
|
Loads and combines samples from a list of source files. Sources are combined in an interleaving way to
|
|
keep default sample order from shortest to longest.
|
|
|
|
Note that when using distributed training, it is much faster to call this function with single pre-
|
|
sorted sample source, because this allows for parallelization of the file I/O. (If this function is
|
|
called with multiple sources, the samples have to be unpacked on a single parent process to allow
|
|
for reading their durations.)
|
|
|
|
Parameters
|
|
----------
|
|
sample_sources : list of str
|
|
Paths to sample source files (SDBs or CSVs)
|
|
buffering : int
|
|
Read-buffer size to use while reading files
|
|
labeled : bool or None
|
|
If True: Reads LabeledSample instances. Fails, if not all sources provide transcripts.
|
|
If False: Ignores transcripts (if available) and always reads (unlabeled) util.audio.Sample instances.
|
|
If None: Reads util.sample_collections.LabeledSample instances from sources with transcripts and
|
|
util.audio.Sample instances from sources with no transcripts.
|
|
reverse : bool
|
|
If the order of the samples should be reversed
|
|
|
|
Returns
|
|
-------
|
|
iterable of util.sample_collections.PackedSample if a single collection is provided, wrapping
|
|
LabeledSample (labeled=True) or util.audio.Sample (labeled=False) supporting len
|
|
or LabeledSample / util.audio.Sample directly, if multiple collections are provided
|
|
"""
|
|
sample_sources = list(sample_sources)
|
|
if len(sample_sources) == 0:
|
|
raise ValueError('No files')
|
|
if len(sample_sources) == 1:
|
|
return samples_from_source(sample_sources[0], buffering=buffering, labeled=labeled, reverse=reverse)
|
|
|
|
# If we wish to interleave based on duration, we have to unpack the audio. Note that this unpacking should
|
|
# be done lazily onn the fly so that it respects the LimitingPool logic used in the feeding code.
|
|
cols = [LenMap(
|
|
unpack_maybe, samples_from_source(source, buffering=buffering, labeled=labeled, reverse=reverse))
|
|
for source in sample_sources]
|
|
|
|
return Interleaved(*cols, key=lambda s: s.duration, reverse=reverse)
|