Resolves #1565 - Limiting and reversing data-sets
This commit is contained in:
		
							parent
							
								
									38f6afdba8
								
							
						
					
					
						commit
						9a5d19d7c5
					
				@ -50,7 +50,11 @@ def evaluate(test_csvs, create_model):
 | 
			
		||||
    else:
 | 
			
		||||
        scorer = None
 | 
			
		||||
 | 
			
		||||
    test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs]
 | 
			
		||||
    test_sets = [create_dataset([csv],
 | 
			
		||||
                                batch_size=FLAGS.test_batch_size,
 | 
			
		||||
                                train_phase=False,
 | 
			
		||||
                                reverse=FLAGS.reverse_test,
 | 
			
		||||
                                limit=FLAGS.limit_test) for csv in test_csvs]
 | 
			
		||||
    iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]),
 | 
			
		||||
                                                 tfv1.data.get_output_shapes(test_sets[0]),
 | 
			
		||||
                                                 output_classes=tfv1.data.get_output_classes(test_sets[0]))
 | 
			
		||||
 | 
			
		||||
@ -417,6 +417,8 @@ def train():
 | 
			
		||||
                               train_phase=True,
 | 
			
		||||
                               exception_box=exception_box,
 | 
			
		||||
                               process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2,
 | 
			
		||||
                               reverse=FLAGS.reverse_train,
 | 
			
		||||
                               limit=FLAGS.limit_train,
 | 
			
		||||
                               buffering=FLAGS.read_buffer)
 | 
			
		||||
 | 
			
		||||
    iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
 | 
			
		||||
@ -433,6 +435,8 @@ def train():
 | 
			
		||||
                                   train_phase=False,
 | 
			
		||||
                                   exception_box=exception_box,
 | 
			
		||||
                                   process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
 | 
			
		||||
                                   reverse=FLAGS.reverse_dev,
 | 
			
		||||
                                   limit=FLAGS.limit_dev,
 | 
			
		||||
                                   buffering=FLAGS.read_buffer) for source in dev_sources]
 | 
			
		||||
        dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
 | 
			
		||||
 | 
			
		||||
@ -443,6 +447,8 @@ def train():
 | 
			
		||||
                                       train_phase=False,
 | 
			
		||||
                                       exception_box=exception_box,
 | 
			
		||||
                                       process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
 | 
			
		||||
                                       reverse=FLAGS.reverse_dev,
 | 
			
		||||
                                       limit=FLAGS.limit_dev,
 | 
			
		||||
                                       buffering=FLAGS.read_buffer) for source in metrics_sources]
 | 
			
		||||
        metrics_init_ops = [iterator.make_initializer(metrics_set) for metrics_set in metrics_sets]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -90,6 +90,8 @@ def create_dataset(sources,
 | 
			
		||||
                   augmentations=None,
 | 
			
		||||
                   cache_path=None,
 | 
			
		||||
                   train_phase=False,
 | 
			
		||||
                   reverse=False,
 | 
			
		||||
                   limit=0,
 | 
			
		||||
                   exception_box=None,
 | 
			
		||||
                   process_ahead=None,
 | 
			
		||||
                   buffering=1 * MEGABYTE):
 | 
			
		||||
