Resolves #3235 - Support for .tar(.gz) targets in bin/data_set_tool.py

This commit is contained in:
Tilman Kamp 2020-08-10 14:24:47 +02:00
parent a6f40a3b2f
commit 96f37a403d
2 changed files with 90 additions and 3 deletions

View File

@ -18,6 +18,7 @@ from mozilla_voice_stt_training.util.downloader import SIMPLE_BAR
from mozilla_voice_stt_training.util.sample_collections import ( from mozilla_voice_stt_training.util.sample_collections import (
CSVWriter, CSVWriter,
DirectSDBWriter, DirectSDBWriter,
TarWriter,
samples_from_sources, samples_from_sources,
) )
from mozilla_voice_stt_training.util.augmentations import ( from mozilla_voice_stt_training.util.augmentations import (
@ -41,8 +42,12 @@ def build_data_set():
writer = CSVWriter(CLI_ARGS.target, absolute_paths=CLI_ARGS.absolute_paths, labeled=labeled) writer = CSVWriter(CLI_ARGS.target, absolute_paths=CLI_ARGS.absolute_paths, labeled=labeled)
elif extension == '.sdb': elif extension == '.sdb':
writer = DirectSDBWriter(CLI_ARGS.target, audio_type=audio_type, labeled=labeled) writer = DirectSDBWriter(CLI_ARGS.target, audio_type=audio_type, labeled=labeled)
elif extension == '.tar':
writer = TarWriter(CLI_ARGS.target, labeled=labeled, gz=False, include=CLI_ARGS.include)
elif extension == '.tgz' or CLI_ARGS.target.lower().endswith('.tar.gz'):
writer = TarWriter(CLI_ARGS.target, labeled=labeled, gz=True, include=CLI_ARGS.include)
else: else:
print('Unknown extension of target file - has to be either .csv or .sdb') print('Unknown extension of target file - has to be either .csv, .sdb, .tar, .tar.gz or .tgz')
sys.exit(1) sys.exit(1)
with writer: with writer:
samples = samples_from_sources(CLI_ARGS.sources, labeled=not CLI_ARGS.unlabeled) samples = samples_from_sources(CLI_ARGS.sources, labeled=not CLI_ARGS.unlabeled)
@ -71,7 +76,7 @@ def handle_args():
) )
parser.add_argument( parser.add_argument(
'target', 'target',
help='SDB or CSV file to create' help='SDB, CSV or TAR(.gz) file to create'
) )
parser.add_argument( parser.add_argument(
'--audio-type', '--audio-type',
@ -90,7 +95,7 @@ def handle_args():
parser.add_argument( parser.add_argument(
'--unlabeled', '--unlabeled',
action='store_true', action='store_true',
help='If to build an SDB with unlabeled (audio only) samples - ' help='If to build an data-set with unlabeled (audio only) samples - '
'typically used for building noise augmentation corpora', 'typically used for building noise augmentation corpora',
) )
parser.add_argument( parser.add_argument(
@ -103,6 +108,11 @@ def handle_args():
action='append', action='append',
help='Add an augmentation operation', help='Add an augmentation operation',
) )
parser.add_argument(
'--include',
action='append',
help='Adds a file to the root directory of .tar(.gz) targets',
)
return parser.parse_args() return parser.parse_args()

View File

@ -1,7 +1,9 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import os import os
import io
import csv import csv
import json import json
import tarfile
from pathlib import Path from pathlib import Path
from functools import partial from functools import partial
@ -377,6 +379,81 @@ class CSVWriter: # pylint: disable=too-many-instance-attributes
self.close() self.close()
class TarWriter: # pylint: disable=too-many-instance-attributes
"""Sample collection writer for writing a CSV data-set and all its referenced WAV samples to a tar file"""
def __init__(self,
tar_filename,
gz=False,
labeled=True,
include=None):
"""
Parameters
----------
tar_filename : str
Path to the tar file to write.
gz : bool
If to compress tar file with gzip.
labeled : bool or None
If True: Writes labeled samples (util.sample_collections.LabeledSample) only.
If False: Ignores transcripts (if available) and writes (unlabeled) util.audio.Sample instances.
include : str[]
List of files to include into tar root.
"""
self.tar = tarfile.open(tar_filename, 'w:gz' if gz else 'w')
samples_dir = tarfile.TarInfo('samples')
samples_dir.type = tarfile.DIRTYPE
self.tar.addfile(samples_dir)
if include:
for include_path in include:
self.tar.add(include_path, recursive=False, arcname=Path(include_path).name)
fieldnames = ['wav_filename', 'wav_filesize']
self.labeled = labeled
if labeled:
fieldnames.append('transcript')
self.csv_file = io.StringIO()
self.csv_writer = csv.DictWriter(self.csv_file, fieldnames=fieldnames)
self.csv_writer.writeheader()
self.counter = 0
def __enter__(self):
return self
def add(self, sample):
sample_filename = 'samples/sample{0:08d}.wav'.format(self.counter)
self.counter += 1
sample.change_audio_type(AUDIO_TYPE_PCM)
sample_file = io.BytesIO()
write_wav(sample_file, sample.audio, audio_format=sample.audio_format)
sample_size = sample_file.tell()
sample_file.seek(0)
sample_tar = tarfile.TarInfo(sample_filename)
sample_tar.size = sample_size
self.tar.addfile(sample_tar, sample_file)
row = {
'wav_filename': sample_filename,
'wav_filesize': sample_size
}
if self.labeled:
row['transcript'] = sample.transcript
self.csv_writer.writerow(row)
return sample_filename
def close(self):
if self.csv_file and self.tar:
csv_tar = tarfile.TarInfo('samples.csv')
csv_tar.size = self.csv_file.tell()
self.csv_file.seek(0)
self.tar.addfile(csv_tar, io.BytesIO(self.csv_file.read().encode('utf8')))
if self.tar:
self.tar.close()
def __len__(self):
return self.counter
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
class SampleList: class SampleList:
"""Sample collection base class with samples loaded from a list of in-memory paths.""" """Sample collection base class with samples loaded from a list of in-memory paths."""
def __init__(self, samples, labeled=True, reverse=False): def __init__(self, samples, labeled=True, reverse=False):