Resolves #3235 - Support for .tar(.gz) targets in bin/data_set_tool.py
This commit is contained in:
parent
a6f40a3b2f
commit
96f37a403d
@ -18,6 +18,7 @@ from mozilla_voice_stt_training.util.downloader import SIMPLE_BAR
|
||||
from mozilla_voice_stt_training.util.sample_collections import (
|
||||
CSVWriter,
|
||||
DirectSDBWriter,
|
||||
TarWriter,
|
||||
samples_from_sources,
|
||||
)
|
||||
from mozilla_voice_stt_training.util.augmentations import (
|
||||
@ -41,8 +42,12 @@ def build_data_set():
|
||||
writer = CSVWriter(CLI_ARGS.target, absolute_paths=CLI_ARGS.absolute_paths, labeled=labeled)
|
||||
elif extension == '.sdb':
|
||||
writer = DirectSDBWriter(CLI_ARGS.target, audio_type=audio_type, labeled=labeled)
|
||||
elif extension == '.tar':
|
||||
writer = TarWriter(CLI_ARGS.target, labeled=labeled, gz=False, include=CLI_ARGS.include)
|
||||
elif extension == '.tgz' or CLI_ARGS.target.lower().endswith('.tar.gz'):
|
||||
writer = TarWriter(CLI_ARGS.target, labeled=labeled, gz=True, include=CLI_ARGS.include)
|
||||
else:
|
||||
print('Unknown extension of target file - has to be either .csv or .sdb')
|
||||
print('Unknown extension of target file - has to be either .csv, .sdb, .tar, .tar.gz or .tgz')
|
||||
sys.exit(1)
|
||||
with writer:
|
||||
samples = samples_from_sources(CLI_ARGS.sources, labeled=not CLI_ARGS.unlabeled)
|
||||
@ -71,7 +76,7 @@ def handle_args():
|
||||
)
|
||||
parser.add_argument(
|
||||
'target',
|
||||
help='SDB or CSV file to create'
|
||||
help='SDB, CSV or TAR(.gz) file to create'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--audio-type',
|
||||
@ -90,7 +95,7 @@ def handle_args():
|
||||
parser.add_argument(
|
||||
'--unlabeled',
|
||||
action='store_true',
|
||||
help='If to build an SDB with unlabeled (audio only) samples - '
|
||||
help='If to build an data-set with unlabeled (audio only) samples - '
|
||||
'typically used for building noise augmentation corpora',
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -103,6 +108,11 @@ def handle_args():
|
||||
action='append',
|
||||
help='Add an augmentation operation',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--include',
|
||||
action='append',
|
||||
help='Adds a file to the root directory of .tar(.gz) targets',
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
@ -1,7 +1,9 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import io
|
||||
import csv
|
||||
import json
|
||||
import tarfile
|
||||
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
@ -377,6 +379,81 @@ class CSVWriter: # pylint: disable=too-many-instance-attributes
|
||||
self.close()
|
||||
|
||||
|
||||
class TarWriter: # pylint: disable=too-many-instance-attributes
|
||||
"""Sample collection writer for writing a CSV data-set and all its referenced WAV samples to a tar file"""
|
||||
def __init__(self,
|
||||
tar_filename,
|
||||
gz=False,
|
||||
labeled=True,
|
||||
include=None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
tar_filename : str
|
||||
Path to the tar file to write.
|
||||
gz : bool
|
||||
If to compress tar file with gzip.
|
||||
labeled : bool or None
|
||||
If True: Writes labeled samples (util.sample_collections.LabeledSample) only.
|
||||
If False: Ignores transcripts (if available) and writes (unlabeled) util.audio.Sample instances.
|
||||
include : str[]
|
||||
List of files to include into tar root.
|
||||
"""
|
||||
self.tar = tarfile.open(tar_filename, 'w:gz' if gz else 'w')
|
||||
samples_dir = tarfile.TarInfo('samples')
|
||||
samples_dir.type = tarfile.DIRTYPE
|
||||
self.tar.addfile(samples_dir)
|
||||
if include:
|
||||
for include_path in include:
|
||||
self.tar.add(include_path, recursive=False, arcname=Path(include_path).name)
|
||||
fieldnames = ['wav_filename', 'wav_filesize']
|
||||
self.labeled = labeled
|
||||
if labeled:
|
||||
fieldnames.append('transcript')
|
||||
self.csv_file = io.StringIO()
|
||||
self.csv_writer = csv.DictWriter(self.csv_file, fieldnames=fieldnames)
|
||||
self.csv_writer.writeheader()
|
||||
self.counter = 0
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def add(self, sample):
|
||||
sample_filename = 'samples/sample{0:08d}.wav'.format(self.counter)
|
||||
self.counter += 1
|
||||
sample.change_audio_type(AUDIO_TYPE_PCM)
|
||||
sample_file = io.BytesIO()
|
||||
write_wav(sample_file, sample.audio, audio_format=sample.audio_format)
|
||||
sample_size = sample_file.tell()
|
||||
sample_file.seek(0)
|
||||
sample_tar = tarfile.TarInfo(sample_filename)
|
||||
sample_tar.size = sample_size
|
||||
self.tar.addfile(sample_tar, sample_file)
|
||||
row = {
|
||||
'wav_filename': sample_filename,
|
||||
'wav_filesize': sample_size
|
||||
}
|
||||
if self.labeled:
|
||||
row['transcript'] = sample.transcript
|
||||
self.csv_writer.writerow(row)
|
||||
return sample_filename
|
||||
|
||||
def close(self):
|
||||
if self.csv_file and self.tar:
|
||||
csv_tar = tarfile.TarInfo('samples.csv')
|
||||
csv_tar.size = self.csv_file.tell()
|
||||
self.csv_file.seek(0)
|
||||
self.tar.addfile(csv_tar, io.BytesIO(self.csv_file.read().encode('utf8')))
|
||||
if self.tar:
|
||||
self.tar.close()
|
||||
|
||||
def __len__(self):
|
||||
return self.counter
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
|
||||
class SampleList:
|
||||
"""Sample collection base class with samples loaded from a list of in-memory paths."""
|
||||
def __init__(self, samples, labeled=True, reverse=False):
|
||||
|
Loading…
x
Reference in New Issue
Block a user