Merge pull request #3177 from tilmankamp/reverse

Resolves #1565 - Limiting and reversing data-sets
This commit is contained in:
Tilman Kamp 2020-07-24 11:26:53 +02:00 committed by GitHub
commit 9e023660ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 61 additions and 21 deletions

View File

@ -50,7 +50,11 @@ def evaluate(test_csvs, create_model):
else: else:
scorer = None scorer = None
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs] test_sets = [create_dataset([csv],
batch_size=FLAGS.test_batch_size,
train_phase=False,
reverse=FLAGS.reverse_test,
limit=FLAGS.limit_test) for csv in test_csvs]
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]), iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]),
tfv1.data.get_output_shapes(test_sets[0]), tfv1.data.get_output_shapes(test_sets[0]),
output_classes=tfv1.data.get_output_classes(test_sets[0])) output_classes=tfv1.data.get_output_classes(test_sets[0]))

View File

@ -417,6 +417,8 @@ def train():
train_phase=True, train_phase=True,
exception_box=exception_box, exception_box=exception_box,
process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2, process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2,
reverse=FLAGS.reverse_train,
limit=FLAGS.limit_train,
buffering=FLAGS.read_buffer) buffering=FLAGS.read_buffer)
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set), iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
@ -433,6 +435,8 @@ def train():
train_phase=False, train_phase=False,
exception_box=exception_box, exception_box=exception_box,
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2, process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
reverse=FLAGS.reverse_dev,
limit=FLAGS.limit_dev,
buffering=FLAGS.read_buffer) for source in dev_sources] buffering=FLAGS.read_buffer) for source in dev_sources]
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets] dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
@ -443,6 +447,8 @@ def train():
train_phase=False, train_phase=False,
exception_box=exception_box, exception_box=exception_box,
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2, process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
reverse=FLAGS.reverse_dev,
limit=FLAGS.limit_dev,
buffering=FLAGS.read_buffer) for source in metrics_sources] buffering=FLAGS.read_buffer) for source in metrics_sources]
metrics_init_ops = [iterator.make_initializer(metrics_set) for metrics_set in metrics_sets] metrics_init_ops = [iterator.make_initializer(metrics_set) for metrics_set in metrics_sets]

View File

