Fix #2830 - Support for unlabeled samples
This commit is contained in:
		
							parent
							
								
									5740d64e6e
								
							
						
					
					
						commit
						41da7b2870
					
				@ -26,8 +26,8 @@ AUDIO_TYPE_LOOKUP = {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
def build_sdb():
 | 
					def build_sdb():
 | 
				
			||||||
    audio_type = AUDIO_TYPE_LOOKUP[CLI_ARGS.audio_type]
 | 
					    audio_type = AUDIO_TYPE_LOOKUP[CLI_ARGS.audio_type]
 | 
				
			||||||
    with DirectSDBWriter(CLI_ARGS.target, audio_type=audio_type) as sdb_writer:
 | 
					    with DirectSDBWriter(CLI_ARGS.target, audio_type=audio_type, labeled=not CLI_ARGS.unlabeled) as sdb_writer:
 | 
				
			||||||
        samples = samples_from_files(CLI_ARGS.sources)
 | 
					        samples = samples_from_files(CLI_ARGS.sources, labeled=not CLI_ARGS.unlabeled)
 | 
				
			||||||
        bar = progressbar.ProgressBar(max_value=len(samples), widgets=SIMPLE_BAR)
 | 
					        bar = progressbar.ProgressBar(max_value=len(samples), widgets=SIMPLE_BAR)
 | 
				
			||||||
        for sample in bar(change_audio_types(samples, audio_type=audio_type, processes=CLI_ARGS.workers)):
 | 
					        for sample in bar(change_audio_types(samples, audio_type=audio_type, processes=CLI_ARGS.workers)):
 | 
				
			||||||
            sdb_writer.add(sample)
 | 
					            sdb_writer.add(sample)
 | 
				
			||||||
@ -36,13 +36,18 @@ def build_sdb():
 | 
				
			|||||||
def handle_args():
 | 
					def handle_args():
 | 
				
			||||||
    parser = argparse.ArgumentParser(description='Tool for building Sample Databases (SDB files) '
 | 
					    parser = argparse.ArgumentParser(description='Tool for building Sample Databases (SDB files) '
 | 
				
			||||||
                                                 'from DeepSpeech CSV files and other SDB files')
 | 
					                                                 'from DeepSpeech CSV files and other SDB files')
 | 
				
			||||||
    parser.add_argument('sources', nargs='+', help='Source CSV and/or SDB files - '
 | 
					    parser.add_argument('sources', nargs='+',
 | 
				
			||||||
                                                   'Note: For getting a correctly ordered target SDB, source SDBs have '
 | 
					                        help='Source CSV and/or SDB files - '
 | 
				
			||||||
                                                   'to have their samples already ordered from shortest to longest.')
 | 
					                             'Note: For getting a correctly ordered target SDB, source SDBs have to have their samples '
 | 
				
			||||||
 | 
					                             'already ordered from shortest to longest.')
 | 
				
			||||||
    parser.add_argument('target', help='SDB file to create')
 | 
					    parser.add_argument('target', help='SDB file to create')
 | 
				
			||||||
    parser.add_argument('--audio-type', default='opus', choices=AUDIO_TYPE_LOOKUP.keys(),
 | 
					    parser.add_argument('--audio-type', default='opus', choices=AUDIO_TYPE_LOOKUP.keys(),
 | 
				
			||||||
                        help='Audio representation inside target SDB')
 | 
					                        help='Audio representation inside target SDB')
 | 
				
			||||||
    parser.add_argument('--workers', type=int, default=None, help='Number of encoding SDB workers')
 | 
					    parser.add_argument('--workers', type=int, default=None,
 | 
				
			||||||
 | 
					                        help='Number of encoding SDB workers')
 | 
				
			||||||
 | 
					    parser.add_argument('--unlabeled', action='store_true',
 | 
				
			||||||
 | 
					                        help='If to build an SDB with unlabeled (audio only) samples - '
 | 
				
			||||||
 | 
					                             'typically used for building noise augmentation corpora')
 | 
				
			||||||
    return parser.parse_args()
 | 
					    return parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -14,7 +14,7 @@ sys.path.insert(1, os.path.join(sys.path[0], '..'))
 | 
				
			|||||||
import random
 | 
					import random
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from util.sample_collections import samples_from_file
 | 
					from util.sample_collections import samples_from_file, LabeledSample
 | 
				
			||||||
from util.audio import AUDIO_TYPE_PCM
 | 
					from util.audio import AUDIO_TYPE_PCM
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -28,6 +28,7 @@ def play_sample(samples, index):
 | 
				
			|||||||
        sys.exit(1)
 | 
					        sys.exit(1)
 | 
				
			||||||
    sample = samples[index]
 | 
					    sample = samples[index]
 | 
				
			||||||
    print('Sample "{}"'.format(sample.sample_id))
 | 
					    print('Sample "{}"'.format(sample.sample_id))
 | 
				
			||||||
 | 
					    if isinstance(sample, LabeledSample):
 | 
				
			||||||
        print('  "{}"'.format(sample.transcript))
 | 
					        print('  "{}"'.format(sample.transcript))
 | 
				
			||||||
    sample.change_audio_type(AUDIO_TYPE_PCM)
 | 
					    sample.change_audio_type(AUDIO_TYPE_PCM)
 | 
				
			||||||
    rate, channels, width = sample.audio_format
 | 
					    rate, channels, width = sample.audio_format
 | 
				
			||||||
 | 
				
			|||||||
@ -26,31 +26,45 @@ OPUS_CHUNK_LEN_SIZE = 2
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Sample:
 | 
					class Sample:
 | 
				
			||||||
    """Represents in-memory audio data of a certain (convertible) representation.
 | 
					 | 
				
			||||||
    Attributes:
 | 
					 | 
				
			||||||
        audio_type (str): See `__init__`.
 | 
					 | 
				
			||||||
        audio_format (tuple:(int, int, int)): See `__init__`.
 | 
					 | 
				
			||||||
        audio (obj): Audio data represented as indicated by `audio_type`
 | 
					 | 
				
			||||||
        duration (float): Audio duration of the sample in seconds
 | 
					 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    def __init__(self, audio_type, raw_data, audio_format=None):
 | 
					    Represents in-memory audio data of a certain (convertible) representation.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Attributes
 | 
				
			||||||
 | 
					    ----------
 | 
				
			||||||
 | 
					    audio_type : str
 | 
				
			||||||
 | 
					        See `__init__`.
 | 
				
			||||||
 | 
					    audio_format : tuple:(int, int, int)
 | 
				
			||||||
 | 
					        See `__init__`.
 | 
				
			||||||
 | 
					    audio : binary
 | 
				
			||||||
 | 
					        Audio data represented as indicated by `audio_type`
 | 
				
			||||||
 | 
					    duration : float
 | 
				
			||||||
 | 
					        Audio duration of the sample in seconds
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
        Creates a Sample from a raw audio representation.
 | 
					    def __init__(self, audio_type, raw_data, audio_format=None, sample_id=None):
 | 
				
			||||||
        :param audio_type: Audio data representation type
 | 
					        """
 | 
				
			||||||
 | 
					        Parameters
 | 
				
			||||||
 | 
					        ----------
 | 
				
			||||||
 | 
					        audio_type : str
 | 
				
			||||||
 | 
					            Audio data representation type
 | 
				
			||||||
            Supported types:
 | 
					            Supported types:
 | 
				
			||||||
                - AUDIO_TYPE_OPUS: Memory file representation (BytesIO) of Opus encoded audio
 | 
					                - util.audio.AUDIO_TYPE_OPUS: Memory file representation (BytesIO) of Opus encoded audio
 | 
				
			||||||
                    wrapped by a custom container format (used in SDBs)
 | 
					                    wrapped by a custom container format (used in SDBs)
 | 
				
			||||||
                - AUDIO_TYPE_WAV: Memory file representation (BytesIO) of a Wave file
 | 
					                - util.audio.AUDIO_TYPE_WAV: Memory file representation (BytesIO) of a Wave file
 | 
				
			||||||
                - AUDIO_TYPE_PCM: Binary representation (bytearray) of PCM encoded audio data (Wave file without header)
 | 
					                - util.audio.AUDIO_TYPE_PCM: Binary representation (bytearray) of PCM encoded audio data (Wave file without header)
 | 
				
			||||||
                - AUDIO_TYPE_NP: NumPy representation of audio data (np.float32) - typically used for GPU feeding
 | 
					                - util.audio.AUDIO_TYPE_NP: NumPy representation of audio data (np.float32) - typically used for GPU feeding
 | 
				
			||||||
        :param raw_data: Audio data in the form of the provided representation type (see audio_type).
 | 
					        raw_data : binary
 | 
				
			||||||
            For types AUDIO_TYPE_OPUS or AUDIO_TYPE_WAV data can also be passed as a bytearray.
 | 
					            Audio data in the form of the provided representation type (see audio_type).
 | 
				
			||||||
        :param audio_format: Tuple of sample-rate, number of channels and sample-width.
 | 
					            For types util.audio.AUDIO_TYPE_OPUS or util.audio.AUDIO_TYPE_WAV data can also be passed as a bytearray.
 | 
				
			||||||
            Required in case of audio_type = AUDIO_TYPE_PCM or AUDIO_TYPE_NP,
 | 
					        audio_format : tuple
 | 
				
			||||||
 | 
					            Tuple of sample-rate, number of channels and sample-width.
 | 
				
			||||||
 | 
					            Required in case of audio_type = util.audio.AUDIO_TYPE_PCM or util.audio.AUDIO_TYPE_NP,
 | 
				
			||||||
            as this information cannot be derived from raw audio data.
 | 
					            as this information cannot be derived from raw audio data.
 | 
				
			||||||
 | 
					        sample_id : str
 | 
				
			||||||
 | 
					            Tracking ID - should indicate sample's origin as precisely as possible
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        self.audio_type = audio_type
 | 
					        self.audio_type = audio_type
 | 
				
			||||||
        self.audio_format = audio_format
 | 
					        self.audio_format = audio_format
 | 
				
			||||||
 | 
					        self.sample_id = sample_id
 | 
				
			||||||
        if audio_type in SERIALIZABLE_AUDIO_TYPES:
 | 
					        if audio_type in SERIALIZABLE_AUDIO_TYPES:
 | 
				
			||||||
            self.audio = raw_data if isinstance(raw_data, io.BytesIO) else io.BytesIO(raw_data)
 | 
					            self.audio = raw_data if isinstance(raw_data, io.BytesIO) else io.BytesIO(raw_data)
 | 
				
			||||||
            self.duration = read_duration(audio_type, self.audio)
 | 
					            self.duration = read_duration(audio_type, self.audio)
 | 
				
			||||||
@ -68,7 +82,11 @@ class Sample:
 | 
				
			|||||||
    def change_audio_type(self, new_audio_type):
 | 
					    def change_audio_type(self, new_audio_type):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        In-place conversion of audio data into a different representation.
 | 
					        In-place conversion of audio data into a different representation.
 | 
				
			||||||
        :param new_audio_type: New audio-type - see `__init__`.
 | 
					
 | 
				
			||||||
 | 
					        Parameters
 | 
				
			||||||
 | 
					        ----------
 | 
				
			||||||
 | 
					        new_audio_type : str
 | 
				
			||||||
 | 
					            New audio-type - see `__init__`.
 | 
				
			||||||
            Not supported: Converting from AUDIO_TYPE_NP into any other type.
 | 
					            Not supported: Converting from AUDIO_TYPE_NP into any other type.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        if self.audio_type == new_audio_type:
 | 
					        if self.audio_type == new_audio_type:
 | 
				
			||||||
 | 
				
			|||||||
@ -116,7 +116,7 @@ def create_dataset(sources,
 | 
				
			|||||||
                   process_ahead=None,
 | 
					                   process_ahead=None,
 | 
				
			||||||
                   buffering=1 * MEGABYTE):
 | 
					                   buffering=1 * MEGABYTE):
 | 
				
			||||||
    def generate_values():
 | 
					    def generate_values():
 | 
				
			||||||
        samples = samples_from_files(sources, buffering=buffering)
 | 
					        samples = samples_from_files(sources, buffering=buffering, labeled=True)
 | 
				
			||||||
        for sample in change_audio_types(samples,
 | 
					        for sample in change_audio_types(samples,
 | 
				
			||||||
                                         AUDIO_TYPE_NP,
 | 
					                                         AUDIO_TYPE_NP,
 | 
				
			||||||
                                         process_ahead=2 * batch_size if process_ahead is None else process_ahead):
 | 
					                                         process_ahead=2 * batch_size if process_ahead is None else process_ahead):
 | 
				
			||||||
 | 
				
			|||||||
