Merge pull request #3145 from tilmankamp/build_sdb_aug

Resolves #3144 - Add augmentation support to build_sdb.py
This commit is contained in:
Tilman Kamp 2020-07-09 14:35:05 +02:00 committed by GitHub
commit 84f4c15278
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 18 additions and 1 deletions

View File

@ -8,6 +8,7 @@ import argparse
import progressbar import progressbar
from deepspeech_training.util.audio import ( from deepspeech_training.util.audio import (
AUDIO_TYPE_PCM,
AUDIO_TYPE_OPUS, AUDIO_TYPE_OPUS,
AUDIO_TYPE_WAV, AUDIO_TYPE_WAV,
change_audio_types, change_audio_types,
@ -17,17 +18,28 @@ from deepspeech_training.util.sample_collections import (
DirectSDBWriter, DirectSDBWriter,
samples_from_sources, samples_from_sources,
) )
from deepspeech_training.util.augmentations import (
parse_augmentations,
apply_sample_augmentations,
SampleAugmentation
)
AUDIO_TYPE_LOOKUP = {"wav": AUDIO_TYPE_WAV, "opus": AUDIO_TYPE_OPUS} AUDIO_TYPE_LOOKUP = {"wav": AUDIO_TYPE_WAV, "opus": AUDIO_TYPE_OPUS}
def build_sdb(): def build_sdb():
audio_type = AUDIO_TYPE_LOOKUP[CLI_ARGS.audio_type] audio_type = AUDIO_TYPE_LOOKUP[CLI_ARGS.audio_type]
augmentations = parse_augmentations(CLI_ARGS.augment)
if any(not isinstance(a, SampleAugmentation) for a in augmentations):
print("Warning: Some of the augmentations cannot be applied by this command.")
with DirectSDBWriter( with DirectSDBWriter(
CLI_ARGS.target, audio_type=audio_type, labeled=not CLI_ARGS.unlabeled CLI_ARGS.target, audio_type=audio_type, labeled=not CLI_ARGS.unlabeled
) as sdb_writer: ) as sdb_writer:
samples = samples_from_sources(CLI_ARGS.sources, labeled=not CLI_ARGS.unlabeled) samples = samples_from_sources(CLI_ARGS.sources, labeled=not CLI_ARGS.unlabeled)
bar = progressbar.ProgressBar(max_value=len(samples), widgets=SIMPLE_BAR) num_samples = len(samples)
if augmentations:
samples = apply_sample_augmentations(samples, audio_type=AUDIO_TYPE_PCM, augmentations=augmentations)
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for sample in bar( for sample in bar(
change_audio_types(samples, audio_type=audio_type, bitrate=CLI_ARGS.bitrate, processes=CLI_ARGS.workers) change_audio_types(samples, audio_type=audio_type, bitrate=CLI_ARGS.bitrate, processes=CLI_ARGS.workers)
): ):
@ -67,6 +79,11 @@ def handle_args():
help="If to build an SDB with unlabeled (audio only) samples - " help="If to build an SDB with unlabeled (audio only) samples - "
"typically used for building noise augmentation corpora", "typically used for building noise augmentation corpora",
) )
parser.add_argument(
"--augment",
action='append',
help="Add an augmentation operation",
)
return parser.parse_args() return parser.parse_args()