Live audio augmentation

This commit is contained in:
Tilman Kamp 2020-04-09 14:23:38 +02:00
parent 927859728f
commit c5ceee26dd
14 changed files with 857 additions and 113 deletions

1
.gitignore vendored
View File

@ -2,6 +2,7 @@
*.pyc
*.swp
*.DS_Store
*.egg-info
.pit*
/.run
/werlog.js

View File

@ -15,7 +15,7 @@ from deepspeech_training.util.audio import (
from deepspeech_training.util.downloader import SIMPLE_BAR
from deepspeech_training.util.sample_collections import (
DirectSDBWriter,
samples_from_files,
samples_from_sources,
)
AUDIO_TYPE_LOOKUP = {"wav": AUDIO_TYPE_WAV, "opus": AUDIO_TYPE_OPUS}
@ -26,12 +26,10 @@ def build_sdb():
with DirectSDBWriter(
CLI_ARGS.target, audio_type=audio_type, labeled=not CLI_ARGS.unlabeled
) as sdb_writer:
samples = samples_from_files(CLI_ARGS.sources, labeled=not CLI_ARGS.unlabeled)
samples = samples_from_sources(CLI_ARGS.sources, labeled=not CLI_ARGS.unlabeled)
bar = progressbar.ProgressBar(max_value=len(samples), widgets=SIMPLE_BAR)
for sample in bar(
change_audio_types(
samples, audio_type=audio_type, processes=CLI_ARGS.workers
)
change_audio_types(samples, audio_type=audio_type, bitrate=CLI_ARGS.bitrate, processes=CLI_ARGS.workers)
):
sdb_writer.add(sample)
@ -55,6 +53,11 @@ def handle_args():
choices=AUDIO_TYPE_LOOKUP.keys(),
help="Audio representation inside target SDB",
)
parser.add_argument(
"--bitrate",
type=int,
help="Bitrate for lossy compressed SDB samples like in case of --audio-type opus",
)
parser.add_argument(
"--workers", type=int, default=None, help="Number of encoding SDB workers"
)

66
bin/compare_samples.py Executable file
View File

@ -0,0 +1,66 @@
#!/usr/bin/env python
"""
Tool for comparing two wav samples
"""
import sys
import argparse
from deepspeech_training.util.audio import AUDIO_TYPE_NP, mean_dbfs
from deepspeech_training.util.sample_collections import load_sample
def fail(message):
print(message, file=sys.stderr, flush=True)
sys.exit(1)
def compare_samples():
sample1 = load_sample(CLI_ARGS.sample1)
sample2 = load_sample(CLI_ARGS.sample2)
if sample1.audio_format != sample2.audio_format:
fail('Samples differ on: audio-format ({} and {})'.format(sample1.audio_format, sample2.audio_format))
if sample1.duration != sample2.duration:
fail('Samples differ on: duration ({} and {})'.format(sample1.duration, sample2.duration))
sample1.change_audio_type(AUDIO_TYPE_NP)
sample2.change_audio_type(AUDIO_TYPE_NP)
audio_diff = sample1.audio - sample2.audio
diff_dbfs = mean_dbfs(audio_diff)
differ_msg = 'Samples differ on: sample data ({:0.2f} dB difference) '.format(diff_dbfs)
equal_msg = 'Samples are considered equal ({:0.2f} dB difference)'.format(diff_dbfs)
if CLI_ARGS.if_differ:
if diff_dbfs <= CLI_ARGS.threshold:
fail(equal_msg)
if not CLI_ARGS.no_success_output:
print(differ_msg, file=sys.stderr, flush=True)
else:
if diff_dbfs > CLI_ARGS.threshold:
fail(differ_msg)
if not CLI_ARGS.no_success_output:
print(equal_msg, file=sys.stderr, flush=True)
def handle_args():
parser = argparse.ArgumentParser(
description="Tool for checking similarity of two samples"
)
parser.add_argument("sample1", help="Filename of sample 1 to compare")
parser.add_argument("sample2", help="Filename of sample 2 to compare")
parser.add_argument("--threshold", type=float, default=-60.0,
help="dB of sample deltas above which they are considered different")
parser.add_argument(
"--if-differ",
action="store_true",
help="If to succeed and return status code 0 on different signals and fail on equal ones (inverse check)."
"This will still fail on different formats or durations.",
)
parser.add_argument(
"--no-success-output",
action="store_true",
help="Stay silent on success (if samples are equal of - with --if-differ - samples are not equal)",
)
return parser.parse_args()
if __name__ == "__main__":
CLI_ARGS = handle_args()
compare_samples()

View File