@ -29,23 +29,45 @@ class LabeledSample(Sample):
 | 
				
			|||||||
    Derived from util.audio.Sample and used by sample collection readers and writers."""
 | 
					    Derived from util.audio.Sample and used by sample collection readers and writers."""
 | 
				
			||||||
    def __init__(self, audio_type, raw_data, transcript, audio_format=DEFAULT_FORMAT, sample_id=None):
 | 
					    def __init__(self, audio_type, raw_data, transcript, audio_format=DEFAULT_FORMAT, sample_id=None):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Creates an in-memory speech sample together with a transcript of the utterance (label).
 | 
					        Parameters
 | 
				
			||||||
        :param audio_type: See util.audio.Sample.__init__ .
 | 
					        ----------
 | 
				
			||||||
        :param raw_data: See util.audio.Sample.__init__ .
 | 
					        audio_type : str
 | 
				
			||||||
        :param transcript: Transcript of the sample's utterance
 | 
					            See util.audio.Sample.__init__ .
 | 
				
			||||||
        :param audio_format: See util.audio.Sample.__init__ .
 | 
					        raw_data : binary
 | 
				
			||||||
        :param sample_id: Tracking ID - typically assigned by collection readers
 | 
					            See util.audio.Sample.__init__ .
 | 
				
			||||||
 | 
					        transcript : str
 | 
				
			||||||
 | 
					            Transcript of the sample's utterance
 | 
				
			||||||
 | 
					        audio_format : tuple
 | 
				
			||||||
 | 
					            See util.audio.Sample.__init__ .
 | 
				
			||||||
 | 
					        sample_id : str
 | 
				
			||||||
 | 
					            Tracking ID - should indicate sample's origin as precisely as possible.
 | 
				
			||||||
 | 
					            It is typically assigned by collection readers.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        super().__init__(audio_type, raw_data, audio_format=audio_format)
 | 
					        super().__init__(audio_type, raw_data, audio_format=audio_format, sample_id=sample_id)
 | 
				
			||||||
        self.sample_id = sample_id
 | 
					 | 
				
			||||||
        self.transcript = transcript
 | 
					        self.transcript = transcript
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class DirectSDBWriter:
 | 
					class DirectSDBWriter:
 | 
				
			||||||
    """Sample collection writer for creating a Sample DB (SDB) file"""
 | 
					    """Sample collection writer for creating a Sample DB (SDB) file"""
 | 
				
			||||||
    def __init__(self, sdb_filename, buffering=BUFFER_SIZE, audio_type=AUDIO_TYPE_OPUS, id_prefix=None):
 | 
					    def __init__(self, sdb_filename, buffering=BUFFER_SIZE, audio_type=AUDIO_TYPE_OPUS, id_prefix=None, labeled=True):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Parameters
 | 
				
			||||||
 | 
					        ----------
 | 
				
			||||||
 | 
					        sdb_filename : str
 | 
				
			||||||
 | 
					            Path to the SDB file to write
 | 
				
			||||||
 | 
					        buffering : int
 | 
				
			||||||
 | 
					            Write-buffer size to use while writing the SDB file
 | 
				
			||||||
 | 
					        audio_type : str
 | 
				
			||||||
 | 
					            See util.audio.Sample.__init__ .
 | 
				
			||||||
 | 
					        id_prefix : str
 | 
				
			||||||
 | 
					            Prefix for IDs of written samples - defaults to sdb_filename
 | 
				
			||||||
 | 
					        labeled : bool or None
 | 
				
			||||||
 | 
					            If True: Writes labeled samples (util.sample_collections.LabeledSample) only.
 | 
				
			||||||
 | 
					            If False: Ignores transcripts (if available) and writes (unlabeled) util.audio.Sample instances.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
        self.sdb_filename = sdb_filename
 | 
					        self.sdb_filename = sdb_filename
 | 
				
			||||||
        self.id_prefix = sdb_filename if id_prefix is None else id_prefix
 | 
					        self.id_prefix = sdb_filename if id_prefix is None else id_prefix
 | 
				
			||||||
 | 
					        self.labeled = labeled
 | 
				
			||||||
        if audio_type not in SERIALIZABLE_AUDIO_TYPES:
 | 
					        if audio_type not in SERIALIZABLE_AUDIO_TYPES:
 | 
				
			||||||
            raise ValueError('Audio type "{}" not supported'.format(audio_type))
 | 
					            raise ValueError('Audio type "{}" not supported'.format(audio_type))
 | 
				
			||||||
        self.audio_type = audio_type
 | 
					        self.audio_type = audio_type
 | 
				
			||||||
