SDB support

This commit is contained in:
Tilman Kamp 2020-03-05 14:09:51 +01:00
parent 3bd0b20bf7
commit 6b1d6773de
19 changed files with 916 additions and 91 deletions

View File

@ -33,7 +33,7 @@ from util.config import Config, initialize_globals
from util.checkpoints import load_or_init_graph from util.checkpoints import load_or_init_graph
from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
from util.flags import create_flags, FLAGS from util.flags import create_flags, FLAGS
from util.helpers import check_ctcdecoder_version from util.helpers import check_ctcdecoder_version, ExceptionBox
from util.logging import log_info, log_error, log_debug, log_progress, create_progressbar from util.logging import log_info, log_error, log_debug, log_progress, create_progressbar
check_ctcdecoder_version() check_ctcdecoder_version()
@ -418,12 +418,17 @@ def train():
FLAGS.augmentation_sparse_warp): FLAGS.augmentation_sparse_warp):
do_cache_dataset = False do_cache_dataset = False
exception_box = ExceptionBox()
# Create training and validation datasets # Create training and validation datasets
train_set = create_dataset(FLAGS.train_files.split(','), train_set = create_dataset(FLAGS.train_files.split(','),
batch_size=FLAGS.train_batch_size, batch_size=FLAGS.train_batch_size,
enable_cache=FLAGS.feature_cache and do_cache_dataset, enable_cache=FLAGS.feature_cache and do_cache_dataset,
cache_path=FLAGS.feature_cache, cache_path=FLAGS.feature_cache,
train_phase=True) train_phase=True,
exception_box=exception_box,
process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2,
buffering=FLAGS.read_buffer)
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set), iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
tfv1.data.get_output_shapes(train_set), tfv1.data.get_output_shapes(train_set),
@ -433,8 +438,13 @@ def train():
train_init_op = iterator.make_initializer(train_set) train_init_op = iterator.make_initializer(train_set)
if FLAGS.dev_files: if FLAGS.dev_files:
dev_csvs = FLAGS.dev_files.split(',') dev_sources = FLAGS.dev_files.split(',')
dev_sets = [create_dataset([csv], batch_size=FLAGS.dev_batch_size, train_phase=False) for csv in dev_csvs] dev_sets = [create_dataset([source],
batch_size=FLAGS.dev_batch_size,
train_phase=False,
exception_box=exception_box,
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
buffering=FLAGS.read_buffer) for source in dev_sources]
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets] dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
# Dropout # Dropout
@ -540,6 +550,7 @@ def train():
_, current_step, batch_loss, problem_files, step_summary = \ _, current_step, batch_loss, problem_files, step_summary = \
session.run([train_op, global_step, loss, non_finite_files, step_summaries_op], session.run([train_op, global_step, loss, non_finite_files, step_summaries_op],
feed_dict=feed_dict) feed_dict=feed_dict)
exception_box.raise_if_set()
except tf.errors.InvalidArgumentError as err: except tf.errors.InvalidArgumentError as err:
if FLAGS.augmentation_sparse_warp: if FLAGS.augmentation_sparse_warp:
log_info("Ignoring sparse warp error: {}".format(err)) log_info("Ignoring sparse warp error: {}".format(err))
@ -547,6 +558,7 @@ def train():
else: else:
raise raise
except tf.errors.OutOfRangeError: except tf.errors.OutOfRangeError:
exception_box.raise_if_set()
break break
if problem_files.size > 0: if problem_files.size > 0:
@ -586,12 +598,12 @@ def train():
# Validation # Validation
dev_loss = 0.0 dev_loss = 0.0
total_steps = 0 total_steps = 0
for csv, init_op in zip(dev_csvs, dev_init_ops): for source, init_op in zip(dev_sources, dev_init_ops):
log_progress('Validating epoch %d on %s...' % (epoch, csv)) log_progress('Validating epoch %d on %s...' % (epoch, source))
set_loss, steps = run_set('dev', epoch, init_op, dataset=csv) set_loss, steps = run_set('dev', epoch, init_op, dataset=source)
dev_loss += set_loss * steps dev_loss += set_loss * steps
total_steps += steps total_steps += steps
log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, csv, set_loss)) log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, source, set_loss))
dev_loss = dev_loss / total_steps dev_loss = dev_loss / total_steps
dev_losses.append(dev_loss) dev_losses.append(dev_loss)

51
bin/build_sdb.py Executable file
View File

@ -0,0 +1,51 @@
#!/usr/bin/env python
'''
Tool for building Sample Databases (SDB files) from DeepSpeech CSV files and other SDB files
Use "python3 build_sdb.py -h" for help
'''
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import argparse
import progressbar
from util.downloader import SIMPLE_BAR
from util.audio import change_audio_types, AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS
from util.sample_collections import samples_from_files, DirectSDBWriter
AUDIO_TYPE_LOOKUP = {
'wav': AUDIO_TYPE_WAV,
'opus': AUDIO_TYPE_OPUS
}
def build_sdb():
audio_type = AUDIO_TYPE_LOOKUP[CLI_ARGS.audio_type]
with DirectSDBWriter(CLI_ARGS.target, audio_type=audio_type) as sdb_writer:
samples = samples_from_files(CLI_ARGS.sources)
bar = progressbar.ProgressBar(max_value=len(samples), widgets=SIMPLE_BAR)
for sample in bar(change_audio_types(samples, audio_type=audio_type, processes=CLI_ARGS.workers)):
sdb_writer.add(sample)
def handle_args():
parser = argparse.ArgumentParser(description='Tool for building Sample Databases (SDB files) '
'from DeepSpeech CSV files and other SDB files')
parser.add_argument('sources', nargs='+', help='Source CSV and/or SDB files - '
'Note: For getting a correctly ordered target SDB, source SDBs have '
'to have their samples already ordered from shortest to longest.')
parser.add_argument('target', help='SDB file to create')
parser.add_argument('--audio-type', default='opus', choices=AUDIO_TYPE_LOOKUP.keys(),
help='Audio representation inside target SDB')
parser.add_argument('--workers', type=int, default=None, help='Number of encoding SDB workers')
return parser.parse_args()
if __name__ == "__main__":
CLI_ARGS = handle_args()
build_sdb()

73
bin/play.py Executable file
View File

@ -0,0 +1,73 @@
#!/usr/bin/env python
"""
Tool for playing samples from Sample Databases (SDB files) and DeepSpeech CSV files
Use "python3 build_sdb.py -h" for help
"""
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import random
import argparse
from util.sample_collections import samples_from_file
from util.audio import AUDIO_TYPE_PCM
def play_sample(samples, index):
if index < 0:
index = len(samples) + index
if CLI_ARGS.random:
index = random.randint(0, len(samples))
elif index >= len(samples):
print('No sample with index {}'.format(CLI_ARGS.start))
sys.exit(1)
sample = samples[index]
print('Sample "{}"'.format(sample.sample_id))
print(' "{}"'.format(sample.transcript))
sample.change_audio_type(AUDIO_TYPE_PCM)
rate, channels, width = sample.audio_format
wave_obj = simpleaudio.WaveObject(sample.audio, channels, width, rate)
play_obj = wave_obj.play()
play_obj.wait_done()
def play_collection():
samples = samples_from_file(CLI_ARGS.collection, buffering=0)
played = 0
index = CLI_ARGS.start
while True:
if 0 <= CLI_ARGS.number <= played:
return
play_sample(samples, index)
played += 1
index = (index + 1) % len(samples)
def handle_args():
parser = argparse.ArgumentParser(description='Tool for playing samples from Sample Databases (SDB files) '
'and DeepSpeech CSV files')
parser.add_argument('collection', help='Sample DB or CSV file to play samples from')
parser.add_argument('--start', type=int, default=0,
help='Sample index to start at (negative numbers are relative to the end of the collection)')
parser.add_argument('--number', type=int, default=-1, help='Number of samples to play (-1 for endless)')
parser.add_argument('--random', action='store_true', help='If samples should be played in random order')
return parser.parse_args()
if __name__ == "__main__":
try:
import simpleaudio
except ModuleNotFoundError:
print('play.py requires Python package "simpleaudio"')
sys.exit(1)
CLI_ARGS = handle_args()
try:
play_collection()
except KeyboardInterrupt:
print(' Stopped')
sys.exit(0)

