Resolves #1565 - Limiting and reversing data-sets
This commit is contained in:
parent
38f6afdba8
commit
9a5d19d7c5
|
@ -50,7 +50,11 @@ def evaluate(test_csvs, create_model):
|
|||
else:
|
||||
scorer = None
|
||||
|
||||
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs]
|
||||
test_sets = [create_dataset([csv],
|
||||
batch_size=FLAGS.test_batch_size,
|
||||
train_phase=False,
|
||||
reverse=FLAGS.reverse_test,
|
||||
limit=FLAGS.limit_test) for csv in test_csvs]
|
||||
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]),
|
||||
tfv1.data.get_output_shapes(test_sets[0]),
|
||||
output_classes=tfv1.data.get_output_classes(test_sets[0]))
|
||||
|
|
|
@ -417,6 +417,8 @@ def train():
|
|||
train_phase=True,
|
||||
exception_box=exception_box,
|
||||
process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2,
|
||||
reverse=FLAGS.reverse_train,
|
||||
limit=FLAGS.limit_train,
|
||||
buffering=FLAGS.read_buffer)
|
||||
|
||||
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
|
||||
|
@ -433,6 +435,8 @@ def train():
|
|||
train_phase=False,
|
||||
exception_box=exception_box,
|
||||
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
|
||||
reverse=FLAGS.reverse_dev,
|
||||
limit=FLAGS.limit_dev,
|
||||
buffering=FLAGS.read_buffer) for source in dev_sources]
|
||||
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
|
||||
|
||||
|
@ -443,6 +447,8 @@ def train():
|
|||
train_phase=False,
|
||||
exception_box=exception_box,
|
||||
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
|
||||
reverse=FLAGS.reverse_dev,
|
||||
limit=FLAGS.limit_dev,
|
||||
buffering=FLAGS.read_buffer) for source in metrics_sources]
|
||||
metrics_init_ops = [iterator.make_initializer(metrics_set) for metrics_set in metrics_sets]
|
||||
|
||||
|
|
|
@ -90,6 +90,8 @@ def create_dataset(sources,
|
|||
augmentations=None,
|
||||
cache_path=None,
|
||||
train_phase=False,
|
||||
reverse=False,
|
||||
limit=0,
|
||||
exception_box=None,
|
||||
process_ahead=None,
|
||||
buffering=1 * MEGABYTE):
|
||||
|
@ -99,8 +101,10 @@ def create_dataset(sources,
|
|||
epoch = epoch_counter['epoch']
|
||||
if train_phase:
|
||||
epoch_counter['epoch'] += 1
|
||||
samples = samples_from_sources(sources, buffering=buffering, labeled=True)
|
||||
samples = samples_from_sources(sources, buffering=buffering, labeled=True, reverse=reverse)
|
||||
num_samples = len(samples)
|
||||
if limit > 0:
|
||||
num_samples = min(limit, num_samples)
|
||||
samples = apply_sample_augmentations(samples,
|
||||
augmentations,
|
||||
buffering=buffering,
|
||||
|
@ -108,6 +112,8 @@ def create_dataset(sources,
|
|||
clock=epoch / epochs,
|
||||
final_clock=(epoch + 1) / epochs)
|
||||
for sample_index, sample in enumerate(samples):
|
||||
if sample_index >= num_samples:
|
||||
break
|
||||
clock = (epoch * num_samples + sample_index) / (epochs * num_samples) if train_phase and epochs > 0 else 0.0
|
||||
transcript = text_to_char_array(sample.transcript, Config.alphabet, context=sample.sample_id)
|
||||
transcript = to_sparse_tuple(transcript)
|
||||
|
|
|
@ -74,6 +74,12 @@ def create_flags():
|
|||
f.DEFINE_integer('limit_dev', 0, 'maximum number of elements to use from validation set - 0 means no limit')
|
||||
f.DEFINE_integer('limit_test', 0, 'maximum number of elements to use from test set - 0 means no limit')
|
||||
|
||||
# Sample order
|
||||
|
||||
f.DEFINE_boolean('reverse_train', False, 'if to reverse sample order of the train set')
|
||||
f.DEFINE_boolean('reverse_dev', False, 'if to reverse sample order of the dev set')
|
||||
f.DEFINE_boolean('reverse_test', False, 'if to reverse sample order of the test set')
|
||||
|
||||
# Checkpointing
|
||||
|
||||
f.DEFINE_string('checkpoint_dir', '', 'directory from which checkpoints are loaded and to which they are saved - defaults to directory "deepspeech/checkpoints" within user\'s data home specified by the XDG Base Directory Specification')
|
||||
|
|
|
@ -65,13 +65,14 @@ class Interleaved:
|
|||
"""Collection that lazily combines sorted collections in an interleaving fashion.
|
||||
During iteration the next smallest element from all the sorted collections is always picked.
|
||||
The collections must support iter() and len()."""
|
||||
def __init__(self, *iterables, key=lambda obj: obj):
|
||||
def __init__(self, *iterables, key=lambda obj: obj, reverse=False):
|
||||
self.iterables = iterables
|
||||
self.key = key
|
||||
self.reverse = reverse
|
||||
self.len = sum(map(len, iterables))
|
||||
|
||||
def __iter__(self):
|
||||
return heapq.merge(*self.iterables, key=self.key)
|
||||
return heapq.merge(*self.iterables, key=self.key, reverse=self.reverse)
|
||||
|
||||
def __len__(self):
|
||||
return self.len
|
||||
|
|
|
@ -6,7 +6,7 @@ import json
|
|||
from pathlib import Path
|
||||
from functools import partial
|
||||
|
||||
from .helpers import MEGABYTE, GIGABYTE, Interleaved
|
||||
from .helpers import KILOBYTE, MEGABYTE, GIGABYTE, Interleaved
|
||||
from .audio import (
|
||||
Sample,
|
||||
DEFAULT_FORMAT,
|
||||
|
@ -23,6 +23,7 @@ BIGINT_SIZE = 2 * INT_SIZE
|
|||
MAGIC = b'SAMPLEDB'
|
||||
|
||||
BUFFER_SIZE = 1 * MEGABYTE
|
||||
REVERSE_BUFFER_SIZE = 16 * KILOBYTE
|
||||
CACHE_SIZE = 1 * GIGABYTE
|
||||
|
||||
SCHEMA_KEY = 'schema'
|
||||
|
@ -189,14 +190,19 @@ class DirectSDBWriter:
|
|||
|
||||
class SDB: # pylint: disable=too-many-instance-attributes
|
||||
"""Sample collection reader for reading a Sample DB (SDB) file"""
|
||||
def __init__(self, sdb_filename, buffering=BUFFER_SIZE, id_prefix=None, labeled=True):
|
||||
def __init__(self,
|
||||
sdb_filename,
|
||||
buffering=BUFFER_SIZE,
|
||||
id_prefix=None,
|
||||
labeled=True,
|
||||
reverse=False):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
sdb_filename : str
|
||||
Path to the SDB file to read samples from
|
||||
buffering : int
|
||||
Read-buffer size to use while reading the SDB file
|
||||
Read-ahead buffer size to use while reading the SDB file in normal order. Fixed to 16kB if in reverse-mode.
|
||||
id_prefix : str
|
||||
Prefix for IDs of read samples - defaults to sdb_filename
|
||||
labeled : bool or None
|
||||
|
@ -207,7 +213,7 @@ class SDB: # pylint: disable=too-many-instance-attributes
|
|||
"""
|
||||
self.sdb_filename = sdb_filename
|
||||
self.id_prefix = sdb_filename if id_prefix is None else id_prefix
|
||||
self.sdb_file = open(sdb_filename, 'rb', buffering=buffering)
|
||||
self.sdb_file = open(sdb_filename, 'rb', buffering=REVERSE_BUFFER_SIZE if reverse else buffering)
|
||||
self.offsets = []
|
||||
if self.sdb_file.read(len(MAGIC)) != MAGIC:
|
||||
raise RuntimeError('No Sample Database')
|
||||
|
@ -237,6 +243,8 @@ class SDB: # pylint: disable=too-many-instance-attributes
|
|||
num_samples = self.read_big_int()
|
||||
for _ in range(num_samples):
|
||||
self.offsets.append(self.read_big_int())
|
||||
if reverse:
|
||||
self.offsets.reverse()
|
||||
|
||||
def read_int(self):
|
||||
return int.from_bytes(self.sdb_file.read(INT_SIZE), BIG_ENDIAN)
|
||||
|
@ -371,7 +379,7 @@ class CSVWriter: # pylint: disable=too-many-instance-attributes
|
|||
|
||||
class SampleList:
|
||||
"""Sample collection base class with samples loaded from a list of in-memory paths."""
|
||||
def __init__(self, samples, labeled=True):
|
||||
def __init__(self, samples, labeled=True, reverse=False):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
|
@ -380,10 +388,12 @@ class SampleList:
|
|||
labeled : bool or None
|
||||
If True: Reads LabeledSample instances.
|
||||
If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
|
||||
reverse : bool
|
||||
If the order of the samples should be reversed
|
||||
"""
|
||||
self.labeled = labeled
|
||||
self.samples = list(samples)
|
||||
self.samples.sort(key=lambda r: r[1])
|
||||
self.samples.sort(key=lambda r: r[1], reverse=reverse)
|
||||
|
||||
def __getitem__(self, i):
|
||||
sample_spec = self.samples[i]
|
||||
|
@ -396,7 +406,7 @@ class SampleList:
|
|||
class CSV(SampleList):
|
||||
"""Sample collection reader for reading a DeepSpeech CSV file
|
||||
Automatically orders samples by CSV column wav_filesize (if available)."""
|
||||
def __init__(self, csv_filename, labeled=None):
|
||||
def __init__(self, csv_filename, labeled=None, reverse=False):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
|
@ -407,6 +417,8 @@ class CSV(SampleList):
|
|||
If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
|
||||
If None: Automatically determines if CSV file has a transcript column
|
||||
(reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances).
|
||||
reverse : bool
|
||||
If the order of the samples should be reversed
|
||||
"""
|
||||
rows = []
|
||||
csv_dir = Path(csv_filename).parent
|
||||
|
@ -427,10 +439,10 @@ class CSV(SampleList):
|
|||
rows.append((wav_filename, wav_filesize, row['transcript']))
|
||||
else:
|
||||
rows.append((wav_filename, wav_filesize))
|
||||
super(CSV, self).__init__(rows, labeled=labeled)
|
||||
super(CSV, self).__init__(rows, labeled=labeled, reverse=reverse)
|
||||
|
||||
|
||||
def samples_from_source(sample_source, buffering=BUFFER_SIZE, labeled=None):
|
||||
def samples_from_source(sample_source, buffering=BUFFER_SIZE, labeled=None, reverse=False):
|
||||
"""
|
||||
Loads samples from a sample source file.
|
||||
|
||||
|
@ -445,6 +457,8 @@ def samples_from_source(sample_source, buffering=BUFFER_SIZE, labeled=None):
|
|||
If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
|
||||
If None: Automatically determines if source provides transcripts
|
||||
(reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances).
|
||||
reverse : bool
|
||||
If the order of the samples should be reversed
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
@ -452,13 +466,13 @@ def samples_from_source(sample_source, buffering=BUFFER_SIZE, labeled=None):
|
|||
"""
|
||||
ext = os.path.splitext(sample_source)[1].lower()
|
||||
if ext == '.sdb':
|
||||
return SDB(sample_source, buffering=buffering, labeled=labeled)
|
||||
return SDB(sample_source, buffering=buffering, labeled=labeled, reverse=reverse)
|
||||
if ext == '.csv':
|
||||
return CSV(sample_source, labeled=labeled)
|
||||
return CSV(sample_source, labeled=labeled, reverse=reverse)
|
||||
raise ValueError('Unknown file type: "{}"'.format(ext))
|
||||
|
||||
|
||||
def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None):
|
||||
def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None, reverse=False):
|
||||
"""
|
||||
Loads and combines samples from a list of source files. Sources are combined in an interleaving way to
|
||||
keep default sample order from shortest to longest.
|
||||
|
@ -474,6 +488,8 @@ def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None):
|
|||
If False: Ignores transcripts (if available) and always reads (unlabeled) util.audio.Sample instances.
|
||||
If None: Reads util.sample_collections.LabeledSample instances from sources with transcripts and
|
||||
util.audio.Sample instances from sources with no transcripts.
|
||||
reverse : bool
|
||||
If the order of the samples should be reversed
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
@ -483,6 +499,7 @@ def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None):
|
|||
if len(sample_sources) == 0:
|
||||
raise ValueError('No files')
|
||||
if len(sample_sources) == 1:
|
||||
return samples_from_source(sample_sources[0], buffering=buffering, labeled=labeled)
|
||||
cols = list(map(partial(samples_from_source, buffering=buffering, labeled=labeled), sample_sources))
|
||||
return Interleaved(*cols, key=lambda s: s.duration)
|
||||
return samples_from_source(sample_sources[0], buffering=buffering, labeled=labeled, reverse=reverse)
|
||||
cols = [samples_from_source(source, buffering=buffering, labeled=labeled, reverse=reverse)
|
||||
for source in sample_sources]
|
||||
return Interleaved(*cols, key=lambda s: s.duration, reverse=reverse)
|
||||
|
|
Loading…
Reference in New Issue