Resolves #3235 - Support for .tar(.gz) targets in bin/data_set_tool.py
This commit is contained in:
parent
a6f40a3b2f
commit
96f37a403d
@ -18,6 +18,7 @@ from mozilla_voice_stt_training.util.downloader import SIMPLE_BAR
|
|||||||
from mozilla_voice_stt_training.util.sample_collections import (
|
from mozilla_voice_stt_training.util.sample_collections import (
|
||||||
CSVWriter,
|
CSVWriter,
|
||||||
DirectSDBWriter,
|
DirectSDBWriter,
|
||||||
|
TarWriter,
|
||||||
samples_from_sources,
|
samples_from_sources,
|
||||||
)
|
)
|
||||||
from mozilla_voice_stt_training.util.augmentations import (
|
from mozilla_voice_stt_training.util.augmentations import (
|
||||||
@ -41,8 +42,12 @@ def build_data_set():
|
|||||||
writer = CSVWriter(CLI_ARGS.target, absolute_paths=CLI_ARGS.absolute_paths, labeled=labeled)
|
writer = CSVWriter(CLI_ARGS.target, absolute_paths=CLI_ARGS.absolute_paths, labeled=labeled)
|
||||||
elif extension == '.sdb':
|
elif extension == '.sdb':
|
||||||
writer = DirectSDBWriter(CLI_ARGS.target, audio_type=audio_type, labeled=labeled)
|
writer = DirectSDBWriter(CLI_ARGS.target, audio_type=audio_type, labeled=labeled)
|
||||||
|
elif extension == '.tar':
|
||||||
|
writer = TarWriter(CLI_ARGS.target, labeled=labeled, gz=False, include=CLI_ARGS.include)
|
||||||
|
elif extension == '.tgz' or CLI_ARGS.target.lower().endswith('.tar.gz'):
|
||||||
|
writer = TarWriter(CLI_ARGS.target, labeled=labeled, gz=True, include=CLI_ARGS.include)
|
||||||
else:
|
else:
|
||||||
print('Unknown extension of target file - has to be either .csv or .sdb')
|
print('Unknown extension of target file - has to be either .csv, .sdb, .tar, .tar.gz or .tgz')
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
with writer:
|
with writer:
|
||||||
samples = samples_from_sources(CLI_ARGS.sources, labeled=not CLI_ARGS.unlabeled)
|
samples = samples_from_sources(CLI_ARGS.sources, labeled=not CLI_ARGS.unlabeled)
|
||||||
@ -71,7 +76,7 @@ def handle_args():
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'target',
|
'target',
|
||||||
help='SDB or CSV file to create'
|
help='SDB, CSV or TAR(.gz) file to create'
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--audio-type',
|
'--audio-type',
|
||||||
@ -90,7 +95,7 @@ def handle_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--unlabeled',
|
'--unlabeled',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='If to build an SDB with unlabeled (audio only) samples - '
|
help='If to build an data-set with unlabeled (audio only) samples - '
|
||||||
'typically used for building noise augmentation corpora',
|
'typically used for building noise augmentation corpora',
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -103,6 +108,11 @@ def handle_args():
|
|||||||
action='append',
|
action='append',
|
||||||
help='Add an augmentation operation',
|
help='Add an augmentation operation',
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--include',
|
||||||
|
action='append',
|
||||||
|
help='Adds a file to the root directory of .tar(.gz) targets',
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import os
|
import os
|
||||||
|
import io
|
||||||
import csv
|
import csv
|
||||||
import json
|
import json
|
||||||
|
import tarfile
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@ -377,6 +379,81 @@ class CSVWriter: # pylint: disable=too-many-instance-attributes
|
|||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
|
|
||||||
|
class TarWriter: # pylint: disable=too-many-instance-attributes
|
||||||
|
"""Sample collection writer for writing a CSV data-set and all its referenced WAV samples to a tar file"""
|
||||||
|
def __init__(self,
|
||||||
|
tar_filename,
|
||||||
|
gz=False,
|
||||||
|
labeled=True,
|
||||||
|
include=None):
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
tar_filename : str
|
||||||
|
Path to the tar file to write.
|
||||||
|
gz : bool
|
||||||
|
If to compress tar file with gzip.
|
||||||
|
labeled : bool or None
|
||||||
|
If True: Writes labeled samples (util.sample_collections.LabeledSample) only.
|
||||||
|
If False: Ignores transcripts (if available) and writes (unlabeled) util.audio.Sample instances.
|
||||||
|
include : str[]
|
||||||
|
List of files to include into tar root.
|
||||||
|
"""
|
||||||
|
self.tar = tarfile.open(tar_filename, 'w:gz' if gz else 'w')
|
||||||
|
samples_dir = tarfile.TarInfo('samples')
|
||||||
|
samples_dir.type = tarfile.DIRTYPE
|
||||||
|
self.tar.addfile(samples_dir)
|
||||||
|
if include:
|
||||||
|
for include_path in include:
|
||||||
|
self.tar.add(include_path, recursive=False, arcname=Path(include_path).name)
|
||||||
|
fieldnames = ['wav_filename', 'wav_filesize']
|
||||||
|
self.labeled = labeled
|
||||||
|
if labeled:
|
||||||
|
fieldnames.append('transcript')
|
||||||
|
self.csv_file = io.StringIO()
|
||||||
|
self.csv_writer = csv.DictWriter(self.csv_file, fieldnames=fieldnames)
|
||||||
|
self.csv_writer.writeheader()
|
||||||
|
self.counter = 0
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def add(self, sample):
|
||||||
|
sample_filename = 'samples/sample{0:08d}.wav'.format(self.counter)
|
||||||
|
self.counter += 1
|
||||||
|
sample.change_audio_type(AUDIO_TYPE_PCM)
|
||||||
|
sample_file = io.BytesIO()
|
||||||
|
write_wav(sample_file, sample.audio, audio_format=sample.audio_format)
|
||||||
|
sample_size = sample_file.tell()
|
||||||
|
sample_file.seek(0)
|
||||||
|
sample_tar = tarfile.TarInfo(sample_filename)
|
||||||
|
sample_tar.size = sample_size
|
||||||
|
self.tar.addfile(sample_tar, sample_file)
|
||||||
|
row = {
|
||||||
|
'wav_filename': sample_filename,
|
||||||
|
'wav_filesize': sample_size
|
||||||
|
}
|
||||||
|
if self.labeled:
|
||||||
|
row['transcript'] = sample.transcript
|
||||||
|
self.csv_writer.writerow(row)
|
||||||
|
return sample_filename
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
if self.csv_file and self.tar:
|
||||||
|
csv_tar = tarfile.TarInfo('samples.csv')
|
||||||
|
csv_tar.size = self.csv_file.tell()
|
||||||
|
self.csv_file.seek(0)
|
||||||
|
self.tar.addfile(csv_tar, io.BytesIO(self.csv_file.read().encode('utf8')))
|
||||||
|
if self.tar:
|
||||||
|
self.tar.close()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.counter
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
|
||||||
class SampleList:
|
class SampleList:
|
||||||
"""Sample collection base class with samples loaded from a list of in-memory paths."""
|
"""Sample collection base class with samples loaded from a list of in-memory paths."""
|
||||||
def __init__(self, samples, labeled=True, reverse=False):
|
def __init__(self, samples, labeled=True, reverse=False):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user