@ -1,54 +1,72 @@
#!/usr/bin/env python
"""
Tool for playing samples from Sample Databases (SDB files) and DeepSpeech CSV files
Tool for playing (and augmenting) single samples or samples from Sample Databases (SDB files) and DeepSpeech CSV files
Use "python3 build_sdb.py -h" for help
"""
import argparse
import random
import os
import sys
import random
import argparse
from deepspeech_training.util.audio import AUDIO_TYPE_PCM
from deepspeech_training.util.sample_collections import LabeledSample, samples_from_file
from deepspeech_training.util.audio import LOADABLE_AUDIO_EXTENSIONS, AUDIO_TYPE_PCM, AUDIO_TYPE_WAV
from deepspeech_training.util.sample_collections import SampleList, LabeledSample, samples_from_source, prepare_samples
def play_sample(samples, index):
if index < 0:
index = len(samples) + index
if CLI_ARGS.random:
index = random.randint(0, len(samples))
elif index >= len(samples):
print("No sample with index {}".format(CLI_ARGS.start))
sys.exit(1)
sample = samples[index]
print('Sample "{}"'.format(sample.sample_id))
if isinstance(sample, LabeledSample):
print(' "{}"'.format(sample.transcript))
sample.change_audio_type(AUDIO_TYPE_PCM)
rate, channels, width = sample.audio_format
wave_obj = simpleaudio.WaveObject(sample.audio, channels, width, rate)
play_obj = wave_obj.play()
play_obj.wait_done()
def play_collection():
samples = samples_from_file(CLI_ARGS.collection, buffering=0)
def get_samples_in_play_order():
ext = os.path.splitext(CLI_ARGS.source)[1].lower()
if ext in LOADABLE_AUDIO_EXTENSIONS:
samples = SampleList([(CLI_ARGS.source, 0)], labeled=False)
else:
samples = samples_from_source(CLI_ARGS.source, buffering=0)
played = 0
index = CLI_ARGS.start
while True:
if 0 <= CLI_ARGS.number <= played:
return
play_sample(samples, index)
if CLI_ARGS.random:
yield samples[random.randint(0, len(samples) - 1)]
elif index < 0:
yield samples[len(samples) + index]
elif index >= len(samples):
print("No sample with index {}".format(CLI_ARGS.start))
sys.exit(1)
else:
yield samples[index]
played += 1
index = (index + 1) % len(samples)
def play_collection():
samples = get_samples_in_play_order()
samples = prepare_samples(samples,
audio_type=AUDIO_TYPE_PCM,
augmentation_specs=CLI_ARGS.augment,
process_ahead=0,
fixed_clock=CLI_ARGS.clock)
for sample in samples:
if not CLI_ARGS.quiet:
print('Sample "{}"'.format(sample.sample_id), file=sys.stderr)
if isinstance(sample, LabeledSample):
print(' "{}"'.format(sample.transcript), file=sys.stderr)
if CLI_ARGS.pipe:
sample.change_audio_type(AUDIO_TYPE_WAV)
sys.stdout.buffer.write(sample.audio.getvalue())
return
wave_obj = simpleaudio.WaveObject(sample.audio,
sample.audio_format.channels,
sample.audio_format.width,
sample.audio_format.rate)
play_obj = wave_obj.play()
play_obj.wait_done()
def handle_args():
parser = argparse.ArgumentParser(
description="Tool for playing samples from Sample Databases (SDB files) "
description="Tool for playing (and augmenting) single samples or samples from Sample Databases (SDB files) "
"and DeepSpeech CSV files"
)
parser.add_argument("collection", help="Sample DB or CSV file to play samples from")
parser.add_argument("source", help="Sample DB, CSV or WAV file to play samples from")
parser.add_argument(
"--start",
type=int,
@ -66,16 +84,40 @@ def handle_args():
action="store_true",
help="If samples should be played in random order",
)
parser.add_argument(
"--augment",
action='append',
help="Add an augmentation operation",
)
parser.add_argument(
"--clock",
type=float,
default=0.5,
help="Simulates clock value used for augmentations during training."
"Ranges from 0.0 (representing parameter start values) to"
"1.0 (representing parameter end values)",
)
parser.add_argument(
"--pipe",
action="store_true",
help="Pipe first sample as wav file to stdout. Forces --number to 1.",
)
parser.add_argument(
"--quiet",
action="store_true",
help="No info logging to console",
)
return parser.parse_args()
if __name__ == "__main__":
try:
import simpleaudio
except ModuleNotFoundError:
print('play.py requires Python package "simpleaudio"')
sys.exit(1)
CLI_ARGS = handle_args()
if not CLI_ARGS.pipe:
try:
import simpleaudio
except ModuleNotFoundError:
print('Unless using the --pipe flag, play.py requires Python package "simpleaudio" for playing samples')
sys.exit(1)
try:
play_collection()
except KeyboardInterrupt:

View File

@ -0,0 +1,66 @@
#!/bin/sh
set -xe
ldc93s1_dir=`cd data/smoke_test; pwd`
ldc93s1_csv="${ldc93s1_dir}/LDC93S1.csv"
ldc93s1_wav="${ldc93s1_dir}/LDC93S1.wav"
ldc93s1_overlay_csv="${ldc93s1_dir}/LDC93S1_overlay.csv"
ldc93s1_overlay_wav="${ldc93s1_dir}/LDC93S1_reversed.wav"
play="python bin/play.py --number 1 --quiet"
compare="python bin/compare_samples.py --no-success-output"
if [ ! -f "${ldc93s1_csv}" ]; then
echo "Downloading and preprocessing LDC93S1 example data, saving in ${ldc93s1_dir}."
python -u bin/import_ldc93s1.py ${ldc93s1_dir}
fi;
if [ ! -f "${ldc93s1_overlay_csv}" ]; then
echo "Reversing ${ldc93s1_wav} to ${ldc93s1_overlay_wav}."
sox "${ldc93s1_wav}" "${ldc93s1_overlay_wav}" reverse
echo "Creating ${ldc93s1_overlay_csv}."
printf "wav_filename\n${ldc93s1_overlay_wav}" > "${ldc93s1_overlay_csv}"
fi;
if ! $compare --if-differ "${ldc93s1_wav}" "${ldc93s1_overlay_wav}"; then
echo "Sample comparison tool not working correctly"
exit 1
fi
$play ${ldc93s1_wav} --augment overlay[source="${ldc93s1_overlay_csv}",snr=20] --pipe >/tmp/overlay-test.wav
if ! $compare --if-differ "${ldc93s1_wav}" /tmp/overlay-test.wav; then
echo "Overlay augmentation had no effect or changed basic sample properties"
exit 1
fi
$play ${ldc93s1_wav} --augment reverb[delay=50.0,decay=2.0] --pipe >/tmp/reverb-test.wav
if ! $compare --if-differ "${ldc93s1_wav}" /tmp/reverb-test.wav; then
echo "Reverb augmentation had no effect or changed basic sample properties"
exit 1
fi
$play ${ldc93s1_wav} --augment gaps[n=10,size=100.0] --pipe >/tmp/gaps-test.wav
if ! $compare --if-differ "${ldc93s1_wav}" /tmp/gaps-test.wav; then
echo "Gaps augmentation had no effect or changed basic sample properties"
exit 1
fi
$play ${ldc93s1_wav} --augment resample[rate=4000] --pipe >/tmp/resample-test.wav
if ! $compare --if-differ "${ldc93s1_wav}" /tmp/resample-test.wav; then
echo "Resample augmentation had no effect or changed basic sample properties"
exit 1
fi
$play ${ldc93s1_wav} --augment codec[bitrate=4000] --pipe >/tmp/codec-test.wav
if ! $compare --if-differ "${ldc93s1_wav}" /tmp/codec-test.wav; then
echo "Codec augmentation had no effect or changed basic sample properties"
exit 1
fi
$play ${ldc93s1_wav} --augment volume --pipe >/tmp/volume-test.wav
if ! $compare --if-differ "${ldc93s1_wav}" /tmp/volume-test.wav; then
echo "Volume augmentation had no effect or changed basic sample properties"
exit 1
fi

View File

