Resolves #1565 - Limiting and reversing data-sets
This commit is contained in:
parent
38f6afdba8
commit
9a5d19d7c5
|
@ -50,7 +50,11 @@ def evaluate(test_csvs, create_model):
|
||||||
else:
|
else:
|
||||||
scorer = None
|
scorer = None
|
||||||
|
|
||||||
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs]
|
test_sets = [create_dataset([csv],
|
||||||
|
batch_size=FLAGS.test_batch_size,
|
||||||
|
train_phase=False,
|
||||||
|
reverse=FLAGS.reverse_test,
|
||||||
|
limit=FLAGS.limit_test) for csv in test_csvs]
|
||||||
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]),
|
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]),
|
||||||
tfv1.data.get_output_shapes(test_sets[0]),
|
tfv1.data.get_output_shapes(test_sets[0]),
|
||||||
output_classes=tfv1.data.get_output_classes(test_sets[0]))
|
output_classes=tfv1.data.get_output_classes(test_sets[0]))
|
||||||
|
|
|
@ -417,6 +417,8 @@ def train():
|
||||||
train_phase=True,
|
train_phase=True,
|
||||||
exception_box=exception_box,
|
exception_box=exception_box,
|
||||||
process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2,
|
process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2,
|
||||||
|
reverse=FLAGS.reverse_train,
|
||||||
|
limit=FLAGS.limit_train,
|
||||||
buffering=FLAGS.read_buffer)
|
buffering=FLAGS.read_buffer)
|
||||||
|
|
||||||
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
|
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
|
||||||
|
@ -433,6 +435,8 @@ def train():
|
||||||
train_phase=False,
|
train_phase=False,
|
||||||
exception_box=exception_box,
|
exception_box=exception_box,
|
||||||
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
|
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
|
||||||
|
reverse=FLAGS.reverse_dev,
|
||||||
|
limit=FLAGS.limit_dev,
|
||||||
buffering=FLAGS.read_buffer) for source in dev_sources]
|
buffering=FLAGS.read_buffer) for source in dev_sources]
|
||||||
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
|
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
|
||||||
|
|
||||||
|
@ -443,6 +447,8 @@ def train():
|
||||||
train_phase=False,
|
train_phase=False,
|
||||||
exception_box=exception_box,
|
exception_box=exception_box,
|
||||||
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
|
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
|
||||||
|
reverse=FLAGS.reverse_dev,
|
||||||
|
limit=FLAGS.limit_dev,
|
||||||
buffering=FLAGS.read_buffer) for source in metrics_sources]
|
buffering=FLAGS.read_buffer) for source in metrics_sources]
|
||||||
metrics_init_ops = [iterator.make_initializer(metrics_set) for metrics_set in metrics_sets]
|
metrics_init_ops = [iterator.make_initializer(metrics_set) for metrics_set in metrics_sets]
|
||||||
|
|
||||||
|
|
|
@ -90,6 +90,8 @@ def create_dataset(sources,
|
||||||
augmentations=None,
|
augmentations=None,
|
||||||
cache_path=None,
|
cache_path=None,
|
||||||
train_phase=False,
|
train_phase=False,
|
||||||
|
reverse=False,
|
||||||
|
limit=0,
|
||||||
exception_box=None,
|
exception_box=None,
|
||||||
process_ahead=None,
|
process_ahead=None,
|
||||||
buffering=1 * MEGABYTE):
|
buffering=1 * MEGABYTE):
|
||||||
|
@ -99,8 +101,10 @@ def create_dataset(sources,
|
||||||
epoch = epoch_counter['epoch']
|
epoch = epoch_counter['epoch']
|
||||||
if train_phase:
|
if train_phase:
|
||||||
epoch_counter['epoch'] += 1
|
epoch_counter['epoch'] += 1
|
||||||
samples = samples_from_sources(sources, buffering=buffering, labeled=True)
|
samples = samples_from_sources(sources, buffering=buffering, labeled=True, reverse=reverse)
|
||||||
num_samples = len(samples)
|
num_samples = len(samples)
|
||||||
|
if limit > 0:
|
||||||
|
num_samples = min(limit, num_samples)
|
||||||
samples = apply_sample_augmentations(samples,
|
samples = apply_sample_augmentations(samples,
|
||||||
augmentations,
|
augmentations,
|
||||||
buffering=buffering,
|
buffering=buffering,
|
||||||
|
@ -108,6 +112,8 @@ def create_dataset(sources,
|
||||||
clock=epoch / epochs,
|
clock=epoch / epochs,
|
||||||
final_clock=(epoch + 1) / epochs)
|
final_clock=(epoch + 1) / epochs)
|
||||||
for sample_index, sample in enumerate(samples):
|
for sample_index, sample in enumerate(samples):
|
||||||
|
if sample_index >= num_samples:
|
||||||
|
break
|
||||||
clock = (epoch * num_samples + sample_index) / (epochs * num_samples) if train_phase and epochs > 0 else 0.0
|
clock = (epoch * num_samples + sample_index) / (epochs * num_samples) if train_phase and epochs > 0 else 0.0
|
||||||
transcript = text_to_char_array(sample.transcript, Config.alphabet, context=sample.sample_id)
|
transcript = text_to_char_array(sample.transcript, Config.alphabet, context=sample.sample_id)
|
||||||
transcript = to_sparse_tuple(transcript)
|
transcript = to_sparse_tuple(transcript)
|
||||||
|
|
|
@ -71,8 +71,14 @@ def create_flags():
|
||||||
# Sample limits
|
# Sample limits
|
||||||
|
|
||||||
f.DEFINE_integer('limit_train', 0, 'maximum number of elements to use from train set - 0 means no limit')
|
f.DEFINE_integer('limit_train', 0, 'maximum number of elements to use from train set - 0 means no limit')
|
||||||
f.DEFINE_integer('limit_dev', 0, 'maximum number of elements to use from validation set- 0 means no limit')
|
f.DEFINE_integer('limit_dev', 0, 'maximum number of elements to use from validation set - 0 means no limit')
|
||||||
f.DEFINE_integer('limit_test', 0, 'maximum number of elements to use from test set- 0 means no limit')
|
f.DEFINE_integer('limit_test', 0, 'maximum number of elements to use from test set - 0 means no limit')
|
||||||
|
|
||||||
|
# Sample order
|
||||||
|
|
||||||
|
f.DEFINE_boolean('reverse_train', False, 'if to reverse sample order of the train set')
|
||||||
|
f.DEFINE_boolean('reverse_dev', False, 'if to reverse sample order of the dev set')
|
||||||
|
f.DEFINE_boolean('reverse_test', False, 'if to reverse sample order of the test set')
|
||||||
|
|
||||||
# Checkpointing
|
# Checkpointing
|
||||||
|
|
||||||
|
|
|
@ -65,13 +65,14 @@ class Interleaved:
|
||||||
"""Collection that lazily combines sorted collections in an interleaving fashion.
|
"""Collection that lazily combines sorted collections in an interleaving fashion.
|
||||||
During iteration the next smallest element from all the sorted collections is always picked.
|
During iteration the next smallest element from all the sorted collections is always picked.
|
||||||
The collections must support iter() and len()."""
|
The collections must support iter() and len()."""
|
||||||
def __init__(self, *iterables, key=lambda obj: obj):
|
def __init__(self, *iterables, key=lambda obj: obj, reverse=False):
|
||||||
self.iterables = iterables
|
self.iterables = iterables
|
||||||
self.key = key
|
self.key = key
|
||||||
|
self.reverse = reverse
|
||||||
self.len = sum(map(len, iterables))
|
self.len = sum(map(len, iterables))
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return heapq.merge(*self.iterables, key=self.key)
|
return heapq.merge(*self.iterables, key=self.key, reverse=self.reverse)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.len
|
return self.len
|
||||||
|
|
|
@ -6,7 +6,7 @@ import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from .helpers import MEGABYTE, GIGABYTE, Interleaved
|
from .helpers import KILOBYTE, MEGABYTE, GIGABYTE, Interleaved
|
||||||
from .audio import (
|
from .audio import (
|
||||||
Sample,
|
Sample,
|
||||||
DEFAULT_FORMAT,
|
DEFAULT_FORMAT,
|
||||||
|
@ -23,6 +23,7 @@ BIGINT_SIZE = 2 * INT_SIZE
|
||||||
MAGIC = b'SAMPLEDB'
|
MAGIC = b'SAMPLEDB'
|
||||||
|
|
||||||
BUFFER_SIZE = 1 * MEGABYTE
|
BUFFER_SIZE = 1 * MEGABYTE
|
||||||
|
REVERSE_BUFFER_SIZE = 16 * KILOBYTE
|
||||||
CACHE_SIZE = 1 * GIGABYTE
|
CACHE_SIZE = 1 * GIGABYTE
|
||||||
|
|
||||||
SCHEMA_KEY = 'schema'
|
SCHEMA_KEY = 'schema'
|
||||||
|
@ -189,14 +190,19 @@ class DirectSDBWriter:
|
||||||
|
|
||||||
class SDB: # pylint: disable=too-many-instance-attributes
|
class SDB: # pylint: disable=too-many-instance-attributes
|
||||||
"""Sample collection reader for reading a Sample DB (SDB) file"""
|
"""Sample collection reader for reading a Sample DB (SDB) file"""
|
||||||
def __init__(self, sdb_filename, buffering=BUFFER_SIZE, id_prefix=None, labeled=True):
|
def __init__(self,
|
||||||
|
sdb_filename,
|
||||||
|
buffering=BUFFER_SIZE,
|
||||||
|
id_prefix=None,
|
||||||
|
labeled=True,
|
||||||
|
reverse=False):
|
||||||
"""
|
"""
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
sdb_filename : str
|
sdb_filename : str
|
||||||
Path to the SDB file to read samples from
|
Path to the SDB file to read samples from
|
||||||
buffering : int
|
buffering : int
|
||||||
Read-buffer size to use while reading the SDB file
|
Read-ahead buffer size to use while reading the SDB file in normal order. Fixed to 16kB if in reverse-mode.
|
||||||
id_prefix : str
|
id_prefix : str
|
||||||
Prefix for IDs of read samples - defaults to sdb_filename
|
Prefix for IDs of read samples - defaults to sdb_filename
|
||||||
labeled : bool or None
|
labeled : bool or None
|
||||||
|
@ -207,7 +213,7 @@ class SDB: # pylint: disable=too-many-instance-attributes
|
||||||
"""
|
"""
|
||||||
self.sdb_filename = sdb_filename
|
self.sdb_filename = sdb_filename
|
||||||
self.id_prefix = sdb_filename if id_prefix is None else id_prefix
|
self.id_prefix = sdb_filename if id_prefix is None else id_prefix
|
||||||
self.sdb_file = open(sdb_filename, 'rb', buffering=buffering)
|
self.sdb_file = open(sdb_filename, 'rb', buffering=REVERSE_BUFFER_SIZE if reverse else buffering)
|
||||||
self.offsets = []
|
self.offsets = []
|
||||||
if self.sdb_file.read(len(MAGIC)) != MAGIC:
|
if self.sdb_file.read(len(MAGIC)) != MAGIC:
|
||||||
raise RuntimeError('No Sample Database')
|
raise RuntimeError('No Sample Database')
|
||||||
|
@ -237,6 +243,8 @@ class SDB: # pylint: disable=too-many-instance-attributes
|
||||||
num_samples = self.read_big_int()
|
num_samples = self.read_big_int()
|
||||||
for _ in range(num_samples):
|
for _ in range(num_samples):
|
||||||
self.offsets.append(self.read_big_int())
|
self.offsets.append(self.read_big_int())
|
||||||
|
if reverse:
|
||||||
|
self.offsets.reverse()
|
||||||
|
|
||||||
def read_int(self):
|
def read_int(self):
|
||||||
return int.from_bytes(self.sdb_file.read(INT_SIZE), BIG_ENDIAN)
|
return int.from_bytes(self.sdb_file.read(INT_SIZE), BIG_ENDIAN)
|
||||||
|
@ -371,7 +379,7 @@ class CSVWriter: # pylint: disable=too-many-instance-attributes
|
||||||
|
|
||||||
class SampleList:
|
class SampleList:
|
||||||
"""Sample collection base class with samples loaded from a list of in-memory paths."""
|
"""Sample collection base class with samples loaded from a list of in-memory paths."""
|
||||||
def __init__(self, samples, labeled=True):
|
def __init__(self, samples, labeled=True, reverse=False):
|
||||||
"""
|
"""
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
@ -380,10 +388,12 @@ class SampleList:
|
||||||
labeled : bool or None
|
labeled : bool or None
|
||||||
If True: Reads LabeledSample instances.
|
If True: Reads LabeledSample instances.
|
||||||
If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
|
If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
|
||||||
|
reverse : bool
|
||||||
|
If the order of the samples should be reversed
|
||||||
"""
|
"""
|
||||||
self.labeled = labeled
|
self.labeled = labeled
|
||||||
self.samples = list(samples)
|
self.samples = list(samples)
|
||||||
self.samples.sort(key=lambda r: r[1])
|
self.samples.sort(key=lambda r: r[1], reverse=reverse)
|
||||||
|
|
||||||
def __getitem__(self, i):
|
def __getitem__(self, i):
|
||||||
sample_spec = self.samples[i]
|
sample_spec = self.samples[i]
|
||||||
|
@ -396,7 +406,7 @@ class SampleList:
|
||||||
class CSV(SampleList):
|
class CSV(SampleList):
|
||||||
"""Sample collection reader for reading a DeepSpeech CSV file
|
"""Sample collection reader for reading a DeepSpeech CSV file
|
||||||
Automatically orders samples by CSV column wav_filesize (if available)."""
|
Automatically orders samples by CSV column wav_filesize (if available)."""
|
||||||
def __init__(self, csv_filename, labeled=None):
|
def __init__(self, csv_filename, labeled=None, reverse=False):
|
||||||
"""
|
"""
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
@ -407,6 +417,8 @@ class CSV(SampleList):
|
||||||
If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
|
If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
|
||||||
If None: Automatically determines if CSV file has a transcript column
|
If None: Automatically determines if CSV file has a transcript column
|
||||||
(reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances).
|
(reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances).
|
||||||
|
reverse : bool
|
||||||
|
If the order of the samples should be reversed
|
||||||
"""
|
"""
|
||||||
rows = []
|
rows = []
|
||||||
csv_dir = Path(csv_filename).parent
|
csv_dir = Path(csv_filename).parent
|
||||||
|
@ -427,10 +439,10 @@ class CSV(SampleList):
|
||||||
rows.append((wav_filename, wav_filesize, row['transcript']))
|
rows.append((wav_filename, wav_filesize, row['transcript']))
|
||||||
else:
|
else:
|
||||||
rows.append((wav_filename, wav_filesize))
|
rows.append((wav_filename, wav_filesize))
|
||||||
super(CSV, self).__init__(rows, labeled=labeled)
|
super(CSV, self).__init__(rows, labeled=labeled, reverse=reverse)
|
||||||
|
|
||||||
|
|
||||||
def samples_from_source(sample_source, buffering=BUFFER_SIZE, labeled=None):
|
def samples_from_source(sample_source, buffering=BUFFER_SIZE, labeled=None, reverse=False):
|
||||||
"""
|
"""
|
||||||
Loads samples from a sample source file.
|
Loads samples from a sample source file.
|
||||||
|
|
||||||
|
@ -445,6 +457,8 @@ def samples_from_source(sample_source, buffering=BUFFER_SIZE, labeled=None):
|
||||||
If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
|
If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
|
||||||
If None: Automatically determines if source provides transcripts
|
If None: Automatically determines if source provides transcripts
|
||||||
(reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances).
|
(reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances).
|
||||||
|
reverse : bool
|
||||||
|
If the order of the samples should be reversed
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
|
@ -452,13 +466,13 @@ def samples_from_source(sample_source, buffering=BUFFER_SIZE, labeled=None):
|
||||||
"""
|
"""
|
||||||
ext = os.path.splitext(sample_source)[1].lower()
|
ext = os.path.splitext(sample_source)[1].lower()
|
||||||
if ext == '.sdb':
|
if ext == '.sdb':
|
||||||
return SDB(sample_source, buffering=buffering, labeled=labeled)
|
return SDB(sample_source, buffering=buffering, labeled=labeled, reverse=reverse)
|
||||||
if ext == '.csv':
|
if ext == '.csv':
|
||||||
return CSV(sample_source, labeled=labeled)
|
return CSV(sample_source, labeled=labeled, reverse=reverse)
|
||||||
raise ValueError('Unknown file type: "{}"'.format(ext))
|
raise ValueError('Unknown file type: "{}"'.format(ext))
|
||||||
|
|
||||||
|
|
||||||
def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None):
|
def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None, reverse=False):
|
||||||
"""
|
"""
|
||||||
Loads and combines samples from a list of source files. Sources are combined in an interleaving way to
|
Loads and combines samples from a list of source files. Sources are combined in an interleaving way to
|
||||||
keep default sample order from shortest to longest.
|
keep default sample order from shortest to longest.
|
||||||
|
@ -474,6 +488,8 @@ def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None):
|
||||||
If False: Ignores transcripts (if available) and always reads (unlabeled) util.audio.Sample instances.
|
If False: Ignores transcripts (if available) and always reads (unlabeled) util.audio.Sample instances.
|
||||||
If None: Reads util.sample_collections.LabeledSample instances from sources with transcripts and
|
If None: Reads util.sample_collections.LabeledSample instances from sources with transcripts and
|
||||||
util.audio.Sample instances from sources with no transcripts.
|
util.audio.Sample instances from sources with no transcripts.
|
||||||
|
reverse : bool
|
||||||
|
If the order of the samples should be reversed
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
|
@ -483,6 +499,7 @@ def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None):
|
||||||
if len(sample_sources) == 0:
|
if len(sample_sources) == 0:
|
||||||
raise ValueError('No files')
|
raise ValueError('No files')
|
||||||
if len(sample_sources) == 1:
|
if len(sample_sources) == 1:
|
||||||
return samples_from_source(sample_sources[0], buffering=buffering, labeled=labeled)
|
return samples_from_source(sample_sources[0], buffering=buffering, labeled=labeled, reverse=reverse)
|
||||||
cols = list(map(partial(samples_from_source, buffering=buffering, labeled=labeled), sample_sources))
|
cols = [samples_from_source(source, buffering=buffering, labeled=labeled, reverse=reverse)
|
||||||
return Interleaved(*cols, key=lambda s: s.duration)
|
for source in sample_sources]
|
||||||
|
return Interleaved(*cols, key=lambda s: s.duration, reverse=reverse)
|
||||||
|
|
Loading…
Reference in New Issue