@ -99,8 +101,10 @@ def create_dataset(sources,
 | 
			
		||||
        epoch = epoch_counter['epoch']
 | 
			
		||||
        if train_phase:
 | 
			
		||||
            epoch_counter['epoch'] += 1
 | 
			
		||||
        samples = samples_from_sources(sources, buffering=buffering, labeled=True)
 | 
			
		||||
        samples = samples_from_sources(sources, buffering=buffering, labeled=True, reverse=reverse)
 | 
			
		||||
        num_samples = len(samples)
 | 
			
		||||
        if limit > 0:
 | 
			
		||||
            num_samples = min(limit, num_samples)
 | 
			
		||||
        samples = apply_sample_augmentations(samples,
 | 
			
		||||
                                             augmentations,
 | 
			
		||||
                                             buffering=buffering,
 | 
			
		||||
@ -108,6 +112,8 @@ def create_dataset(sources,
 | 
			
		||||
                                             clock=epoch / epochs,
 | 
			
		||||
                                             final_clock=(epoch + 1) / epochs)
 | 
			
		||||
        for sample_index, sample in enumerate(samples):
 | 
			
		||||
            if sample_index >= num_samples:
 | 
			
		||||
                break
 | 
			
		||||
            clock = (epoch * num_samples + sample_index) / (epochs * num_samples) if train_phase and epochs > 0 else 0.0
 | 
			
		||||
            transcript = text_to_char_array(sample.transcript, Config.alphabet, context=sample.sample_id)
 | 
			
		||||
            transcript = to_sparse_tuple(transcript)
 | 
			
		||||
 | 
			
		||||
@ -71,8 +71,14 @@ def create_flags():
 | 
			
		||||
    # Sample limits
 | 
			
		||||
 | 
			
		||||
    f.DEFINE_integer('limit_train', 0, 'maximum number of elements to use from train set - 0 means no limit')
 | 
			
		||||
    f.DEFINE_integer('limit_dev', 0, 'maximum number of elements to use from validation set- 0 means no limit')
 | 
			
		||||
    f.DEFINE_integer('limit_test', 0, 'maximum number of elements to use from test set- 0 means no limit')
 | 
			
		||||
    f.DEFINE_integer('limit_dev', 0, 'maximum number of elements to use from validation set - 0 means no limit')
 | 
			
		||||
    f.DEFINE_integer('limit_test', 0, 'maximum number of elements to use from test set - 0 means no limit')
 | 
			
		||||
 | 
			
		||||
    # Sample order
 | 
			
		||||
 | 
			
		||||
    f.DEFINE_boolean('reverse_train', False, 'if to reverse sample order of the train set')
 | 
			
		||||
    f.DEFINE_boolean('reverse_dev', False, 'if to reverse sample order of the dev set')
 | 
			
		||||
    f.DEFINE_boolean('reverse_test', False, 'if to reverse sample order of the test set')
 | 
			
		||||
 | 
			
		||||
    # Checkpointing
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -65,13 +65,14 @@ class Interleaved:
 | 
			
		||||
    """Collection that lazily combines sorted collections in an interleaving fashion.
 | 
			
		||||
    During iteration the next smallest element from all the sorted collections is always picked.
 | 
			
		||||
    The collections must support iter() and len()."""
 | 
			
		||||
    def __init__(self, *iterables, key=lambda obj: obj):
 | 
			
		||||
    def __init__(self, *iterables, key=lambda obj: obj, reverse=False):
 | 
			
		||||
        self.iterables = iterables
 | 
			
		||||
        self.key = key
 | 
			
		||||
        self.reverse = reverse
 | 
			
		||||
        self.len = sum(map(len, iterables))
 | 
			
		||||
 | 
			
		||||
    def __iter__(self):
 | 
			
		||||
        return heapq.merge(*self.iterables, key=self.key)
 | 
			
		||||
        return heapq.merge(*self.iterables, key=self.key, reverse=self.reverse)
 | 
			
		||||
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        return self.len
 | 
			
		||||
 | 
			
		||||
@ -6,7 +6,7 @@ import json
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from functools import partial
 | 
			
		||||
 | 
			
		||||
from .helpers import MEGABYTE, GIGABYTE, Interleaved
 | 
			
		||||
from .helpers import KILOBYTE, MEGABYTE, GIGABYTE, Interleaved
 | 
			
		||||
from .audio import (
 | 
			
		||||
    Sample,
 | 
			
		||||
    DEFAULT_FORMAT,
 | 
			
		||||
@ -23,6 +23,7 @@ BIGINT_SIZE = 2 * INT_SIZE
 | 
			
		||||
MAGIC = b'SAMPLEDB'
 | 
			
		||||
 | 
			
		||||
BUFFER_SIZE = 1 * MEGABYTE
 | 
			
		||||
REVERSE_BUFFER_SIZE = 16 * KILOBYTE
 | 
			
		||||
CACHE_SIZE = 1 * GIGABYTE
 | 
			
		||||
 | 
			
		||||
SCHEMA_KEY = 'schema'
 | 
			
		||||
@ -189,14 +190,19 @@ class DirectSDBWriter:
 | 
			
		||||
 | 
			
		||||
class SDB:  # pylint: disable=too-many-instance-attributes
 | 
			
		||||
    """Sample collection reader for reading a Sample DB (SDB) file"""
 | 
			
		||||
    def __init__(self, sdb_filename, buffering=BUFFER_SIZE, id_prefix=None, labeled=True):
 | 
			
		||||
    def __init__(self,
 | 
			
		||||
                 sdb_filename,
 | 
			
		||||
                 buffering=BUFFER_SIZE,
 | 
			
		||||
                 id_prefix=None,
 | 
			
		||||
                 labeled=True,
 | 
			
		||||
                 reverse=False):
 | 
			
		||||
        """
 | 
			
		||||
        Parameters
 | 
			
		||||
        ----------
 | 
			
		||||
        sdb_filename : str
 | 
			
		||||
            Path to the SDB file to read samples from
 | 
			
		||||
        buffering : int
 | 
			
		||||
            Read-buffer size to use while reading the SDB file
 | 
			
		||||
            Read-ahead buffer size to use while reading the SDB file in normal order. Fixed to 16kB if in reverse-mode.
 | 
			
		||||
        id_prefix : str
 | 
			
		||||
            Prefix for IDs of read samples - defaults to sdb_filename
 | 
			
		||||
        labeled : bool or None
 | 
			
		||||
@ -207,7 +213,7 @@ class SDB:  # pylint: disable=too-many-instance-attributes
 | 
			
		||||
        """
 | 
			
		||||
        self.sdb_filename = sdb_filename
 | 
			
		||||
        self.id_prefix = sdb_filename if id_prefix is None else id_prefix
 | 
			
		||||
        self.sdb_file = open(sdb_filename, 'rb', buffering=buffering)
 | 
			
		||||
        self.sdb_file = open(sdb_filename, 'rb', buffering=REVERSE_BUFFER_SIZE if reverse else buffering)
 | 
			
		||||
        self.offsets = []
 | 
			
		||||
        if self.sdb_file.read(len(MAGIC)) != MAGIC:
 | 
			
		||||
            raise RuntimeError('No Sample Database')
 | 
			
		||||
@ -237,6 +243,8 @@ class SDB:  # pylint: disable=too-many-instance-attributes
 | 
			
		||||
        num_samples = self.read_big_int()
 | 
			
		||||
        for _ in range(num_samples):
 | 
			
		||||
            self.offsets.append(self.read_big_int())
 | 
			
		||||
        if reverse:
 | 
			
		||||
            self.offsets.reverse()
 | 
			
		||||
 | 
			
		||||
    def read_int(self):
 | 
			
		||||
        return int.from_bytes(self.sdb_file.read(INT_SIZE), BIG_ENDIAN)
 | 
			
		||||
@ -371,7 +379,7 @@ class CSVWriter:  # pylint: disable=too-many-instance-attributes
 | 
			
		||||
 | 
			
		||||
class SampleList:
 | 
			
		||||
    """Sample collection base class with samples loaded from a list of in-memory paths."""
 | 
			
		||||
    def __init__(self, samples, labeled=True):
 | 
			
		||||
    def __init__(self, samples, labeled=True, reverse=False):
 | 
			
		||||
        """
 | 
			
		||||
        Parameters
 | 
			
		||||
        ----------
 | 
			
		||||
@ -380,10 +388,12 @@ class SampleList:
 | 
			
		||||
        labeled : bool or None
 | 
			
		||||
            If True: Reads LabeledSample instances.
 | 
			
		||||
            If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
 | 
			
		||||
        reverse : bool
 | 
			
		||||
            If the order of the samples should be reversed
 | 
			
		||||
        """
 | 
			
		||||
        self.labeled = labeled
 | 
			
		||||
        self.samples = list(samples)
 | 
			
		||||
        self.samples.sort(key=lambda r: r[1])
 | 
			
		||||
        self.samples.sort(key=lambda r: r[1], reverse=reverse)
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, i):
 | 
			
		||||
        sample_spec = self.samples[i]
 | 
			
		||||
@ -396,7 +406,7 @@ class SampleList:
 | 
			
		||||
class CSV(SampleList):
 | 
			
		||||
    """Sample collection reader for reading a DeepSpeech CSV file
 | 
			
		||||
    Automatically orders samples by CSV column wav_filesize (if available)."""
 | 
			
		||||
    def __init__(self, csv_filename, labeled=None):
 | 
			
		||||
    def __init__(self, csv_filename, labeled=None, reverse=False):
 | 
			
		||||
        """
 | 
			
		||||
        Parameters
 | 
			
		||||
        ----------
 | 
			
		||||
@ -407,6 +417,8 @@ class CSV(SampleList):
 | 
			
		||||
            If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
 | 
			
		||||
            If None: Automatically determines if CSV file has a transcript column
 | 
			
		||||
            (reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances).
 | 
			
		||||
        reverse : bool
 | 
			
		||||
            If the order of the samples should be reversed
 | 
			
		||||
        """
 | 
			
		||||
        rows = []
 | 
			
		||||
        csv_dir = Path(csv_filename).parent
 | 
			
		||||
@ -427,10 +439,10 @@ class CSV(SampleList):
 | 
			
		||||
                    rows.append((wav_filename, wav_filesize, row['transcript']))
 | 
			
		||||
                else:
 | 
			
		||||
                    rows.append((wav_filename, wav_filesize))
 | 
			
		||||
        super(CSV, self).__init__(rows, labeled=labeled)
 | 
			
		||||
        super(CSV, self).__init__(rows, labeled=labeled, reverse=reverse)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def samples_from_source(sample_source, buffering=BUFFER_SIZE, labeled=None):
 | 
			
		||||
def samples_from_source(sample_source, buffering=BUFFER_SIZE, labeled=None, reverse=False):
 | 
			
		||||
    """
 | 
			
		||||
    Loads samples from a sample source file.
 | 
			
		||||
 | 
			
		||||
@ -445,6 +457,8 @@ def samples_from_source(sample_source, buffering=BUFFER_SIZE, labeled=None):
 | 
			
		||||
        If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
 | 
			
		||||
        If None: Automatically determines if source provides transcripts
 | 
			
		||||
        (reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances).
 | 
			
		||||
    reverse : bool
 | 
			
		||||
        If the order of the samples should be reversed
 | 
			
		||||
 | 
			
		||||
    Returns
 | 
			
		||||
    -------
 | 
			
		||||
@ -452,13 +466,13 @@ def samples_from_source(sample_source, buffering=BUFFER_SIZE, labeled=None):
 | 
			
		||||
    """
 | 
			
		||||
    ext = os.path.splitext(sample_source)[1].lower()
 | 
			
		||||
    if ext == '.sdb':
 | 
			
		||||
        return SDB(sample_source, buffering=buffering, labeled=labeled)
 | 
			
		||||
        return SDB(sample_source, buffering=buffering, labeled=labeled, reverse=reverse)
 | 
			
		||||
    if ext == '.csv':
 | 
			
		||||
        return CSV(sample_source, labeled=labeled)
 | 
			
		||||
        return CSV(sample_source, labeled=labeled, reverse=reverse)
 | 
			
		||||
    raise ValueError('Unknown file type: "{}"'.format(ext))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None):
 | 
			
		||||
def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None, reverse=False):
 | 
			
		||||
    """
 | 
			
		||||
    Loads and combines samples from a list of source files. Sources are combined in an interleaving way to
 | 
			
		||||
    keep default sample order from shortest to longest.
 | 
			
		||||
@ -474,6 +488,8 @@ def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None):
 | 
			
		||||
        If False: Ignores transcripts (if available) and always reads (unlabeled) util.audio.Sample instances.
 | 
			
		||||
        If None: Reads util.sample_collections.LabeledSample instances from sources with transcripts and
 | 
			
		||||
        util.audio.Sample instances from sources with no transcripts.
 | 
			
		||||
    reverse : bool
 | 
			
		||||
        If the order of the samples should be reversed
 | 
			
		||||
 | 
			
		||||
    Returns
 | 
			
		||||
    -------
 | 
			
		||||
@ -483,6 +499,7 @@ def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None):
 | 
			
		||||
    if len(sample_sources) == 0:
 | 
			
		||||
        raise ValueError('No files')
 | 
			
		||||
    if len(sample_sources) == 1:
 | 
			
		||||
        return samples_from_source(sample_sources[0], buffering=buffering, labeled=labeled)
 | 
			
		||||
    cols = list(map(partial(samples_from_source, buffering=buffering, labeled=labeled), sample_sources))
 | 
			
		||||
    return Interleaved(*cols, key=lambda s: s.duration)
 | 
			
		||||
        return samples_from_source(sample_sources[0], buffering=buffering, labeled=labeled, reverse=reverse)
 | 
			
		||||
    cols = [samples_from_source(source, buffering=buffering, labeled=labeled, reverse=reverse)
 | 
			
		||||
            for source in sample_sources]
 | 
			
		||||
    return Interleaved(*cols, key=lambda s: s.duration, reverse=reverse)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user