@ -263,21 +263,139 @@ UTF-8 mode
DeepSpeech includes a UTF-8 operating mode which can be useful to model languages with very large alphabets, such as Chinese Mandarin. For details on how it works and how to use it, see :ref:`decoder-docs`.
Training with augmentation
^^^^^^^^^^^^^^^^^^^^^^^^^^
Augmentation
^^^^^^^^^^^^
Augmentation is a useful technique for better generalization of machine learning models. Thus, a pre-processing pipeline with various augmentation techniques on raw pcm and spectrogram has been implemented and can be used while training the model. Following are the available augmentation techniques that can be enabled at training time by using the corresponding flags in the command line.
Audio Augmentation
~~~~~~~~~~~~~~~~~~
Audio Augmentation before feature caching
-----------------------------------------
Augmentations that are applied before potential feature caching can be specified through the ``--augment`` multi-flag.
Each sample of the training data will get treated by every specified augmentation in their given order. However: If an augmentation will actually get applied to a sample, is decided by chance on base of the augmentation's probability value. For example a value of ``p=0.1`` would apply the according augmentation to just 10% of all samples. This also means that augmentations are not mutually exclusive on a per-sample basis.
The ``--augment`` flag's value follows a common format (given by an overlay example):
.. code-block:: bash
python3 DeepSpeech.py --augment overlay[p=0.1,source=/path/to/audio.sdb,snr=20.0] ...
Values specified in the following as ``<float-range>`` or ``<int-range>`` are supporting the following formats:
* ``<value>``: A constant value
* ``<value>~<r>``: A center value with a randomization radius around it. E.g. ``1.2~0.4`` will result in picking of a random value between 0.8 and 1.6 on each sample augmentation.
* ``<start>:<end>``: The value will range from `<start>` at the beginning of an epoch to `<end>` at the end of an epoch. E.g. ``-0.2:1.2`` (float) or ``2000:4000`` (int)
* ``<start>:<end>~<r>``: Combination of the latter two cases with a ranging center value. E.g. ``4-6~2`` would at the beginning of an epoch pick values between 2 and 6 and at the end of an epoch between 4 and 8.
The flag ``--augmentations_per_epoch`` allows to specify how often the whole training-set should be repeated per epoch for re-augmenting all its samples. Be aware: This will also multiply the required size of the feature cache (if enabled).
**Overlay augmentation** ``--augment overlay[p=<float>,source=<str>,snr=<float-range>,layers=<int-range>]``
Layers another audio source (multiple times) onto augmented samples.
* **p**: probability value between 0.0 (never) and 1.0 (always) if a given sample gets augmented by this method
* **source**: path to the sample collection to use for augmenting (*.sdb or *.csv file)
* **snr**: signal to noise ratio in dB - positive values for lowering volume of the overlay in relation to the sample
* **layers**: number of layers of the overlay signal (e.g. 10 layers of speech to get "cocktail-party effect")
**Reverb augmentation** ``--augment reverb[p=<float>,delay=<float-range>,decay=<float-range>]``
Adds reverberation to the augmented samples.
* **p**: probability value between 0.0 (never) and 1.0 (always) if a given sample gets augmented by this method
* **delay**: time delay in ms for the first signal reflection - higher values are widening the perceived "room"
* **decay**: sound decay in dB per reflection - higher values will result in a less reflective perceived "room"
**Gaps augmentation** ``--augment gaps[p=<float>,n=<int-range>,size=<float-range>]``
Zeros time-intervals within the augmented samples.
* **p**: probability value between 0.0 (never) and 1.0 (always) if a given sample gets augmented by this method
* **n**: number of intervals to zero
* **size**: interval durations in ms
**Resample augmentation** ``--augment resample[p=<float>,rate=<int-range>]``
Re-samples augmented samples to another sample-rate and back.
* **p**: probability value between 0.0 (never) and 1.0 (always) if a given sample gets augmented by this method
* **rate**: sample-rate to re-sample to
**Codec augmentation** ``--augment codec[p=<float>,bitrate=<int-range>]``
Compresses and re-expands augmented samples using the lossy Opus audio codec.
* **p**: probability value between 0.0 (never) and 1.0 (always) if a given sample gets augmented by this method
* **bitrate**: bitrate used during compression
**Volume augmentation** ``--augment volume[p=<float>,dbfs=<float-range>]``
Measures and levels augmented samples to a target dBFS value.
* **p**: probability value between 0.0 (never) and 1.0 (always) if a given sample gets augmented by this method
* **dbfs** : target volume in dBFS (default value of 3.0103 will normalize min and max amplitudes to -1.0/1.0)
Example training with all augmentations:
.. code-block:: bash
python -u DeepSpeech.py \
--train_files "train.sdb" \
--augmentations_per_epoch 10 \
--augment overlay[p=0.5,source=noise.sdb,layers=1,snr=50:20~10] \
--augment overlay[p=0.2,source=voices.sdb,layers=10:6,snr=50:20~10] \
--augment reverb[p=0.1,delay=50.0~30.0,decay=10.0:2.0~1.0] \
--augment gaps[p=0.05,n=1:3~2,size=10:100] \
--augment resample[p=0.1,rate=12000:8000~4000] \
--augment codec[p=0.1,bitrate=48000:16000] \
--augment volume[p=0.1,dbfs=-10:-40] \
[...]
The ``bin/play.py`` tool also supports ``--augment`` parameters and can be used for experimenting with different configurations.
Example of playing all samples with reverberation and maximized volume:
.. code-block:: bash
bin/play.py --augment reverb[p=0.1,delay=50.0,decay=2.0] --augment volume --random test.sdb
Example simulation of the codec augmentation of a wav-file first at the beginning and then at the end of an epoch:
.. code-block:: bash
bin/play.py --augment codec[p=0.1,bitrate=48000:16000] --clock 0.0 test.wav
bin/play.py --augment codec[p=0.1,bitrate=48000:16000] --clock 1.0 test.wav
Audio Augmentation after feature caching
----------------------------------------
#. **Standard deviation for Gaussian additive noise:** ``--data_aug_features_additive``
#. **Standard deviation for Normal distribution around 1 for multiplicative noise:** ``--data_aug_features_multiplicative``
#. **Standard deviation for speeding-up tempo. If Standard deviation is 0, this augmentation is not performed:** ``--augmentation_speed_up_std``
Spectrogram Augmentation
~~~~~~~~~~~~~~~~~~~~~~~~
Spectrogram Augmentation after feature caching
----------------------------------------------
Inspired by Google Paper on `SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition <https://arxiv.org/abs/1904.08779>`_

View File

