Resolves #1565 - Limiting and reversing data-sets

This commit is contained in:
Tilman Kamp 2020-07-23 16:59:12 +02:00
parent 38f6afdba8
commit 9a5d19d7c5
6 changed files with 61 additions and 21 deletions

View File

@ -50,7 +50,11 @@ def evaluate(test_csvs, create_model):
else:
scorer = None
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs]
test_sets = [create_dataset([csv],
batch_size=FLAGS.test_batch_size,
train_phase=False,
reverse=FLAGS.reverse_test,
limit=FLAGS.limit_test) for csv in test_csvs]
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]),
tfv1.data.get_output_shapes(test_sets[0]),
output_classes=tfv1.data.get_output_classes(test_sets[0]))

View File

@ -417,6 +417,8 @@ def train():
train_phase=True,
exception_box=exception_box,
process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2,
reverse=FLAGS.reverse_train,
limit=FLAGS.limit_train,
buffering=FLAGS.read_buffer)
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
@ -433,6 +435,8 @@ def train():
train_phase=False,
exception_box=exception_box,
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
reverse=FLAGS.reverse_dev,
limit=FLAGS.limit_dev,
buffering=FLAGS.read_buffer) for source in dev_sources]
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
@ -443,6 +447,8 @@ def train():
train_phase=False,
exception_box=exception_box,
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
reverse=FLAGS.reverse_dev,
limit=FLAGS.limit_dev,
buffering=FLAGS.read_buffer) for source in metrics_sources]
metrics_init_ops = [iterator.make_initializer(metrics_set) for metrics_set in metrics_sets]

View File

