SDB support
This commit is contained in:
parent
3bd0b20bf7
commit
6b1d6773de
|
@ -33,7 +33,7 @@ from util.config import Config, initialize_globals
|
|||
from util.checkpoints import load_or_init_graph
|
||||
from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
|
||||
from util.flags import create_flags, FLAGS
|
||||
from util.helpers import check_ctcdecoder_version
|
||||
from util.helpers import check_ctcdecoder_version, ExceptionBox
|
||||
from util.logging import log_info, log_error, log_debug, log_progress, create_progressbar
|
||||
|
||||
check_ctcdecoder_version()
|
||||
|
@ -418,12 +418,17 @@ def train():
|
|||
FLAGS.augmentation_sparse_warp):
|
||||
do_cache_dataset = False
|
||||
|
||||
exception_box = ExceptionBox()
|
||||
|
||||
# Create training and validation datasets
|
||||
train_set = create_dataset(FLAGS.train_files.split(','),
|
||||
batch_size=FLAGS.train_batch_size,
|
||||
enable_cache=FLAGS.feature_cache and do_cache_dataset,
|
||||
cache_path=FLAGS.feature_cache,
|
||||
train_phase=True)
|
||||
train_phase=True,
|
||||
exception_box=exception_box,
|
||||
process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2,
|
||||
buffering=FLAGS.read_buffer)
|
||||
|
||||
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
|
||||
tfv1.data.get_output_shapes(train_set),
|
||||
|
@ -433,8 +438,13 @@ def train():
|
|||
train_init_op = iterator.make_initializer(train_set)
|
||||
|
||||
if FLAGS.dev_files:
|
||||
dev_csvs = FLAGS.dev_files.split(',')
|
||||
dev_sets = [create_dataset([csv], batch_size=FLAGS.dev_batch_size, train_phase=False) for csv in dev_csvs]
|
||||
dev_sources = FLAGS.dev_files.split(',')
|
||||
dev_sets = [create_dataset([source],
|
||||
batch_size=FLAGS.dev_batch_size,
|
||||
train_phase=False,
|
||||
exception_box=exception_box,
|
||||
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
|
||||
buffering=FLAGS.read_buffer) for source in dev_sources]
|
||||
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
|
||||
|
||||
# Dropout
|
||||
|
@ -540,6 +550,7 @@ def train():
|
|||
_, current_step, batch_loss, problem_files, step_summary = \
|
||||
session.run([train_op, global_step, loss, non_finite_files, step_summaries_op],
|
||||
feed_dict=feed_dict)
|
||||
exception_box.raise_if_set()
|
||||
except tf.errors.InvalidArgumentError as err:
|
||||
if FLAGS.augmentation_sparse_warp:
|
||||
log_info("Ignoring sparse warp error: {}".format(err))
|
||||
|
@ -547,6 +558,7 @@ def train():
|
|||
else:
|
||||
raise
|
||||
except tf.errors.OutOfRangeError:
|
||||
exception_box.raise_if_set()
|
||||
break
|
||||
|
||||
if problem_files.size > 0:
|
||||
|
@ -586,12 +598,12 @@ def train():
|
|||
# Validation
|
||||
dev_loss = 0.0
|
||||
total_steps = 0
|
||||
for csv, init_op in zip(dev_csvs, dev_init_ops):
|
||||
log_progress('Validating epoch %d on %s...' % (epoch, csv))
|
||||
set_loss, steps = run_set('dev', epoch, init_op, dataset=csv)
|
||||
for source, init_op in zip(dev_sources, dev_init_ops):
|
||||
log_progress('Validating epoch %d on %s...' % (epoch, source))
|
||||
set_loss, steps = run_set('dev', epoch, init_op, dataset=source)
|
||||
dev_loss += set_loss * steps
|
||||
total_steps += steps
|
||||
log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, csv, set_loss))
|
||||
log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, source, set_loss))
|
||||
|
||||
dev_loss = dev_loss / total_steps
|
||||
dev_losses.append(dev_loss)
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
#!/usr/bin/env python
|
||||
'''
|
||||
Tool for building Sample Databases (SDB files) from DeepSpeech CSV files and other SDB files
|
||||
Use "python3 build_sdb.py -h" for help
|
||||
'''
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
import argparse
|
||||
import progressbar
|
||||
|
||||
from util.downloader import SIMPLE_BAR
|
||||
from util.audio import change_audio_types, AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS
|
||||
from util.sample_collections import samples_from_files, DirectSDBWriter
|
||||
|
||||
AUDIO_TYPE_LOOKUP = {
|
||||
'wav': AUDIO_TYPE_WAV,
|
||||
'opus': AUDIO_TYPE_OPUS
|
||||
}
|
||||
|
||||
|
||||
def build_sdb():
|
||||
audio_type = AUDIO_TYPE_LOOKUP[CLI_ARGS.audio_type]
|
||||
with DirectSDBWriter(CLI_ARGS.target, audio_type=audio_type) as sdb_writer:
|
||||
samples = samples_from_files(CLI_ARGS.sources)
|
||||
bar = progressbar.ProgressBar(max_value=len(samples), widgets=SIMPLE_BAR)
|
||||
for sample in bar(change_audio_types(samples, audio_type=audio_type, processes=CLI_ARGS.workers)):
|
||||
sdb_writer.add(sample)
|
||||
|
||||
|
||||
def handle_args():
|
||||
parser = argparse.ArgumentParser(description='Tool for building Sample Databases (SDB files) '
|
||||
'from DeepSpeech CSV files and other SDB files')
|
||||
parser.add_argument('sources', nargs='+', help='Source CSV and/or SDB files - '
|
||||
'Note: For getting a correctly ordered target SDB, source SDBs have '
|
||||
'to have their samples already ordered from shortest to longest.')
|
||||
parser.add_argument('target', help='SDB file to create')
|
||||
parser.add_argument('--audio-type', default='opus', choices=AUDIO_TYPE_LOOKUP.keys(),
|
||||
help='Audio representation inside target SDB')
|
||||
parser.add_argument('--workers', type=int, default=None, help='Number of encoding SDB workers')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
CLI_ARGS = handle_args()
|
||||
build_sdb()
|
|
@ -0,0 +1,73 @@
|
|||
#!/usr/bin/env python
|
||||
"""
|
||||
Tool for playing samples from Sample Databases (SDB files) and DeepSpeech CSV files
|
||||
Use "python3 build_sdb.py -h" for help
|
||||
"""
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
import random
|
||||
import argparse
|
||||
|
||||
from util.sample_collections import samples_from_file
|
||||
from util.audio import AUDIO_TYPE_PCM
|
||||
|
||||
|
||||
def play_sample(samples, index):
|
||||
if index < 0:
|
||||
index = len(samples) + index
|
||||
if CLI_ARGS.random:
|
||||
index = random.randint(0, len(samples))
|
||||
elif index >= len(samples):
|
||||
print('No sample with index {}'.format(CLI_ARGS.start))
|
||||
sys.exit(1)
|
||||
sample = samples[index]
|
||||
print('Sample "{}"'.format(sample.sample_id))
|
||||
print(' "{}"'.format(sample.transcript))
|
||||
sample.change_audio_type(AUDIO_TYPE_PCM)
|
||||
rate, channels, width = sample.audio_format
|
||||
wave_obj = simpleaudio.WaveObject(sample.audio, channels, width, rate)
|
||||
play_obj = wave_obj.play()
|
||||
play_obj.wait_done()
|
||||
|
||||
|
||||
def play_collection():
|
||||
samples = samples_from_file(CLI_ARGS.collection, buffering=0)
|
||||
played = 0
|
||||
index = CLI_ARGS.start
|
||||
while True:
|
||||
if 0 <= CLI_ARGS.number <= played:
|
||||
return
|
||||
play_sample(samples, index)
|
||||
played += 1
|
||||
index = (index + 1) % len(samples)
|
||||
|
||||
|
||||
def handle_args():
|
||||
parser = argparse.ArgumentParser(description='Tool for playing samples from Sample Databases (SDB files) '
|
||||
'and DeepSpeech CSV files')
|
||||
parser.add_argument('collection', help='Sample DB or CSV file to play samples from')
|
||||
parser.add_argument('--start', type=int, default=0,
|
||||
help='Sample index to start at (negative numbers are relative to the end of the collection)')
|
||||
parser.add_argument('--number', type=int, default=-1, help='Number of samples to play (-1 for endless)')
|
||||
parser.add_argument('--random', action='store_true', help='If samples should be played in random order')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
import simpleaudio
|
||||
except ModuleNotFoundError:
|
||||
print('play.py requires Python package "simpleaudio"')
|
||||
sys.exit(1)
|
||||
CLI_ARGS = handle_args()
|
||||
try:
|
||||
play_collection()
|
||||
except KeyboardInterrupt:
|
||||
print(' Stopped')
|
||||
sys.exit(0)
|
|
@ -0,0 +1,37 @@
|
|||
#!/bin/sh
|
||||
|
||||
set -xe
|
||||
|
||||
ldc93s1_dir="./data/smoke_test"
|
||||
ldc93s1_csv="${ldc93s1_dir}/ldc93s1.csv"
|
||||
ldc93s1_sdb="${ldc93s1_dir}/ldc93s1.sdb"
|
||||
|
||||
if [ ! -f "${ldc93s1_dir}/ldc93s1.csv" ]; then
|
||||
echo "Downloading and preprocessing LDC93S1 example data, saving in ${ldc93s1_dir}."
|
||||
python -u bin/import_ldc93s1.py ${ldc93s1_dir}
|
||||
fi;
|
||||
|
||||
if [ ! -f "${ldc93s1_dir}/ldc93s1.sdb" ]; then
|
||||
echo "Converting LDC93S1 example data, saving to ${ldc93s1_sdb}."
|
||||
python -u bin/build_sdb.py ${ldc93s1_csv} ${ldc93s1_sdb}
|
||||
fi;
|
||||
|
||||
# Force only one visible device because we have a single-sample dataset
|
||||
# and when trying to run on multiple devices (like GPUs), this will break
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||
--train_files ${ldc93s1_sdb} --train_batch_size 1 \
|
||||
--dev_files ${ldc93s1_sdb} --dev_batch_size 1 \
|
||||
--test_files ${ldc93s1_sdb} --test_batch_size 1 \
|
||||
--n_hidden 100 --epochs 1 \
|
||||
--max_to_keep 1 --checkpoint_dir '/tmp/ckpt_sdb' \
|
||||
--learning_rate 0.001 --dropout_rate 0.05 \
|
||||
--scorer_path 'data/smoke_test/pruned_lm.scorer' | tee /tmp/resume.log
|
||||
|
||||
if ! grep "Loading best validating checkpoint from" /tmp/resume.log; then
|
||||
echo "Did not resume training from checkpoint"
|
||||
exit 1
|
||||
else
|
||||
exit 0
|
||||
fi
|
|
@ -0,0 +1,34 @@
|
|||
#!/bin/sh
|
||||
|
||||
set -xe
|
||||
|
||||
ldc93s1_dir="./data/smoke_test"
|
||||
ldc93s1_csv="${ldc93s1_dir}/ldc93s1.csv"
|
||||
ldc93s1_sdb="${ldc93s1_dir}/ldc93s1.sdb"
|
||||
|
||||
epoch_count=$1
|
||||
audio_sample_rate=$2
|
||||
|
||||
if [ ! -f "${ldc93s1_dir}/ldc93s1.csv" ]; then
|
||||
echo "Downloading and preprocessing LDC93S1 example data, saving in ${ldc93s1_dir}."
|
||||
python -u bin/import_ldc93s1.py ${ldc93s1_dir}
|
||||
fi;
|
||||
|
||||
if [ ! -f "${ldc93s1_dir}/ldc93s1.sdb" ]; then
|
||||
echo "Converting LDC93S1 example data, saving to ${ldc93s1_sdb}."
|
||||
python -u bin/build_sdb.py ${ldc93s1_csv} ${ldc93s1_sdb}
|
||||
fi;
|
||||
|
||||
# Force only one visible device because we have a single-sample dataset
|
||||
# and when trying to run on multiple devices (like GPUs), this will break
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||
--train_files ${ldc93s1_sdb} --train_batch_size 1 \
|
||||
--dev_files ${ldc93s1_sdb} --dev_batch_size 1 \
|
||||
--test_files ${ldc93s1_sdb} --test_batch_size 1 \
|
||||
--n_hidden 100 --epochs $epoch_count \
|
||||
--max_to_keep 1 --checkpoint_dir '/tmp/ckpt_sdb' \
|
||||
--learning_rate 0.001 --dropout_rate 0.05 --export_dir '/tmp/train_sdb' \
|
||||
--scorer_path 'data/smoke_test/pruned_lm.scorer' \
|
||||
--audio_sample_rate ${audio_sample_rate}
|
|
@ -0,0 +1,35 @@
|
|||
#!/bin/sh
|
||||
|
||||
set -xe
|
||||
|
||||
ldc93s1_dir="./data/smoke_test"
|
||||
ldc93s1_csv="${ldc93s1_dir}/ldc93s1.csv"
|
||||
ldc93s1_sdb="${ldc93s1_dir}/ldc93s1.sdb"
|
||||
|
||||
epoch_count=$1
|
||||
audio_sample_rate=$2
|
||||
|
||||
if [ ! -f "${ldc93s1_dir}/ldc93s1.csv" ]; then
|
||||
echo "Downloading and preprocessing LDC93S1 example data, saving in ${ldc93s1_dir}."
|
||||
python -u bin/import_ldc93s1.py ${ldc93s1_dir}
|
||||
fi;
|
||||
|
||||
if [ ! -f "${ldc93s1_dir}/ldc93s1.sdb" ]; then
|
||||
echo "Converting LDC93S1 example data, saving to ${ldc93s1_sdb}."
|
||||
python -u bin/build_sdb.py ${ldc93s1_csv} ${ldc93s1_sdb}
|
||||
fi;
|
||||
|
||||
# Force only one visible device because we have a single-sample dataset
|
||||
# and when trying to run on multiple devices (like GPUs), this will break
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||
--train_files ${ldc93s1_sdb},${ldc93s1_csv} --train_batch_size 1 \
|
||||
--feature_cache '/tmp/ldc93s1_cache_sdb_csv' \
|
||||
--dev_files ${ldc93s1_sdb},${ldc93s1_csv} --dev_batch_size 1 \
|
||||
--test_files ${ldc93s1_sdb},${ldc93s1_csv} --test_batch_size 1 \
|
||||
--n_hidden 100 --epochs $epoch_count \
|
||||
--max_to_keep 1 --checkpoint_dir '/tmp/ckpt_sdb_csv' \
|
||||
--learning_rate 0.001 --dropout_rate 0.05 --export_dir '/tmp/train_sdb_csv' \
|
||||
--scorer_path 'data/smoke_test/pruned_lm.scorer' \
|
||||
--audio_sample_rate ${audio_sample_rate}
|
|
@ -2,12 +2,12 @@
|
|||
tensorflow == 1.15.0
|
||||
numpy == 1.18.1
|
||||
progressbar2
|
||||
pandas
|
||||
six
|
||||
pyxdg
|
||||
attrdict
|
||||
absl-py
|
||||
semver
|
||||
opuslib == 2.0.0
|
||||
|
||||
# Requirements for building native_client files
|
||||
setuptools
|
||||
|
@ -15,6 +15,7 @@ setuptools
|
|||
# Requirements for importers
|
||||
sox
|
||||
bs4
|
||||
pandas
|
||||
requests
|
||||
librosa
|
||||
soundfile
|
||||
|
|
|
@ -5,6 +5,9 @@ python:
|
|||
apt: 'python3-virtualenv python3-setuptools python3-pip python3-wheel python3-pkg-resources'
|
||||
packages_docs_bionic:
|
||||
apt: 'python3 python3-pip zip doxygen'
|
||||
training:
|
||||
packages_trusty:
|
||||
apt: 'libopus0'
|
||||
tensorflow:
|
||||
packages_trusty:
|
||||
apt: 'make build-essential gfortran git libblas-dev liblapack-dev libsox-dev libmagic-dev libgsm1-dev libltdl-dev libpng-dev python zlib1g-dev'
|
||||
|
|
|
@ -48,6 +48,11 @@ pushd ${HOME}/DeepSpeech/ds/
|
|||
time ./bin/run-tc-ldc93s1_new.sh 249 "${sample_rate}"
|
||||
time ./bin/run-tc-ldc93s1_new.sh 1 "${sample_rate}"
|
||||
time ./bin/run-tc-ldc93s1_tflite.sh "${sample_rate}"
|
||||
# Testing single SDB source
|
||||
time ./bin/run-tc-ldc93s1_new_sdb.sh 220 "${sample_rate}"
|
||||
# Testing interleaved source (SDB+CSV combination) - run twice to test preprocessed features
|
||||
time ./bin/run-tc-ldc93s1_new_sdb_csv.sh 109 "${sample_rate}"
|
||||
time ./bin/run-tc-ldc93s1_new_sdb_csv.sh 1 "${sample_rate}"
|
||||
popd
|
||||
|
||||
cp /tmp/train/output_graph.pb ${TASKCLUSTER_ARTIFACTS}
|
||||
|
@ -62,6 +67,7 @@ cp /tmp/train/output_graph.pbmm ${TASKCLUSTER_ARTIFACTS}
|
|||
|
||||
pushd ${HOME}/DeepSpeech/ds/
|
||||
time ./bin/run-tc-ldc93s1_checkpoint.sh
|
||||
time ./bin/run-tc-ldc93s1_checkpoint_sdb.sh
|
||||
popd
|
||||
|
||||
virtualenv_deactivate "${pyalias}" "deepspeech"
|
||||
|
|
|
@ -2,6 +2,9 @@ build:
|
|||
template_file: test-linux-opt-base.tyml
|
||||
dependencies:
|
||||
- "linux-amd64-ctc-opt"
|
||||
system_setup:
|
||||
>
|
||||
apt-get -qq update && apt-get -qq -y install ${training.packages_trusty.apt}
|
||||
args:
|
||||
tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-train-tests.sh 3.5.8:m 16k"
|
||||
metadata:
|
||||
|
|
|
@ -2,6 +2,9 @@ build:
|
|||
template_file: test-linux-opt-base.tyml
|
||||
dependencies:
|
||||
- "linux-amd64-ctc-opt"
|
||||
system_setup:
|
||||
>
|
||||
apt-get -qq update && apt-get -qq -y install ${training.packages_trusty.apt}
|
||||
args:
|
||||
tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-train-tests.sh 3.6.10:m 16k"
|
||||
metadata:
|
||||
|
|
|
@ -2,6 +2,9 @@ build:
|
|||
template_file: test-linux-opt-base.tyml
|
||||
dependencies:
|
||||
- "linux-amd64-ctc-opt"
|
||||
system_setup:
|
||||
>
|
||||
apt-get -qq update && apt-get -qq -y install ${training.packages_trusty.apt}
|
||||
args:
|
||||
tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-train-tests.sh 3.6.10:m 8k"
|
||||
metadata:
|
||||
|
|
237
util/audio.py
237
util/audio.py
|
@ -1,34 +1,127 @@
|
|||
import os
|
||||
import sox
|
||||
import io
|
||||
import wave
|
||||
import tempfile
|
||||
import collections
|
||||
import numpy as np
|
||||
|
||||
from util.helpers import LimitingPool
|
||||
|
||||
DEFAULT_RATE = 16000
|
||||
DEFAULT_CHANNELS = 1
|
||||
DEFAULT_WIDTH = 2
|
||||
DEFAULT_FORMAT = (DEFAULT_RATE, DEFAULT_CHANNELS, DEFAULT_WIDTH)
|
||||
|
||||
AUDIO_TYPE_NP = 'application/vnd.mozilla.np'
|
||||
AUDIO_TYPE_PCM = 'application/vnd.mozilla.pcm'
|
||||
AUDIO_TYPE_WAV = 'audio/wav'
|
||||
AUDIO_TYPE_OPUS = 'application/vnd.mozilla.opus'
|
||||
SERIALIZABLE_AUDIO_TYPES = [AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS]
|
||||
|
||||
def get_audio_format(wav_file):
|
||||
OPUS_PCM_LEN_SIZE = 4
|
||||
OPUS_RATE_SIZE = 4
|
||||
OPUS_CHANNELS_SIZE = 1
|
||||
OPUS_WIDTH_SIZE = 1
|
||||
OPUS_CHUNK_LEN_SIZE = 2
|
||||
|
||||
|
||||
class Sample:
|
||||
"""Represents in-memory audio data of a certain (convertible) representation.
|
||||
Attributes:
|
||||
audio_type (str): See `__init__`.
|
||||
audio_format (tuple:(int, int, int)): See `__init__`.
|
||||
audio (obj): Audio data represented as indicated by `audio_type`
|
||||
duration (float): Audio duration of the sample in seconds
|
||||
"""
|
||||
def __init__(self, audio_type, raw_data, audio_format=None):
|
||||
"""
|
||||
Creates a Sample from a raw audio representation.
|
||||
:param audio_type: Audio data representation type
|
||||
Supported types:
|
||||
- AUDIO_TYPE_OPUS: Memory file representation (BytesIO) of Opus encoded audio
|
||||
wrapped by a custom container format (used in SDBs)
|
||||
- AUDIO_TYPE_WAV: Memory file representation (BytesIO) of a Wave file
|
||||
- AUDIO_TYPE_PCM: Binary representation (bytearray) of PCM encoded audio data (Wave file without header)
|
||||
- AUDIO_TYPE_NP: NumPy representation of audio data (np.float32) - typically used for GPU feeding
|
||||
:param raw_data: Audio data in the form of the provided representation type (see audio_type).
|
||||
For types AUDIO_TYPE_OPUS or AUDIO_TYPE_WAV data can also be passed as a bytearray.
|
||||
:param audio_format: Tuple of sample-rate, number of channels and sample-width.
|
||||
Required in case of audio_type = AUDIO_TYPE_PCM or AUDIO_TYPE_NP,
|
||||
as this information cannot be derived from raw audio data.
|
||||
"""
|
||||
self.audio_type = audio_type
|
||||
self.audio_format = audio_format
|
||||
if audio_type in SERIALIZABLE_AUDIO_TYPES:
|
||||
self.audio = raw_data if isinstance(raw_data, io.BytesIO) else io.BytesIO(raw_data)
|
||||
self.duration = read_duration(audio_type, self.audio)
|
||||
else:
|
||||
self.audio = raw_data
|
||||
if self.audio_format is None:
|
||||
raise ValueError('For audio type "{}" parameter "audio_format" is mandatory'.format(self.audio_type))
|
||||
if audio_type == AUDIO_TYPE_PCM:
|
||||
self.duration = get_pcm_duration(len(self.audio), self.audio_format)
|
||||
elif audio_type == AUDIO_TYPE_NP:
|
||||
self.duration = get_np_duration(len(self.audio), self.audio_format)
|
||||
else:
|
||||
raise ValueError('Unsupported audio type: {}'.format(self.audio_type))
|
||||
|
||||
def change_audio_type(self, new_audio_type):
|
||||
"""
|
||||
In-place conversion of audio data into a different representation.
|
||||
:param new_audio_type: New audio-type - see `__init__`.
|
||||
Not supported: Converting from AUDIO_TYPE_NP into any other type.
|
||||
"""
|
||||
if self.audio_type == new_audio_type:
|
||||
return
|
||||
if new_audio_type == AUDIO_TYPE_PCM and self.audio_type in SERIALIZABLE_AUDIO_TYPES:
|
||||
self.audio_format, audio = read_audio(self.audio_type, self.audio)
|
||||
self.audio.close()
|
||||
self.audio = audio
|
||||
elif new_audio_type == AUDIO_TYPE_NP:
|
||||
self.change_audio_type(AUDIO_TYPE_PCM)
|
||||
self.audio = pcm_to_np(self.audio_format, self.audio)
|
||||
elif new_audio_type in SERIALIZABLE_AUDIO_TYPES:
|
||||
self.change_audio_type(AUDIO_TYPE_PCM)
|
||||
audio_bytes = io.BytesIO()
|
||||
write_audio(new_audio_type, audio_bytes, self.audio_format, self.audio)
|
||||
audio_bytes.seek(0)
|
||||
self.audio = audio_bytes
|
||||
else:
|
||||
raise RuntimeError('Changing audio representation type from "{}" to "{}" not supported'
|
||||
.format(self.audio_type, new_audio_type))
|
||||
self.audio_type = new_audio_type
|
||||
|
||||
|
||||
def change_audio_types(samples, audio_type=AUDIO_TYPE_PCM, processes=None, process_ahead=None):
|
||||
def change_audio_type(sample):
|
||||
sample.change_audio_type(audio_type)
|
||||
return sample
|
||||
with LimitingPool(processes=processes, process_ahead=process_ahead) as pool:
|
||||
yield from pool.imap(change_audio_type, samples)
|
||||
|
||||
|
||||
def read_audio_format_from_wav_file(wav_file):
|
||||
return wav_file.getframerate(), wav_file.getnchannels(), wav_file.getsampwidth()
|
||||
|
||||
|
||||
def get_num_samples(audio_data, audio_format=DEFAULT_FORMAT):
|
||||
def get_num_samples(pcm_buffer_size, audio_format=DEFAULT_FORMAT):
|
||||
_, channels, width = audio_format
|
||||
return len(audio_data) // (channels * width)
|
||||
return pcm_buffer_size // (channels * width)
|
||||
|
||||
|
||||
def get_duration(audio_data, audio_format=DEFAULT_FORMAT):
|
||||
return get_num_samples(audio_data, audio_format) / audio_format[0]
|
||||
def get_pcm_duration(pcm_buffer_size, audio_format=DEFAULT_FORMAT):
|
||||
"""Calculates duration in seconds of a binary PCM buffer (typically read from a WAV file)"""
|
||||
return get_num_samples(pcm_buffer_size, audio_format) / audio_format[0]
|
||||
|
||||
|
||||
def get_duration_ms(audio_data, audio_format=DEFAULT_FORMAT):
|
||||
return get_duration(audio_data, audio_format) * 1000
|
||||
def get_np_duration(np_len, audio_format=DEFAULT_FORMAT):
|
||||
"""Calculates duration in seconds of NumPy audio data"""
|
||||
return np_len / audio_format[0]
|
||||
|
||||
|
||||
def convert_audio(src_audio_path, dst_audio_path, file_type=None, audio_format=DEFAULT_FORMAT):
|
||||
sample_rate, channels, width = audio_format
|
||||
import sox
|
||||
transformer = sox.Transformer()
|
||||
transformer.set_output_format(file_type=file_type, rate=sample_rate, channels=channels, bits=width*8)
|
||||
transformer.build(src_audio_path, dst_audio_path)
|
||||
|
@ -45,7 +138,7 @@ class AudioFile:
|
|||
def __enter__(self):
|
||||
if self.audio_path.endswith('.wav'):
|
||||
self.open_file = wave.open(self.audio_path, 'r')
|
||||
if get_audio_format(self.open_file) == self.audio_format:
|
||||
if read_audio_format_from_wav_file(self.open_file) == self.audio_format:
|
||||
if self.as_path:
|
||||
self.open_file.close()
|
||||
return self.audio_path
|
||||
|
@ -66,12 +159,12 @@ class AudioFile:
|
|||
|
||||
|
||||
def read_frames(wav_file, frame_duration_ms=30, yield_remainder=False):
|
||||
audio_format = get_audio_format(wav_file)
|
||||
audio_format = read_audio_format_from_wav_file(wav_file)
|
||||
frame_size = int(audio_format[0] * (frame_duration_ms / 1000.0))
|
||||
while True:
|
||||
try:
|
||||
data = wav_file.readframes(frame_size)
|
||||
if not yield_remainder and get_duration_ms(data, audio_format) < frame_duration_ms:
|
||||
if not yield_remainder and get_pcm_duration(len(data), audio_format) * 1000 < frame_duration_ms:
|
||||
break
|
||||
yield data
|
||||
except EOFError:
|
||||
|
@ -106,7 +199,7 @@ def vad_split(audio_frames,
|
|||
frame_duration_ms = 0
|
||||
frame_index = 0
|
||||
for frame_index, frame in enumerate(audio_frames):
|
||||
frame_duration_ms = get_duration_ms(frame, audio_format)
|
||||
frame_duration_ms = get_pcm_duration(len(frame), audio_format) * 1000
|
||||
if int(frame_duration_ms) not in [10, 20, 30]:
|
||||
raise ValueError('VAD-splitting only supported for frame durations 10, 20, or 30 ms')
|
||||
is_speech = vad.is_speech(frame, sample_rate)
|
||||
|
@ -133,3 +226,123 @@ def vad_split(audio_frames,
|
|||
yield b''.join(voiced_frames), \
|
||||
frame_duration_ms * (frame_index - len(voiced_frames)), \
|
||||
frame_duration_ms * (frame_index + 1)
|
||||
|
||||
|
||||
def pack_number(n, num_bytes):
|
||||
return n.to_bytes(num_bytes, 'big', signed=False)
|
||||
|
||||
|
||||
def unpack_number(data):
|
||||
return int.from_bytes(data, 'big', signed=False)
|
||||
|
||||
|
||||
def get_opus_frame_size(rate):
|
||||
return 60 * rate // 1000
|
||||
|
||||
|
||||
def write_opus(opus_file, audio_format, audio_data):
|
||||
rate, channels, width = audio_format
|
||||
frame_size = get_opus_frame_size(rate)
|
||||
import opuslib # pylint: disable=import-outside-toplevel
|
||||
encoder = opuslib.Encoder(rate, channels, 'audio')
|
||||
chunk_size = frame_size * channels * width
|
||||
opus_file.write(pack_number(len(audio_data), OPUS_PCM_LEN_SIZE))
|
||||
opus_file.write(pack_number(rate, OPUS_RATE_SIZE))
|
||||
opus_file.write(pack_number(channels, OPUS_CHANNELS_SIZE))
|
||||
opus_file.write(pack_number(width, OPUS_WIDTH_SIZE))
|
||||
for i in range(0, len(audio_data), chunk_size):
|
||||
chunk = audio_data[i:i + chunk_size]
|
||||
# Preventing non-deterministic encoding results from uninitialized remainder of the encoder buffer
|
||||
if len(chunk) < chunk_size:
|
||||
chunk = chunk + bytearray(chunk_size - len(chunk))
|
||||
encoded = encoder.encode(chunk, frame_size)
|
||||
opus_file.write(pack_number(len(encoded), OPUS_CHUNK_LEN_SIZE))
|
||||
opus_file.write(encoded)
|
||||
|
||||
|
||||
def read_opus_header(opus_file):
|
||||
opus_file.seek(0)
|
||||
pcm_buffer_size = unpack_number(opus_file.read(OPUS_PCM_LEN_SIZE))
|
||||
rate = unpack_number(opus_file.read(OPUS_RATE_SIZE))
|
||||
channels = unpack_number(opus_file.read(OPUS_CHANNELS_SIZE))
|
||||
width = unpack_number(opus_file.read(OPUS_WIDTH_SIZE))
|
||||
return pcm_buffer_size, (rate, channels, width)
|
||||
|
||||
|
||||
def read_opus(opus_file):
|
||||
pcm_buffer_size, audio_format = read_opus_header(opus_file)
|
||||
rate, channels, _ = audio_format
|
||||
frame_size = get_opus_frame_size(rate)
|
||||
import opuslib # pylint: disable=import-outside-toplevel
|
||||
decoder = opuslib.Decoder(rate, channels)
|
||||
audio_data = bytearray()
|
||||
while len(audio_data) < pcm_buffer_size:
|
||||
chunk_len = unpack_number(opus_file.read(OPUS_CHUNK_LEN_SIZE))
|
||||
chunk = opus_file.read(chunk_len)
|
||||
decoded = decoder.decode(chunk, frame_size)
|
||||
audio_data.extend(decoded)
|
||||
audio_data = audio_data[:pcm_buffer_size]
|
||||
return audio_format, audio_data
|
||||
|
||||
|
||||
def write_wav(wav_file, audio_format, pcm_data):
|
||||
with wave.open(wav_file, 'wb') as wav_file_writer:
|
||||
rate, channels, width = audio_format
|
||||
wav_file_writer.setframerate(rate)
|
||||
wav_file_writer.setnchannels(channels)
|
||||
wav_file_writer.setsampwidth(width)
|
||||
wav_file_writer.writeframes(pcm_data)
|
||||
|
||||
|
||||
def read_wav(wav_file):
|
||||
wav_file.seek(0)
|
||||
with wave.open(wav_file, 'rb') as wav_file_reader:
|
||||
audio_format = read_audio_format_from_wav_file(wav_file_reader)
|
||||
pcm_data = wav_file_reader.readframes(wav_file_reader.getnframes())
|
||||
return audio_format, pcm_data
|
||||
|
||||
|
||||
def read_audio(audio_type, audio_file):
|
||||
if audio_type == AUDIO_TYPE_WAV:
|
||||
return read_wav(audio_file)
|
||||
if audio_type == AUDIO_TYPE_OPUS:
|
||||
return read_opus(audio_file)
|
||||
raise ValueError('Unsupported audio type: {}'.format(audio_type))
|
||||
|
||||
|
||||
def write_audio(audio_type, audio_file, audio_format, pcm_data):
|
||||
if audio_type == AUDIO_TYPE_WAV:
|
||||
return write_wav(audio_file, audio_format, pcm_data)
|
||||
if audio_type == AUDIO_TYPE_OPUS:
|
||||
return write_opus(audio_file, audio_format, pcm_data)
|
||||
raise ValueError('Unsupported audio type: {}'.format(audio_type))
|
||||
|
||||
|
||||
def read_wav_duration(wav_file):
|
||||
wav_file.seek(0)
|
||||
with wave.open(wav_file, 'rb') as wav_file_reader:
|
||||
return wav_file_reader.getnframes() / wav_file_reader.getframerate()
|
||||
|
||||
|
||||
def read_opus_duration(opus_file):
|
||||
pcm_buffer_size, audio_format = read_opus_header(opus_file)
|
||||
return get_pcm_duration(pcm_buffer_size, audio_format)
|
||||
|
||||
|
||||
def read_duration(audio_type, audio_file):
|
||||
if audio_type == AUDIO_TYPE_WAV:
|
||||
return read_wav_duration(audio_file)
|
||||
if audio_type == AUDIO_TYPE_OPUS:
|
||||
return read_opus_duration(audio_file)
|
||||
raise ValueError('Unsupported audio type: {}'.format(audio_type))
|
||||
|
||||
|
||||
def pcm_to_np(audio_format, pcm_data):
|
||||
_, channels, width = audio_format
|
||||
if width not in [1, 2, 4]:
|
||||
raise ValueError('Unsupported sample width: {}'.format(width))
|
||||
dtype = [None, np.int8, np.int16, None, np.int32][width]
|
||||
samples = np.frombuffer(pcm_data, dtype=dtype)
|
||||
assert channels == 1 # only mono supported for now
|
||||
samples = samples.astype(np.float32) / np.iinfo(dtype).max
|
||||
return np.expand_dims(samples, axis=1)
|
||||
|
|
|
@ -12,6 +12,7 @@ from util.flags import FLAGS
|
|||
from util.gpu import get_available_gpus
|
||||
from util.logging import log_error
|
||||
from util.text import Alphabet, UTF8Alphabet
|
||||
from util.helpers import parse_file_size
|
||||
|
||||
class ConfigSingleton:
|
||||
_config = None
|
||||
|
@ -29,6 +30,9 @@ Config = ConfigSingleton() # pylint: disable=invalid-name
|
|||
def initialize_globals():
|
||||
c = AttrDict()
|
||||
|
||||
# Read-buffer
|
||||
FLAGS.read_buffer = parse_file_size(FLAGS.read_buffer)
|
||||
|
||||
# Set default dropout rates
|
||||
if FLAGS.dropout_rate2 < 0:
|
||||
FLAGS.dropout_rate2 = FLAGS.dropout_rate
|
||||
|
|
107
util/feeding.py
107
util/feeding.py
|
@ -1,12 +1,9 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import pandas
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.python.ops import gen_audio_ops as contrib_audio
|
||||
|
@ -15,27 +12,18 @@ from util.config import Config
|
|||
from util.text import text_to_char_array
|
||||
from util.flags import FLAGS
|
||||
from util.spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up, augment_sparse_warp
|
||||
from util.audio import read_frames_from_file, vad_split, DEFAULT_FORMAT
|
||||
from util.audio import change_audio_types, read_frames_from_file, vad_split, pcm_to_np, DEFAULT_FORMAT, AUDIO_TYPE_NP
|
||||
from util.sample_collections import samples_from_files
|
||||
from util.helpers import remember_exception, MEGABYTE
|
||||
|
||||
|
||||
def read_csvs(csv_files):
|
||||
sets = []
|
||||
for csv in csv_files:
|
||||
file = pandas.read_csv(csv, encoding='utf-8', na_filter=False)
|
||||
#FIXME: not cross-platform
|
||||
csv_dir = os.path.dirname(os.path.abspath(csv))
|
||||
file['wav_filename'] = file['wav_filename'].str.replace(r'(^[^/])', lambda m: os.path.join(csv_dir, m.group(1))) # pylint: disable=cell-var-from-loop
|
||||
sets.append(file)
|
||||
# Concat all sets, drop any extra columns, re-index the final result as 0..N
|
||||
return pandas.concat(sets, join='inner', ignore_index=True)
|
||||
|
||||
|
||||
def samples_to_mfccs(samples, sample_rate, train_phase=False, wav_filename=None):
|
||||
def samples_to_mfccs(samples, sample_rate, train_phase=False, sample_id=None):
|
||||
if train_phase:
|
||||
# We need the lambdas to make TensorFlow happy.
|
||||
# pylint: disable=unnecessary-lambda
|
||||
tf.cond(tf.math.not_equal(sample_rate, FLAGS.audio_sample_rate),
|
||||
lambda: tf.print('WARNING: sample rate of file', wav_filename, '(', sample_rate, ') does not match FLAGS.audio_sample_rate. This can lead to incorrect results.'),
|
||||
lambda: tf.print('WARNING: sample rate of sample', sample_id, '(', sample_rate, ') '
|
||||
'does not match FLAGS.audio_sample_rate. This can lead to incorrect results.'),
|
||||
lambda: tf.no_op(),
|
||||
name='matching_sample_rate')
|
||||
|
||||
|
@ -84,10 +72,8 @@ def samples_to_mfccs(samples, sample_rate, train_phase=False, wav_filename=None)
|
|||
return mfccs, tf.shape(input=mfccs)[0]
|
||||
|
||||
|
||||
def audiofile_to_features(wav_filename, train_phase=False):
|
||||
samples = tf.io.read_file(wav_filename)
|
||||
decoded = contrib_audio.decode_wav(samples, desired_channels=1)
|
||||
features, features_len = samples_to_mfccs(decoded.audio, decoded.sample_rate, train_phase=train_phase, wav_filename=wav_filename)
|
||||
def audio_to_features(audio, sample_rate, train_phase=False, sample_id=None):
|
||||
features, features_len = samples_to_mfccs(audio, sample_rate, train_phase=train_phase, sample_id=sample_id)
|
||||
|
||||
if train_phase:
|
||||
if FLAGS.data_aug_features_multiplicative > 0:
|
||||
|
@ -99,10 +85,17 @@ def audiofile_to_features(wav_filename, train_phase=False):
|
|||
return features, features_len
|
||||
|
||||
|
||||
def entry_to_features(wav_filename, transcript, train_phase):
|
||||
def audiofile_to_features(wav_filename, train_phase=False):
|
||||
samples = tf.io.read_file(wav_filename)
|
||||
decoded = contrib_audio.decode_wav(samples, desired_channels=1)
|
||||
return audio_to_features(decoded.audio, decoded.sample_rate, train_phase=train_phase, sample_id=wav_filename)
|
||||
|
||||
|
||||
def entry_to_features(sample_id, audio, sample_rate, transcript, train_phase=False):
|
||||
# https://bugs.python.org/issue32117
|
||||
features, features_len = audiofile_to_features(wav_filename, train_phase=train_phase)
|
||||
return wav_filename, features, features_len, tf.SparseTensor(*transcript)
|
||||
features, features_len = audio_to_features(audio, sample_rate, train_phase=train_phase, sample_id=sample_id)
|
||||
sparse_transcript = tf.SparseTensor(*transcript)
|
||||
return sample_id, features, features_len, sparse_transcript
|
||||
|
||||
|
||||
def to_sparse_tuple(sequence):
|
||||
|
@ -114,15 +107,22 @@ def to_sparse_tuple(sequence):
|
|||
return indices, sequence, shape
|
||||
|
||||
|
||||
def create_dataset(csvs, batch_size, enable_cache=False, cache_path=None, train_phase=False):
|
||||
df = read_csvs(csvs)
|
||||
df.sort_values(by='wav_filesize', inplace=True)
|
||||
|
||||
df['transcript'] = df.apply(text_to_char_array, alphabet=Config.alphabet, result_type='reduce', axis=1)
|
||||
|
||||
def create_dataset(sources,
|
||||
batch_size,
|
||||
enable_cache=False,
|
||||
cache_path=None,
|
||||
train_phase=False,
|
||||
exception_box=None,
|
||||
process_ahead=None,
|
||||
buffering=1 * MEGABYTE):
|
||||
def generate_values():
|
||||
for _, row in df.iterrows():
|
||||
yield row.wav_filename, to_sparse_tuple(row.transcript)
|
||||
samples = samples_from_files(sources, buffering=buffering)
|
||||
for sample in change_audio_types(samples,
|
||||
AUDIO_TYPE_NP,
|
||||
process_ahead=2 * batch_size if process_ahead is None else process_ahead):
|
||||
transcript = text_to_char_array(sample.transcript, Config.alphabet, context=sample.sample_id)
|
||||
transcript = to_sparse_tuple(transcript)
|
||||
yield sample.sample_id, sample.audio, sample.audio_format[0], transcript
|
||||
|
||||
# Batching a dataset of 2D SparseTensors creates 3D batches, which fail
|
||||
# when passed to tf.nn.ctc_loss, so we reshape them to remove the extra
|
||||
|
@ -131,27 +131,23 @@ def create_dataset(csvs, batch_size, enable_cache=False, cache_path=None, train_
|
|||
shape = sparse.dense_shape
|
||||
return tf.sparse.reshape(sparse, [shape[0], shape[2]])
|
||||
|
||||
def batch_fn(wav_filenames, features, features_len, transcripts):
|
||||
def batch_fn(sample_ids, features, features_len, transcripts):
|
||||
features = tf.data.Dataset.zip((features, features_len))
|
||||
features = features.padded_batch(batch_size,
|
||||
padded_shapes=([None, Config.n_input], []))
|
||||
features = features.padded_batch(batch_size, padded_shapes=([None, Config.n_input], []))
|
||||
transcripts = transcripts.batch(batch_size).map(sparse_reshape)
|
||||
wav_filenames = wav_filenames.batch(batch_size)
|
||||
return tf.data.Dataset.zip((wav_filenames, features, transcripts))
|
||||
sample_ids = sample_ids.batch(batch_size)
|
||||
return tf.data.Dataset.zip((sample_ids, features, transcripts))
|
||||
|
||||
num_gpus = len(Config.available_devices)
|
||||
process_fn = partial(entry_to_features, train_phase=train_phase)
|
||||
|
||||
dataset = (tf.data.Dataset.from_generator(generate_values,
|
||||
output_types=(tf.string, (tf.int64, tf.int32, tf.int64)))
|
||||
dataset = (tf.data.Dataset.from_generator(remember_exception(generate_values, exception_box),
|
||||
output_types=(tf.string, tf.float32, tf.int32,
|
||||
(tf.int64, tf.int32, tf.int64)))
|
||||
.map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE))
|
||||
|
||||
if enable_cache:
|
||||
dataset = dataset.cache(cache_path)
|
||||
|
||||
dataset = (dataset.window(batch_size, drop_remainder=True).flat_map(batch_fn)
|
||||
.prefetch(num_gpus))
|
||||
|
||||
.prefetch(len(Config.available_devices)))
|
||||
return dataset
|
||||
|
||||
|
||||
|
@ -160,27 +156,24 @@ def split_audio_file(audio_path,
|
|||
batch_size=1,
|
||||
aggressiveness=3,
|
||||
outlier_duration_ms=10000,
|
||||
outlier_batch_size=1):
|
||||
sample_rate, _, sample_width = audio_format
|
||||
multiplier = 1.0 / (1 << (8 * sample_width - 1))
|
||||
|
||||
outlier_batch_size=1,
|
||||
exception_box=None):
|
||||
def generate_values():
|
||||
frames = read_frames_from_file(audio_path)
|
||||
segments = vad_split(frames, aggressiveness=aggressiveness)
|
||||
for segment in segments:
|
||||
segment_buffer, time_start, time_end = segment
|
||||
samples = np.frombuffer(segment_buffer, dtype=np.int16)
|
||||
samples = samples * multiplier
|
||||
samples = np.expand_dims(samples, axis=1)
|
||||
samples = pcm_to_np(audio_format, segment_buffer)
|
||||
yield time_start, time_end, samples
|
||||
|
||||
def to_mfccs(time_start, time_end, samples):
|
||||
features, features_len = samples_to_mfccs(samples, sample_rate)
|
||||
features, features_len = samples_to_mfccs(samples, audio_format[0])
|
||||
return time_start, time_end, features, features_len
|
||||
|
||||
def create_batch_set(bs, criteria):
|
||||
return (tf.data.Dataset
|
||||
.from_generator(generate_values, output_types=(tf.int32, tf.int32, tf.float32))
|
||||
.from_generator(remember_exception(generate_values, exception_box),
|
||||
output_types=(tf.int32, tf.int32, tf.float32))
|
||||
.map(to_mfccs, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
||||
.filter(criteria)
|
||||
.padded_batch(bs, padded_shapes=([], [], [None, Config.n_input], [])))
|
||||
|
@ -192,9 +185,3 @@ def split_audio_file(audio_path,
|
|||
dataset = nds.concatenate(ods)
|
||||
dataset = dataset.prefetch(len(Config.available_devices))
|
||||
return dataset
|
||||
|
||||
|
||||
def secs_to_hours(secs):
|
||||
hours, remainder = divmod(secs, 3600)
|
||||
minutes, seconds = divmod(remainder, 60)
|
||||
return '%d:%02d:%02d' % (hours, minutes, seconds)
|
||||
|
|
|
@ -15,6 +15,7 @@ def create_flags():
|
|||
f.DEFINE_string('dev_files', '', 'comma separated list of files specifying the dataset used for validation. Multiple files will get merged. If empty, validation will not be run.')
|
||||
f.DEFINE_string('test_files', '', 'comma separated list of files specifying the dataset used for testing. Multiple files will get merged. If empty, the model will not be tested.')
|
||||
|
||||
f.DEFINE_string('read_buffer', '1MB', 'buffer-size for reading samples from datasets (supports file-size suffixes KB, MB, GB, TB)')
|
||||
f.DEFINE_string('feature_cache', '', 'cache MFCC features to disk to speed up future training runs ont he same data. This flag specifies the path where cached features extracted from --train_files will be saved. If empty, or if online augmentation flags are enabled, caching will be disabled.')
|
||||
|
||||
f.DEFINE_integer('feature_win_len', 32, 'feature extraction audio window length in milliseconds')
|
||||
|
|
103
util/helpers.py
103
util/helpers.py
|
@ -1,10 +1,32 @@
|
|||
import os
|
||||
import semver
|
||||
import sys
|
||||
import time
|
||||
import heapq
|
||||
import semver
|
||||
|
||||
from multiprocessing.dummy import Pool as ThreadPool
|
||||
|
||||
KILO = 1024
|
||||
KILOBYTE = 1 * KILO
|
||||
MEGABYTE = KILO * KILOBYTE
|
||||
GIGABYTE = KILO * MEGABYTE
|
||||
TERABYTE = KILO * GIGABYTE
|
||||
SIZE_PREFIX_LOOKUP = {'k': KILOBYTE, 'm': MEGABYTE, 'g': GIGABYTE, 't': TERABYTE}
|
||||
|
||||
|
||||
def parse_file_size(file_size):
|
||||
file_size = file_size.lower().strip()
|
||||
if len(file_size) == 0:
|
||||
return 0
|
||||
n = int(keep_only_digits(file_size))
|
||||
if file_size[-1] == 'b':
|
||||
file_size = file_size[:-1]
|
||||
e = file_size[-1]
|
||||
return SIZE_PREFIX_LOOKUP[e] * n if e in SIZE_PREFIX_LOOKUP else n
|
||||
|
||||
|
||||
def keep_only_digits(txt):
|
||||
return ''.join(filter(lambda c: c.isdigit(), txt))
|
||||
return ''.join(filter(str.isdigit, txt))
|
||||
|
||||
|
||||
def secs_to_hours(secs):
|
||||
|
@ -21,7 +43,8 @@ def check_ctcdecoder_version():
|
|||
from ds_ctcdecoder import __version__ as decoder_version
|
||||
except ImportError as e:
|
||||
if e.msg.find('__version__') > 0:
|
||||
print("DeepSpeech version ({ds_version}) requires CTC decoder to expose __version__. Please upgrade the ds_ctcdecoder package to version {ds_version}".format(ds_version=ds_version_s))
|
||||
print("DeepSpeech version ({ds_version}) requires CTC decoder to expose __version__. "
|
||||
"Please upgrade the ds_ctcdecoder package to version {ds_version}".format(ds_version=ds_version_s))
|
||||
sys.exit(1)
|
||||
raise e
|
||||
|
||||
|
@ -29,7 +52,79 @@ def check_ctcdecoder_version():
|
|||
|
||||
rv = semver.compare(ds_version_s, decoder_version_s)
|
||||
if rv != 0:
|
||||
print("DeepSpeech version ({}) and CTC decoder version ({}) do not match. Please ensure matching versions are in use.".format(ds_version_s, decoder_version_s))
|
||||
print("DeepSpeech version ({}) and CTC decoder version ({}) do not match. "
|
||||
"Please ensure matching versions are in use.".format(ds_version_s, decoder_version_s))
|
||||
sys.exit(1)
|
||||
|
||||
return rv
|
||||
|
||||
|
||||
class Interleaved:
|
||||
"""Collection that lazily combines sorted collections in an interleaving fashion.
|
||||
During iteration the next smallest element from all the sorted collections is always picked.
|
||||
The collections must support iter() and len()."""
|
||||
def __init__(self, *iterables, key=lambda obj: obj):
|
||||
self.iterables = iterables
|
||||
self.key = key
|
||||
self.len = sum(map(len, iterables))
|
||||
|
||||
def __iter__(self):
|
||||
return heapq.merge(*self.iterables, key=self.key)
|
||||
|
||||
def __len__(self):
|
||||
return self.len
|
||||
|
||||
|
||||
class LimitingPool:
|
||||
"""Limits unbound ahead-processing of multiprocessing.Pool's imap method
|
||||
before items get consumed by the iteration caller.
|
||||
This prevents OOM issues in situations where items represent larger memory allocations."""
|
||||
def __init__(self, processes=None, process_ahead=None, sleeping_for=0.1):
|
||||
self.process_ahead = os.cpu_count() if process_ahead is None else process_ahead
|
||||
self.sleeping_for = sleeping_for
|
||||
self.processed = 0
|
||||
self.pool = ThreadPool(processes=processes)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def _limit(self, it):
|
||||
for obj in it:
|
||||
while self.processed >= self.process_ahead:
|
||||
time.sleep(self.sleeping_for)
|
||||
self.processed += 1
|
||||
yield obj
|
||||
|
||||
def imap(self, fun, it):
|
||||
for obj in self.pool.imap(fun, self._limit(it)):
|
||||
self.processed -= 1
|
||||
yield obj
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.pool.close()
|
||||
|
||||
|
||||
class ExceptionBox:
|
||||
"""Helper class for passing-back and re-raising an exception from inside a TensorFlow dataset generator.
|
||||
Used in conjunction with `remember_exception`."""
|
||||
def __init__(self):
|
||||
self.exception = None
|
||||
|
||||
def raise_if_set(self):
|
||||
if self.exception is not None:
|
||||
exception = self.exception
|
||||
self.exception = None
|
||||
raise exception # pylint: disable = raising-bad-type
|
||||
|
||||
|
||||
def remember_exception(iterable, exception_box=None):
|
||||
"""Wraps a TensorFlow dataset generator for catching its actual exceptions
|
||||
that would otherwise just interrupt iteration w/o bubbling up."""
|
||||
def do_iterate():
|
||||
try:
|
||||
yield from iterable()
|
||||
except StopIteration:
|
||||
return
|
||||
except Exception as ex: # pylint: disable = broad-except
|
||||
exception_box.exception = ex
|
||||
return iterable if exception_box is None else do_iterate
|
||||
|
|
|
@ -0,0 +1,262 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import csv
|
||||
import json
|
||||
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
from util.helpers import MEGABYTE, GIGABYTE, Interleaved
|
||||
from util.audio import Sample, DEFAULT_FORMAT, AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS, SERIALIZABLE_AUDIO_TYPES
|
||||
|
||||
BIG_ENDIAN = 'big'
|
||||
INT_SIZE = 4
|
||||
BIGINT_SIZE = 2 * INT_SIZE
|
||||
MAGIC = b'SAMPLEDB'
|
||||
|
||||
BUFFER_SIZE = 1 * MEGABYTE
|
||||
CACHE_SIZE = 1 * GIGABYTE
|
||||
|
||||
SCHEMA_KEY = 'schema'
|
||||
CONTENT_KEY = 'content'
|
||||
MIME_TYPE_KEY = 'mime-type'
|
||||
MIME_TYPE_TEXT = 'text/plain'
|
||||
CONTENT_TYPE_SPEECH = 'speech'
|
||||
CONTENT_TYPE_TRANSCRIPT = 'transcript'
|
||||
|
||||
|
||||
class LabeledSample(Sample):
|
||||
"""In-memory labeled audio sample representing an utterance.
|
||||
Derived from util.audio.Sample and used by sample collection readers and writers."""
|
||||
def __init__(self, audio_type, raw_data, transcript, audio_format=DEFAULT_FORMAT, sample_id=None):
|
||||
"""
|
||||
Creates an in-memory speech sample together with a transcript of the utterance (label).
|
||||
:param audio_type: See util.audio.Sample.__init__ .
|
||||
:param raw_data: See util.audio.Sample.__init__ .
|
||||
:param transcript: Transcript of the sample's utterance
|
||||
:param audio_format: See util.audio.Sample.__init__ .
|
||||
:param sample_id: Tracking ID - typically assigned by collection readers
|
||||
"""
|
||||
super().__init__(audio_type, raw_data, audio_format=audio_format)
|
||||
self.sample_id = sample_id
|
||||
self.transcript = transcript
|
||||
|
||||
|
||||
class DirectSDBWriter:
|
||||
"""Sample collection writer for creating a Sample DB (SDB) file"""
|
||||
def __init__(self, sdb_filename, buffering=BUFFER_SIZE, audio_type=AUDIO_TYPE_OPUS, id_prefix=None):
|
||||
self.sdb_filename = sdb_filename
|
||||
self.id_prefix = sdb_filename if id_prefix is None else id_prefix
|
||||
if audio_type not in SERIALIZABLE_AUDIO_TYPES:
|
||||
raise ValueError('Audio type "{}" not supported'.format(audio_type))
|
||||
self.audio_type = audio_type
|
||||
self.sdb_file = open(sdb_filename, 'wb', buffering=buffering)
|
||||
self.offsets = []
|
||||
self.num_samples = 0
|
||||
|
||||
self.sdb_file.write(MAGIC)
|
||||
|
||||
meta_data = {
|
||||
SCHEMA_KEY: [
|
||||
{CONTENT_KEY: CONTENT_TYPE_SPEECH, MIME_TYPE_KEY: audio_type},
|
||||
{CONTENT_KEY: CONTENT_TYPE_TRANSCRIPT, MIME_TYPE_KEY: MIME_TYPE_TEXT}
|
||||
]
|
||||
}
|
||||
meta_data = json.dumps(meta_data).encode()
|
||||
self.write_big_int(len(meta_data))
|
||||
self.sdb_file.write(meta_data)
|
||||
|
||||
self.offset_samples = self.sdb_file.tell()
|
||||
self.sdb_file.seek(2 * BIGINT_SIZE, 1)
|
||||
|
||||
def write_int(self, n):
|
||||
return self.sdb_file.write(n.to_bytes(INT_SIZE, BIG_ENDIAN))
|
||||
|
||||
def write_big_int(self, n):
|
||||
return self.sdb_file.write(n.to_bytes(BIGINT_SIZE, BIG_ENDIAN))
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def add(self, sample):
|
||||
def to_bytes(n):
|
||||
return n.to_bytes(INT_SIZE, BIG_ENDIAN)
|
||||
sample.change_audio_type(self.audio_type)
|
||||
opus = sample.audio.getbuffer()
|
||||
opus_len = to_bytes(len(opus))
|
||||
transcript = sample.transcript.encode()
|
||||
transcript_len = to_bytes(len(transcript))
|
||||
entry_len = to_bytes(len(opus_len) + len(opus) + len(transcript_len) + len(transcript))
|
||||
buffer = b''.join([entry_len, opus_len, opus, transcript_len, transcript])
|
||||
self.offsets.append(self.sdb_file.tell())
|
||||
self.sdb_file.write(buffer)
|
||||
sample.sample_id = '{}:{}'.format(self.id_prefix, self.num_samples)
|
||||
self.num_samples += 1
|
||||
return sample.sample_id
|
||||
|
||||
def close(self):
|
||||
if self.sdb_file is None:
|
||||
return
|
||||
offset_index = self.sdb_file.tell()
|
||||
self.sdb_file.seek(self.offset_samples)
|
||||
self.write_big_int(offset_index - self.offset_samples - BIGINT_SIZE)
|
||||
self.write_big_int(self.num_samples)
|
||||
|
||||
self.sdb_file.seek(offset_index + BIGINT_SIZE)
|
||||
self.write_big_int(self.num_samples)
|
||||
for offset in self.offsets:
|
||||
self.write_big_int(offset)
|
||||
offset_end = self.sdb_file.tell()
|
||||
self.sdb_file.seek(offset_index)
|
||||
self.write_big_int(offset_end - offset_index - BIGINT_SIZE)
|
||||
self.sdb_file.close()
|
||||
self.sdb_file = None
|
||||
|
||||
def __len__(self):
|
||||
return len(self.offsets)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
|
||||
class SDB: # pylint: disable=too-many-instance-attributes
|
||||
"""Sample collection reader for reading a Sample DB (SDB) file"""
|
||||
def __init__(self, sdb_filename, buffering=BUFFER_SIZE, id_prefix=None):
|
||||
self.sdb_filename = sdb_filename
|
||||
self.id_prefix = sdb_filename if id_prefix is None else id_prefix
|
||||
self.sdb_file = open(sdb_filename, 'rb', buffering=buffering)
|
||||
self.offsets = []
|
||||
if self.sdb_file.read(len(MAGIC)) != MAGIC:
|
||||
raise RuntimeError('No Sample Database')
|
||||
meta_chunk_len = self.read_big_int()
|
||||
self.meta = json.loads(self.sdb_file.read(meta_chunk_len).decode())
|
||||
if SCHEMA_KEY not in self.meta:
|
||||
raise RuntimeError('Missing schema')
|
||||
self.schema = self.meta[SCHEMA_KEY]
|
||||
|
||||
speech_columns = self.find_columns(content=CONTENT_TYPE_SPEECH, mime_type=SERIALIZABLE_AUDIO_TYPES)
|
||||
if not speech_columns:
|
||||
raise RuntimeError('No speech data (missing in schema)')
|
||||
self.speech_index = speech_columns[0]
|
||||
self.audio_type = self.schema[self.speech_index][MIME_TYPE_KEY]
|
||||
|
||||
transcript_columns = self.find_columns(content=CONTENT_TYPE_TRANSCRIPT, mime_type=MIME_TYPE_TEXT)
|
||||
if not transcript_columns:
|
||||
raise RuntimeError('No transcript data (missing in schema)')
|
||||
self.transcript_index = transcript_columns[0]
|
||||
|
||||
sample_chunk_len = self.read_big_int()
|
||||
self.sdb_file.seek(sample_chunk_len + BIGINT_SIZE, 1)
|
||||
num_samples = self.read_big_int()
|
||||
for _ in range(num_samples):
|
||||
self.offsets.append(self.read_big_int())
|
||||
|
||||
def read_int(self):
|
||||
return int.from_bytes(self.sdb_file.read(INT_SIZE), BIG_ENDIAN)
|
||||
|
||||
def read_big_int(self):
|
||||
return int.from_bytes(self.sdb_file.read(BIGINT_SIZE), BIG_ENDIAN)
|
||||
|
||||
def find_columns(self, content=None, mime_type=None):
|
||||
criteria = []
|
||||
if content is not None:
|
||||
criteria.append((CONTENT_KEY, content))
|
||||
if mime_type is not None:
|
||||
criteria.append((MIME_TYPE_KEY, mime_type))
|
||||
if len(criteria) == 0:
|
||||
raise ValueError('At least one of "content" or "mime-type" has to be provided')
|
||||
matches = []
|
||||
for index, column in enumerate(self.schema):
|
||||
matched = 0
|
||||
for field, value in criteria:
|
||||
if column[field] == value or (isinstance(value, list) and column[field] in value):
|
||||
matched += 1
|
||||
if matched == len(criteria):
|
||||
matches.append(index)
|
||||
return matches
|
||||
|
||||
def read_row(self, row_index, *columns):
|
||||
columns = list(columns)
|
||||
column_data = [None] * len(columns)
|
||||
found = 0
|
||||
if not 0 <= row_index < len(self.offsets):
|
||||
raise ValueError('Wrong sample index: {} - has to be between 0 and {}'
|
||||
.format(row_index, len(self.offsets) - 1))
|
||||
self.sdb_file.seek(self.offsets[row_index] + INT_SIZE)
|
||||
for index in range(len(self.schema)):
|
||||
chunk_len = self.read_int()
|
||||
if index in columns:
|
||||
column_data[columns.index(index)] = self.sdb_file.read(chunk_len)
|
||||
found += 1
|
||||
if found == len(columns):
|
||||
return tuple(column_data)
|
||||
else:
|
||||
self.sdb_file.seek(chunk_len, 1)
|
||||
return tuple(column_data)
|
||||
|
||||
def __getitem__(self, i):
|
||||
audio_data, transcript = self.read_row(i, self.speech_index, self.transcript_index)
|
||||
transcript = transcript.decode()
|
||||
sample_id = '{}:{}'.format(self.id_prefix, i)
|
||||
return LabeledSample(self.audio_type, audio_data, transcript, sample_id=sample_id)
|
||||
|
||||
def __iter__(self):
|
||||
for i in range(len(self.offsets)):
|
||||
yield self[i]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.offsets)
|
||||
|
||||
def close(self):
|
||||
if self.sdb_file is not None:
|
||||
self.sdb_file.close()
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
|
||||
class CSV:
|
||||
"""Sample collection reader for reading a DeepSpeech CSV file"""
|
||||
def __init__(self, csv_filename):
|
||||
self.csv_filename = csv_filename
|
||||
self.rows = []
|
||||
csv_dir = Path(csv_filename).parent
|
||||
with open(csv_filename, 'r', encoding='utf8') as csv_file:
|
||||
reader = csv.DictReader(csv_file)
|
||||
for row in reader:
|
||||
wav_filename = Path(row['wav_filename'])
|
||||
if not wav_filename.is_absolute():
|
||||
wav_filename = csv_dir / wav_filename
|
||||
self.rows.append((str(wav_filename), int(row['wav_filesize']), row['transcript']))
|
||||
self.rows.sort(key=lambda r: r[1])
|
||||
|
||||
def __getitem__(self, i):
|
||||
wav_filename, _, transcript = self.rows[i]
|
||||
with open(wav_filename, 'rb') as wav_file:
|
||||
return LabeledSample(AUDIO_TYPE_WAV, wav_file.read(), transcript, sample_id=wav_filename)
|
||||
|
||||
def __iter__(self):
|
||||
for i in range(len(self.rows)):
|
||||
yield self[i]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.rows)
|
||||
|
||||
|
||||
def samples_from_file(filename, buffering=BUFFER_SIZE):
|
||||
"""Returns an iterable of LabeledSample objects loaded from a file."""
|
||||
ext = os.path.splitext(filename)[1].lower()
|
||||
if ext == '.sdb':
|
||||
return SDB(filename, buffering=buffering)
|
||||
if ext == '.csv':
|
||||
return CSV(filename)
|
||||
raise ValueError('Unknown file type: "{}"'.format(ext))
|
||||
|
||||
|
||||
def samples_from_files(filenames, buffering=BUFFER_SIZE):
|
||||
"""Returns an iterable of LabeledSample objects from a list of files."""
|
||||
if len(filenames) == 0:
|
||||
raise ValueError('No files')
|
||||
if len(filenames) == 1:
|
||||
return samples_from_file(filenames[0], buffering=buffering)
|
||||
cols = list(map(partial(samples_from_file, buffering=buffering), filenames))
|
||||
return Interleaved(*cols, key=lambda s: s.duration)
|
14
util/text.py
14
util/text.py
|
@ -4,7 +4,6 @@ import numpy as np
|
|||
import re
|
||||
import struct
|
||||
|
||||
from util.flags import FLAGS
|
||||
from six.moves import range
|
||||
|
||||
class Alphabet(object):
|
||||
|
@ -120,19 +119,22 @@ class UTF8Alphabet(object):
|
|||
return ''
|
||||
|
||||
|
||||
def text_to_char_array(series, alphabet):
|
||||
def text_to_char_array(transcript, alphabet, context=''):
|
||||
r"""
|
||||
Given a Pandas Series containing transcript string, map characters to
|
||||
Given a transcript string, map characters to
|
||||
integers and return a numpy array representing the processed string.
|
||||
Use a string in `context` for adding text to raised exceptions.
|
||||
"""
|
||||
try:
|
||||
transcript = np.asarray(alphabet.encode(series['transcript']))
|
||||
transcript = alphabet.encode(transcript)
|
||||
if len(transcript) == 0:
|
||||
raise ValueError('While processing: {}\nFound an empty transcript! You must include a transcript for all training data.'.format(series['wav_filename']))
|
||||
raise ValueError('While processing {}: Found an empty transcript! '
|
||||
'You must include a transcript for all training data.'
|
||||
.format(context))
|
||||
return transcript
|
||||
except KeyError as e:
|
||||
# Provide the row context (especially wav_filename) for alphabet errors
|
||||
raise ValueError('While processing: {}\n{}'.format(series['wav_filename'], e))
|
||||
raise ValueError('While processing: {}\n{}'.format(context, e))
|
||||
|
||||
|
||||
# The following code is from: http://hetland.org/coding/python/levenshtein.py
|
||||
|
|
Loading…
Reference in New Issue