diff --git a/bin/data_set_tool.py b/bin/data_set_tool.py index 7b7322c3..4bfe7bd8 100755 --- a/bin/data_set_tool.py +++ b/bin/data_set_tool.py @@ -18,6 +18,7 @@ from mozilla_voice_stt_training.util.downloader import SIMPLE_BAR from mozilla_voice_stt_training.util.sample_collections import ( CSVWriter, DirectSDBWriter, + TarWriter, samples_from_sources, ) from mozilla_voice_stt_training.util.augmentations import ( @@ -41,8 +42,12 @@ def build_data_set(): writer = CSVWriter(CLI_ARGS.target, absolute_paths=CLI_ARGS.absolute_paths, labeled=labeled) elif extension == '.sdb': writer = DirectSDBWriter(CLI_ARGS.target, audio_type=audio_type, labeled=labeled) + elif extension == '.tar': + writer = TarWriter(CLI_ARGS.target, labeled=labeled, gz=False, include=CLI_ARGS.include) + elif extension == '.tgz' or CLI_ARGS.target.lower().endswith('.tar.gz'): + writer = TarWriter(CLI_ARGS.target, labeled=labeled, gz=True, include=CLI_ARGS.include) else: - print('Unknown extension of target file - has to be either .csv or .sdb') + print('Unknown extension of target file - has to be either .csv, .sdb, .tar, .tar.gz or .tgz') sys.exit(1) with writer: samples = samples_from_sources(CLI_ARGS.sources, labeled=not CLI_ARGS.unlabeled) @@ -71,7 +76,7 @@ def handle_args(): ) parser.add_argument( 'target', - help='SDB or CSV file to create' + help='SDB, CSV or TAR(.gz) file to create' ) parser.add_argument( '--audio-type', @@ -90,7 +95,7 @@ def handle_args(): parser.add_argument( '--unlabeled', action='store_true', - help='If to build an SDB with unlabeled (audio only) samples - ' + help='If to build an data-set with unlabeled (audio only) samples - ' 'typically used for building noise augmentation corpora', ) parser.add_argument( @@ -103,6 +108,11 @@ def handle_args(): action='append', help='Add an augmentation operation', ) + parser.add_argument( + '--include', + action='append', + help='Adds a file to the root directory of .tar(.gz) targets', + ) return parser.parse_args() diff --git a/training/mozilla_voice_stt_training/util/sample_collections.py b/training/mozilla_voice_stt_training/util/sample_collections.py index 15c97f97..3f1b55ea 100644 --- a/training/mozilla_voice_stt_training/util/sample_collections.py +++ b/training/mozilla_voice_stt_training/util/sample_collections.py @@ -1,7 +1,9 @@ # -*- coding: utf-8 -*- import os +import io import csv import json +import tarfile from pathlib import Path from functools import partial @@ -377,6 +379,81 @@ class CSVWriter: # pylint: disable=too-many-instance-attributes self.close() +class TarWriter: # pylint: disable=too-many-instance-attributes + """Sample collection writer for writing a CSV data-set and all its referenced WAV samples to a tar file""" + def __init__(self, + tar_filename, + gz=False, + labeled=True, + include=None): + """ + Parameters + ---------- + tar_filename : str + Path to the tar file to write. + gz : bool + If to compress tar file with gzip. + labeled : bool or None + If True: Writes labeled samples (util.sample_collections.LabeledSample) only. + If False: Ignores transcripts (if available) and writes (unlabeled) util.audio.Sample instances. + include : str[] + List of files to include into tar root. + """ + self.tar = tarfile.open(tar_filename, 'w:gz' if gz else 'w') + samples_dir = tarfile.TarInfo('samples') + samples_dir.type = tarfile.DIRTYPE + self.tar.addfile(samples_dir) + if include: + for include_path in include: + self.tar.add(include_path, recursive=False, arcname=Path(include_path).name) + fieldnames = ['wav_filename', 'wav_filesize'] + self.labeled = labeled + if labeled: + fieldnames.append('transcript') + self.csv_file = io.StringIO() + self.csv_writer = csv.DictWriter(self.csv_file, fieldnames=fieldnames) + self.csv_writer.writeheader() + self.counter = 0 + + def __enter__(self): + return self + + def add(self, sample): + sample_filename = 'samples/sample{0:08d}.wav'.format(self.counter) + self.counter += 1 + sample.change_audio_type(AUDIO_TYPE_PCM) + sample_file = io.BytesIO() + write_wav(sample_file, sample.audio, audio_format=sample.audio_format) + sample_size = sample_file.tell() + sample_file.seek(0) + sample_tar = tarfile.TarInfo(sample_filename) + sample_tar.size = sample_size + self.tar.addfile(sample_tar, sample_file) + row = { + 'wav_filename': sample_filename, + 'wav_filesize': sample_size + } + if self.labeled: + row['transcript'] = sample.transcript + self.csv_writer.writerow(row) + return sample_filename + + def close(self): + if self.csv_file and self.tar: + csv_tar = tarfile.TarInfo('samples.csv') + csv_tar.size = self.csv_file.tell() + self.csv_file.seek(0) + self.tar.addfile(csv_tar, io.BytesIO(self.csv_file.read().encode('utf8'))) + if self.tar: + self.tar.close() + + def __len__(self): + return self.counter + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + class SampleList: """Sample collection base class with samples loaded from a list of in-memory paths.""" def __init__(self, samples, labeled=True, reverse=False):