@ -90,6 +90,8 @@ def create_dataset(sources,
augmentations=None,
cache_path=None,
train_phase=False,
reverse=False,
limit=0,
exception_box=None,
process_ahead=None,
buffering=1 * MEGABYTE):
@ -99,8 +101,10 @@ def create_dataset(sources,
epoch = epoch_counter['epoch']
if train_phase:
epoch_counter['epoch'] += 1
samples = samples_from_sources(sources, buffering=buffering, labeled=True)
samples = samples_from_sources(sources, buffering=buffering, labeled=True, reverse=reverse)
num_samples = len(samples)
if limit > 0:
num_samples = min(limit, num_samples)
samples = apply_sample_augmentations(samples,
augmentations,
buffering=buffering,
@ -108,6 +112,8 @@ def create_dataset(sources,
clock=epoch / epochs,
final_clock=(epoch + 1) / epochs)
for sample_index, sample in enumerate(samples):
if sample_index >= num_samples:
break
clock = (epoch * num_samples + sample_index) / (epochs * num_samples) if train_phase and epochs > 0 else 0.0
transcript = text_to_char_array(sample.transcript, Config.alphabet, context=sample.sample_id)
transcript = to_sparse_tuple(transcript)

View File

@ -74,6 +74,12 @@ def create_flags():
f.DEFINE_integer('limit_dev', 0, 'maximum number of elements to use from validation set - 0 means no limit')
f.DEFINE_integer('limit_test', 0, 'maximum number of elements to use from test set - 0 means no limit')
# Sample order
f.DEFINE_boolean('reverse_train', False, 'if to reverse sample order of the train set')
f.DEFINE_boolean('reverse_dev', False, 'if to reverse sample order of the dev set')
f.DEFINE_boolean('reverse_test', False, 'if to reverse sample order of the test set')
# Checkpointing
f.DEFINE_string('checkpoint_dir', '', 'directory from which checkpoints are loaded and to which they are saved - defaults to directory "deepspeech/checkpoints" within user\'s data home specified by the XDG Base Directory Specification')

View File

@ -65,13 +65,14 @@ class Interleaved:
"""Collection that lazily combines sorted collections in an interleaving fashion.
During iteration the next smallest element from all the sorted collections is always picked.
The collections must support iter() and len()."""
def __init__(self, *iterables, key=lambda obj: obj):
def __init__(self, *iterables, key=lambda obj: obj, reverse=False):
self.iterables = iterables
self.key = key
self.reverse = reverse
self.len = sum(map(len, iterables))
def __iter__(self):
return heapq.merge(*self.iterables, key=self.key)
return heapq.merge(*self.iterables, key=self.key, reverse=self.reverse)
def __len__(self):
return self.len

View File

@ -6,7 +6,7 @@ import json
from pathlib import Path
from functools import partial
from .helpers import MEGABYTE, GIGABYTE, Interleaved
from .helpers import KILOBYTE, MEGABYTE, GIGABYTE, Interleaved
from .audio import (
Sample,
DEFAULT_FORMAT,
@ -23,6 +23,7 @@ BIGINT_SIZE = 2 * INT_SIZE
MAGIC = b'SAMPLEDB'
BUFFER_SIZE = 1 * MEGABYTE
REVERSE_BUFFER_SIZE = 16 * KILOBYTE
CACHE_SIZE = 1 * GIGABYTE
SCHEMA_KEY = 'schema'
@ -189,14 +190,19 @@ class DirectSDBWriter:
class SDB: # pylint: disable=too-many-instance-attributes
"""Sample collection reader for reading a Sample DB (SDB) file"""
def __init__(self, sdb_filename, buffering=BUFFER_SIZE, id_prefix=None, labeled=True):
def __init__(self,
sdb_filename,
buffering=BUFFER_SIZE,
id_prefix=None,
labeled=True,
reverse=False):
"""
Parameters
----------
sdb_filename : str
Path to the SDB file to read samples from
buffering : int
Read-buffer size to use while reading the SDB file
Read-ahead buffer size to use while reading the SDB file in normal order. Fixed to 16kB if in reverse-mode.
id_prefix : str
Prefix for IDs of read samples - defaults to sdb_filename
labeled : bool or None
@ -207,7 +213,7 @@ class SDB: # pylint: disable=too-many-instance-attributes
"""
self.sdb_filename = sdb_filename
self.id_prefix = sdb_filename if id_prefix is None else id_prefix
self.sdb_file = open(sdb_filename, 'rb', buffering=buffering)
self.sdb_file = open(sdb_filename, 'rb', buffering=REVERSE_BUFFER_SIZE if reverse else buffering)
self.offsets = []
if self.sdb_file.read(len(MAGIC)) != MAGIC:
raise RuntimeError('No Sample Database')
@ -237,6 +243,8 @@ class SDB: # pylint: disable=too-many-instance-attributes
num_samples = self.read_big_int()
for _ in range(num_samples):
self.offsets.append(self.read_big_int())
if reverse:
self.offsets.reverse()
def read_int(self):
return int.from_bytes(self.sdb_file.read(INT_SIZE), BIG_ENDIAN)
@ -371,7 +379,7 @@ class CSVWriter: # pylint: disable=too-many-instance-attributes
class SampleList:
"""Sample collection base class with samples loaded from a list of in-memory paths."""
def __init__(self, samples, labeled=True):
def __init__(self, samples, labeled=True, reverse=False):
"""
Parameters
----------
@ -380,10 +388,12 @@ class SampleList:
labeled : bool or None
If True: Reads LabeledSample instances.
If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
reverse : bool
If the order of the samples should be reversed
"""
self.labeled = labeled
self.samples = list(samples)
self.samples.sort(key=lambda r: r[1])
self.samples.sort(key=lambda r: r[1], reverse=reverse)
def __getitem__(self, i):
sample_spec = self.samples[i]
@ -396,7 +406,7 @@ class SampleList:
class CSV(SampleList):
"""Sample collection reader for reading a DeepSpeech CSV file
Automatically orders samples by CSV column wav_filesize (if available)."""
def __init__(self, csv_filename, labeled=None):
def __init__(self, csv_filename, labeled=None, reverse=False):
"""
Parameters
----------
@ -407,6 +417,8 @@ class CSV(SampleList):
If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
If None: Automatically determines if CSV file has a transcript column
(reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances).
reverse : bool
If the order of the samples should be reversed
"""
rows = []
csv_dir = Path(csv_filename).parent
@ -427,10 +439,10 @@ class CSV(SampleList):
rows.append((wav_filename, wav_filesize, row['transcript']))
else:
rows.append((wav_filename, wav_filesize))
super(CSV, self).__init__(rows, labeled=labeled)
super(CSV, self).__init__(rows, labeled=labeled, reverse=reverse)
def samples_from_source(sample_source, buffering=BUFFER_SIZE, labeled=None):
def samples_from_source(sample_source, buffering=BUFFER_SIZE, labeled=None, reverse=False):
"""
Loads samples from a sample source file.
@ -445,6 +457,8 @@ def samples_from_source(sample_source, buffering=BUFFER_SIZE, labeled=None):
If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
If None: Automatically determines if source provides transcripts
(reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances).
reverse : bool
If the order of the samples should be reversed
Returns
-------
@ -452,13 +466,13 @@ def samples_from_source(sample_source, buffering=BUFFER_SIZE, labeled=None):
"""
ext = os.path.splitext(sample_source)[1].lower()
if ext == '.sdb':
return SDB(sample_source, buffering=buffering, labeled=labeled)
return SDB(sample_source, buffering=buffering, labeled=labeled, reverse=reverse)
if ext == '.csv':
return CSV(sample_source, labeled=labeled)
return CSV(sample_source, labeled=labeled, reverse=reverse)
raise ValueError('Unknown file type: "{}"'.format(ext))
def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None):
def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None, reverse=False):
"""
Loads and combines samples from a list of source files. Sources are combined in an interleaving way to
keep default sample order from shortest to longest.
@ -474,6 +488,8 @@ def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None):
If False: Ignores transcripts (if available) and always reads (unlabeled) util.audio.Sample instances.
If None: Reads util.sample_collections.LabeledSample instances from sources with transcripts and
util.audio.Sample instances from sources with no transcripts.
reverse : bool
If the order of the samples should be reversed
Returns
-------
@ -483,6 +499,7 @@ def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None):
if len(sample_sources) == 0:
raise ValueError('No files')
if len(sample_sources) == 1:
return samples_from_source(sample_sources[0], buffering=buffering, labeled=labeled)
cols = list(map(partial(samples_from_source, buffering=buffering, labeled=labeled), sample_sources))
return Interleaved(*cols, key=lambda s: s.duration)
return samples_from_source(sample_sources[0], buffering=buffering, labeled=labeled, reverse=reverse)
cols = [samples_from_source(source, buffering=buffering, labeled=labeled, reverse=reverse)
for source in sample_sources]
return Interleaved(*cols, key=lambda s: s.duration, reverse=reverse)