@ -55,12 +77,10 @@ class DirectSDBWriter:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        self.sdb_file.write(MAGIC)
 | 
					        self.sdb_file.write(MAGIC)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        meta_data = {
 | 
					        schema_entries = [{CONTENT_KEY: CONTENT_TYPE_SPEECH, MIME_TYPE_KEY: audio_type}]
 | 
				
			||||||
            SCHEMA_KEY: [
 | 
					        if self.labeled:
 | 
				
			||||||
                {CONTENT_KEY: CONTENT_TYPE_SPEECH, MIME_TYPE_KEY: audio_type},
 | 
					            schema_entries.append({CONTENT_KEY: CONTENT_TYPE_TRANSCRIPT, MIME_TYPE_KEY: MIME_TYPE_TEXT})
 | 
				
			||||||
                {CONTENT_KEY: CONTENT_TYPE_TRANSCRIPT, MIME_TYPE_KEY: MIME_TYPE_TEXT}
 | 
					        meta_data = {SCHEMA_KEY: schema_entries}
 | 
				
			||||||
            ]
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        meta_data = json.dumps(meta_data).encode()
 | 
					        meta_data = json.dumps(meta_data).encode()
 | 
				
			||||||
        self.write_big_int(len(meta_data))
 | 
					        self.write_big_int(len(meta_data))
 | 
				
			||||||
        self.sdb_file.write(meta_data)
 | 
					        self.sdb_file.write(meta_data)
 | 
				
			||||||
@ -83,10 +103,14 @@ class DirectSDBWriter:
 | 
				
			|||||||
        sample.change_audio_type(self.audio_type)
 | 
					        sample.change_audio_type(self.audio_type)
 | 
				
			||||||
        opus = sample.audio.getbuffer()
 | 
					        opus = sample.audio.getbuffer()
 | 
				
			||||||
        opus_len = to_bytes(len(opus))
 | 
					        opus_len = to_bytes(len(opus))
 | 
				
			||||||
 | 
					        if self.labeled:
 | 
				
			||||||
            transcript = sample.transcript.encode()
 | 
					            transcript = sample.transcript.encode()
 | 
				
			||||||
            transcript_len = to_bytes(len(transcript))
 | 
					            transcript_len = to_bytes(len(transcript))
 | 
				
			||||||
            entry_len = to_bytes(len(opus_len) + len(opus) + len(transcript_len) + len(transcript))
 | 
					            entry_len = to_bytes(len(opus_len) + len(opus) + len(transcript_len) + len(transcript))
 | 
				
			||||||
            buffer = b''.join([entry_len, opus_len, opus, transcript_len, transcript])
 | 
					            buffer = b''.join([entry_len, opus_len, opus, transcript_len, transcript])
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            entry_len = to_bytes(len(opus_len) + len(opus))
 | 
				
			||||||
 | 
					            buffer = b''.join([entry_len, opus_len, opus])
 | 
				
			||||||
        self.offsets.append(self.sdb_file.tell())
 | 
					        self.offsets.append(self.sdb_file.tell())
 | 
				
			||||||
        self.sdb_file.write(buffer)
 | 
					        self.sdb_file.write(buffer)
 | 
				
			||||||
        sample.sample_id = '{}:{}'.format(self.id_prefix, self.num_samples)
 | 
					        sample.sample_id = '{}:{}'.format(self.id_prefix, self.num_samples)
 | 
				
			||||||
@ -120,7 +144,22 @@ class DirectSDBWriter:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
class SDB:  # pylint: disable=too-many-instance-attributes
 | 
					class SDB:  # pylint: disable=too-many-instance-attributes
 | 
				
			||||||
    """Sample collection reader for reading a Sample DB (SDB) file"""
 | 
					    """Sample collection reader for reading a Sample DB (SDB) file"""
 | 
				
			||||||
    def __init__(self, sdb_filename, buffering=BUFFER_SIZE, id_prefix=None):
 | 
					    def __init__(self, sdb_filename, buffering=BUFFER_SIZE, id_prefix=None, labeled=True):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Parameters
 | 
				
			||||||
 | 
					        ----------
 | 
				
			||||||
 | 
					        sdb_filename : str
 | 
				
			||||||
 | 
					            Path to the SDB file to read samples from
 | 
				
			||||||
 | 
					        buffering : int
 | 
				
			||||||
 | 
					            Read-buffer size to use while reading the SDB file
 | 
				
			||||||
 | 
					        id_prefix : str
 | 
				
			||||||
 | 
					            Prefix for IDs of read samples - defaults to sdb_filename
 | 
				
			||||||
 | 
					        labeled : bool or None
 | 
				
			||||||
 | 
					            If True: Reads util.sample_collections.LabeledSample instances. Fails, if SDB file provides no transcripts.
 | 
				
			||||||
 | 
					            If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
 | 
				
			||||||
 | 
					            If None: Automatically determines if SDB schema has transcripts
 | 
				
			||||||
 | 
					            (reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances).
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
        self.sdb_filename = sdb_filename
 | 
					        self.sdb_filename = sdb_filename
 | 
				
			||||||
        self.id_prefix = sdb_filename if id_prefix is None else id_prefix
 | 
					        self.id_prefix = sdb_filename if id_prefix is None else id_prefix
 | 
				
			||||||
        self.sdb_file = open(sdb_filename, 'rb', buffering=buffering)
 | 
					        self.sdb_file = open(sdb_filename, 'rb', buffering=buffering)
 | 
				
			||||||
@ -139,10 +178,14 @@ class SDB:  # pylint: disable=too-many-instance-attributes
 | 
				
			|||||||
        self.speech_index = speech_columns[0]
 | 
					        self.speech_index = speech_columns[0]
 | 
				
			||||||
        self.audio_type = self.schema[self.speech_index][MIME_TYPE_KEY]
 | 
					        self.audio_type = self.schema[self.speech_index][MIME_TYPE_KEY]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.transcript_index = None
 | 
				
			||||||
 | 
					        if labeled is not False:
 | 
				
			||||||
            transcript_columns = self.find_columns(content=CONTENT_TYPE_TRANSCRIPT, mime_type=MIME_TYPE_TEXT)
 | 
					            transcript_columns = self.find_columns(content=CONTENT_TYPE_TRANSCRIPT, mime_type=MIME_TYPE_TEXT)
 | 
				
			||||||
        if not transcript_columns:
 | 
					            if transcript_columns:
 | 
				
			||||||
            raise RuntimeError('No transcript data (missing in schema)')
 | 
					 | 
				
			||||||
                self.transcript_index = transcript_columns[0]
 | 
					                self.transcript_index = transcript_columns[0]
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                if labeled is True:
 | 
				
			||||||
 | 
					                    raise RuntimeError('No transcript data (missing in schema)')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        sample_chunk_len = self.read_big_int()
 | 
					        sample_chunk_len = self.read_big_int()
 | 
				
			||||||
        self.sdb_file.seek(sample_chunk_len + BIGINT_SIZE, 1)
 | 
					        self.sdb_file.seek(sample_chunk_len + BIGINT_SIZE, 1)
 | 
				
			||||||
@ -194,9 +237,12 @@ class SDB:  # pylint: disable=too-many-instance-attributes
 | 
				
			|||||||
        return tuple(column_data)
 | 
					        return tuple(column_data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __getitem__(self, i):
 | 
					    def __getitem__(self, i):
 | 
				
			||||||
 | 
					        sample_id = '{}:{}'.format(self.id_prefix, i)
 | 
				
			||||||
 | 
					        if self.transcript_index is None:
 | 
				
			||||||
 | 
					            [audio_data] = self.read_row(i, self.speech_index)
 | 
				
			||||||
 | 
					            return Sample(self.audio_type, audio_data, sample_id=sample_id)
 | 
				
			||||||
        audio_data, transcript = self.read_row(i, self.speech_index, self.transcript_index)
 | 
					        audio_data, transcript = self.read_row(i, self.speech_index, self.transcript_index)
 | 
				
			||||||
        transcript = transcript.decode()
 | 
					        transcript = transcript.decode()
 | 
				
			||||||
        sample_id = '{}:{}'.format(self.id_prefix, i)
 | 
					 | 
				
			||||||
        return LabeledSample(self.audio_type, audio_data, transcript, sample_id=sample_id)
 | 
					        return LabeledSample(self.audio_type, audio_data, transcript, sample_id=sample_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __iter__(self):
 | 
					    def __iter__(self):
 | 
				
			||||||
@ -215,24 +261,50 @@ class SDB:  # pylint: disable=too-many-instance-attributes
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class CSV:
 | 
					class CSV:
 | 
				
			||||||
    """Sample collection reader for reading a DeepSpeech CSV file"""
 | 
					    """Sample collection reader for reading a DeepSpeech CSV file
 | 
				
			||||||
    def __init__(self, csv_filename):
 | 
					    Automatically orders samples by CSV column wav_filesize (if available)."""
 | 
				
			||||||
 | 
					    def __init__(self, csv_filename, labeled=None):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Parameters
 | 
				
			||||||
 | 
					        ----------
 | 
				
			||||||
 | 
					        csv_filename : str
 | 
				
			||||||
 | 
					            Path to the CSV file containing sample audio paths and transcripts
 | 
				
			||||||
 | 
					        labeled : bool or None
 | 
				
			||||||
 | 
					            If True: Reads LabeledSample instances. Fails, if CSV file has no transcript column.
 | 
				
			||||||
 | 
					            If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
 | 
				
			||||||
 | 
					            If None: Automatically determines if CSV file has a transcript column
 | 
				
			||||||
 | 
					            (reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances).
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
        self.csv_filename = csv_filename
 | 
					        self.csv_filename = csv_filename
 | 
				
			||||||
 | 
					        self.labeled = labeled
 | 
				
			||||||
        self.rows = []
 | 
					        self.rows = []
 | 
				
			||||||
        csv_dir = Path(csv_filename).parent
 | 
					        csv_dir = Path(csv_filename).parent
 | 
				
			||||||
        with open(csv_filename, 'r', encoding='utf8') as csv_file:
 | 
					        with open(csv_filename, 'r', encoding='utf8') as csv_file:
 | 
				
			||||||
            reader = csv.DictReader(csv_file)
 | 
					            reader = csv.DictReader(csv_file)
 | 
				
			||||||
 | 
					            if 'transcript' in reader.fieldnames:
 | 
				
			||||||
 | 
					                if self.labeled is None:
 | 
				
			||||||
 | 
					                    self.labeled = True
 | 
				
			||||||
 | 
					            elif self.labeled:
 | 
				
			||||||
 | 
					                raise RuntimeError('No transcript data (missing CSV column)')
 | 
				
			||||||
            for row in reader:
 | 
					            for row in reader:
 | 
				
			||||||
                wav_filename = Path(row['wav_filename'])
 | 
					                wav_filename = Path(row['wav_filename'])
 | 
				
			||||||
                if not wav_filename.is_absolute():
 | 
					                if not wav_filename.is_absolute():
 | 
				
			||||||
                    wav_filename = csv_dir / wav_filename
 | 
					                    wav_filename = csv_dir / wav_filename
 | 
				
			||||||
                self.rows.append((str(wav_filename), int(row['wav_filesize']), row['transcript']))
 | 
					                wav_filename = str(wav_filename)
 | 
				
			||||||
 | 
					                wav_filesize = int(row['wav_filesize']) if 'wav_filesize' in row else 0
 | 
				
			||||||
 | 
					                if self.labeled:
 | 
				
			||||||
 | 
					                    self.rows.append((wav_filename, wav_filesize, row['transcript']))
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    self.rows.append((wav_filename, wav_filesize))
 | 
				
			||||||
        self.rows.sort(key=lambda r: r[1])
 | 
					        self.rows.sort(key=lambda r: r[1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __getitem__(self, i):
 | 
					    def __getitem__(self, i):
 | 
				
			||||||
        wav_filename, _, transcript = self.rows[i]
 | 
					        row = self.rows[i]
 | 
				
			||||||
 | 
					        wav_filename = row[0]
 | 
				
			||||||
        with open(wav_filename, 'rb') as wav_file:
 | 
					        with open(wav_filename, 'rb') as wav_file:
 | 
				
			||||||
            return LabeledSample(AUDIO_TYPE_WAV, wav_file.read(), transcript, sample_id=wav_filename)
 | 
					            if self.labeled:
 | 
				
			||||||
 | 
					                return LabeledSample(AUDIO_TYPE_WAV, wav_file.read(), row[2], sample_id=wav_filename)
 | 
				
			||||||
 | 
					            return Sample(AUDIO_TYPE_WAV, wav_file.read(), sample_id=wav_filename)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __iter__(self):
 | 
					    def __iter__(self):
 | 
				
			||||||
        for i in range(len(self.rows)):
 | 
					        for i in range(len(self.rows)):
 | 
				
			||||||
@ -242,21 +314,52 @@ class CSV:
 | 
				
			|||||||
        return len(self.rows)
 | 
					        return len(self.rows)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def samples_from_file(filename, buffering=BUFFER_SIZE):
 | 
					def samples_from_file(filename, buffering=BUFFER_SIZE, labeled=None):
 | 
				
			||||||
    """Returns an iterable of LabeledSample objects loaded from a file."""
 | 
					    """
 | 
				
			||||||
 | 
					    Returns an iterable of util.sample_collections.LabeledSample or util.audio.Sample instances
 | 
				
			||||||
 | 
					    loaded from a sample source file.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Parameters
 | 
				
			||||||
 | 
					    ----------
 | 
				
			||||||
 | 
					    filename : str
 | 
				
			||||||
 | 
					        Path to the sample source file (SDB or CSV)
 | 
				
			||||||
 | 
					    buffering : int
 | 
				
			||||||
 | 
					        Read-buffer size to use while reading files
 | 
				
			||||||
 | 
					    labeled : bool or None
 | 
				
			||||||
 | 
					        If True: Reads LabeledSample instances. Fails, if source provides no transcripts.
 | 
				
			||||||
 | 
					        If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
 | 
				
			||||||
 | 
					        If None: Automatically determines if source provides transcripts
 | 
				
			||||||
 | 
					        (reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances).
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
    ext = os.path.splitext(filename)[1].lower()
 | 
					    ext = os.path.splitext(filename)[1].lower()
 | 
				
			||||||
    if ext == '.sdb':
 | 
					    if ext == '.sdb':
 | 
				
			||||||
        return SDB(filename, buffering=buffering)
 | 
					        return SDB(filename, buffering=buffering, labeled=labeled)
 | 
				
			||||||
    if ext == '.csv':
 | 
					    if ext == '.csv':
 | 
				
			||||||
        return CSV(filename)
 | 
					        return CSV(filename, labeled=labeled)
 | 
				
			||||||
    raise ValueError('Unknown file type: "{}"'.format(ext))
 | 
					    raise ValueError('Unknown file type: "{}"'.format(ext))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def samples_from_files(filenames, buffering=BUFFER_SIZE):
 | 
					def samples_from_files(filenames, buffering=BUFFER_SIZE, labeled=None):
 | 
				
			||||||
    """Returns an iterable of LabeledSample objects from a list of files."""
 | 
					    """
 | 
				
			||||||
 | 
					    Returns an iterable of util.sample_collections.LabeledSample or util.audio.Sample instances
 | 
				
			||||||
 | 
					    loaded from a collection of sample source files.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Parameters
 | 
				
			||||||
 | 
					    ----------
 | 
				
			||||||
 | 
					    filenames : list of str
 | 
				
			||||||
 | 
					        Paths to sample source files (SDBs or CSVs)
 | 
				
			||||||
 | 
					    buffering : int
 | 
				
			||||||
 | 
					        Read-buffer size to use while reading files
 | 
				
			||||||
 | 
					    labeled : bool or None
 | 
				
			||||||
 | 
					        If True: Reads LabeledSample instances. Fails, if not all sources provide transcripts.
 | 
				
			||||||
 | 
					        If False: Ignores transcripts (if available) and always reads (unlabeled) util.audio.Sample instances.
 | 
				
			||||||
 | 
					        If None: Reads util.sample_collections.LabeledSample instances from sources with transcripts and
 | 
				
			||||||
 | 
					        util.audio.Sample instances from sources with no transcripts.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    filenames = list(filenames)
 | 
				
			||||||
    if len(filenames) == 0:
 | 
					    if len(filenames) == 0:
 | 
				
			||||||
        raise ValueError('No files')
 | 
					        raise ValueError('No files')
 | 
				
			||||||
    if len(filenames) == 1:
 | 
					    if len(filenames) == 1:
 | 
				
			||||||
        return samples_from_file(filenames[0], buffering=buffering)
 | 
					        return samples_from_file(filenames[0], buffering=buffering, labeled=labeled)
 | 
				
			||||||
    cols = list(map(partial(samples_from_file, buffering=buffering), filenames))
 | 
					    cols = list(map(partial(samples_from_file, buffering=buffering, labeled=labeled), filenames))
 | 
				
			||||||
    return Interleaved(*cols, key=lambda s: s.duration)
 | 
					    return Interleaved(*cols, key=lambda s: s.duration)
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user