diff --git a/training/deepspeech_training/evaluate.py b/training/deepspeech_training/evaluate.py index 00eac8c7..965b3370 100755 --- a/training/deepspeech_training/evaluate.py +++ b/training/deepspeech_training/evaluate.py @@ -50,7 +50,11 @@ def evaluate(test_csvs, create_model): else: scorer = None - test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs] + test_sets = [create_dataset([csv], + batch_size=FLAGS.test_batch_size, + train_phase=False, + reverse=FLAGS.reverse_test, + limit=FLAGS.limit_test) for csv in test_csvs] iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]), tfv1.data.get_output_shapes(test_sets[0]), output_classes=tfv1.data.get_output_classes(test_sets[0])) diff --git a/training/deepspeech_training/train.py b/training/deepspeech_training/train.py index 93d0c727..47052a07 100644 --- a/training/deepspeech_training/train.py +++ b/training/deepspeech_training/train.py @@ -417,6 +417,8 @@ def train(): train_phase=True, exception_box=exception_box, process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2, + reverse=FLAGS.reverse_train, + limit=FLAGS.limit_train, buffering=FLAGS.read_buffer) iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set), @@ -433,6 +435,8 @@ def train(): train_phase=False, exception_box=exception_box, process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2, + reverse=FLAGS.reverse_dev, + limit=FLAGS.limit_dev, buffering=FLAGS.read_buffer) for source in dev_sources] dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets] @@ -443,6 +447,8 @@ def train(): train_phase=False, exception_box=exception_box, process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2, + reverse=FLAGS.reverse_dev, + limit=FLAGS.limit_dev, buffering=FLAGS.read_buffer) for source in metrics_sources] metrics_init_ops = [iterator.make_initializer(metrics_set) for metrics_set in metrics_sets] diff --git a/training/deepspeech_training/util/feeding.py b/training/deepspeech_training/util/feeding.py index 4c9b681d..9a26215c 100644 --- a/training/deepspeech_training/util/feeding.py +++ b/training/deepspeech_training/util/feeding.py @@ -90,6 +90,8 @@ def create_dataset(sources, augmentations=None, cache_path=None, train_phase=False, + reverse=False, + limit=0, exception_box=None, process_ahead=None, buffering=1 * MEGABYTE): @@ -99,8 +101,10 @@ def create_dataset(sources, epoch = epoch_counter['epoch'] if train_phase: epoch_counter['epoch'] += 1 - samples = samples_from_sources(sources, buffering=buffering, labeled=True) + samples = samples_from_sources(sources, buffering=buffering, labeled=True, reverse=reverse) num_samples = len(samples) + if limit > 0: + num_samples = min(limit, num_samples) samples = apply_sample_augmentations(samples, augmentations, buffering=buffering, @@ -108,6 +112,8 @@ def create_dataset(sources, clock=epoch / epochs, final_clock=(epoch + 1) / epochs) for sample_index, sample in enumerate(samples): + if sample_index >= num_samples: + break clock = (epoch * num_samples + sample_index) / (epochs * num_samples) if train_phase and epochs > 0 else 0.0 transcript = text_to_char_array(sample.transcript, Config.alphabet, context=sample.sample_id) transcript = to_sparse_tuple(transcript) diff --git a/training/deepspeech_training/util/flags.py b/training/deepspeech_training/util/flags.py index 6bf64251..128441fd 100644 --- a/training/deepspeech_training/util/flags.py +++ b/training/deepspeech_training/util/flags.py @@ -71,8 +71,14 @@ def create_flags(): # Sample limits f.DEFINE_integer('limit_train', 0, 'maximum number of elements to use from train set - 0 means no limit') - f.DEFINE_integer('limit_dev', 0, 'maximum number of elements to use from validation set- 0 means no limit') - f.DEFINE_integer('limit_test', 0, 'maximum number of elements to use from test set- 0 means no limit') + f.DEFINE_integer('limit_dev', 0, 'maximum number of elements to use from validation set - 0 means no limit') + f.DEFINE_integer('limit_test', 0, 'maximum number of elements to use from test set - 0 means no limit') + + # Sample order + + f.DEFINE_boolean('reverse_train', False, 'if to reverse sample order of the train set') + f.DEFINE_boolean('reverse_dev', False, 'if to reverse sample order of the dev set') + f.DEFINE_boolean('reverse_test', False, 'if to reverse sample order of the test set') # Checkpointing diff --git a/training/deepspeech_training/util/helpers.py b/training/deepspeech_training/util/helpers.py index 32116f3f..195c117e 100644 --- a/training/deepspeech_training/util/helpers.py +++ b/training/deepspeech_training/util/helpers.py @@ -65,13 +65,14 @@ class Interleaved: """Collection that lazily combines sorted collections in an interleaving fashion. During iteration the next smallest element from all the sorted collections is always picked. The collections must support iter() and len().""" - def __init__(self, *iterables, key=lambda obj: obj): + def __init__(self, *iterables, key=lambda obj: obj, reverse=False): self.iterables = iterables self.key = key + self.reverse = reverse self.len = sum(map(len, iterables)) def __iter__(self): - return heapq.merge(*self.iterables, key=self.key) + return heapq.merge(*self.iterables, key=self.key, reverse=self.reverse) def __len__(self): return self.len diff --git a/training/deepspeech_training/util/sample_collections.py b/training/deepspeech_training/util/sample_collections.py index b220e1b3..15c97f97 100644 --- a/training/deepspeech_training/util/sample_collections.py +++ b/training/deepspeech_training/util/sample_collections.py @@ -6,7 +6,7 @@ import json from pathlib import Path from functools import partial -from .helpers import MEGABYTE, GIGABYTE, Interleaved +from .helpers import KILOBYTE, MEGABYTE, GIGABYTE, Interleaved from .audio import ( Sample, DEFAULT_FORMAT, @@ -23,6 +23,7 @@ BIGINT_SIZE = 2 * INT_SIZE MAGIC = b'SAMPLEDB' BUFFER_SIZE = 1 * MEGABYTE +REVERSE_BUFFER_SIZE = 16 * KILOBYTE CACHE_SIZE = 1 * GIGABYTE SCHEMA_KEY = 'schema' @@ -189,14 +190,19 @@ class DirectSDBWriter: class SDB: # pylint: disable=too-many-instance-attributes """Sample collection reader for reading a Sample DB (SDB) file""" - def __init__(self, sdb_filename, buffering=BUFFER_SIZE, id_prefix=None, labeled=True): + def __init__(self, + sdb_filename, + buffering=BUFFER_SIZE, + id_prefix=None, + labeled=True, + reverse=False): """ Parameters ---------- sdb_filename : str Path to the SDB file to read samples from buffering : int - Read-buffer size to use while reading the SDB file + Read-ahead buffer size to use while reading the SDB file in normal order. Fixed to 16kB if in reverse-mode. id_prefix : str Prefix for IDs of read samples - defaults to sdb_filename labeled : bool or None @@ -207,7 +213,7 @@ class SDB: # pylint: disable=too-many-instance-attributes """ self.sdb_filename = sdb_filename self.id_prefix = sdb_filename if id_prefix is None else id_prefix - self.sdb_file = open(sdb_filename, 'rb', buffering=buffering) + self.sdb_file = open(sdb_filename, 'rb', buffering=REVERSE_BUFFER_SIZE if reverse else buffering) self.offsets = [] if self.sdb_file.read(len(MAGIC)) != MAGIC: raise RuntimeError('No Sample Database') @@ -237,6 +243,8 @@ class SDB: # pylint: disable=too-many-instance-attributes num_samples = self.read_big_int() for _ in range(num_samples): self.offsets.append(self.read_big_int()) + if reverse: + self.offsets.reverse() def read_int(self): return int.from_bytes(self.sdb_file.read(INT_SIZE), BIG_ENDIAN) @@ -371,7 +379,7 @@ class CSVWriter: # pylint: disable=too-many-instance-attributes class SampleList: """Sample collection base class with samples loaded from a list of in-memory paths.""" - def __init__(self, samples, labeled=True): + def __init__(self, samples, labeled=True, reverse=False): """ Parameters ---------- @@ -380,10 +388,12 @@ class SampleList: labeled : bool or None If True: Reads LabeledSample instances. If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances. + reverse : bool + If the order of the samples should be reversed """ self.labeled = labeled self.samples = list(samples) - self.samples.sort(key=lambda r: r[1]) + self.samples.sort(key=lambda r: r[1], reverse=reverse) def __getitem__(self, i): sample_spec = self.samples[i] @@ -396,7 +406,7 @@ class SampleList: class CSV(SampleList): """Sample collection reader for reading a DeepSpeech CSV file Automatically orders samples by CSV column wav_filesize (if available).""" - def __init__(self, csv_filename, labeled=None): + def __init__(self, csv_filename, labeled=None, reverse=False): """ Parameters ---------- @@ -407,6 +417,8 @@ class CSV(SampleList): If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances. If None: Automatically determines if CSV file has a transcript column (reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances). + reverse : bool + If the order of the samples should be reversed """ rows = [] csv_dir = Path(csv_filename).parent @@ -427,10 +439,10 @@ class CSV(SampleList): rows.append((wav_filename, wav_filesize, row['transcript'])) else: rows.append((wav_filename, wav_filesize)) - super(CSV, self).__init__(rows, labeled=labeled) + super(CSV, self).__init__(rows, labeled=labeled, reverse=reverse) -def samples_from_source(sample_source, buffering=BUFFER_SIZE, labeled=None): +def samples_from_source(sample_source, buffering=BUFFER_SIZE, labeled=None, reverse=False): """ Loads samples from a sample source file. @@ -445,6 +457,8 @@ def samples_from_source(sample_source, buffering=BUFFER_SIZE, labeled=None): If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances. If None: Automatically determines if source provides transcripts (reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances). + reverse : bool + If the order of the samples should be reversed Returns ------- @@ -452,13 +466,13 @@ def samples_from_source(sample_source, buffering=BUFFER_SIZE, labeled=None): """ ext = os.path.splitext(sample_source)[1].lower() if ext == '.sdb': - return SDB(sample_source, buffering=buffering, labeled=labeled) + return SDB(sample_source, buffering=buffering, labeled=labeled, reverse=reverse) if ext == '.csv': - return CSV(sample_source, labeled=labeled) + return CSV(sample_source, labeled=labeled, reverse=reverse) raise ValueError('Unknown file type: "{}"'.format(ext)) -def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None): +def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None, reverse=False): """ Loads and combines samples from a list of source files. Sources are combined in an interleaving way to keep default sample order from shortest to longest. @@ -474,6 +488,8 @@ def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None): If False: Ignores transcripts (if available) and always reads (unlabeled) util.audio.Sample instances. If None: Reads util.sample_collections.LabeledSample instances from sources with transcripts and util.audio.Sample instances from sources with no transcripts. + reverse : bool + If the order of the samples should be reversed Returns ------- @@ -483,6 +499,7 @@ def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None): if len(sample_sources) == 0: raise ValueError('No files') if len(sample_sources) == 1: - return samples_from_source(sample_sources[0], buffering=buffering, labeled=labeled) - cols = list(map(partial(samples_from_source, buffering=buffering, labeled=labeled), sample_sources)) - return Interleaved(*cols, key=lambda s: s.duration) + return samples_from_source(sample_sources[0], buffering=buffering, labeled=labeled, reverse=reverse) + cols = [samples_from_source(source, buffering=buffering, labeled=labeled, reverse=reverse) + for source in sample_sources] + return Interleaved(*cols, key=lambda s: s.duration, reverse=reverse)