@ -90,6 +90,8 @@ def create_dataset(sources,
augmentations=None, augmentations=None,
cache_path=None, cache_path=None,
train_phase=False, train_phase=False,
reverse=False,
limit=0,
exception_box=None, exception_box=None,
process_ahead=None, process_ahead=None,
buffering=1 * MEGABYTE): buffering=1 * MEGABYTE):
@ -99,8 +101,10 @@ def create_dataset(sources,
epoch = epoch_counter['epoch'] epoch = epoch_counter['epoch']
if train_phase: if train_phase:
epoch_counter['epoch'] += 1 epoch_counter['epoch'] += 1
samples = samples_from_sources(sources, buffering=buffering, labeled=True) samples = samples_from_sources(sources, buffering=buffering, labeled=True, reverse=reverse)
num_samples = len(samples) num_samples = len(samples)
if limit > 0:
num_samples = min(limit, num_samples)
samples = apply_sample_augmentations(samples, samples = apply_sample_augmentations(samples,
augmentations, augmentations,
buffering=buffering, buffering=buffering,
@ -108,6 +112,8 @@ def create_dataset(sources,
clock=epoch / epochs, clock=epoch / epochs,
final_clock=(epoch + 1) / epochs) final_clock=(epoch + 1) / epochs)
for sample_index, sample in enumerate(samples): for sample_index, sample in enumerate(samples):
if sample_index >= num_samples:
break
clock = (epoch * num_samples + sample_index) / (epochs * num_samples) if train_phase and epochs > 0 else 0.0 clock = (epoch * num_samples + sample_index) / (epochs * num_samples) if train_phase and epochs > 0 else 0.0
transcript = text_to_char_array(sample.transcript, Config.alphabet, context=sample.sample_id) transcript = text_to_char_array(sample.transcript, Config.alphabet, context=sample.sample_id)
transcript = to_sparse_tuple(transcript) transcript = to_sparse_tuple(transcript)

View File

@ -74,6 +74,12 @@ def create_flags():
f.DEFINE_integer('limit_dev', 0, 'maximum number of elements to use from validation set - 0 means no limit') f.DEFINE_integer('limit_dev', 0, 'maximum number of elements to use from validation set - 0 means no limit')
f.DEFINE_integer('limit_test', 0, 'maximum number of elements to use from test set - 0 means no limit') f.DEFINE_integer('limit_test', 0, 'maximum number of elements to use from test set - 0 means no limit')
# Sample order
f.DEFINE_boolean('reverse_train', False, 'if to reverse sample order of the train set')
f.DEFINE_boolean('reverse_dev', False, 'if to reverse sample order of the dev set')
f.DEFINE_boolean('reverse_test', False, 'if to reverse sample order of the test set')
# Checkpointing # Checkpointing
f.DEFINE_string('checkpoint_dir', '', 'directory from which checkpoints are loaded and to which they are saved - defaults to directory "deepspeech/checkpoints" within user\'s data home specified by the XDG Base Directory Specification') f.DEFINE_string('checkpoint_dir', '', 'directory from which checkpoints are loaded and to which they are saved - defaults to directory "deepspeech/checkpoints" within user\'s data home specified by the XDG Base Directory Specification')

View File

@ -65,13 +65,14 @@ class Interleaved:
"""Collection that lazily combines sorted collections in an interleaving fashion. """Collection that lazily combines sorted collections in an interleaving fashion.
During iteration the next smallest element from all the sorted collections is always picked. During iteration the next smallest element from all the sorted collections is always picked.
The collections must support iter() and len().""" The collections must support iter() and len()."""
def __init__(self, *iterables, key=lambda obj: obj): def __init__(self, *iterables, key=lambda obj: obj, reverse=False):
self.iterables = iterables self.iterables = iterables
self.key = key self.key = key
self.reverse = reverse
self.len = sum(map(len, iterables)) self.len = sum(map(len, iterables))
def __iter__(self): def __iter__(self):
return heapq.merge(*self.iterables, key=self.key) return heapq.merge(*self.iterables, key=self.key, reverse=self.reverse)
def __len__(self): def __len__(self):
return self.len return self.len

View File

@ -6,7 +6,7 @@ import json
from pathlib import Path from pathlib import Path
from functools import partial from functools import partial
from .helpers import MEGABYTE, GIGABYTE, Interleaved from .helpers import KILOBYTE, MEGABYTE, GIGABYTE, Interleaved
from .audio import ( from .audio import (
Sample, Sample,
DEFAULT_FORMAT, DEFAULT_FORMAT,
@ -23,6 +23,7 @@ BIGINT_SIZE = 2 * INT_SIZE
MAGIC = b'SAMPLEDB' MAGIC = b'SAMPLEDB'
BUFFER_SIZE = 1 * MEGABYTE BUFFER_SIZE = 1 * MEGABYTE
REVERSE_BUFFER_SIZE = 16 * KILOBYTE
CACHE_SIZE = 1 * GIGABYTE CACHE_SIZE = 1 * GIGABYTE
SCHEMA_KEY = 'schema' SCHEMA_KEY = 'schema'
@ -189,14 +190,19 @@ class DirectSDBWriter:
class SDB: # pylint: disable=too-many-instance-attributes class SDB: # pylint: disable=too-many-instance-attributes
"""Sample collection reader for reading a Sample DB (SDB) file""" """Sample collection reader for reading a Sample DB (SDB) file"""
def __init__(self, sdb_filename, buffering=BUFFER_SIZE, id_prefix=None, labeled=True): def __init__(self,
sdb_filename,
buffering=BUFFER_SIZE,
id_prefix=None,
labeled=True,
reverse=False):
""" """
Parameters Parameters
---------- ----------
sdb_filename : str sdb_filename : str
Path to the SDB file to read samples from Path to the SDB file to read samples from
buffering : int buffering : int
Read-buffer size to use while reading the SDB file Read-ahead buffer size to use while reading the SDB file in normal order. Fixed to 16kB if in reverse-mode.
id_prefix : str id_prefix : str
Prefix for IDs of read samples - defaults to sdb_filename Prefix for IDs of read samples - defaults to sdb_filename
labeled : bool or None labeled : bool or None
@ -207,7 +213,7 @@ class SDB: # pylint: disable=too-many-instance-attributes
""" """
self.sdb_filename = sdb_filename self.sdb_filename = sdb_filename
self.id_prefix = sdb_filename if id_prefix is None else id_prefix self.id_prefix = sdb_filename if id_prefix is None else id_prefix
self.sdb_file = open(sdb_filename, 'rb', buffering=buffering) self.sdb_file = open(sdb_filename, 'rb', buffering=REVERSE_BUFFER_SIZE if reverse else buffering)
self.offsets = [] self.offsets = []
if self.sdb_file.read(len(MAGIC)) != MAGIC: if self.sdb_file.read(len(MAGIC)) != MAGIC:
raise RuntimeError('No Sample Database') raise RuntimeError('No Sample Database')
@ -237,6 +243,8 @@ class SDB: # pylint: disable=too-many-instance-attributes
num_samples = self.read_big_int() num_samples = self.read_big_int()
for _ in range(num_samples): for _ in range(num_samples):
self.offsets.append(self.read_big_int()) self.offsets.append(self.read_big_int())
if reverse:
self.offsets.reverse()
def read_int(self): def read_int(self):
return int.from_bytes(self.sdb_file.read(INT_SIZE), BIG_ENDIAN) return int.from_bytes(self.sdb_file.read(INT_SIZE), BIG_ENDIAN)
@ -371,7 +379,7 @@ class CSVWriter: # pylint: disable=too-many-instance-attributes
class SampleList: class SampleList:
"""Sample collection base class with samples loaded from a list of in-memory paths.""" """Sample collection base class with samples loaded from a list of in-memory paths."""
def __init__(self, samples, labeled=True): def __init__(self, samples, labeled=True, reverse=False):
""" """
Parameters Parameters
---------- ----------
@ -380,10 +388,12 @@ class SampleList:
labeled : bool or None labeled : bool or None
If True: Reads LabeledSample instances. If True: Reads LabeledSample instances.
If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances. If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
reverse : bool
If the order of the samples should be reversed
""" """
self.labeled = labeled self.labeled = labeled
self.samples = list(samples) self.samples = list(samples)
self.samples.sort(key=lambda r: r[1]) self.samples.sort(key=lambda r: r[1], reverse=reverse)
def __getitem__(self, i): def __getitem__(self, i):
sample_spec = self.samples[i] sample_spec = self.samples[i]
@ -396,7 +406,7 @@ class SampleList:
class CSV(SampleList): class CSV(SampleList):
"""Sample collection reader for reading a DeepSpeech CSV file """Sample collection reader for reading a DeepSpeech CSV file
Automatically orders samples by CSV column wav_filesize (if available).""" Automatically orders samples by CSV column wav_filesize (if available)."""
def __init__(self, csv_filename, labeled=None): def __init__(self, csv_filename, labeled=None, reverse=False):
""" """
Parameters Parameters
---------- ----------
@ -407,6 +417,8 @@ class CSV(SampleList):
If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances. If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
If None: Automatically determines if CSV file has a transcript column If None: Automatically determines if CSV file has a transcript column
(reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances). (reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances).
reverse : bool
If the order of the samples should be reversed
""" """
rows = [] rows = []
csv_dir = Path(csv_filename).parent csv_dir = Path(csv_filename).parent
@ -427,10 +439,10 @@ class CSV(SampleList):
rows.append((wav_filename, wav_filesize, row['transcript'])) rows.append((wav_filename, wav_filesize, row['transcript']))
else: else:
rows.append((wav_filename, wav_filesize)) rows.append((wav_filename, wav_filesize))
super(CSV, self).__init__(rows, labeled=labeled) super(CSV, self).__init__(rows, labeled=labeled, reverse=reverse)
def samples_from_source(sample_source, buffering=BUFFER_SIZE, labeled=None): def samples_from_source(sample_source, buffering=BUFFER_SIZE, labeled=None, reverse=False):
""" """
Loads samples from a sample source file. Loads samples from a sample source file.
@ -445,6 +457,8 @@ def samples_from_source(sample_source, buffering=BUFFER_SIZE, labeled=None):
If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances. If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
If None: Automatically determines if source provides transcripts If None: Automatically determines if source provides transcripts
(reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances). (reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances).
reverse : bool
If the order of the samples should be reversed
Returns Returns
------- -------
@ -452,13 +466,13 @@ def samples_from_source(sample_source, buffering=BUFFER_SIZE, labeled=None):
""" """
ext = os.path.splitext(sample_source)[1].lower() ext = os.path.splitext(sample_source)[1].lower()
if ext == '.sdb': if ext == '.sdb':
return SDB(sample_source, buffering=buffering, labeled=labeled) return SDB(sample_source, buffering=buffering, labeled=labeled, reverse=reverse)
if ext == '.csv': if ext == '.csv':
return CSV(sample_source, labeled=labeled) return CSV(sample_source, labeled=labeled, reverse=reverse)
raise ValueError('Unknown file type: "{}"'.format(ext)) raise ValueError('Unknown file type: "{}"'.format(ext))
def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None): def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None, reverse=False):
""" """
Loads and combines samples from a list of source files. Sources are combined in an interleaving way to Loads and combines samples from a list of source files. Sources are combined in an interleaving way to
keep default sample order from shortest to longest. keep default sample order from shortest to longest.
@ -474,6 +488,8 @@ def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None):
If False: Ignores transcripts (if available) and always reads (unlabeled) util.audio.Sample instances. If False: Ignores transcripts (if available) and always reads (unlabeled) util.audio.Sample instances.
If None: Reads util.sample_collections.LabeledSample instances from sources with transcripts and If None: Reads util.sample_collections.LabeledSample instances from sources with transcripts and
util.audio.Sample instances from sources with no transcripts. util.audio.Sample instances from sources with no transcripts.
reverse : bool
If the order of the samples should be reversed
Returns Returns
------- -------
@ -483,6 +499,7 @@ def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None):
if len(sample_sources) == 0: if len(sample_sources) == 0:
raise ValueError('No files') raise ValueError('No files')
if len(sample_sources) == 1: if len(sample_sources) == 1:
return samples_from_source(sample_sources[0], buffering=buffering, labeled=labeled) return samples_from_source(sample_sources[0], buffering=buffering, labeled=labeled, reverse=reverse)
cols = list(map(partial(samples_from_source, buffering=buffering, labeled=labeled), sample_sources)) cols = [samples_from_source(source, buffering=buffering, labeled=labeled, reverse=reverse)
return Interleaved(*cols, key=lambda s: s.duration) for source in sample_sources]
return Interleaved(*cols, key=lambda s: s.duration, reverse=reverse)