View File

@ -0,0 +1,37 @@
#!/bin/sh
set -xe
ldc93s1_dir="./data/smoke_test"
ldc93s1_csv="${ldc93s1_dir}/ldc93s1.csv"
ldc93s1_sdb="${ldc93s1_dir}/ldc93s1.sdb"
if [ ! -f "${ldc93s1_dir}/ldc93s1.csv" ]; then
echo "Downloading and preprocessing LDC93S1 example data, saving in ${ldc93s1_dir}."
python -u bin/import_ldc93s1.py ${ldc93s1_dir}
fi;
if [ ! -f "${ldc93s1_dir}/ldc93s1.sdb" ]; then
echo "Converting LDC93S1 example data, saving to ${ldc93s1_sdb}."
python -u bin/build_sdb.py ${ldc93s1_csv} ${ldc93s1_sdb}
fi;
# Force only one visible device because we have a single-sample dataset
# and when trying to run on multiple devices (like GPUs), this will break
export CUDA_VISIBLE_DEVICES=0
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--train_files ${ldc93s1_sdb} --train_batch_size 1 \
--dev_files ${ldc93s1_sdb} --dev_batch_size 1 \
--test_files ${ldc93s1_sdb} --test_batch_size 1 \
--n_hidden 100 --epochs 1 \
--max_to_keep 1 --checkpoint_dir '/tmp/ckpt_sdb' \
--learning_rate 0.001 --dropout_rate 0.05 \
--scorer_path 'data/smoke_test/pruned_lm.scorer' | tee /tmp/resume.log
if ! grep "Loading best validating checkpoint from" /tmp/resume.log; then
echo "Did not resume training from checkpoint"
exit 1
else
exit 0
fi

34
bin/run-tc-ldc93s1_new_sdb.sh Executable file
View File

@ -0,0 +1,34 @@
#!/bin/sh
set -xe
ldc93s1_dir="./data/smoke_test"
ldc93s1_csv="${ldc93s1_dir}/ldc93s1.csv"
ldc93s1_sdb="${ldc93s1_dir}/ldc93s1.sdb"
epoch_count=$1
audio_sample_rate=$2
if [ ! -f "${ldc93s1_dir}/ldc93s1.csv" ]; then
echo "Downloading and preprocessing LDC93S1 example data, saving in ${ldc93s1_dir}."
python -u bin/import_ldc93s1.py ${ldc93s1_dir}
fi;
if [ ! -f "${ldc93s1_dir}/ldc93s1.sdb" ]; then
echo "Converting LDC93S1 example data, saving to ${ldc93s1_sdb}."
python -u bin/build_sdb.py ${ldc93s1_csv} ${ldc93s1_sdb}
fi;
# Force only one visible device because we have a single-sample dataset
# and when trying to run on multiple devices (like GPUs), this will break
export CUDA_VISIBLE_DEVICES=0
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--train_files ${ldc93s1_sdb} --train_batch_size 1 \
--dev_files ${ldc93s1_sdb} --dev_batch_size 1 \
--test_files ${ldc93s1_sdb} --test_batch_size 1 \
--n_hidden 100 --epochs $epoch_count \
--max_to_keep 1 --checkpoint_dir '/tmp/ckpt_sdb' \
--learning_rate 0.001 --dropout_rate 0.05 --export_dir '/tmp/train_sdb' \
--scorer_path 'data/smoke_test/pruned_lm.scorer' \
--audio_sample_rate ${audio_sample_rate}

View File

@ -0,0 +1,35 @@
#!/bin/sh
set -xe
ldc93s1_dir="./data/smoke_test"
ldc93s1_csv="${ldc93s1_dir}/ldc93s1.csv"
ldc93s1_sdb="${ldc93s1_dir}/ldc93s1.sdb"
epoch_count=$1
audio_sample_rate=$2
if [ ! -f "${ldc93s1_dir}/ldc93s1.csv" ]; then
echo "Downloading and preprocessing LDC93S1 example data, saving in ${ldc93s1_dir}."
python -u bin/import_ldc93s1.py ${ldc93s1_dir}
fi;
if [ ! -f "${ldc93s1_dir}/ldc93s1.sdb" ]; then
echo "Converting LDC93S1 example data, saving to ${ldc93s1_sdb}."
python -u bin/build_sdb.py ${ldc93s1_csv} ${ldc93s1_sdb}
fi;
# Force only one visible device because we have a single-sample dataset
# and when trying to run on multiple devices (like GPUs), this will break
export CUDA_VISIBLE_DEVICES=0
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--train_files ${ldc93s1_sdb},${ldc93s1_csv} --train_batch_size 1 \
--feature_cache '/tmp/ldc93s1_cache_sdb_csv' \
--dev_files ${ldc93s1_sdb},${ldc93s1_csv} --dev_batch_size 1 \
--test_files ${ldc93s1_sdb},${ldc93s1_csv} --test_batch_size 1 \
--n_hidden 100 --epochs $epoch_count \
--max_to_keep 1 --checkpoint_dir '/tmp/ckpt_sdb_csv' \
--learning_rate 0.001 --dropout_rate 0.05 --export_dir '/tmp/train_sdb_csv' \
--scorer_path 'data/smoke_test/pruned_lm.scorer' \
--audio_sample_rate ${audio_sample_rate}

View File

@ -2,12 +2,12 @@
tensorflow == 1.15.0 tensorflow == 1.15.0
numpy == 1.18.1 numpy == 1.18.1
progressbar2 progressbar2
pandas
six six
pyxdg pyxdg
attrdict attrdict
absl-py absl-py
semver semver
opuslib == 2.0.0
# Requirements for building native_client files # Requirements for building native_client files
setuptools setuptools
@ -15,6 +15,7 @@ setuptools
# Requirements for importers # Requirements for importers
sox sox
bs4 bs4
pandas
requests requests
librosa librosa
soundfile soundfile

View File

@ -5,6 +5,9 @@ python:
apt: 'python3-virtualenv python3-setuptools python3-pip python3-wheel python3-pkg-resources' apt: 'python3-virtualenv python3-setuptools python3-pip python3-wheel python3-pkg-resources'
packages_docs_bionic: packages_docs_bionic:
apt: 'python3 python3-pip zip doxygen' apt: 'python3 python3-pip zip doxygen'
training:
packages_trusty:
apt: 'libopus0'
tensorflow: tensorflow:
packages_trusty: packages_trusty:
apt: 'make build-essential gfortran git libblas-dev liblapack-dev libsox-dev libmagic-dev libgsm1-dev libltdl-dev libpng-dev python zlib1g-dev' apt: 'make build-essential gfortran git libblas-dev liblapack-dev libsox-dev libmagic-dev libgsm1-dev libltdl-dev libpng-dev python zlib1g-dev'

View File