@ -39,6 +39,8 @@ echo "Moving ${sample_name} to LDC93S1.wav"
mv "${DS_ROOT_TASK}/DeepSpeech/ds/data/smoke_test/${sample_name}" "${DS_ROOT_TASK}/DeepSpeech/ds/data/smoke_test/LDC93S1.wav"
pushd ${HOME}/DeepSpeech/ds/
# Testing signal augmentations
time ./bin/run-tc-signal_augmentations.sh
# Run twice to test preprocessed features
time ./bin/run-tc-ldc93s1_new.sh 249 "${sample_rate}"
time ./bin/run-tc-ldc93s1_new.sh 1 "${sample_rate}"

View File

@ -424,6 +424,8 @@ def train():
# Create training and validation datasets
train_set = create_dataset(FLAGS.train_files.split(','),
batch_size=FLAGS.train_batch_size,
repetitions=FLAGS.augmentations_per_epoch,
augmentation_specs=FLAGS.augment,
enable_cache=FLAGS.feature_cache and do_cache_dataset,
cache_path=FLAGS.feature_cache,
train_phase=True,

View File

@ -1,14 +1,16 @@
import os
import io
import wave
import math
import tempfile
import collections
import numpy as np
from .helpers import LimitingPool
from .helpers import LimitingPool, np_capped_squares
from collections import namedtuple
AudioFormat = namedtuple('AudioFormat', 'rate channels width')
dBFS = namedtuple('dBFS', 'mean max')
DEFAULT_RATE = 16000
DEFAULT_CHANNELS = 1
@ -20,6 +22,7 @@ AUDIO_TYPE_PCM = 'application/vnd.mozilla.pcm'
AUDIO_TYPE_WAV = 'audio/wav'
AUDIO_TYPE_OPUS = 'application/vnd.mozilla.opus'
SERIALIZABLE_AUDIO_TYPES = [AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS]
LOADABLE_AUDIO_EXTENSIONS = {'.wav': AUDIO_TYPE_WAV}
OPUS_PCM_LEN_SIZE = 4
OPUS_RATE_SIZE = 4
@ -81,7 +84,7 @@ class Sample:
else:
raise ValueError('Unsupported audio type: {}'.format(self.audio_type))
def change_audio_type(self, new_audio_type):
def change_audio_type(self, new_audio_type, bitrate=None):
"""
In-place conversion of audio data into a different representation.
@ -89,6 +92,8 @@ class Sample:
----------
new_audio_type : str
New audio-type - see `__init__`.
bitrate : int
Bitrate to use in case of converting to a lossy audio-type.
"""
if self.audio_type == new_audio_type:
return
@ -104,7 +109,7 @@ class Sample:
elif new_audio_type in SERIALIZABLE_AUDIO_TYPES:
self.change_audio_type(AUDIO_TYPE_PCM)
audio_bytes = io.BytesIO()
write_audio(new_audio_type, audio_bytes, self.audio, audio_format=self.audio_format)
write_audio(new_audio_type, audio_bytes, self.audio, audio_format=self.audio_format, bitrate=bitrate)
audio_bytes.seek(0)
self.audio = audio_bytes
else:
@ -114,14 +119,20 @@ class Sample:
def _change_audio_type(sample_and_audio_type):
sample, audio_type = sample_and_audio_type
sample.change_audio_type(audio_type)
sample, audio_type, bitrate = sample_and_audio_type
sample.change_audio_type(audio_type, bitrate=bitrate)
return sample
def change_audio_types(samples, audio_type=AUDIO_TYPE_PCM, processes=None, process_ahead=None):
def change_audio_types(samples, audio_type=AUDIO_TYPE_PCM, bitrate=None, processes=None, process_ahead=None):
with LimitingPool(processes=processes, process_ahead=process_ahead) as pool:
yield from pool.imap(_change_audio_type, map(lambda s: (s, audio_type), samples))
yield from pool.imap(_change_audio_type, map(lambda s: (s, audio_type, bitrate), samples))
def get_audio_type_from_extension(ext):
if ext in LOADABLE_AUDIO_EXTENSIONS:
return LOADABLE_AUDIO_EXTENSIONS[ext]
return None
def read_audio_format_from_wav_file(wav_file):
@ -264,10 +275,12 @@ def get_opus_frame_size(rate):
return 60 * rate // 1000
def write_opus(opus_file, audio_data, audio_format=DEFAULT_FORMAT):
def write_opus(opus_file, audio_data, audio_format=DEFAULT_FORMAT, bitrate=None):
frame_size = get_opus_frame_size(audio_format.rate)
import opuslib # pylint: disable=import-outside-toplevel
encoder = opuslib.Encoder(audio_format.rate, audio_format.channels, 'audio')
if bitrate is not None:
encoder.bitrate = bitrate
chunk_size = frame_size * audio_format.channels * audio_format.width
opus_file.write(pack_number(len(audio_data), OPUS_PCM_LEN_SIZE))
opus_file.write(pack_number(audio_format.rate, OPUS_RATE_SIZE))
@ -277,7 +290,7 @@ def write_opus(opus_file, audio_data, audio_format=DEFAULT_FORMAT):
chunk = audio_data[i:i + chunk_size]
# Preventing non-deterministic encoding results from uninitialized remainder of the encoder buffer
if len(chunk) < chunk_size:
chunk = chunk + bytearray(chunk_size - len(chunk))
chunk = chunk + b'\0' * (chunk_size - len(chunk))
encoded = encoder.encode(chunk, frame_size)
opus_file.write(pack_number(len(encoded), OPUS_CHUNK_LEN_SIZE))
opus_file.write(encoded)
@ -304,7 +317,7 @@ def read_opus(opus_file):
decoded = decoder.decode(chunk, frame_size)
audio_data.extend(decoded)
audio_data = audio_data[:pcm_buffer_size]
return audio_format, audio_data
return audio_format, bytes(audio_data)
def write_wav(wav_file, pcm_data, audio_format=DEFAULT_FORMAT):
@ -331,11 +344,11 @@ def read_audio(audio_type, audio_file):
raise ValueError('Unsupported audio type: {}'.format(audio_type))
def write_audio(audio_type, audio_file, pcm_data, audio_format=DEFAULT_FORMAT):
def write_audio(audio_type, audio_file, pcm_data, audio_format=DEFAULT_FORMAT, bitrate=None):
if audio_type == AUDIO_TYPE_WAV:
return write_wav(audio_file, pcm_data, audio_format=audio_format)
if audio_type == AUDIO_TYPE_OPUS:
return write_opus(audio_file, pcm_data, audio_format=audio_format)
return write_opus(audio_file, pcm_data, audio_format=audio_format, bitrate=bitrate)
raise ValueError('Unsupported audio type: {}'.format(audio_type))
@ -376,6 +389,26 @@ def np_to_pcm(np_data, audio_format=DEFAULT_FORMAT):
assert audio_format.channels == 1 # only mono supported for now
dtype = get_dtype(audio_format)
np_data = np_data.squeeze()
np_data *= dtype.max
np_data = np_data * np.iinfo(dtype).max
np_data = np_data.astype(dtype)
return bytearray(np_data.tobytes())
return np_data.tobytes()
def rms_to_dbfs(rms):
return 20.0 * math.log10(max(1e-16, rms)) + 3.0103
def mean_dbfs(sample_data):
return rms_to_dbfs(math.sqrt(np.mean(np_capped_squares(sample_data))))
def max_dbfs(sample_data):
return rms_to_dbfs(max(abs(np.min(sample_data)), abs(np.max(sample_data))))
def gain_db_to_ratio(gain_db):
return math.pow(10.0, gain_db / 20.0)
def normalize_audio(sample_data, dbfs=3.0103):
return np.maximum(np.minimum(sample_data * gain_db_to_ratio(dbfs - max_dbfs(sample_data)), 1.0), -1.0)

View File

@ -12,8 +12,8 @@ from .config import Config
from .text import text_to_char_array
from .flags import FLAGS
from .spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up, augment_sparse_warp
from .audio import change_audio_types, read_frames_from_file, vad_split, pcm_to_np, DEFAULT_FORMAT, AUDIO_TYPE_NP
from .sample_collections import samples_from_files
from .audio import read_frames_from_file, vad_split, pcm_to_np, DEFAULT_FORMAT
from .sample_collections import samples_from_sources, prepare_samples
from .helpers import remember_exception, MEGABYTE
@ -109,6 +109,8 @@ def to_sparse_tuple(sequence):
def create_dataset(sources,
batch_size,
repetitions=1,
augmentation_specs=None,
enable_cache=False,
cache_path=None,
train_phase=False,
@ -116,10 +118,13 @@ def create_dataset(sources,
process_ahead=None,
buffering=1 * MEGABYTE):
def generate_values():
samples = samples_from_files(sources, buffering=buffering, labeled=True)
for sample in change_audio_types(samples,
AUDIO_TYPE_NP,
process_ahead=2 * batch_size if process_ahead is None else process_ahead):
samples = samples_from_sources(sources, buffering=buffering, labeled=True)
samples = prepare_samples(samples,
repetitions=repetitions,
augmentation_specs=augmentation_specs,
buffering=buffering,
process_ahead=2 * batch_size if process_ahead is None else process_ahead)
for sample in samples:
transcript = text_to_char_array(sample.transcript, Config.alphabet, context=sample.sample_id)
transcript = to_sparse_tuple(transcript)
yield sample.sample_id, sample.audio, sample.audio_format.rate, transcript

View File

@ -26,6 +26,9 @@ def create_flags():
# Data Augmentation
# ================
f.DEFINE_multi_string('augment', None, 'specifies an augmentation of the training samples. Format is "--augment operation[param1=value1, ...]"')
f.DEFINE_integer('augmentations_per_epoch', 1, 'how often the train set should be repeated and re-augmented per epoch')
f.DEFINE_float('data_aug_features_additive', 0, 'std of the Gaussian additive noise')
f.DEFINE_float('data_aug_features_multiplicative', 0, 'std of normal distribution around 1 for multiplicative noise')

View File

@ -1,10 +1,14 @@
import os
import sys
import time
import math
import heapq
import semver
import random
import numpy as np
from multiprocessing import Pool
from collections import namedtuple
KILO = 1024
KILOBYTE = 1 * KILO
@ -13,6 +17,8 @@ GIGABYTE = KILO * MEGABYTE
TERABYTE = KILO * GIGABYTE
SIZE_PREFIX_LOOKUP = {'k': KILOBYTE, 'm': MEGABYTE, 'g': GIGABYTE, 't': TERABYTE}
ValueRange = namedtuple('ValueRange', 'start end r')
def parse_file_size(file_size):
file_size = file_size.lower().strip()
@ -79,11 +85,11 @@ class LimitingPool:
"""Limits unbound ahead-processing of multiprocessing.Pool's imap method
before items get consumed by the iteration caller.
This prevents OOM issues in situations where items represent larger memory allocations."""
def __init__(self, processes=None, process_ahead=None, sleeping_for=0.1):
def __init__(self, processes=None, initializer=None, initargs=None, process_ahead=None, sleeping_for=0.1):
self.process_ahead = os.cpu_count() if process_ahead is None else process_ahead
self.sleeping_for = sleeping_for
self.processed = 0
self.pool = Pool(processes=processes)
self.pool = Pool(processes=processes, initializer=initializer, initargs=initargs)
def __enter__(self):
return self
@ -100,6 +106,9 @@ class LimitingPool:
self.processed -= 1
yield obj
def terminate(self):
self.pool.terminate()
def __exit__(self, exc_type, exc_value, traceback):
self.pool.close()
@ -128,3 +137,54 @@ def remember_exception(iterable, exception_box=None):
except Exception as ex: # pylint: disable = broad-except
exception_box.exception = ex
return iterable if exception_box is None else do_iterate
def get_value_range(value, target_type):
if isinstance(value, str):
r = target_type(0)
parts = value.split('~')
if len(parts) == 2:
value = parts[0]
r = target_type(parts[1])
elif len(parts) > 2:
raise ValueError('Cannot parse value range')
parts = value.split(':')
if len(parts) == 1:
parts.append(parts[0])
elif len(parts) > 2:
raise ValueError('Cannot parse value range')
return ValueRange(target_type(parts[0]), target_type(parts[1]), r)
if isinstance(value, tuple):
if len(value) == 2:
return ValueRange(target_type(value[0]), target_type(value[1]), 0)
if len(value) == 3:
return ValueRange(target_type(value[0]), target_type(value[1]), target_type(value[1]))
raise ValueError('Cannot convert to ValueRange: Wrong tuple size')
return ValueRange(target_type(value), target_type(value), 0)
def int_range(value):
return get_value_range(value, int)
def float_range(value):
return get_value_range(value, float)
def pick_value_from_range(value_range, clock=None):
clock = random.random() if clock is None else max(0.0, min(1.0, float(clock)))
value = value_range.start + clock * (value_range.end - value_range.start)
value = random.uniform(value - value_range.r, value + value_range.r)
return round(value) if isinstance(value_range.start, int) else value
def call_if_exists(o, name, *args, **kwargs):
method = getattr(o, name, None)
if callable(method):
method(*args, **kwargs)
def np_capped_squares(data):
sqrt_max = math.sqrt(np.finfo(data.dtype).max)
data = np.minimum(np.maximum(data, -sqrt_max), sqrt_max) # prevent overflow during squaring
return data ** 2

View File

@ -2,11 +2,14 @@
import os
import csv
import json
import random
from pathlib import Path
from functools import partial
from .helpers import MEGABYTE, GIGABYTE, Interleaved
from .audio import Sample, DEFAULT_FORMAT, AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS, SERIALIZABLE_AUDIO_TYPES
from .signal_augmentations import parse_augmentation
from .helpers import MEGABYTE, GIGABYTE, Interleaved, LimitingPool, call_if_exists
from .audio import Sample, DEFAULT_FORMAT, AUDIO_TYPE_OPUS, AUDIO_TYPE_NP, SERIALIZABLE_AUDIO_TYPES, get_audio_type_from_extension
BIG_ENDIAN = 'big'
INT_SIZE = 4
@ -47,9 +50,42 @@ class LabeledSample(Sample):
self.transcript = transcript
def load_sample(filename, label=None):
"""
Loads audio-file as a (labeled or unlabeled) sample
Parameters
----------
filename : str
Filename of the audio-file to load as sample
label : str
Label (transcript) of the sample.
If None: return util.audio.Sample instance
Otherwise: return util.sample_collections.LabeledSample instance
Returns
-------
util.audio.Sample instance if label is None, else util.sample_collections.LabeledSample instance
"""
ext = os.path.splitext(filename)[1].lower()
audio_type = get_audio_type_from_extension(ext)
if audio_type is None:
raise ValueError('Unknown audio type extension "{}"'.format(ext))
with open(filename, 'rb') as audio_file:
if label is None:
return Sample(audio_type, audio_file.read(), sample_id=filename)
return LabeledSample(audio_type, audio_file.read(), label, sample_id=filename)
class DirectSDBWriter:
"""Sample collection writer for creating a Sample DB (SDB) file"""
def __init__(self, sdb_filename, buffering=BUFFER_SIZE, audio_type=AUDIO_TYPE_OPUS, id_prefix=None, labeled=True):
def __init__(self,
sdb_filename,
buffering=BUFFER_SIZE,
audio_type=AUDIO_TYPE_OPUS,
bitrate=None,
id_prefix=None,
labeled=True):
"""
Parameters
----------
@ -59,6 +95,8 @@ class DirectSDBWriter:
Write-buffer size to use while writing the SDB file
audio_type : str
See util.audio.Sample.__init__ .
bitrate : int
Bitrate for sample-compression in case of lossy audio_type (e.g. AUDIO_TYPE_OPUS)
id_prefix : str
Prefix for IDs of written samples - defaults to sdb_filename
labeled : bool or None
@ -71,6 +109,7 @@ class DirectSDBWriter:
if audio_type not in SERIALIZABLE_AUDIO_TYPES:
raise ValueError('Audio type "{}" not supported'.format(audio_type))
self.audio_type = audio_type
self.bitrate = bitrate
self.sdb_file = open(sdb_filename, 'wb', buffering=buffering)
self.offsets = []
self.num_samples = 0
@ -100,7 +139,7 @@ class DirectSDBWriter:
def add(self, sample):
def to_bytes(n):
return n.to_bytes(INT_SIZE, BIG_ENDIAN)
sample.change_audio_type(self.audio_type)
sample.change_audio_type(self.audio_type, bitrate=self.bitrate)
opus = sample.audio.getbuffer()
opus_len = to_bytes(len(opus))
if self.labeled:
@ -260,7 +299,36 @@ class SDB: # pylint: disable=too-many-instance-attributes
self.close()
class CSV:
class SampleList:
"""Sample collection reader for reading a DeepSpeech CSV file
Automatically orders samples by CSV column wav_filesize (if available)."""
def __init__(self, samples, labeled=True):
"""
Parameters
----------
samples : iterable of tuples of the form (sample_filename, filesize [, transcript])
File-size is used for ordering the samples; transcript has to be provided if labeled=True
labeled : bool or None
If True: Reads LabeledSample instances. Fails, if CSV file has no transcript column.
If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
"""
self.labeled = labeled
self.samples = list(samples)
self.samples.sort(key=lambda r: r[1])
def __getitem__(self, i):
sample_spec = self.samples[i]
return load_sample(sample_spec[0], label=sample_spec[2] if self.labeled else None)
def __iter__(self):
for i in range(len(self.samples)):
yield self[i]
def __len__(self):
return len(self.samples)
class CSV(SampleList):
"""Sample collection reader for reading a DeepSpeech CSV file
Automatically orders samples by CSV column wav_filesize (if available)."""
def __init__(self, csv_filename, labeled=None):
@ -275,16 +343,14 @@ class CSV:
If None: Automatically determines if CSV file has a transcript column
(reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances).
"""
self.csv_filename = csv_filename
self.labeled = labeled
self.rows = []
rows = []
csv_dir = Path(csv_filename).parent
with open(csv_filename, 'r', encoding='utf8') as csv_file:
reader = csv.DictReader(csv_file)
if 'transcript' in reader.fieldnames:
if self.labeled is None:
self.labeled = True
elif self.labeled:
if labeled is None:
labeled = True
elif labeled:
raise RuntimeError('No transcript data (missing CSV column)')
for row in reader:
wav_filename = Path(row['wav_filename'])
@ -292,36 +358,20 @@ class CSV:
wav_filename = csv_dir / wav_filename
wav_filename = str(wav_filename)
wav_filesize = int(row['wav_filesize']) if 'wav_filesize' in row else 0
if self.labeled:
self.rows.append((wav_filename, wav_filesize, row['transcript']))
if labeled:
rows.append((wav_filename, wav_filesize, row['transcript']))
else:
self.rows.append((wav_filename, wav_filesize))
self.rows.sort(key=lambda r: r[1])
def __getitem__(self, i):
row = self.rows[i]
wav_filename = row[0]
with open(wav_filename, 'rb') as wav_file:
if self.labeled:
return LabeledSample(AUDIO_TYPE_WAV, wav_file.read(), row[2], sample_id=wav_filename)
return Sample(AUDIO_TYPE_WAV, wav_file.read(), sample_id=wav_filename)
def __iter__(self):
for i in range(len(self.rows)):
yield self[i]
def __len__(self):
return len(self.rows)
rows.append((wav_filename, wav_filesize))
super(CSV, self).__init__(rows, labeled=labeled)
def samples_from_file(filename, buffering=BUFFER_SIZE, labeled=None):
def samples_from_source(sample_source, buffering=BUFFER_SIZE, labeled=None):
"""
Returns an iterable of util.sample_collections.LabeledSample or util.audio.Sample instances
loaded from a sample source file.
Loads samples from a sample source file.
Parameters
----------
filename : str
sample_source : str
Path to the sample source file (SDB or CSV)
buffering : int
Read-buffer size to use while reading files
@ -330,23 +380,27 @@ def samples_from_file(filename, buffering=BUFFER_SIZE, labeled=None):
If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
If None: Automatically determines if source provides transcripts
(reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances).
Returns
-------
iterable of util.sample_collections.LabeledSample or util.audio.Sample instances supporting len.
"""
ext = os.path.splitext(filename)[1].lower()
ext = os.path.splitext(sample_source)[1].lower()
if ext == '.sdb':
return SDB(filename, buffering=buffering, labeled=labeled)
return SDB(sample_source, buffering=buffering, labeled=labeled)
if ext == '.csv':
return CSV(filename, labeled=labeled)
return CSV(sample_source, labeled=labeled)
raise ValueError('Unknown file type: "{}"'.format(ext))
def samples_from_files(filenames, buffering=BUFFER_SIZE, labeled=None):
def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None):
"""
Returns an iterable of util.sample_collections.LabeledSample or util.audio.Sample instances
loaded from a collection of sample source files.
Loads and combines samples from a list of source files. Sources are combined in an interleaving way to
keep default sample order from shortest to longest.
Parameters
----------
filenames : list of str
sample_sources : list of str
Paths to sample source files (SDBs or CSVs)
buffering : int
Read-buffer size to use while reading files
@ -355,11 +409,100 @@ def samples_from_files(filenames, buffering=BUFFER_SIZE, labeled=None):
If False: Ignores transcripts (if available) and always reads (unlabeled) util.audio.Sample instances.
If None: Reads util.sample_collections.LabeledSample instances from sources with transcripts and
util.audio.Sample instances from sources with no transcripts.
Returns
-------
iterable of util.sample_collections.LabeledSample (labeled=True) or util.audio.Sample (labeled=False) supporting len
"""
filenames = list(filenames)
if len(filenames) == 0:
sample_sources = list(sample_sources)
if len(sample_sources) == 0:
raise ValueError('No files')
if len(filenames) == 1:
return samples_from_file(filenames[0], buffering=buffering, labeled=labeled)
cols = list(map(partial(samples_from_file, buffering=buffering, labeled=labeled), filenames))
if len(sample_sources) == 1:
return samples_from_source(sample_sources[0], buffering=buffering, labeled=labeled)
cols = list(map(partial(samples_from_source, buffering=buffering, labeled=labeled), sample_sources))
return Interleaved(*cols, key=lambda s: s.duration)
class PreparationContext:
def __init__(self, target_audio_type, augmentations):
self.target_audio_type = target_audio_type
self.augmentations = augmentations
PREPARATION_CONTEXT = None
def _init_preparation_worker(preparation_context):
global PREPARATION_CONTEXT # pylint: disable=global-statement
PREPARATION_CONTEXT = preparation_context
def _prepare_sample(timed_sample, context=None):
context = PREPARATION_CONTEXT if context is None else context
sample, clock = timed_sample
for augmentation in context.augmentations:
if random.random() < augmentation.probability:
augmentation.apply(sample, clock)
sample.change_audio_type(new_audio_type=context.target_audio_type)
return sample
def prepare_samples(samples,
audio_type=AUDIO_TYPE_NP,
augmentation_specs=None,
buffering=BUFFER_SIZE,
process_ahead=None,
repetitions=1,
fixed_clock=None):
"""
Prepares samples for being used during training.
This includes parallel and buffered application of augmentations and a conversion to a specified audio-type.
Parameters
----------
samples : Sample enumeration
Typically produced by samples_from_sources.
audio_type : str
Target audio-type to convert samples to. See util.audio.Sample.__init__ .
augmentation_specs : list of str
Augmentation specifications like ["reverb[delay=20.0,decay=-20]", "volume"]. See TRAINING.rst.
buffering : int
Read-buffer size to use while reading files.
process_ahead : int
Number of samples to pre-process ahead of time.
repetitions : int
How often the input sample enumeration should get repeated for being re-augmented.
fixed_clock : float
Sets the internal clock to a value between 0.0 (beginning of epoch) and 1.0 (end of epoch).
Setting this to a number is used for simulating augmentations at a certain epoch-time.
If kept at None (default), the internal clock will run regularly from 0.0 to 1.0,
hence preparing them for training.
Returns
-------
iterable of util.sample_collections.LabeledSample or util.audio.Sample
"""
def timed_samples():
for repetition in range(repetitions):
for sample_index, sample in enumerate(samples):
if fixed_clock is None:
yield sample, (repetition * len(samples) + sample_index) / (repetitions * len(samples))
else:
yield sample, fixed_clock
augmentations = [] if augmentation_specs is None else list(map(parse_augmentation, augmentation_specs))
try:
for augmentation in augmentations:
call_if_exists(augmentation, 'start', buffering=buffering)
context = PreparationContext(audio_type, augmentations)
if process_ahead == 0:
for timed_sample in timed_samples():
yield _prepare_sample(timed_sample, context=context)
else:
with LimitingPool(process_ahead=process_ahead,
initializer=_init_preparation_worker,
initargs=(context,)) as pool:
yield from pool.imap(_prepare_sample, timed_samples())
finally:
for augmentation in augmentations:
call_if_exists(augmentation, 'stop')

View File

@ -0,0 +1,200 @@
import os
import re
import math
import random
import numpy as np
from multiprocessing import Queue, Process
from .audio import gain_db_to_ratio, max_dbfs, normalize_audio, AUDIO_TYPE_NP, AUDIO_TYPE_PCM, AUDIO_TYPE_OPUS
from .helpers import int_range, float_range, pick_value_from_range, MEGABYTE
SPEC_PARSER = re.compile(r'^([a-z]+)(\[(.*)\])?$')
BUFFER_SIZE = 1 * MEGABYTE
def _enqueue_overlay_samples(sample_source, queue, buffering=BUFFER_SIZE):
# preventing cyclic import problems
from .sample_collections import samples_from_source # pylint: disable=import-outside-toplevel
samples = samples_from_source(sample_source, buffering=buffering, labeled=False)
while True:
for sample in samples:
queue.put(sample)
class Overlay:
"""See "Overlay augmentation" in TRAINING.rst"""
def __init__(self, source, p=1.0, snr=3.0, layers=1):
self.source = source
self.probability = float(p)
self.snr = float_range(snr)
self.layers = int_range(layers)
self.queue = Queue(max(1, math.floor(self.probability * self.layers[1] * os.cpu_count())))
self.current_sample = None
self.enqueue_process = None
def start(self, buffering=BUFFER_SIZE):
self.enqueue_process = Process(target=_enqueue_overlay_samples,
args=(self.source, self.queue),
kwargs={'buffering': buffering})
self.enqueue_process.start()
def apply(self, sample, clock):
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
n_layers = pick_value_from_range(self.layers, clock=clock)
audio = sample.audio
overlay_data = np.zeros_like(audio)
for _ in range(n_layers):
overlay_offset = 0
while overlay_offset < len(audio):
if self.current_sample is None:
next_overlay_sample = self.queue.get()
next_overlay_sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
self.current_sample = next_overlay_sample.audio
n_required = len(audio) - overlay_offset
n_current = len(self.current_sample)
if n_required >= n_current: # take it completely
overlay_data[overlay_offset:overlay_offset + n_current] += self.current_sample
overlay_offset += n_current
self.current_sample = None
else: # take required slice from head and keep tail for next layer or sample
overlay_data[overlay_offset:overlay_offset + n_required] += self.current_sample[0:n_required]
overlay_offset += n_required
self.current_sample = self.current_sample[n_required:]
snr_db = pick_value_from_range(self.snr, clock=clock)
orig_dbfs = max_dbfs(audio)
overlay_gain = orig_dbfs - max_dbfs(overlay_data) - snr_db
audio += overlay_data * gain_db_to_ratio(overlay_gain)
sample.audio = normalize_audio(audio, dbfs=orig_dbfs)
def stop(self):
if self.enqueue_process is not None:
self.enqueue_process.terminate()
class Reverb:
"""See "Reverb augmentation" in TRAINING.rst"""
def __init__(self, p=1.0, delay=20.0, decay=10.0):
self.probability = float(p)
self.delay = float_range(delay)
self.decay = float_range(decay)
def apply(self, sample, clock):
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
audio = np.array(sample.audio, dtype=np.float64)
orig_dbfs = max_dbfs(audio)
delay = pick_value_from_range(self.delay, clock=clock)
decay = pick_value_from_range(self.decay, clock=clock)
decay = gain_db_to_ratio(-decay)
result = np.copy(audio)
primes = [17, 19, 23, 29, 31]
for delay_prime in primes: # primes to minimize comb filter interference
layer = np.copy(audio)
n_delay = math.floor(delay * (delay_prime / primes[0]) * sample.audio_format.rate / 1000.0)
n_delay = max(16, n_delay) # 16 samples minimum to avoid performance trap and risk of division by zero
for w_index in range(0, math.floor(len(audio) / n_delay)):
w1 = w_index * n_delay
w2 = (w_index + 1) * n_delay
width = min(len(audio) - w2, n_delay) # last window could be smaller
layer[w2:w2 + width] += decay * layer[w1:w1 + width]
result += layer
audio = normalize_audio(result, dbfs=orig_dbfs)
sample.audio = np.array(audio, dtype=np.float32)
class Resample:
"""See "Resample augmentation" in TRAINING.rst"""
def __init__(self, p=1.0, rate=8000):
self.probability = float(p)
self.rate = int_range(rate)
def apply(self, sample, clock):
# late binding librosa and its dependencies
from librosa.core import resample # pylint: disable=import-outside-toplevel
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
rate = pick_value_from_range(self.rate, clock=clock)
audio = sample.audio
orig_len = len(audio)
audio = np.swapaxes(audio, 0, 1)
audio = resample(audio, sample.audio_format.rate, rate)
audio = resample(audio, rate, sample.audio_format.rate)
audio = np.swapaxes(audio, 0, 1)[0:orig_len]
sample.audio = audio
class Codec:
"""See "Codec augmentation" in TRAINING.rst"""
def __init__(self, p=1.0, bitrate=3200):
self.probability = float(p)
self.bitrate = int_range(bitrate)
def apply(self, sample, clock):
bitrate = pick_value_from_range(self.bitrate, clock=clock)
sample.change_audio_type(new_audio_type=AUDIO_TYPE_PCM) # decoding to ensure it has to get encoded again
sample.change_audio_type(new_audio_type=AUDIO_TYPE_OPUS, bitrate=bitrate) # will get decoded again downstream
class Gaps:
"""See "Gaps augmentation" in TRAINING.rst"""
def __init__(self, p=1.0, n=1, size=50.0):
self.probability = float(p)
self.n_gaps = int_range(n)
self.size = float_range(size)
def apply(self, sample, clock):
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
audio = sample.audio
n_gaps = pick_value_from_range(self.n_gaps, clock=clock)
for _ in range(n_gaps):
size = pick_value_from_range(self.size, clock=clock)
size = int(size * sample.audio_format.rate / 1000.0)
size = min(size, len(audio) // 10) # a gap should never exceed 10 percent of the audio
offset = random.randint(0, max(0, len(audio) - size - 1))
audio[offset:offset + size] = 0
sample.audio = audio
class Volume:
"""See "Volume augmentation" in TRAINING.rst"""
def __init__(self, p=1.0, dbfs=3.0103):
self.probability = float(p)
self.target_dbfs = float_range(dbfs)
def apply(self, sample, clock):
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
target_dbfs = pick_value_from_range(self.target_dbfs, clock=clock)
sample.audio = normalize_audio(sample.audio, dbfs=target_dbfs)
def parse_augmentation(augmentation_spec):
"""
Parses an augmentation specification.
Parameters
----------
augmentation_spec : str
Augmentation specification like "reverb[delay=20.0,decay=-20]".
Returns
-------
Instance of an augmentation class from util.signal_augmentations.*.
"""
match = SPEC_PARSER.match(augmentation_spec)
if not match:
raise ValueError('Augmentation specification has wrong format')
cls_name = match.group(1)[0].upper() + match.group(1)[1:]
if cls_name not in globals():
raise ValueError('Unknown augmentation: {}'.format(cls_name))
augmentation_cls = globals()[cls_name]
parameters = [] if match.group(3) is None else match.group(3).split(',')
args = []
kwargs = {}
for parameter in parameters:
pair = tuple(list(map(str.strip, (parameter.split('=')))))
if len(pair) == 1:
args.append(pair)
elif len(pair) == 2:
kwargs[pair[0]] = pair[1]
else:
raise ValueError('Unable to parse augmentation value assignment')
return augmentation_cls(*args, **kwargs)