diff --git a/bin/build_sdb.py b/bin/build_sdb.py deleted file mode 100755 index ac7be8af..00000000 --- a/bin/build_sdb.py +++ /dev/null @@ -1,92 +0,0 @@ -#!/usr/bin/env python -""" -Tool for building Sample Databases (SDB files) from DeepSpeech CSV files and other SDB files -Use "python3 build_sdb.py -h" for help -""" -import argparse - -import progressbar - -from deepspeech_training.util.audio import ( - AUDIO_TYPE_PCM, - AUDIO_TYPE_OPUS, - AUDIO_TYPE_WAV, - change_audio_types, -) -from deepspeech_training.util.downloader import SIMPLE_BAR -from deepspeech_training.util.sample_collections import ( - DirectSDBWriter, - samples_from_sources, -) -from deepspeech_training.util.augmentations import ( - parse_augmentations, - apply_sample_augmentations, - SampleAugmentation -) - -AUDIO_TYPE_LOOKUP = {"wav": AUDIO_TYPE_WAV, "opus": AUDIO_TYPE_OPUS} - - -def build_sdb(): - audio_type = AUDIO_TYPE_LOOKUP[CLI_ARGS.audio_type] - augmentations = parse_augmentations(CLI_ARGS.augment) - if any(not isinstance(a, SampleAugmentation) for a in augmentations): - print("Warning: Some of the augmentations cannot be applied by this command.") - with DirectSDBWriter( - CLI_ARGS.target, audio_type=audio_type, labeled=not CLI_ARGS.unlabeled - ) as sdb_writer: - samples = samples_from_sources(CLI_ARGS.sources, labeled=not CLI_ARGS.unlabeled) - num_samples = len(samples) - if augmentations: - samples = apply_sample_augmentations(samples, audio_type=AUDIO_TYPE_PCM, augmentations=augmentations) - bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) - for sample in bar( - change_audio_types(samples, audio_type=audio_type, bitrate=CLI_ARGS.bitrate, processes=CLI_ARGS.workers) - ): - sdb_writer.add(sample) - - -def handle_args(): - parser = argparse.ArgumentParser( - description="Tool for building Sample Databases (SDB files) " - "from DeepSpeech CSV files and other SDB files" - ) - parser.add_argument( - "sources", - nargs="+", - help="Source CSV and/or SDB files - " - "Note: For getting a correctly ordered target SDB, source SDBs have to have their samples " - "already ordered from shortest to longest.", - ) - parser.add_argument("target", help="SDB file to create") - parser.add_argument( - "--audio-type", - default="opus", - choices=AUDIO_TYPE_LOOKUP.keys(), - help="Audio representation inside target SDB", - ) - parser.add_argument( - "--bitrate", - type=int, - help="Bitrate for lossy compressed SDB samples like in case of --audio-type opus", - ) - parser.add_argument( - "--workers", type=int, default=None, help="Number of encoding SDB workers" - ) - parser.add_argument( - "--unlabeled", - action="store_true", - help="If to build an SDB with unlabeled (audio only) samples - " - "typically used for building noise augmentation corpora", - ) - parser.add_argument( - "--augment", - action='append', - help="Add an augmentation operation", - ) - return parser.parse_args() - - -if __name__ == "__main__": - CLI_ARGS = handle_args() - build_sdb() diff --git a/bin/data_set_tool.py b/bin/data_set_tool.py new file mode 100755 index 00000000..589b4585 --- /dev/null +++ b/bin/data_set_tool.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python +''' +Tool for building a combined SDB or CSV sample-set from other sets +Use 'python3 data_set_tool.py -h' for help +''' +import sys +import argparse +import progressbar +from pathlib import Path + +from deepspeech_training.util.audio import ( + AUDIO_TYPE_PCM, + AUDIO_TYPE_OPUS, + AUDIO_TYPE_WAV, + change_audio_types, +) +from deepspeech_training.util.downloader import SIMPLE_BAR +from deepspeech_training.util.sample_collections import ( + CSVWriter, + DirectSDBWriter, + samples_from_sources, +) +from deepspeech_training.util.augmentations import ( + parse_augmentations, + apply_sample_augmentations, + SampleAugmentation +) + +AUDIO_TYPE_LOOKUP = {'wav': AUDIO_TYPE_WAV, 'opus': AUDIO_TYPE_OPUS} + + +def build_data_set(): + audio_type = AUDIO_TYPE_LOOKUP[CLI_ARGS.audio_type] + augmentations = parse_augmentations(CLI_ARGS.augment) + if any(not isinstance(a, SampleAugmentation) for a in augmentations): + print('Warning: Some of the specified augmentations will not get applied, as this tool only supports ' + 'overlay, codec, reverb, resample and volume.') + extension = Path(CLI_ARGS.target).suffix.lower() + labeled = not CLI_ARGS.unlabeled + if extension == '.csv': + writer = CSVWriter(CLI_ARGS.target, absolute_paths=CLI_ARGS.absolute_paths, labeled=labeled) + elif extension == '.sdb': + writer = DirectSDBWriter(CLI_ARGS.target, audio_type=audio_type, labeled=labeled) + else: + print('Unknown extension of target file - has to be either .csv or .sdb') + sys.exit(1) + with writer: + samples = samples_from_sources(CLI_ARGS.sources, labeled=not CLI_ARGS.unlabeled) + num_samples = len(samples) + if augmentations: + samples = apply_sample_augmentations(samples, audio_type=AUDIO_TYPE_PCM, augmentations=augmentations) + bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) + for sample in bar(change_audio_types( + samples, + audio_type=audio_type, + bitrate=CLI_ARGS.bitrate, + processes=CLI_ARGS.workers)): + writer.add(sample) + + +def handle_args(): + parser = argparse.ArgumentParser( + description='Tool for building a combined SDB or CSV sample-set from other sets' + ) + parser.add_argument( + 'sources', + nargs='+', + help='Source CSV and/or SDB files - ' + 'Note: For getting a correctly ordered target set, source SDBs have to have their samples ' + 'already ordered from shortest to longest.', + ) + parser.add_argument( + 'target', + help='SDB or CSV file to create' + ) + parser.add_argument( + '--audio-type', + default='opus', + choices=AUDIO_TYPE_LOOKUP.keys(), + help='Audio representation inside target SDB', + ) + parser.add_argument( + '--bitrate', + type=int, + help='Bitrate for lossy compressed SDB samples like in case of --audio-type opus', + ) + parser.add_argument( + '--workers', type=int, default=None, help='Number of encoding SDB workers' + ) + parser.add_argument( + '--unlabeled', + action='store_true', + help='If to build an SDB with unlabeled (audio only) samples - ' + 'typically used for building noise augmentation corpora', + ) + parser.add_argument( + '--absolute-paths', + action='store_true', + help='If to reference samples by their absolute paths when writing CSV files', + ) + parser.add_argument( + '--augment', + action='append', + help='Add an augmentation operation', + ) + return parser.parse_args() + + +if __name__ == '__main__': + CLI_ARGS = handle_args() + build_data_set() diff --git a/bin/play.py b/bin/play.py index e9348c8e..1e8c59ca 100755 --- a/bin/play.py +++ b/bin/play.py @@ -1,7 +1,7 @@ #!/usr/bin/env python """ Tool for playing (and augmenting) single samples or samples from Sample Databases (SDB files) and DeepSpeech CSV files -Use "python3 build_sdb.py -h" for help +Use "python3 play.py -h" for help """ import os diff --git a/bin/run-tc-ldc93s1_checkpoint_sdb.sh b/bin/run-tc-ldc93s1_checkpoint_sdb.sh index 6f5c307f..c811f984 100755 --- a/bin/run-tc-ldc93s1_checkpoint_sdb.sh +++ b/bin/run-tc-ldc93s1_checkpoint_sdb.sh @@ -13,7 +13,7 @@ fi; if [ ! -f "${ldc93s1_dir}/ldc93s1.sdb" ]; then echo "Converting LDC93S1 example data, saving to ${ldc93s1_sdb}." - python -u bin/build_sdb.py ${ldc93s1_csv} ${ldc93s1_sdb} + python -u bin/data_set_tool.py ${ldc93s1_csv} ${ldc93s1_sdb} fi; # Force only one visible device because we have a single-sample dataset diff --git a/bin/run-tc-ldc93s1_new_sdb.sh b/bin/run-tc-ldc93s1_new_sdb.sh index 76032aa2..6cd4a450 100755 --- a/bin/run-tc-ldc93s1_new_sdb.sh +++ b/bin/run-tc-ldc93s1_new_sdb.sh @@ -16,7 +16,7 @@ fi; if [ ! -f "${ldc93s1_dir}/ldc93s1.sdb" ]; then echo "Converting LDC93S1 example data, saving to ${ldc93s1_sdb}." - python -u bin/build_sdb.py ${ldc93s1_csv} ${ldc93s1_sdb} + python -u bin/data_set_tool.py ${ldc93s1_csv} ${ldc93s1_sdb} fi; # Force only one visible device because we have a single-sample dataset diff --git a/bin/run-tc-ldc93s1_new_sdb_csv.sh b/bin/run-tc-ldc93s1_new_sdb_csv.sh index 1b0f6d3d..ec3e7774 100755 --- a/bin/run-tc-ldc93s1_new_sdb_csv.sh +++ b/bin/run-tc-ldc93s1_new_sdb_csv.sh @@ -16,7 +16,7 @@ fi; if [ ! -f "${ldc93s1_dir}/ldc93s1.sdb" ]; then echo "Converting LDC93S1 example data, saving to ${ldc93s1_sdb}." - python -u bin/build_sdb.py ${ldc93s1_csv} ${ldc93s1_sdb} + python -u bin/data_set_tool.py ${ldc93s1_csv} ${ldc93s1_sdb} fi; # Force only one visible device because we have a single-sample dataset diff --git a/doc/TRAINING.rst b/doc/TRAINING.rst index 764088b5..68007457 100644 --- a/doc/TRAINING.rst +++ b/doc/TRAINING.rst @@ -496,7 +496,7 @@ Example training with all augmentations: [...] -The ``bin/play.py`` tool also supports ``--augment`` parameters (for sample domain augmentations) and can be used for experimenting with different configurations. +The ``bin/play.py`` and ``bin/data_set_tool.py`` tools also support ``--augment`` parameters (for sample domain augmentations) and can be used for experimenting with different configurations or creating augmented data sets. Example of playing all samples with reverberation and maximized volume: @@ -510,3 +510,12 @@ Example simulation of the codec augmentation of a wav-file first at the beginnin bin/play.py --augment codec[p=0.1,bitrate=48000:16000] --clock 0.0 test.wav bin/play.py --augment codec[p=0.1,bitrate=48000:16000] --clock 1.0 test.wav + +Example of creating a pre-augmented test set: + +.. code-block:: bash + + bin/data_set_tool.py \ + --augment overlay[source=noise.sdb,layers=1,snr=20~10] \ + --augment resample[rate=12000:8000~4000] \ + test.sdb test-augmented.sdb diff --git a/training/deepspeech_training/util/sample_collections.py b/training/deepspeech_training/util/sample_collections.py index 37210659..b220e1b3 100644 --- a/training/deepspeech_training/util/sample_collections.py +++ b/training/deepspeech_training/util/sample_collections.py @@ -7,7 +7,15 @@ from pathlib import Path from functools import partial from .helpers import MEGABYTE, GIGABYTE, Interleaved -from .audio import Sample, DEFAULT_FORMAT, AUDIO_TYPE_OPUS, SERIALIZABLE_AUDIO_TYPES, get_audio_type_from_extension +from .audio import ( + Sample, + DEFAULT_FORMAT, + AUDIO_TYPE_PCM, + AUDIO_TYPE_OPUS, + SERIALIZABLE_AUDIO_TYPES, + get_audio_type_from_extension, + write_wav +) BIG_ENDIAN = 'big' INT_SIZE = 4 @@ -297,6 +305,70 @@ class SDB: # pylint: disable=too-many-instance-attributes self.close() +class CSVWriter: # pylint: disable=too-many-instance-attributes + """Sample collection writer for writing a CSV data-set and all its referenced WAV samples""" + def __init__(self, + csv_filename, + absolute_paths=False, + labeled=True): + """ + Parameters + ---------- + csv_filename : str + Path to the CSV file to write. + Will create a directory (CSV-filename without extension) next to it and fail if it already exists. + absolute_paths : bool + If paths in CSV file should be absolute instead of relative to the CSV file's parent directory. + labeled : bool or None + If True: Writes labeled samples (util.sample_collections.LabeledSample) only. + If False: Ignores transcripts (if available) and writes (unlabeled) util.audio.Sample instances. + """ + self.csv_filename = Path(csv_filename) + self.csv_base_dir = self.csv_filename.parent.resolve().absolute() + self.set_name = self.csv_filename.stem + self.csv_dir = self.csv_base_dir / self.set_name + if self.csv_dir.exists(): + raise RuntimeError('"{}" already existing'.format(self.csv_dir)) + os.mkdir(str(self.csv_dir)) + self.absolute_paths = absolute_paths + fieldnames = ['wav_filename', 'wav_filesize'] + self.labeled = labeled + if labeled: + fieldnames.append('transcript') + self.csv_file = open(csv_filename, 'w', encoding='utf-8', newline='') + self.csv_writer = csv.DictWriter(self.csv_file, fieldnames=fieldnames) + self.csv_writer.writeheader() + self.counter = 0 + + def __enter__(self): + return self + + def add(self, sample): + sample_filename = self.csv_dir / 'sample{0:08d}.wav'.format(self.counter) + self.counter += 1 + sample.change_audio_type(AUDIO_TYPE_PCM) + write_wav(str(sample_filename), sample.audio, audio_format=sample.audio_format) + sample.sample_id = str(sample_filename.relative_to(self.csv_base_dir)) + row = { + 'wav_filename': str(sample_filename.absolute()) if self.absolute_paths else sample.sample_id, + 'wav_filesize': sample_filename.stat().st_size + } + if self.labeled: + row['transcript'] = sample.transcript + self.csv_writer.writerow(row) + return sample.sample_id + + def close(self): + if self.csv_file: + self.csv_file.close() + + def __len__(self): + return self.counter + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + class SampleList: """Sample collection base class with samples loaded from a list of in-memory paths.""" def __init__(self, samples, labeled=True):