@ -48,6 +48,11 @@ pushd ${HOME}/DeepSpeech/ds/
time ./bin/run-tc-ldc93s1_new.sh 249 "${sample_rate}" time ./bin/run-tc-ldc93s1_new.sh 249 "${sample_rate}"
time ./bin/run-tc-ldc93s1_new.sh 1 "${sample_rate}" time ./bin/run-tc-ldc93s1_new.sh 1 "${sample_rate}"
time ./bin/run-tc-ldc93s1_tflite.sh "${sample_rate}" time ./bin/run-tc-ldc93s1_tflite.sh "${sample_rate}"
# Testing single SDB source
time ./bin/run-tc-ldc93s1_new_sdb.sh 220 "${sample_rate}"
# Testing interleaved source (SDB+CSV combination) - run twice to test preprocessed features
time ./bin/run-tc-ldc93s1_new_sdb_csv.sh 109 "${sample_rate}"
time ./bin/run-tc-ldc93s1_new_sdb_csv.sh 1 "${sample_rate}"
popd popd
cp /tmp/train/output_graph.pb ${TASKCLUSTER_ARTIFACTS} cp /tmp/train/output_graph.pb ${TASKCLUSTER_ARTIFACTS}
@ -62,6 +67,7 @@ cp /tmp/train/output_graph.pbmm ${TASKCLUSTER_ARTIFACTS}
pushd ${HOME}/DeepSpeech/ds/ pushd ${HOME}/DeepSpeech/ds/
time ./bin/run-tc-ldc93s1_checkpoint.sh time ./bin/run-tc-ldc93s1_checkpoint.sh
time ./bin/run-tc-ldc93s1_checkpoint_sdb.sh
popd popd
virtualenv_deactivate "${pyalias}" "deepspeech" virtualenv_deactivate "${pyalias}" "deepspeech"

View File

@ -2,6 +2,9 @@ build:
template_file: test-linux-opt-base.tyml template_file: test-linux-opt-base.tyml
dependencies: dependencies:
- "linux-amd64-ctc-opt" - "linux-amd64-ctc-opt"
system_setup:
>
apt-get -qq update && apt-get -qq -y install ${training.packages_trusty.apt}
args: args:
tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-train-tests.sh 3.5.8:m 16k" tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-train-tests.sh 3.5.8:m 16k"
metadata: metadata:

View File

@ -2,6 +2,9 @@ build:
template_file: test-linux-opt-base.tyml template_file: test-linux-opt-base.tyml
dependencies: dependencies:
- "linux-amd64-ctc-opt" - "linux-amd64-ctc-opt"
system_setup:
>
apt-get -qq update && apt-get -qq -y install ${training.packages_trusty.apt}
args: args:
tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-train-tests.sh 3.6.10:m 16k" tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-train-tests.sh 3.6.10:m 16k"
metadata: metadata:

View File

@ -2,6 +2,9 @@ build:
template_file: test-linux-opt-base.tyml template_file: test-linux-opt-base.tyml
dependencies: dependencies:
- "linux-amd64-ctc-opt" - "linux-amd64-ctc-opt"
system_setup:
>
apt-get -qq update && apt-get -qq -y install ${training.packages_trusty.apt}
args: args:
tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-train-tests.sh 3.6.10:m 8k" tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-train-tests.sh 3.6.10:m 8k"
metadata: metadata:

View File

