Resolves #3235 - Support for .tar(.gz) targets in bin/data_set_tool.py

This commit is contained in:
Tilman Kamp 2020-08-10 14:24:47 +02:00
parent a6f40a3b2f
commit 96f37a403d
2 changed files with 90 additions and 3 deletions

View File

@ -18,6 +18,7 @@ from mozilla_voice_stt_training.util.downloader import SIMPLE_BAR
from mozilla_voice_stt_training.util.sample_collections import (
CSVWriter,
DirectSDBWriter,
TarWriter,
samples_from_sources,
)
from mozilla_voice_stt_training.util.augmentations import (
@ -41,8 +42,12 @@ def build_data_set():
writer = CSVWriter(CLI_ARGS.target, absolute_paths=CLI_ARGS.absolute_paths, labeled=labeled)
elif extension == '.sdb':
writer = DirectSDBWriter(CLI_ARGS.target, audio_type=audio_type, labeled=labeled)
elif extension == '.tar':
writer = TarWriter(CLI_ARGS.target, labeled=labeled, gz=False, include=CLI_ARGS.include)
elif extension == '.tgz' or CLI_ARGS.target.lower().endswith('.tar.gz'):
writer = TarWriter(CLI_ARGS.target, labeled=labeled, gz=True, include=CLI_ARGS.include)
else:
print('Unknown extension of target file - has to be either .csv or .sdb')
print('Unknown extension of target file - has to be either .csv, .sdb, .tar, .tar.gz or .tgz')
sys.exit(1)
with writer:
samples = samples_from_sources(CLI_ARGS.sources, labeled=not CLI_ARGS.unlabeled)
@ -71,7 +76,7 @@ def handle_args():
)
parser.add_argument(
'target',
help='SDB or CSV file to create'
help='SDB, CSV or TAR(.gz) file to create'
)
parser.add_argument(
'--audio-type',
@ -90,7 +95,7 @@ def handle_args():
parser.add_argument(
'--unlabeled',
action='store_true',
help='If to build an SDB with unlabeled (audio only) samples - '
help='If to build an data-set with unlabeled (audio only) samples - '
'typically used for building noise augmentation corpora',
)
parser.add_argument(
@ -103,6 +108,11 @@ def handle_args():
action='append',
help='Add an augmentation operation',
)
parser.add_argument(
'--include',
action='append',
help='Adds a file to the root directory of .tar(.gz) targets',
)
return parser.parse_args()

View File

@ -1,7 +1,9 @@
# -*- coding: utf-8 -*-
import os
import io
import csv
import json
import tarfile
from pathlib import Path
from functools import partial
@ -377,6 +379,81 @@ class CSVWriter: # pylint: disable=too-many-instance-attributes
self.close()
class TarWriter: # pylint: disable=too-many-instance-attributes
"""Sample collection writer for writing a CSV data-set and all its referenced WAV samples to a tar file"""
def __init__(self,
tar_filename,
gz=False,
labeled=True,
include=None):
"""
Parameters
----------
tar_filename : str
Path to the tar file to write.
gz : bool
If to compress tar file with gzip.
labeled : bool or None
If True: Writes labeled samples (util.sample_collections.LabeledSample) only.
If False: Ignores transcripts (if available) and writes (unlabeled) util.audio.Sample instances.
include : str[]
List of files to include into tar root.
"""
self.tar = tarfile.open(tar_filename, 'w:gz' if gz else 'w')
samples_dir = tarfile.TarInfo('samples')
samples_dir.type = tarfile.DIRTYPE
self.tar.addfile(samples_dir)
if include:
for include_path in include:
self.tar.add(include_path, recursive=False, arcname=Path(include_path).name)
fieldnames = ['wav_filename', 'wav_filesize']
self.labeled = labeled
if labeled:
fieldnames.append('transcript')
self.csv_file = io.StringIO()
self.csv_writer = csv.DictWriter(self.csv_file, fieldnames=fieldnames)
self.csv_writer.writeheader()
self.counter = 0
def __enter__(self):
return self
def add(self, sample):
sample_filename = 'samples/sample{0:08d}.wav'.format(self.counter)
self.counter += 1
sample.change_audio_type(AUDIO_TYPE_PCM)
sample_file = io.BytesIO()
write_wav(sample_file, sample.audio, audio_format=sample.audio_format)
sample_size = sample_file.tell()
sample_file.seek(0)
sample_tar = tarfile.TarInfo(sample_filename)
sample_tar.size = sample_size
self.tar.addfile(sample_tar, sample_file)
row = {
'wav_filename': sample_filename,
'wav_filesize': sample_size
}
if self.labeled:
row['transcript'] = sample.transcript
self.csv_writer.writerow(row)
return sample_filename
def close(self):
if self.csv_file and self.tar:
csv_tar = tarfile.TarInfo('samples.csv')
csv_tar.size = self.csv_file.tell()
self.csv_file.seek(0)
self.tar.addfile(csv_tar, io.BytesIO(self.csv_file.read().encode('utf8')))
if self.tar:
self.tar.close()
def __len__(self):
return self.counter
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
class SampleList:
"""Sample collection base class with samples loaded from a list of in-memory paths."""
def __init__(self, samples, labeled=True, reverse=False):