@ -1,34 +1,127 @@
import os import os
import sox import io
import wave import wave
import tempfile import tempfile
import collections import collections
import numpy as np
from util.helpers import LimitingPool
DEFAULT_RATE = 16000 DEFAULT_RATE = 16000
DEFAULT_CHANNELS = 1 DEFAULT_CHANNELS = 1
DEFAULT_WIDTH = 2 DEFAULT_WIDTH = 2
DEFAULT_FORMAT = (DEFAULT_RATE, DEFAULT_CHANNELS, DEFAULT_WIDTH) DEFAULT_FORMAT = (DEFAULT_RATE, DEFAULT_CHANNELS, DEFAULT_WIDTH)
AUDIO_TYPE_NP = 'application/vnd.mozilla.np'
AUDIO_TYPE_PCM = 'application/vnd.mozilla.pcm'
AUDIO_TYPE_WAV = 'audio/wav'
AUDIO_TYPE_OPUS = 'application/vnd.mozilla.opus'
SERIALIZABLE_AUDIO_TYPES = [AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS]
def get_audio_format(wav_file): OPUS_PCM_LEN_SIZE = 4
OPUS_RATE_SIZE = 4
OPUS_CHANNELS_SIZE = 1
OPUS_WIDTH_SIZE = 1
OPUS_CHUNK_LEN_SIZE = 2
class Sample:
"""Represents in-memory audio data of a certain (convertible) representation.
Attributes:
audio_type (str): See `__init__`.
audio_format (tuple:(int, int, int)): See `__init__`.
audio (obj): Audio data represented as indicated by `audio_type`
duration (float): Audio duration of the sample in seconds
"""
def __init__(self, audio_type, raw_data, audio_format=None):
"""
Creates a Sample from a raw audio representation.
:param audio_type: Audio data representation type
Supported types:
- AUDIO_TYPE_OPUS: Memory file representation (BytesIO) of Opus encoded audio
wrapped by a custom container format (used in SDBs)
- AUDIO_TYPE_WAV: Memory file representation (BytesIO) of a Wave file
- AUDIO_TYPE_PCM: Binary representation (bytearray) of PCM encoded audio data (Wave file without header)
- AUDIO_TYPE_NP: NumPy representation of audio data (np.float32) - typically used for GPU feeding
:param raw_data: Audio data in the form of the provided representation type (see audio_type).
For types AUDIO_TYPE_OPUS or AUDIO_TYPE_WAV data can also be passed as a bytearray.
:param audio_format: Tuple of sample-rate, number of channels and sample-width.
Required in case of audio_type = AUDIO_TYPE_PCM or AUDIO_TYPE_NP,
as this information cannot be derived from raw audio data.
"""
self.audio_type = audio_type
self.audio_format = audio_format
if audio_type in SERIALIZABLE_AUDIO_TYPES:
self.audio = raw_data if isinstance(raw_data, io.BytesIO) else io.BytesIO(raw_data)
self.duration = read_duration(audio_type, self.audio)
else:
self.audio = raw_data
if self.audio_format is None:
raise ValueError('For audio type "{}" parameter "audio_format" is mandatory'.format(self.audio_type))
if audio_type == AUDIO_TYPE_PCM:
self.duration = get_pcm_duration(len(self.audio), self.audio_format)
elif audio_type == AUDIO_TYPE_NP:
self.duration = get_np_duration(len(self.audio), self.audio_format)
else:
raise ValueError('Unsupported audio type: {}'.format(self.audio_type))
def change_audio_type(self, new_audio_type):
"""
In-place conversion of audio data into a different representation.
:param new_audio_type: New audio-type - see `__init__`.
Not supported: Converting from AUDIO_TYPE_NP into any other type.
"""
if self.audio_type == new_audio_type:
return
if new_audio_type == AUDIO_TYPE_PCM and self.audio_type in SERIALIZABLE_AUDIO_TYPES:
self.audio_format, audio = read_audio(self.audio_type, self.audio)
self.audio.close()
self.audio = audio
elif new_audio_type == AUDIO_TYPE_NP:
self.change_audio_type(AUDIO_TYPE_PCM)
self.audio = pcm_to_np(self.audio_format, self.audio)
elif new_audio_type in SERIALIZABLE_AUDIO_TYPES:
self.change_audio_type(AUDIO_TYPE_PCM)
audio_bytes = io.BytesIO()
write_audio(new_audio_type, audio_bytes, self.audio_format, self.audio)
audio_bytes.seek(0)
self.audio = audio_bytes
else:
raise RuntimeError('Changing audio representation type from "{}" to "{}" not supported'
.format(self.audio_type, new_audio_type))
self.audio_type = new_audio_type
def change_audio_types(samples, audio_type=AUDIO_TYPE_PCM, processes=None, process_ahead=None):
def change_audio_type(sample):
sample.change_audio_type(audio_type)
return sample
with LimitingPool(processes=processes, process_ahead=process_ahead) as pool:
yield from pool.imap(change_audio_type, samples)
def read_audio_format_from_wav_file(wav_file):
return wav_file.getframerate(), wav_file.getnchannels(), wav_file.getsampwidth() return wav_file.getframerate(), wav_file.getnchannels(), wav_file.getsampwidth()
def get_num_samples(audio_data, audio_format=DEFAULT_FORMAT): def get_num_samples(pcm_buffer_size, audio_format=DEFAULT_FORMAT):
_, channels, width = audio_format _, channels, width = audio_format
return len(audio_data) // (channels * width) return pcm_buffer_size // (channels * width)
def get_duration(audio_data, audio_format=DEFAULT_FORMAT): def get_pcm_duration(pcm_buffer_size, audio_format=DEFAULT_FORMAT):
return get_num_samples(audio_data, audio_format) / audio_format[0] """Calculates duration in seconds of a binary PCM buffer (typically read from a WAV file)"""
return get_num_samples(pcm_buffer_size, audio_format) / audio_format[0]
def get_duration_ms(audio_data, audio_format=DEFAULT_FORMAT): def get_np_duration(np_len, audio_format=DEFAULT_FORMAT):
return get_duration(audio_data, audio_format) * 1000 """Calculates duration in seconds of NumPy audio data"""
return np_len / audio_format[0]
def convert_audio(src_audio_path, dst_audio_path, file_type=None, audio_format=DEFAULT_FORMAT): def convert_audio(src_audio_path, dst_audio_path, file_type=None, audio_format=DEFAULT_FORMAT):
sample_rate, channels, width = audio_format sample_rate, channels, width = audio_format
import sox
transformer = sox.Transformer() transformer = sox.Transformer()
transformer.set_output_format(file_type=file_type, rate=sample_rate, channels=channels, bits=width*8) transformer.set_output_format(file_type=file_type, rate=sample_rate, channels=channels, bits=width*8)
transformer.build(src_audio_path, dst_audio_path) transformer.build(src_audio_path, dst_audio_path)
@ -45,7 +138,7 @@ class AudioFile:
def __enter__(self): def __enter__(self):
if self.audio_path.endswith('.wav'): if self.audio_path.endswith('.wav'):
self.open_file = wave.open(self.audio_path, 'r') self.open_file = wave.open(self.audio_path, 'r')
if get_audio_format(self.open_file) == self.audio_format: if read_audio_format_from_wav_file(self.open_file) == self.audio_format:
if self.as_path: if self.as_path:
self.open_file.close() self.open_file.close()
return self.audio_path return self.audio_path
@ -66,12 +159,12 @@ class AudioFile:
def read_frames(wav_file, frame_duration_ms=30, yield_remainder=False): def read_frames(wav_file, frame_duration_ms=30, yield_remainder=False):
audio_format = get_audio_format(wav_file) audio_format = read_audio_format_from_wav_file(wav_file)
frame_size = int(audio_format[0] * (frame_duration_ms / 1000.0)) frame_size = int(audio_format[0] * (frame_duration_ms / 1000.0))
while True: while True:
try: try:
data = wav_file.readframes(frame_size) data = wav_file.readframes(frame_size)
if not yield_remainder and get_duration_ms(data, audio_format) < frame_duration_ms: if not yield_remainder and get_pcm_duration(len(data), audio_format) * 1000 < frame_duration_ms:
break break
yield data yield data
except EOFError: except EOFError:
@ -106,7 +199,7 @@ def vad_split(audio_frames,
frame_duration_ms = 0 frame_duration_ms = 0
frame_index = 0 frame_index = 0
for frame_index, frame in enumerate(audio_frames): for frame_index, frame in enumerate(audio_frames):
frame_duration_ms = get_duration_ms(frame, audio_format) frame_duration_ms = get_pcm_duration(len(frame), audio_format) * 1000
if int(frame_duration_ms) not in [10, 20, 30]: if int(frame_duration_ms) not in [10, 20, 30]:
raise ValueError('VAD-splitting only supported for frame durations 10, 20, or 30 ms') raise ValueError('VAD-splitting only supported for frame durations 10, 20, or 30 ms')
is_speech = vad.is_speech(frame, sample_rate) is_speech = vad.is_speech(frame, sample_rate)
@ -133,3 +226,123 @@ def vad_split(audio_frames,
yield b''.join(voiced_frames), \ yield b''.join(voiced_frames), \
frame_duration_ms * (frame_index - len(voiced_frames)), \ frame_duration_ms * (frame_index - len(voiced_frames)), \
frame_duration_ms * (frame_index + 1) frame_duration_ms * (frame_index + 1)
def pack_number(n, num_bytes):
return n.to_bytes(num_bytes, 'big', signed=False)
def unpack_number(data):
return int.from_bytes(data, 'big', signed=False)
def get_opus_frame_size(rate):
return 60 * rate // 1000
def write_opus(opus_file, audio_format, audio_data):
rate, channels, width = audio_format
frame_size = get_opus_frame_size(rate)
import opuslib # pylint: disable=import-outside-toplevel
encoder = opuslib.Encoder(rate, channels, 'audio')
chunk_size = frame_size * channels * width
opus_file.write(pack_number(len(audio_data), OPUS_PCM_LEN_SIZE))
opus_file.write(pack_number(rate, OPUS_RATE_SIZE))
opus_file.write(pack_number(channels, OPUS_CHANNELS_SIZE))
opus_file.write(pack_number(width, OPUS_WIDTH_SIZE))
for i in range(0, len(audio_data), chunk_size):
chunk = audio_data[i:i + chunk_size]
# Preventing non-deterministic encoding results from uninitialized remainder of the encoder buffer
if len(chunk) < chunk_size:
chunk = chunk + bytearray(chunk_size - len(chunk))
encoded = encoder.encode(chunk, frame_size)
opus_file.write(pack_number(len(encoded), OPUS_CHUNK_LEN_SIZE))
opus_file.write(encoded)
def read_opus_header(opus_file):
opus_file.seek(0)
pcm_buffer_size = unpack_number(opus_file.read(OPUS_PCM_LEN_SIZE))
rate = unpack_number(opus_file.read(OPUS_RATE_SIZE))
channels = unpack_number(opus_file.read(OPUS_CHANNELS_SIZE))
width = unpack_number(opus_file.read(OPUS_WIDTH_SIZE))
return pcm_buffer_size, (rate, channels, width)
def read_opus(opus_file):
pcm_buffer_size, audio_format = read_opus_header(opus_file)
rate, channels, _ = audio_format
frame_size = get_opus_frame_size(rate)
import opuslib # pylint: disable=import-outside-toplevel
decoder = opuslib.Decoder(rate, channels)
audio_data = bytearray()
while len(audio_data) < pcm_buffer_size:
chunk_len = unpack_number(opus_file.read(OPUS_CHUNK_LEN_SIZE))
chunk = opus_file.read(chunk_len)
decoded = decoder.decode(chunk, frame_size)
audio_data.extend(decoded)
audio_data = audio_data[:pcm_buffer_size]
return audio_format, audio_data
def write_wav(wav_file, audio_format, pcm_data):
with wave.open(wav_file, 'wb') as wav_file_writer:
rate, channels, width = audio_format
wav_file_writer.setframerate(rate)
wav_file_writer.setnchannels(channels)
wav_file_writer.setsampwidth(width)
wav_file_writer.writeframes(pcm_data)
def read_wav(wav_file):
wav_file.seek(0)
with wave.open(wav_file, 'rb') as wav_file_reader:
audio_format = read_audio_format_from_wav_file(wav_file_reader)
pcm_data = wav_file_reader.readframes(wav_file_reader.getnframes())
return audio_format, pcm_data
def read_audio(audio_type, audio_file):
if audio_type == AUDIO_TYPE_WAV:
return read_wav(audio_file)
if audio_type == AUDIO_TYPE_OPUS:
return read_opus(audio_file)
raise ValueError('Unsupported audio type: {}'.format(audio_type))
def write_audio(audio_type, audio_file, audio_format, pcm_data):
if audio_type == AUDIO_TYPE_WAV:
return write_wav(audio_file, audio_format, pcm_data)
if audio_type == AUDIO_TYPE_OPUS:
return write_opus(audio_file, audio_format, pcm_data)
raise ValueError('Unsupported audio type: {}'.format(audio_type))
def read_wav_duration(wav_file):
wav_file.seek(0)
with wave.open(wav_file, 'rb') as wav_file_reader:
return wav_file_reader.getnframes() / wav_file_reader.getframerate()
def read_opus_duration(opus_file):
pcm_buffer_size, audio_format = read_opus_header(opus_file)
return get_pcm_duration(pcm_buffer_size, audio_format)
def read_duration(audio_type, audio_file):
if audio_type == AUDIO_TYPE_WAV:
return read_wav_duration(audio_file)
if audio_type == AUDIO_TYPE_OPUS:
return read_opus_duration(audio_file)
raise ValueError('Unsupported audio type: {}'.format(audio_type))
def pcm_to_np(audio_format, pcm_data):
_, channels, width = audio_format
if width not in [1, 2, 4]:
raise ValueError('Unsupported sample width: {}'.format(width))
dtype = [None, np.int8, np.int16, None, np.int32][width]
samples = np.frombuffer(pcm_data, dtype=dtype)
assert channels == 1 # only mono supported for now
samples = samples.astype(np.float32) / np.iinfo(dtype).max
return np.expand_dims(samples, axis=1)

View File

@ -12,6 +12,7 @@ from util.flags import FLAGS
from util.gpu import get_available_gpus from util.gpu import get_available_gpus
from util.logging import log_error from util.logging import log_error
from util.text import Alphabet, UTF8Alphabet from util.text import Alphabet, UTF8Alphabet
from util.helpers import parse_file_size
class ConfigSingleton: class ConfigSingleton:
_config = None _config = None
@ -29,6 +30,9 @@ Config = ConfigSingleton() # pylint: disable=invalid-name
def initialize_globals(): def initialize_globals():
c = AttrDict() c = AttrDict()
# Read-buffer
FLAGS.read_buffer = parse_file_size(FLAGS.read_buffer)
# Set default dropout rates # Set default dropout rates
if FLAGS.dropout_rate2 < 0: if FLAGS.dropout_rate2 < 0:
FLAGS.dropout_rate2 = FLAGS.dropout_rate FLAGS.dropout_rate2 = FLAGS.dropout_rate

View File

@ -1,12 +1,9 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import os
from functools import partial from functools import partial
import numpy as np import numpy as np
import pandas
import tensorflow as tf import tensorflow as tf
from tensorflow.python.ops import gen_audio_ops as contrib_audio from tensorflow.python.ops import gen_audio_ops as contrib_audio
@ -15,27 +12,18 @@ from util.config import Config
from util.text import text_to_char_array from util.text import text_to_char_array
from util.flags import FLAGS from util.flags import FLAGS
from util.spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up, augment_sparse_warp from util.spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up, augment_sparse_warp
from util.audio import read_frames_from_file, vad_split, DEFAULT_FORMAT from util.audio import change_audio_types, read_frames_from_file, vad_split, pcm_to_np, DEFAULT_FORMAT, AUDIO_TYPE_NP
from util.sample_collections import samples_from_files
from util.helpers import remember_exception, MEGABYTE
def read_csvs(csv_files): def samples_to_mfccs(samples, sample_rate, train_phase=False, sample_id=None):
sets = []
for csv in csv_files:
file = pandas.read_csv(csv, encoding='utf-8', na_filter=False)
#FIXME: not cross-platform
csv_dir = os.path.dirname(os.path.abspath(csv))
file['wav_filename'] = file['wav_filename'].str.replace(r'(^[^/])', lambda m: os.path.join(csv_dir, m.group(1))) # pylint: disable=cell-var-from-loop
sets.append(file)
# Concat all sets, drop any extra columns, re-index the final result as 0..N
return pandas.concat(sets, join='inner', ignore_index=True)
def samples_to_mfccs(samples, sample_rate, train_phase=False, wav_filename=None):
if train_phase: if train_phase:
# We need the lambdas to make TensorFlow happy. # We need the lambdas to make TensorFlow happy.
# pylint: disable=unnecessary-lambda # pylint: disable=unnecessary-lambda
tf.cond(tf.math.not_equal(sample_rate, FLAGS.audio_sample_rate), tf.cond(tf.math.not_equal(sample_rate, FLAGS.audio_sample_rate),
lambda: tf.print('WARNING: sample rate of file', wav_filename, '(', sample_rate, ') does not match FLAGS.audio_sample_rate. This can lead to incorrect results.'), lambda: tf.print('WARNING: sample rate of sample', sample_id, '(', sample_rate, ') '
'does not match FLAGS.audio_sample_rate. This can lead to incorrect results.'),
lambda: tf.no_op(), lambda: tf.no_op(),
name='matching_sample_rate') name='matching_sample_rate')
@ -84,10 +72,8 @@ def samples_to_mfccs(samples, sample_rate, train_phase=False, wav_filename=None)
return mfccs, tf.shape(input=mfccs)[0] return mfccs, tf.shape(input=mfccs)[0]
def audiofile_to_features(wav_filename, train_phase=False): def audio_to_features(audio, sample_rate, train_phase=False, sample_id=None):
samples = tf.io.read_file(wav_filename) features, features_len = samples_to_mfccs(audio, sample_rate, train_phase=train_phase, sample_id=sample_id)
decoded = contrib_audio.decode_wav(samples, desired_channels=1)
features, features_len = samples_to_mfccs(decoded.audio, decoded.sample_rate, train_phase=train_phase, wav_filename=wav_filename)
if train_phase: if train_phase:
if FLAGS.data_aug_features_multiplicative > 0: if FLAGS.data_aug_features_multiplicative > 0:
@ -99,10 +85,17 @@ def audiofile_to_features(wav_filename, train_phase=False):
return features, features_len return features, features_len
def entry_to_features(wav_filename, transcript, train_phase): def audiofile_to_features(wav_filename, train_phase=False):
samples = tf.io.read_file(wav_filename)
decoded = contrib_audio.decode_wav(samples, desired_channels=1)
return audio_to_features(decoded.audio, decoded.sample_rate, train_phase=train_phase, sample_id=wav_filename)
def entry_to_features(sample_id, audio, sample_rate, transcript, train_phase=False):
# https://bugs.python.org/issue32117 # https://bugs.python.org/issue32117
features, features_len = audiofile_to_features(wav_filename, train_phase=train_phase) features, features_len = audio_to_features(audio, sample_rate, train_phase=train_phase, sample_id=sample_id)
return wav_filename, features, features_len, tf.SparseTensor(*transcript) sparse_transcript = tf.SparseTensor(*transcript)
return sample_id, features, features_len, sparse_transcript
def to_sparse_tuple(sequence): def to_sparse_tuple(sequence):
@ -114,15 +107,22 @@ def to_sparse_tuple(sequence):
return indices, sequence, shape return indices, sequence, shape
def create_dataset(csvs, batch_size, enable_cache=False, cache_path=None, train_phase=False): def create_dataset(sources,
df = read_csvs(csvs) batch_size,
df.sort_values(by='wav_filesize', inplace=True) enable_cache=False,
cache_path=None,
df['transcript'] = df.apply(text_to_char_array, alphabet=Config.alphabet, result_type='reduce', axis=1) train_phase=False,
exception_box=None,
process_ahead=None,
buffering=1 * MEGABYTE):
def generate_values(): def generate_values():
for _, row in df.iterrows(): samples = samples_from_files(sources, buffering=buffering)
yield row.wav_filename, to_sparse_tuple(row.transcript) for sample in change_audio_types(samples,
AUDIO_TYPE_NP,
process_ahead=2 * batch_size if process_ahead is None else process_ahead):
transcript = text_to_char_array(sample.transcript, Config.alphabet, context=sample.sample_id)
transcript = to_sparse_tuple(transcript)
yield sample.sample_id, sample.audio, sample.audio_format[0], transcript
# Batching a dataset of 2D SparseTensors creates 3D batches, which fail # Batching a dataset of 2D SparseTensors creates 3D batches, which fail
# when passed to tf.nn.ctc_loss, so we reshape them to remove the extra # when passed to tf.nn.ctc_loss, so we reshape them to remove the extra
@ -131,27 +131,23 @@ def create_dataset(csvs, batch_size, enable_cache=False, cache_path=None, train_
shape = sparse.dense_shape shape = sparse.dense_shape
return tf.sparse.reshape(sparse, [shape[0], shape[2]]) return tf.sparse.reshape(sparse, [shape[0], shape[2]])
def batch_fn(wav_filenames, features, features_len, transcripts): def batch_fn(sample_ids, features, features_len, transcripts):
features = tf.data.Dataset.zip((features, features_len)) features = tf.data.Dataset.zip((features, features_len))
features = features.padded_batch(batch_size, features = features.padded_batch(batch_size, padded_shapes=([None, Config.n_input], []))
padded_shapes=([None, Config.n_input], []))
transcripts = transcripts.batch(batch_size).map(sparse_reshape) transcripts = transcripts.batch(batch_size).map(sparse_reshape)
wav_filenames = wav_filenames.batch(batch_size) sample_ids = sample_ids.batch(batch_size)
return tf.data.Dataset.zip((wav_filenames, features, transcripts)) return tf.data.Dataset.zip((sample_ids, features, transcripts))
num_gpus = len(Config.available_devices)
process_fn = partial(entry_to_features, train_phase=train_phase) process_fn = partial(entry_to_features, train_phase=train_phase)
dataset = (tf.data.Dataset.from_generator(generate_values, dataset = (tf.data.Dataset.from_generator(remember_exception(generate_values, exception_box),
output_types=(tf.string, (tf.int64, tf.int32, tf.int64))) output_types=(tf.string, tf.float32, tf.int32,
(tf.int64, tf.int32, tf.int64)))
.map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)) .map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE))
if enable_cache: if enable_cache:
dataset = dataset.cache(cache_path) dataset = dataset.cache(cache_path)
dataset = (dataset.window(batch_size, drop_remainder=True).flat_map(batch_fn) dataset = (dataset.window(batch_size, drop_remainder=True).flat_map(batch_fn)
.prefetch(num_gpus)) .prefetch(len(Config.available_devices)))
return dataset return dataset
@ -160,27 +156,24 @@ def split_audio_file(audio_path,
batch_size=1, batch_size=1,
aggressiveness=3, aggressiveness=3,
outlier_duration_ms=10000, outlier_duration_ms=10000,
outlier_batch_size=1): outlier_batch_size=1,
sample_rate, _, sample_width = audio_format exception_box=None):
multiplier = 1.0 / (1 << (8 * sample_width - 1))
def generate_values(): def generate_values():
frames = read_frames_from_file(audio_path) frames = read_frames_from_file(audio_path)
segments = vad_split(frames, aggressiveness=aggressiveness) segments = vad_split(frames, aggressiveness=aggressiveness)
for segment in segments: for segment in segments:
segment_buffer, time_start, time_end = segment segment_buffer, time_start, time_end = segment
samples = np.frombuffer(segment_buffer, dtype=np.int16) samples = pcm_to_np(audio_format, segment_buffer)
samples = samples * multiplier
samples = np.expand_dims(samples, axis=1)
yield time_start, time_end, samples yield time_start, time_end, samples
def to_mfccs(time_start, time_end, samples): def to_mfccs(time_start, time_end, samples):
features, features_len = samples_to_mfccs(samples, sample_rate) features, features_len = samples_to_mfccs(samples, audio_format[0])
return time_start, time_end, features, features_len return time_start, time_end, features, features_len
def create_batch_set(bs, criteria): def create_batch_set(bs, criteria):
return (tf.data.Dataset return (tf.data.Dataset
.from_generator(generate_values, output_types=(tf.int32, tf.int32, tf.float32)) .from_generator(remember_exception(generate_values, exception_box),
output_types=(tf.int32, tf.int32, tf.float32))
.map(to_mfccs, num_parallel_calls=tf.data.experimental.AUTOTUNE) .map(to_mfccs, num_parallel_calls=tf.data.experimental.AUTOTUNE)
.filter(criteria) .filter(criteria)
.padded_batch(bs, padded_shapes=([], [], [None, Config.n_input], []))) .padded_batch(bs, padded_shapes=([], [], [None, Config.n_input], [])))
@ -192,9 +185,3 @@ def split_audio_file(audio_path,
dataset = nds.concatenate(ods) dataset = nds.concatenate(ods)
dataset = dataset.prefetch(len(Config.available_devices)) dataset = dataset.prefetch(len(Config.available_devices))
return dataset return dataset
def secs_to_hours(secs):
hours, remainder = divmod(secs, 3600)
minutes, seconds = divmod(remainder, 60)
return '%d:%02d:%02d' % (hours, minutes, seconds)

View File

@ -15,6 +15,7 @@ def create_flags():
f.DEFINE_string('dev_files', '', 'comma separated list of files specifying the dataset used for validation. Multiple files will get merged. If empty, validation will not be run.') f.DEFINE_string('dev_files', '', 'comma separated list of files specifying the dataset used for validation. Multiple files will get merged. If empty, validation will not be run.')
f.DEFINE_string('test_files', '', 'comma separated list of files specifying the dataset used for testing. Multiple files will get merged. If empty, the model will not be tested.') f.DEFINE_string('test_files', '', 'comma separated list of files specifying the dataset used for testing. Multiple files will get merged. If empty, the model will not be tested.')
f.DEFINE_string('read_buffer', '1MB', 'buffer-size for reading samples from datasets (supports file-size suffixes KB, MB, GB, TB)')
f.DEFINE_string('feature_cache', '', 'cache MFCC features to disk to speed up future training runs ont he same data. This flag specifies the path where cached features extracted from --train_files will be saved. If empty, or if online augmentation flags are enabled, caching will be disabled.') f.DEFINE_string('feature_cache', '', 'cache MFCC features to disk to speed up future training runs ont he same data. This flag specifies the path where cached features extracted from --train_files will be saved. If empty, or if online augmentation flags are enabled, caching will be disabled.')
f.DEFINE_integer('feature_win_len', 32, 'feature extraction audio window length in milliseconds') f.DEFINE_integer('feature_win_len', 32, 'feature extraction audio window length in milliseconds')

View File

@ -1,10 +1,32 @@
import os import os
import semver
import sys import sys
import time
import heapq
import semver
from multiprocessing.dummy import Pool as ThreadPool
KILO = 1024
KILOBYTE = 1 * KILO
MEGABYTE = KILO * KILOBYTE
GIGABYTE = KILO * MEGABYTE
TERABYTE = KILO * GIGABYTE
SIZE_PREFIX_LOOKUP = {'k': KILOBYTE, 'm': MEGABYTE, 'g': GIGABYTE, 't': TERABYTE}
def parse_file_size(file_size):
file_size = file_size.lower().strip()
if len(file_size) == 0:
return 0
n = int(keep_only_digits(file_size))
if file_size[-1] == 'b':
file_size = file_size[:-1]
e = file_size[-1]
return SIZE_PREFIX_LOOKUP[e] * n if e in SIZE_PREFIX_LOOKUP else n
def keep_only_digits(txt): def keep_only_digits(txt):
return ''.join(filter(lambda c: c.isdigit(), txt)) return ''.join(filter(str.isdigit, txt))
def secs_to_hours(secs): def secs_to_hours(secs):
@ -21,7 +43,8 @@ def check_ctcdecoder_version():
from ds_ctcdecoder import __version__ as decoder_version from ds_ctcdecoder import __version__ as decoder_version
except ImportError as e: except ImportError as e:
if e.msg.find('__version__') > 0: if e.msg.find('__version__') > 0:
print("DeepSpeech version ({ds_version}) requires CTC decoder to expose __version__. Please upgrade the ds_ctcdecoder package to version {ds_version}".format(ds_version=ds_version_s)) print("DeepSpeech version ({ds_version}) requires CTC decoder to expose __version__. "
"Please upgrade the ds_ctcdecoder package to version {ds_version}".format(ds_version=ds_version_s))
sys.exit(1) sys.exit(1)
raise e raise e
@ -29,7 +52,79 @@ def check_ctcdecoder_version():
rv = semver.compare(ds_version_s, decoder_version_s) rv = semver.compare(ds_version_s, decoder_version_s)
if rv != 0: if rv != 0:
print("DeepSpeech version ({}) and CTC decoder version ({}) do not match. Please ensure matching versions are in use.".format(ds_version_s, decoder_version_s)) print("DeepSpeech version ({}) and CTC decoder version ({}) do not match. "
"Please ensure matching versions are in use.".format(ds_version_s, decoder_version_s))
sys.exit(1) sys.exit(1)
return rv return rv
class Interleaved:
"""Collection that lazily combines sorted collections in an interleaving fashion.
During iteration the next smallest element from all the sorted collections is always picked.
The collections must support iter() and len()."""
def __init__(self, *iterables, key=lambda obj: obj):
self.iterables = iterables
self.key = key
self.len = sum(map(len, iterables))
def __iter__(self):
return heapq.merge(*self.iterables, key=self.key)
def __len__(self):
return self.len
class LimitingPool:
"""Limits unbound ahead-processing of multiprocessing.Pool's imap method
before items get consumed by the iteration caller.
This prevents OOM issues in situations where items represent larger memory allocations."""
def __init__(self, processes=None, process_ahead=None, sleeping_for=0.1):
self.process_ahead = os.cpu_count() if process_ahead is None else process_ahead
self.sleeping_for = sleeping_for
self.processed = 0
self.pool = ThreadPool(processes=processes)
def __enter__(self):
return self
def _limit(self, it):
for obj in it:
while self.processed >= self.process_ahead:
time.sleep(self.sleeping_for)
self.processed += 1
yield obj
def imap(self, fun, it):
for obj in self.pool.imap(fun, self._limit(it)):
self.processed -= 1
yield obj
def __exit__(self, exc_type, exc_value, traceback):
self.pool.close()
class ExceptionBox:
"""Helper class for passing-back and re-raising an exception from inside a TensorFlow dataset generator.
Used in conjunction with `remember_exception`."""
def __init__(self):
self.exception = None
def raise_if_set(self):
if self.exception is not None:
exception = self.exception
self.exception = None
raise exception # pylint: disable = raising-bad-type
def remember_exception(iterable, exception_box=None):
"""Wraps a TensorFlow dataset generator for catching its actual exceptions
that would otherwise just interrupt iteration w/o bubbling up."""
def do_iterate():
try:
yield from iterable()
except StopIteration:
return
except Exception as ex: # pylint: disable = broad-except
exception_box.exception = ex
return iterable if exception_box is None else do_iterate

262
util/sample_collections.py Normal file
View File

@ -0,0 +1,262 @@
# -*- coding: utf-8 -*-
import os
import csv
import json
from pathlib import Path
from functools import partial
from util.helpers import MEGABYTE, GIGABYTE, Interleaved
from util.audio import Sample, DEFAULT_FORMAT, AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS, SERIALIZABLE_AUDIO_TYPES
BIG_ENDIAN = 'big'
INT_SIZE = 4
BIGINT_SIZE = 2 * INT_SIZE
MAGIC = b'SAMPLEDB'
BUFFER_SIZE = 1 * MEGABYTE
CACHE_SIZE = 1 * GIGABYTE
SCHEMA_KEY = 'schema'
CONTENT_KEY = 'content'
MIME_TYPE_KEY = 'mime-type'
MIME_TYPE_TEXT = 'text/plain'
CONTENT_TYPE_SPEECH = 'speech'
CONTENT_TYPE_TRANSCRIPT = 'transcript'
class LabeledSample(Sample):
"""In-memory labeled audio sample representing an utterance.
Derived from util.audio.Sample and used by sample collection readers and writers."""
def __init__(self, audio_type, raw_data, transcript, audio_format=DEFAULT_FORMAT, sample_id=None):
"""
Creates an in-memory speech sample together with a transcript of the utterance (label).
:param audio_type: See util.audio.Sample.__init__ .
:param raw_data: See util.audio.Sample.__init__ .
:param transcript: Transcript of the sample's utterance
:param audio_format: See util.audio.Sample.__init__ .
:param sample_id: Tracking ID - typically assigned by collection readers
"""
super().__init__(audio_type, raw_data, audio_format=audio_format)
self.sample_id = sample_id
self.transcript = transcript
class DirectSDBWriter:
"""Sample collection writer for creating a Sample DB (SDB) file"""
def __init__(self, sdb_filename, buffering=BUFFER_SIZE, audio_type=AUDIO_TYPE_OPUS, id_prefix=None):
self.sdb_filename = sdb_filename
self.id_prefix = sdb_filename if id_prefix is None else id_prefix
if audio_type not in SERIALIZABLE_AUDIO_TYPES:
raise ValueError('Audio type "{}" not supported'.format(audio_type))
self.audio_type = audio_type
self.sdb_file = open(sdb_filename, 'wb', buffering=buffering)
self.offsets = []
self.num_samples = 0
self.sdb_file.write(MAGIC)
meta_data = {
SCHEMA_KEY: [
{CONTENT_KEY: CONTENT_TYPE_SPEECH, MIME_TYPE_KEY: audio_type},
{CONTENT_KEY: CONTENT_TYPE_TRANSCRIPT, MIME_TYPE_KEY: MIME_TYPE_TEXT}
]
}
meta_data = json.dumps(meta_data).encode()
self.write_big_int(len(meta_data))
self.sdb_file.write(meta_data)
self.offset_samples = self.sdb_file.tell()
self.sdb_file.seek(2 * BIGINT_SIZE, 1)
def write_int(self, n):
return self.sdb_file.write(n.to_bytes(INT_SIZE, BIG_ENDIAN))
def write_big_int(self, n):
return self.sdb_file.write(n.to_bytes(BIGINT_SIZE, BIG_ENDIAN))
def __enter__(self):
return self
def add(self, sample):
def to_bytes(n):
return n.to_bytes(INT_SIZE, BIG_ENDIAN)
sample.change_audio_type(self.audio_type)
opus = sample.audio.getbuffer()
opus_len = to_bytes(len(opus))
transcript = sample.transcript.encode()
transcript_len = to_bytes(len(transcript))
entry_len = to_bytes(len(opus_len) + len(opus) + len(transcript_len) + len(transcript))
buffer = b''.join([entry_len, opus_len, opus, transcript_len, transcript])
self.offsets.append(self.sdb_file.tell())
self.sdb_file.write(buffer)
sample.sample_id = '{}:{}'.format(self.id_prefix, self.num_samples)
self.num_samples += 1
return sample.sample_id
def close(self):
if self.sdb_file is None:
return
offset_index = self.sdb_file.tell()
self.sdb_file.seek(self.offset_samples)
self.write_big_int(offset_index - self.offset_samples - BIGINT_SIZE)
self.write_big_int(self.num_samples)
self.sdb_file.seek(offset_index + BIGINT_SIZE)
self.write_big_int(self.num_samples)
for offset in self.offsets:
self.write_big_int(offset)
offset_end = self.sdb_file.tell()
self.sdb_file.seek(offset_index)
self.write_big_int(offset_end - offset_index - BIGINT_SIZE)
self.sdb_file.close()
self.sdb_file = None
def __len__(self):
return len(self.offsets)
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
class SDB: # pylint: disable=too-many-instance-attributes
"""Sample collection reader for reading a Sample DB (SDB) file"""
def __init__(self, sdb_filename, buffering=BUFFER_SIZE, id_prefix=None):
self.sdb_filename = sdb_filename
self.id_prefix = sdb_filename if id_prefix is None else id_prefix
self.sdb_file = open(sdb_filename, 'rb', buffering=buffering)
self.offsets = []
if self.sdb_file.read(len(MAGIC)) != MAGIC:
raise RuntimeError('No Sample Database')
meta_chunk_len = self.read_big_int()
self.meta = json.loads(self.sdb_file.read(meta_chunk_len).decode())
if SCHEMA_KEY not in self.meta:
raise RuntimeError('Missing schema')
self.schema = self.meta[SCHEMA_KEY]
speech_columns = self.find_columns(content=CONTENT_TYPE_SPEECH, mime_type=SERIALIZABLE_AUDIO_TYPES)
if not speech_columns:
raise RuntimeError('No speech data (missing in schema)')
self.speech_index = speech_columns[0]
self.audio_type = self.schema[self.speech_index][MIME_TYPE_KEY]
transcript_columns = self.find_columns(content=CONTENT_TYPE_TRANSCRIPT, mime_type=MIME_TYPE_TEXT)
if not transcript_columns:
raise RuntimeError('No transcript data (missing in schema)')
self.transcript_index = transcript_columns[0]
sample_chunk_len = self.read_big_int()
self.sdb_file.seek(sample_chunk_len + BIGINT_SIZE, 1)
num_samples = self.read_big_int()
for _ in range(num_samples):
self.offsets.append(self.read_big_int())
def read_int(self):
return int.from_bytes(self.sdb_file.read(INT_SIZE), BIG_ENDIAN)
def read_big_int(self):
return int.from_bytes(self.sdb_file.read(BIGINT_SIZE), BIG_ENDIAN)
def find_columns(self, content=None, mime_type=None):
criteria = []
if content is not None:
criteria.append((CONTENT_KEY, content))
if mime_type is not None:
criteria.append((MIME_TYPE_KEY, mime_type))
if len(criteria) == 0:
raise ValueError('At least one of "content" or "mime-type" has to be provided')
matches = []
for index, column in enumerate(self.schema):
matched = 0
for field, value in criteria:
if column[field] == value or (isinstance(value, list) and column[field] in value):
matched += 1
if matched == len(criteria):
matches.append(index)
return matches
def read_row(self, row_index, *columns):
columns = list(columns)
column_data = [None] * len(columns)
found = 0
if not 0 <= row_index < len(self.offsets):
raise ValueError('Wrong sample index: {} - has to be between 0 and {}'
.format(row_index, len(self.offsets) - 1))
self.sdb_file.seek(self.offsets[row_index] + INT_SIZE)
for index in range(len(self.schema)):
chunk_len = self.read_int()
if index in columns:
column_data[columns.index(index)] = self.sdb_file.read(chunk_len)
found += 1
if found == len(columns):
return tuple(column_data)
else:
self.sdb_file.seek(chunk_len, 1)
return tuple(column_data)
def __getitem__(self, i):
audio_data, transcript = self.read_row(i, self.speech_index, self.transcript_index)
transcript = transcript.decode()
sample_id = '{}:{}'.format(self.id_prefix, i)
return LabeledSample(self.audio_type, audio_data, transcript, sample_id=sample_id)
def __iter__(self):
for i in range(len(self.offsets)):
yield self[i]
def __len__(self):
return len(self.offsets)
def close(self):
if self.sdb_file is not None:
self.sdb_file.close()
def __del__(self):
self.close()
class CSV:
"""Sample collection reader for reading a DeepSpeech CSV file"""
def __init__(self, csv_filename):
self.csv_filename = csv_filename
self.rows = []
csv_dir = Path(csv_filename).parent
with open(csv_filename, 'r', encoding='utf8') as csv_file:
reader = csv.DictReader(csv_file)
for row in reader:
wav_filename = Path(row['wav_filename'])
if not wav_filename.is_absolute():
wav_filename = csv_dir / wav_filename
self.rows.append((str(wav_filename), int(row['wav_filesize']), row['transcript']))
self.rows.sort(key=lambda r: r[1])
def __getitem__(self, i):
wav_filename, _, transcript = self.rows[i]
with open(wav_filename, 'rb') as wav_file:
return LabeledSample(AUDIO_TYPE_WAV, wav_file.read(), transcript, sample_id=wav_filename)
def __iter__(self):
for i in range(len(self.rows)):
yield self[i]
def __len__(self):
return len(self.rows)
def samples_from_file(filename, buffering=BUFFER_SIZE):
"""Returns an iterable of LabeledSample objects loaded from a file."""
ext = os.path.splitext(filename)[1].lower()
if ext == '.sdb':
return SDB(filename, buffering=buffering)
if ext == '.csv':
return CSV(filename)
raise ValueError('Unknown file type: "{}"'.format(ext))
def samples_from_files(filenames, buffering=BUFFER_SIZE):
"""Returns an iterable of LabeledSample objects from a list of files."""
if len(filenames) == 0:
raise ValueError('No files')
if len(filenames) == 1:
return samples_from_file(filenames[0], buffering=buffering)
cols = list(map(partial(samples_from_file, buffering=buffering), filenames))
return Interleaved(*cols, key=lambda s: s.duration)

View File

@ -4,7 +4,6 @@ import numpy as np
import re import re
import struct import struct
from util.flags import FLAGS
from six.moves import range from six.moves import range
class Alphabet(object): class Alphabet(object):
@ -120,19 +119,22 @@ class UTF8Alphabet(object):
return '' return ''
def text_to_char_array(series, alphabet): def text_to_char_array(transcript, alphabet, context=''):
r""" r"""
Given a Pandas Series containing transcript string, map characters to Given a transcript string, map characters to
integers and return a numpy array representing the processed string. integers and return a numpy array representing the processed string.
Use a string in `context` for adding text to raised exceptions.
""" """
try: try:
transcript = np.asarray(alphabet.encode(series['transcript'])) transcript = alphabet.encode(transcript)
if len(transcript) == 0: if len(transcript) == 0:
raise ValueError('While processing: {}\nFound an empty transcript! You must include a transcript for all training data.'.format(series['wav_filename'])) raise ValueError('While processing {}: Found an empty transcript! '
'You must include a transcript for all training data.'
.format(context))
return transcript return transcript
except KeyError as e: except KeyError as e:
# Provide the row context (especially wav_filename) for alphabet errors # Provide the row context (especially wav_filename) for alphabet errors
raise ValueError('While processing: {}\n{}'.format(series['wav_filename'], e)) raise ValueError('While processing: {}\n{}'.format(context, e))
# The following code is from: http://hetland.org/coding/python/levenshtein.py # The following code is from: http://hetland.org/coding/python/levenshtein.py