diff --git a/DeepSpeech.py b/DeepSpeech.py index 8ebd1e25..304d3dc2 100644 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -33,7 +33,7 @@ from util.config import Config, initialize_globals from util.checkpoints import load_or_init_graph from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features from util.flags import create_flags, FLAGS -from util.helpers import check_ctcdecoder_version +from util.helpers import check_ctcdecoder_version, ExceptionBox from util.logging import log_info, log_error, log_debug, log_progress, create_progressbar check_ctcdecoder_version() @@ -418,12 +418,17 @@ def train(): FLAGS.augmentation_sparse_warp): do_cache_dataset = False + exception_box = ExceptionBox() + # Create training and validation datasets train_set = create_dataset(FLAGS.train_files.split(','), batch_size=FLAGS.train_batch_size, enable_cache=FLAGS.feature_cache and do_cache_dataset, cache_path=FLAGS.feature_cache, - train_phase=True) + train_phase=True, + exception_box=exception_box, + process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2, + buffering=FLAGS.read_buffer) iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set), tfv1.data.get_output_shapes(train_set), @@ -433,8 +438,13 @@ def train(): train_init_op = iterator.make_initializer(train_set) if FLAGS.dev_files: - dev_csvs = FLAGS.dev_files.split(',') - dev_sets = [create_dataset([csv], batch_size=FLAGS.dev_batch_size, train_phase=False) for csv in dev_csvs] + dev_sources = FLAGS.dev_files.split(',') + dev_sets = [create_dataset([source], + batch_size=FLAGS.dev_batch_size, + train_phase=False, + exception_box=exception_box, + process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2, + buffering=FLAGS.read_buffer) for source in dev_sources] dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets] # Dropout @@ -540,6 +550,7 @@ def train(): _, current_step, batch_loss, problem_files, step_summary = \ session.run([train_op, global_step, loss, non_finite_files, step_summaries_op], feed_dict=feed_dict) + exception_box.raise_if_set() except tf.errors.InvalidArgumentError as err: if FLAGS.augmentation_sparse_warp: log_info("Ignoring sparse warp error: {}".format(err)) @@ -547,6 +558,7 @@ def train(): else: raise except tf.errors.OutOfRangeError: + exception_box.raise_if_set() break if problem_files.size > 0: @@ -586,12 +598,12 @@ def train(): # Validation dev_loss = 0.0 total_steps = 0 - for csv, init_op in zip(dev_csvs, dev_init_ops): - log_progress('Validating epoch %d on %s...' % (epoch, csv)) - set_loss, steps = run_set('dev', epoch, init_op, dataset=csv) + for source, init_op in zip(dev_sources, dev_init_ops): + log_progress('Validating epoch %d on %s...' % (epoch, source)) + set_loss, steps = run_set('dev', epoch, init_op, dataset=source) dev_loss += set_loss * steps total_steps += steps - log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, csv, set_loss)) + log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, source, set_loss)) dev_loss = dev_loss / total_steps dev_losses.append(dev_loss) diff --git a/bin/build_sdb.py b/bin/build_sdb.py new file mode 100755 index 00000000..b4912972 --- /dev/null +++ b/bin/build_sdb.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python +''' +Tool for building Sample Databases (SDB files) from DeepSpeech CSV files and other SDB files +Use "python3 build_sdb.py -h" for help +''' +from __future__ import absolute_import, division, print_function + +# Make sure we can import stuff from util/ +# This script needs to be run from the root of the DeepSpeech repository +import os +import sys +sys.path.insert(1, os.path.join(sys.path[0], '..')) + +import argparse +import progressbar + +from util.downloader import SIMPLE_BAR +from util.audio import change_audio_types, AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS +from util.sample_collections import samples_from_files, DirectSDBWriter + +AUDIO_TYPE_LOOKUP = { + 'wav': AUDIO_TYPE_WAV, + 'opus': AUDIO_TYPE_OPUS +} + + +def build_sdb(): + audio_type = AUDIO_TYPE_LOOKUP[CLI_ARGS.audio_type] + with DirectSDBWriter(CLI_ARGS.target, audio_type=audio_type) as sdb_writer: + samples = samples_from_files(CLI_ARGS.sources) + bar = progressbar.ProgressBar(max_value=len(samples), widgets=SIMPLE_BAR) + for sample in bar(change_audio_types(samples, audio_type=audio_type, processes=CLI_ARGS.workers)): + sdb_writer.add(sample) + + +def handle_args(): + parser = argparse.ArgumentParser(description='Tool for building Sample Databases (SDB files) ' + 'from DeepSpeech CSV files and other SDB files') + parser.add_argument('sources', nargs='+', help='Source CSV and/or SDB files - ' + 'Note: For getting a correctly ordered target SDB, source SDBs have ' + 'to have their samples already ordered from shortest to longest.') + parser.add_argument('target', help='SDB file to create') + parser.add_argument('--audio-type', default='opus', choices=AUDIO_TYPE_LOOKUP.keys(), + help='Audio representation inside target SDB') + parser.add_argument('--workers', type=int, default=None, help='Number of encoding SDB workers') + return parser.parse_args() + + +if __name__ == "__main__": + CLI_ARGS = handle_args() + build_sdb() diff --git a/bin/play.py b/bin/play.py new file mode 100755 index 00000000..55da4bc5 --- /dev/null +++ b/bin/play.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python +""" +Tool for playing samples from Sample Databases (SDB files) and DeepSpeech CSV files +Use "python3 build_sdb.py -h" for help +""" +from __future__ import absolute_import, division, print_function + +# Make sure we can import stuff from util/ +# This script needs to be run from the root of the DeepSpeech repository +import os +import sys +sys.path.insert(1, os.path.join(sys.path[0], '..')) + +import random +import argparse + +from util.sample_collections import samples_from_file +from util.audio import AUDIO_TYPE_PCM + + +def play_sample(samples, index): + if index < 0: + index = len(samples) + index + if CLI_ARGS.random: + index = random.randint(0, len(samples)) + elif index >= len(samples): + print('No sample with index {}'.format(CLI_ARGS.start)) + sys.exit(1) + sample = samples[index] + print('Sample "{}"'.format(sample.sample_id)) + print(' "{}"'.format(sample.transcript)) + sample.change_audio_type(AUDIO_TYPE_PCM) + rate, channels, width = sample.audio_format + wave_obj = simpleaudio.WaveObject(sample.audio, channels, width, rate) + play_obj = wave_obj.play() + play_obj.wait_done() + + +def play_collection(): + samples = samples_from_file(CLI_ARGS.collection, buffering=0) + played = 0 + index = CLI_ARGS.start + while True: + if 0 <= CLI_ARGS.number <= played: + return + play_sample(samples, index) + played += 1 + index = (index + 1) % len(samples) + + +def handle_args(): + parser = argparse.ArgumentParser(description='Tool for playing samples from Sample Databases (SDB files) ' + 'and DeepSpeech CSV files') + parser.add_argument('collection', help='Sample DB or CSV file to play samples from') + parser.add_argument('--start', type=int, default=0, + help='Sample index to start at (negative numbers are relative to the end of the collection)') + parser.add_argument('--number', type=int, default=-1, help='Number of samples to play (-1 for endless)') + parser.add_argument('--random', action='store_true', help='If samples should be played in random order') + return parser.parse_args() + + +if __name__ == "__main__": + try: + import simpleaudio + except ModuleNotFoundError: + print('play.py requires Python package "simpleaudio"') + sys.exit(1) + CLI_ARGS = handle_args() + try: + play_collection() + except KeyboardInterrupt: + print(' Stopped') + sys.exit(0) diff --git a/bin/run-tc-ldc93s1_checkpoint_sdb.sh b/bin/run-tc-ldc93s1_checkpoint_sdb.sh new file mode 100755 index 00000000..6f5c307f --- /dev/null +++ b/bin/run-tc-ldc93s1_checkpoint_sdb.sh @@ -0,0 +1,37 @@ +#!/bin/sh + +set -xe + +ldc93s1_dir="./data/smoke_test" +ldc93s1_csv="${ldc93s1_dir}/ldc93s1.csv" +ldc93s1_sdb="${ldc93s1_dir}/ldc93s1.sdb" + +if [ ! -f "${ldc93s1_dir}/ldc93s1.csv" ]; then + echo "Downloading and preprocessing LDC93S1 example data, saving in ${ldc93s1_dir}." + python -u bin/import_ldc93s1.py ${ldc93s1_dir} +fi; + +if [ ! -f "${ldc93s1_dir}/ldc93s1.sdb" ]; then + echo "Converting LDC93S1 example data, saving to ${ldc93s1_sdb}." + python -u bin/build_sdb.py ${ldc93s1_csv} ${ldc93s1_sdb} +fi; + +# Force only one visible device because we have a single-sample dataset +# and when trying to run on multiple devices (like GPUs), this will break +export CUDA_VISIBLE_DEVICES=0 + +python -u DeepSpeech.py --noshow_progressbar --noearly_stop \ + --train_files ${ldc93s1_sdb} --train_batch_size 1 \ + --dev_files ${ldc93s1_sdb} --dev_batch_size 1 \ + --test_files ${ldc93s1_sdb} --test_batch_size 1 \ + --n_hidden 100 --epochs 1 \ + --max_to_keep 1 --checkpoint_dir '/tmp/ckpt_sdb' \ + --learning_rate 0.001 --dropout_rate 0.05 \ + --scorer_path 'data/smoke_test/pruned_lm.scorer' | tee /tmp/resume.log + +if ! grep "Loading best validating checkpoint from" /tmp/resume.log; then + echo "Did not resume training from checkpoint" + exit 1 +else + exit 0 +fi diff --git a/bin/run-tc-ldc93s1_new_sdb.sh b/bin/run-tc-ldc93s1_new_sdb.sh new file mode 100755 index 00000000..76032aa2 --- /dev/null +++ b/bin/run-tc-ldc93s1_new_sdb.sh @@ -0,0 +1,34 @@ +#!/bin/sh + +set -xe + +ldc93s1_dir="./data/smoke_test" +ldc93s1_csv="${ldc93s1_dir}/ldc93s1.csv" +ldc93s1_sdb="${ldc93s1_dir}/ldc93s1.sdb" + +epoch_count=$1 +audio_sample_rate=$2 + +if [ ! -f "${ldc93s1_dir}/ldc93s1.csv" ]; then + echo "Downloading and preprocessing LDC93S1 example data, saving in ${ldc93s1_dir}." + python -u bin/import_ldc93s1.py ${ldc93s1_dir} +fi; + +if [ ! -f "${ldc93s1_dir}/ldc93s1.sdb" ]; then + echo "Converting LDC93S1 example data, saving to ${ldc93s1_sdb}." + python -u bin/build_sdb.py ${ldc93s1_csv} ${ldc93s1_sdb} +fi; + +# Force only one visible device because we have a single-sample dataset +# and when trying to run on multiple devices (like GPUs), this will break +export CUDA_VISIBLE_DEVICES=0 + +python -u DeepSpeech.py --noshow_progressbar --noearly_stop \ + --train_files ${ldc93s1_sdb} --train_batch_size 1 \ + --dev_files ${ldc93s1_sdb} --dev_batch_size 1 \ + --test_files ${ldc93s1_sdb} --test_batch_size 1 \ + --n_hidden 100 --epochs $epoch_count \ + --max_to_keep 1 --checkpoint_dir '/tmp/ckpt_sdb' \ + --learning_rate 0.001 --dropout_rate 0.05 --export_dir '/tmp/train_sdb' \ + --scorer_path 'data/smoke_test/pruned_lm.scorer' \ + --audio_sample_rate ${audio_sample_rate} diff --git a/bin/run-tc-ldc93s1_new_sdb_csv.sh b/bin/run-tc-ldc93s1_new_sdb_csv.sh new file mode 100755 index 00000000..1b0f6d3d --- /dev/null +++ b/bin/run-tc-ldc93s1_new_sdb_csv.sh @@ -0,0 +1,35 @@ +#!/bin/sh + +set -xe + +ldc93s1_dir="./data/smoke_test" +ldc93s1_csv="${ldc93s1_dir}/ldc93s1.csv" +ldc93s1_sdb="${ldc93s1_dir}/ldc93s1.sdb" + +epoch_count=$1 +audio_sample_rate=$2 + +if [ ! -f "${ldc93s1_dir}/ldc93s1.csv" ]; then + echo "Downloading and preprocessing LDC93S1 example data, saving in ${ldc93s1_dir}." + python -u bin/import_ldc93s1.py ${ldc93s1_dir} +fi; + +if [ ! -f "${ldc93s1_dir}/ldc93s1.sdb" ]; then + echo "Converting LDC93S1 example data, saving to ${ldc93s1_sdb}." + python -u bin/build_sdb.py ${ldc93s1_csv} ${ldc93s1_sdb} +fi; + +# Force only one visible device because we have a single-sample dataset +# and when trying to run on multiple devices (like GPUs), this will break +export CUDA_VISIBLE_DEVICES=0 + +python -u DeepSpeech.py --noshow_progressbar --noearly_stop \ + --train_files ${ldc93s1_sdb},${ldc93s1_csv} --train_batch_size 1 \ + --feature_cache '/tmp/ldc93s1_cache_sdb_csv' \ + --dev_files ${ldc93s1_sdb},${ldc93s1_csv} --dev_batch_size 1 \ + --test_files ${ldc93s1_sdb},${ldc93s1_csv} --test_batch_size 1 \ + --n_hidden 100 --epochs $epoch_count \ + --max_to_keep 1 --checkpoint_dir '/tmp/ckpt_sdb_csv' \ + --learning_rate 0.001 --dropout_rate 0.05 --export_dir '/tmp/train_sdb_csv' \ + --scorer_path 'data/smoke_test/pruned_lm.scorer' \ + --audio_sample_rate ${audio_sample_rate} diff --git a/requirements.txt b/requirements.txt index 742b8244..a249e323 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,12 +2,12 @@ tensorflow == 1.15.0 numpy == 1.18.1 progressbar2 -pandas six pyxdg attrdict absl-py semver +opuslib == 2.0.0 # Requirements for building native_client files setuptools @@ -15,6 +15,7 @@ setuptools # Requirements for importers sox bs4 +pandas requests librosa soundfile diff --git a/taskcluster/.shared.yml b/taskcluster/.shared.yml index e762ebd8..bae8b779 100644 --- a/taskcluster/.shared.yml +++ b/taskcluster/.shared.yml @@ -5,6 +5,9 @@ python: apt: 'python3-virtualenv python3-setuptools python3-pip python3-wheel python3-pkg-resources' packages_docs_bionic: apt: 'python3 python3-pip zip doxygen' +training: + packages_trusty: + apt: 'libopus0' tensorflow: packages_trusty: apt: 'make build-essential gfortran git libblas-dev liblapack-dev libsox-dev libmagic-dev libgsm1-dev libltdl-dev libpng-dev python zlib1g-dev' diff --git a/taskcluster/tc-train-tests.sh b/taskcluster/tc-train-tests.sh index 1be6533b..2273405a 100644 --- a/taskcluster/tc-train-tests.sh +++ b/taskcluster/tc-train-tests.sh @@ -48,6 +48,11 @@ pushd ${HOME}/DeepSpeech/ds/ time ./bin/run-tc-ldc93s1_new.sh 249 "${sample_rate}" time ./bin/run-tc-ldc93s1_new.sh 1 "${sample_rate}" time ./bin/run-tc-ldc93s1_tflite.sh "${sample_rate}" + # Testing single SDB source + time ./bin/run-tc-ldc93s1_new_sdb.sh 220 "${sample_rate}" + # Testing interleaved source (SDB+CSV combination) - run twice to test preprocessed features + time ./bin/run-tc-ldc93s1_new_sdb_csv.sh 109 "${sample_rate}" + time ./bin/run-tc-ldc93s1_new_sdb_csv.sh 1 "${sample_rate}" popd cp /tmp/train/output_graph.pb ${TASKCLUSTER_ARTIFACTS} @@ -62,6 +67,7 @@ cp /tmp/train/output_graph.pbmm ${TASKCLUSTER_ARTIFACTS} pushd ${HOME}/DeepSpeech/ds/ time ./bin/run-tc-ldc93s1_checkpoint.sh + time ./bin/run-tc-ldc93s1_checkpoint_sdb.sh popd virtualenv_deactivate "${pyalias}" "deepspeech" diff --git a/taskcluster/test-training_16k-linux-amd64-py35m-opt.yml b/taskcluster/test-training_16k-linux-amd64-py35m-opt.yml index e950969f..3f68fea3 100644 --- a/taskcluster/test-training_16k-linux-amd64-py35m-opt.yml +++ b/taskcluster/test-training_16k-linux-amd64-py35m-opt.yml @@ -2,6 +2,9 @@ build: template_file: test-linux-opt-base.tyml dependencies: - "linux-amd64-ctc-opt" + system_setup: + > + apt-get -qq update && apt-get -qq -y install ${training.packages_trusty.apt} args: tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-train-tests.sh 3.5.8:m 16k" metadata: diff --git a/taskcluster/test-training_16k-linux-amd64-py36m-opt.yml b/taskcluster/test-training_16k-linux-amd64-py36m-opt.yml index 0bb84191..9fa9791b 100644 --- a/taskcluster/test-training_16k-linux-amd64-py36m-opt.yml +++ b/taskcluster/test-training_16k-linux-amd64-py36m-opt.yml @@ -2,6 +2,9 @@ build: template_file: test-linux-opt-base.tyml dependencies: - "linux-amd64-ctc-opt" + system_setup: + > + apt-get -qq update && apt-get -qq -y install ${training.packages_trusty.apt} args: tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-train-tests.sh 3.6.10:m 16k" metadata: diff --git a/taskcluster/test-training_8k-linux-amd64-py36m-opt.yml b/taskcluster/test-training_8k-linux-amd64-py36m-opt.yml index e4164a9b..dc2b486f 100644 --- a/taskcluster/test-training_8k-linux-amd64-py36m-opt.yml +++ b/taskcluster/test-training_8k-linux-amd64-py36m-opt.yml @@ -2,6 +2,9 @@ build: template_file: test-linux-opt-base.tyml dependencies: - "linux-amd64-ctc-opt" + system_setup: + > + apt-get -qq update && apt-get -qq -y install ${training.packages_trusty.apt} args: tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-train-tests.sh 3.6.10:m 8k" metadata: diff --git a/util/audio.py b/util/audio.py index e713ca7c..cad596d7 100644 --- a/util/audio.py +++ b/util/audio.py @@ -1,34 +1,127 @@ import os -import sox +import io import wave import tempfile import collections +import numpy as np + +from util.helpers import LimitingPool DEFAULT_RATE = 16000 DEFAULT_CHANNELS = 1 DEFAULT_WIDTH = 2 DEFAULT_FORMAT = (DEFAULT_RATE, DEFAULT_CHANNELS, DEFAULT_WIDTH) +AUDIO_TYPE_NP = 'application/vnd.mozilla.np' +AUDIO_TYPE_PCM = 'application/vnd.mozilla.pcm' +AUDIO_TYPE_WAV = 'audio/wav' +AUDIO_TYPE_OPUS = 'application/vnd.mozilla.opus' +SERIALIZABLE_AUDIO_TYPES = [AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS] -def get_audio_format(wav_file): +OPUS_PCM_LEN_SIZE = 4 +OPUS_RATE_SIZE = 4 +OPUS_CHANNELS_SIZE = 1 +OPUS_WIDTH_SIZE = 1 +OPUS_CHUNK_LEN_SIZE = 2 + + +class Sample: + """Represents in-memory audio data of a certain (convertible) representation. + Attributes: + audio_type (str): See `__init__`. + audio_format (tuple:(int, int, int)): See `__init__`. + audio (obj): Audio data represented as indicated by `audio_type` + duration (float): Audio duration of the sample in seconds + """ + def __init__(self, audio_type, raw_data, audio_format=None): + """ + Creates a Sample from a raw audio representation. + :param audio_type: Audio data representation type + Supported types: + - AUDIO_TYPE_OPUS: Memory file representation (BytesIO) of Opus encoded audio + wrapped by a custom container format (used in SDBs) + - AUDIO_TYPE_WAV: Memory file representation (BytesIO) of a Wave file + - AUDIO_TYPE_PCM: Binary representation (bytearray) of PCM encoded audio data (Wave file without header) + - AUDIO_TYPE_NP: NumPy representation of audio data (np.float32) - typically used for GPU feeding + :param raw_data: Audio data in the form of the provided representation type (see audio_type). + For types AUDIO_TYPE_OPUS or AUDIO_TYPE_WAV data can also be passed as a bytearray. + :param audio_format: Tuple of sample-rate, number of channels and sample-width. + Required in case of audio_type = AUDIO_TYPE_PCM or AUDIO_TYPE_NP, + as this information cannot be derived from raw audio data. + """ + self.audio_type = audio_type + self.audio_format = audio_format + if audio_type in SERIALIZABLE_AUDIO_TYPES: + self.audio = raw_data if isinstance(raw_data, io.BytesIO) else io.BytesIO(raw_data) + self.duration = read_duration(audio_type, self.audio) + else: + self.audio = raw_data + if self.audio_format is None: + raise ValueError('For audio type "{}" parameter "audio_format" is mandatory'.format(self.audio_type)) + if audio_type == AUDIO_TYPE_PCM: + self.duration = get_pcm_duration(len(self.audio), self.audio_format) + elif audio_type == AUDIO_TYPE_NP: + self.duration = get_np_duration(len(self.audio), self.audio_format) + else: + raise ValueError('Unsupported audio type: {}'.format(self.audio_type)) + + def change_audio_type(self, new_audio_type): + """ + In-place conversion of audio data into a different representation. + :param new_audio_type: New audio-type - see `__init__`. + Not supported: Converting from AUDIO_TYPE_NP into any other type. + """ + if self.audio_type == new_audio_type: + return + if new_audio_type == AUDIO_TYPE_PCM and self.audio_type in SERIALIZABLE_AUDIO_TYPES: + self.audio_format, audio = read_audio(self.audio_type, self.audio) + self.audio.close() + self.audio = audio + elif new_audio_type == AUDIO_TYPE_NP: + self.change_audio_type(AUDIO_TYPE_PCM) + self.audio = pcm_to_np(self.audio_format, self.audio) + elif new_audio_type in SERIALIZABLE_AUDIO_TYPES: + self.change_audio_type(AUDIO_TYPE_PCM) + audio_bytes = io.BytesIO() + write_audio(new_audio_type, audio_bytes, self.audio_format, self.audio) + audio_bytes.seek(0) + self.audio = audio_bytes + else: + raise RuntimeError('Changing audio representation type from "{}" to "{}" not supported' + .format(self.audio_type, new_audio_type)) + self.audio_type = new_audio_type + + +def change_audio_types(samples, audio_type=AUDIO_TYPE_PCM, processes=None, process_ahead=None): + def change_audio_type(sample): + sample.change_audio_type(audio_type) + return sample + with LimitingPool(processes=processes, process_ahead=process_ahead) as pool: + yield from pool.imap(change_audio_type, samples) + + +def read_audio_format_from_wav_file(wav_file): return wav_file.getframerate(), wav_file.getnchannels(), wav_file.getsampwidth() -def get_num_samples(audio_data, audio_format=DEFAULT_FORMAT): +def get_num_samples(pcm_buffer_size, audio_format=DEFAULT_FORMAT): _, channels, width = audio_format - return len(audio_data) // (channels * width) + return pcm_buffer_size // (channels * width) -def get_duration(audio_data, audio_format=DEFAULT_FORMAT): - return get_num_samples(audio_data, audio_format) / audio_format[0] +def get_pcm_duration(pcm_buffer_size, audio_format=DEFAULT_FORMAT): + """Calculates duration in seconds of a binary PCM buffer (typically read from a WAV file)""" + return get_num_samples(pcm_buffer_size, audio_format) / audio_format[0] -def get_duration_ms(audio_data, audio_format=DEFAULT_FORMAT): - return get_duration(audio_data, audio_format) * 1000 +def get_np_duration(np_len, audio_format=DEFAULT_FORMAT): + """Calculates duration in seconds of NumPy audio data""" + return np_len / audio_format[0] def convert_audio(src_audio_path, dst_audio_path, file_type=None, audio_format=DEFAULT_FORMAT): sample_rate, channels, width = audio_format + import sox transformer = sox.Transformer() transformer.set_output_format(file_type=file_type, rate=sample_rate, channels=channels, bits=width*8) transformer.build(src_audio_path, dst_audio_path) @@ -45,7 +138,7 @@ class AudioFile: def __enter__(self): if self.audio_path.endswith('.wav'): self.open_file = wave.open(self.audio_path, 'r') - if get_audio_format(self.open_file) == self.audio_format: + if read_audio_format_from_wav_file(self.open_file) == self.audio_format: if self.as_path: self.open_file.close() return self.audio_path @@ -66,12 +159,12 @@ class AudioFile: def read_frames(wav_file, frame_duration_ms=30, yield_remainder=False): - audio_format = get_audio_format(wav_file) + audio_format = read_audio_format_from_wav_file(wav_file) frame_size = int(audio_format[0] * (frame_duration_ms / 1000.0)) while True: try: data = wav_file.readframes(frame_size) - if not yield_remainder and get_duration_ms(data, audio_format) < frame_duration_ms: + if not yield_remainder and get_pcm_duration(len(data), audio_format) * 1000 < frame_duration_ms: break yield data except EOFError: @@ -106,7 +199,7 @@ def vad_split(audio_frames, frame_duration_ms = 0 frame_index = 0 for frame_index, frame in enumerate(audio_frames): - frame_duration_ms = get_duration_ms(frame, audio_format) + frame_duration_ms = get_pcm_duration(len(frame), audio_format) * 1000 if int(frame_duration_ms) not in [10, 20, 30]: raise ValueError('VAD-splitting only supported for frame durations 10, 20, or 30 ms') is_speech = vad.is_speech(frame, sample_rate) @@ -133,3 +226,123 @@ def vad_split(audio_frames, yield b''.join(voiced_frames), \ frame_duration_ms * (frame_index - len(voiced_frames)), \ frame_duration_ms * (frame_index + 1) + + +def pack_number(n, num_bytes): + return n.to_bytes(num_bytes, 'big', signed=False) + + +def unpack_number(data): + return int.from_bytes(data, 'big', signed=False) + + +def get_opus_frame_size(rate): + return 60 * rate // 1000 + + +def write_opus(opus_file, audio_format, audio_data): + rate, channels, width = audio_format + frame_size = get_opus_frame_size(rate) + import opuslib # pylint: disable=import-outside-toplevel + encoder = opuslib.Encoder(rate, channels, 'audio') + chunk_size = frame_size * channels * width + opus_file.write(pack_number(len(audio_data), OPUS_PCM_LEN_SIZE)) + opus_file.write(pack_number(rate, OPUS_RATE_SIZE)) + opus_file.write(pack_number(channels, OPUS_CHANNELS_SIZE)) + opus_file.write(pack_number(width, OPUS_WIDTH_SIZE)) + for i in range(0, len(audio_data), chunk_size): + chunk = audio_data[i:i + chunk_size] + # Preventing non-deterministic encoding results from uninitialized remainder of the encoder buffer + if len(chunk) < chunk_size: + chunk = chunk + bytearray(chunk_size - len(chunk)) + encoded = encoder.encode(chunk, frame_size) + opus_file.write(pack_number(len(encoded), OPUS_CHUNK_LEN_SIZE)) + opus_file.write(encoded) + + +def read_opus_header(opus_file): + opus_file.seek(0) + pcm_buffer_size = unpack_number(opus_file.read(OPUS_PCM_LEN_SIZE)) + rate = unpack_number(opus_file.read(OPUS_RATE_SIZE)) + channels = unpack_number(opus_file.read(OPUS_CHANNELS_SIZE)) + width = unpack_number(opus_file.read(OPUS_WIDTH_SIZE)) + return pcm_buffer_size, (rate, channels, width) + + +def read_opus(opus_file): + pcm_buffer_size, audio_format = read_opus_header(opus_file) + rate, channels, _ = audio_format + frame_size = get_opus_frame_size(rate) + import opuslib # pylint: disable=import-outside-toplevel + decoder = opuslib.Decoder(rate, channels) + audio_data = bytearray() + while len(audio_data) < pcm_buffer_size: + chunk_len = unpack_number(opus_file.read(OPUS_CHUNK_LEN_SIZE)) + chunk = opus_file.read(chunk_len) + decoded = decoder.decode(chunk, frame_size) + audio_data.extend(decoded) + audio_data = audio_data[:pcm_buffer_size] + return audio_format, audio_data + + +def write_wav(wav_file, audio_format, pcm_data): + with wave.open(wav_file, 'wb') as wav_file_writer: + rate, channels, width = audio_format + wav_file_writer.setframerate(rate) + wav_file_writer.setnchannels(channels) + wav_file_writer.setsampwidth(width) + wav_file_writer.writeframes(pcm_data) + + +def read_wav(wav_file): + wav_file.seek(0) + with wave.open(wav_file, 'rb') as wav_file_reader: + audio_format = read_audio_format_from_wav_file(wav_file_reader) + pcm_data = wav_file_reader.readframes(wav_file_reader.getnframes()) + return audio_format, pcm_data + + +def read_audio(audio_type, audio_file): + if audio_type == AUDIO_TYPE_WAV: + return read_wav(audio_file) + if audio_type == AUDIO_TYPE_OPUS: + return read_opus(audio_file) + raise ValueError('Unsupported audio type: {}'.format(audio_type)) + + +def write_audio(audio_type, audio_file, audio_format, pcm_data): + if audio_type == AUDIO_TYPE_WAV: + return write_wav(audio_file, audio_format, pcm_data) + if audio_type == AUDIO_TYPE_OPUS: + return write_opus(audio_file, audio_format, pcm_data) + raise ValueError('Unsupported audio type: {}'.format(audio_type)) + + +def read_wav_duration(wav_file): + wav_file.seek(0) + with wave.open(wav_file, 'rb') as wav_file_reader: + return wav_file_reader.getnframes() / wav_file_reader.getframerate() + + +def read_opus_duration(opus_file): + pcm_buffer_size, audio_format = read_opus_header(opus_file) + return get_pcm_duration(pcm_buffer_size, audio_format) + + +def read_duration(audio_type, audio_file): + if audio_type == AUDIO_TYPE_WAV: + return read_wav_duration(audio_file) + if audio_type == AUDIO_TYPE_OPUS: + return read_opus_duration(audio_file) + raise ValueError('Unsupported audio type: {}'.format(audio_type)) + + +def pcm_to_np(audio_format, pcm_data): + _, channels, width = audio_format + if width not in [1, 2, 4]: + raise ValueError('Unsupported sample width: {}'.format(width)) + dtype = [None, np.int8, np.int16, None, np.int32][width] + samples = np.frombuffer(pcm_data, dtype=dtype) + assert channels == 1 # only mono supported for now + samples = samples.astype(np.float32) / np.iinfo(dtype).max + return np.expand_dims(samples, axis=1) diff --git a/util/config.py b/util/config.py index 0e3a719b..bc9255dc 100755 --- a/util/config.py +++ b/util/config.py @@ -12,6 +12,7 @@ from util.flags import FLAGS from util.gpu import get_available_gpus from util.logging import log_error from util.text import Alphabet, UTF8Alphabet +from util.helpers import parse_file_size class ConfigSingleton: _config = None @@ -29,6 +30,9 @@ Config = ConfigSingleton() # pylint: disable=invalid-name def initialize_globals(): c = AttrDict() + # Read-buffer + FLAGS.read_buffer = parse_file_size(FLAGS.read_buffer) + # Set default dropout rates if FLAGS.dropout_rate2 < 0: FLAGS.dropout_rate2 = FLAGS.dropout_rate diff --git a/util/feeding.py b/util/feeding.py index 3e21427f..93c5699b 100644 --- a/util/feeding.py +++ b/util/feeding.py @@ -1,12 +1,9 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, division, print_function -import os - from functools import partial import numpy as np -import pandas import tensorflow as tf from tensorflow.python.ops import gen_audio_ops as contrib_audio @@ -15,27 +12,18 @@ from util.config import Config from util.text import text_to_char_array from util.flags import FLAGS from util.spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up, augment_sparse_warp -from util.audio import read_frames_from_file, vad_split, DEFAULT_FORMAT +from util.audio import change_audio_types, read_frames_from_file, vad_split, pcm_to_np, DEFAULT_FORMAT, AUDIO_TYPE_NP +from util.sample_collections import samples_from_files +from util.helpers import remember_exception, MEGABYTE -def read_csvs(csv_files): - sets = [] - for csv in csv_files: - file = pandas.read_csv(csv, encoding='utf-8', na_filter=False) - #FIXME: not cross-platform - csv_dir = os.path.dirname(os.path.abspath(csv)) - file['wav_filename'] = file['wav_filename'].str.replace(r'(^[^/])', lambda m: os.path.join(csv_dir, m.group(1))) # pylint: disable=cell-var-from-loop - sets.append(file) - # Concat all sets, drop any extra columns, re-index the final result as 0..N - return pandas.concat(sets, join='inner', ignore_index=True) - - -def samples_to_mfccs(samples, sample_rate, train_phase=False, wav_filename=None): +def samples_to_mfccs(samples, sample_rate, train_phase=False, sample_id=None): if train_phase: # We need the lambdas to make TensorFlow happy. # pylint: disable=unnecessary-lambda tf.cond(tf.math.not_equal(sample_rate, FLAGS.audio_sample_rate), - lambda: tf.print('WARNING: sample rate of file', wav_filename, '(', sample_rate, ') does not match FLAGS.audio_sample_rate. This can lead to incorrect results.'), + lambda: tf.print('WARNING: sample rate of sample', sample_id, '(', sample_rate, ') ' + 'does not match FLAGS.audio_sample_rate. This can lead to incorrect results.'), lambda: tf.no_op(), name='matching_sample_rate') @@ -84,10 +72,8 @@ def samples_to_mfccs(samples, sample_rate, train_phase=False, wav_filename=None) return mfccs, tf.shape(input=mfccs)[0] -def audiofile_to_features(wav_filename, train_phase=False): - samples = tf.io.read_file(wav_filename) - decoded = contrib_audio.decode_wav(samples, desired_channels=1) - features, features_len = samples_to_mfccs(decoded.audio, decoded.sample_rate, train_phase=train_phase, wav_filename=wav_filename) +def audio_to_features(audio, sample_rate, train_phase=False, sample_id=None): + features, features_len = samples_to_mfccs(audio, sample_rate, train_phase=train_phase, sample_id=sample_id) if train_phase: if FLAGS.data_aug_features_multiplicative > 0: @@ -99,10 +85,17 @@ def audiofile_to_features(wav_filename, train_phase=False): return features, features_len -def entry_to_features(wav_filename, transcript, train_phase): +def audiofile_to_features(wav_filename, train_phase=False): + samples = tf.io.read_file(wav_filename) + decoded = contrib_audio.decode_wav(samples, desired_channels=1) + return audio_to_features(decoded.audio, decoded.sample_rate, train_phase=train_phase, sample_id=wav_filename) + + +def entry_to_features(sample_id, audio, sample_rate, transcript, train_phase=False): # https://bugs.python.org/issue32117 - features, features_len = audiofile_to_features(wav_filename, train_phase=train_phase) - return wav_filename, features, features_len, tf.SparseTensor(*transcript) + features, features_len = audio_to_features(audio, sample_rate, train_phase=train_phase, sample_id=sample_id) + sparse_transcript = tf.SparseTensor(*transcript) + return sample_id, features, features_len, sparse_transcript def to_sparse_tuple(sequence): @@ -114,15 +107,22 @@ def to_sparse_tuple(sequence): return indices, sequence, shape -def create_dataset(csvs, batch_size, enable_cache=False, cache_path=None, train_phase=False): - df = read_csvs(csvs) - df.sort_values(by='wav_filesize', inplace=True) - - df['transcript'] = df.apply(text_to_char_array, alphabet=Config.alphabet, result_type='reduce', axis=1) - +def create_dataset(sources, + batch_size, + enable_cache=False, + cache_path=None, + train_phase=False, + exception_box=None, + process_ahead=None, + buffering=1 * MEGABYTE): def generate_values(): - for _, row in df.iterrows(): - yield row.wav_filename, to_sparse_tuple(row.transcript) + samples = samples_from_files(sources, buffering=buffering) + for sample in change_audio_types(samples, + AUDIO_TYPE_NP, + process_ahead=2 * batch_size if process_ahead is None else process_ahead): + transcript = text_to_char_array(sample.transcript, Config.alphabet, context=sample.sample_id) + transcript = to_sparse_tuple(transcript) + yield sample.sample_id, sample.audio, sample.audio_format[0], transcript # Batching a dataset of 2D SparseTensors creates 3D batches, which fail # when passed to tf.nn.ctc_loss, so we reshape them to remove the extra @@ -131,27 +131,23 @@ def create_dataset(csvs, batch_size, enable_cache=False, cache_path=None, train_ shape = sparse.dense_shape return tf.sparse.reshape(sparse, [shape[0], shape[2]]) - def batch_fn(wav_filenames, features, features_len, transcripts): + def batch_fn(sample_ids, features, features_len, transcripts): features = tf.data.Dataset.zip((features, features_len)) - features = features.padded_batch(batch_size, - padded_shapes=([None, Config.n_input], [])) + features = features.padded_batch(batch_size, padded_shapes=([None, Config.n_input], [])) transcripts = transcripts.batch(batch_size).map(sparse_reshape) - wav_filenames = wav_filenames.batch(batch_size) - return tf.data.Dataset.zip((wav_filenames, features, transcripts)) + sample_ids = sample_ids.batch(batch_size) + return tf.data.Dataset.zip((sample_ids, features, transcripts)) - num_gpus = len(Config.available_devices) process_fn = partial(entry_to_features, train_phase=train_phase) - dataset = (tf.data.Dataset.from_generator(generate_values, - output_types=(tf.string, (tf.int64, tf.int32, tf.int64))) + dataset = (tf.data.Dataset.from_generator(remember_exception(generate_values, exception_box), + output_types=(tf.string, tf.float32, tf.int32, + (tf.int64, tf.int32, tf.int64))) .map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)) - if enable_cache: dataset = dataset.cache(cache_path) - dataset = (dataset.window(batch_size, drop_remainder=True).flat_map(batch_fn) - .prefetch(num_gpus)) - + .prefetch(len(Config.available_devices))) return dataset @@ -160,27 +156,24 @@ def split_audio_file(audio_path, batch_size=1, aggressiveness=3, outlier_duration_ms=10000, - outlier_batch_size=1): - sample_rate, _, sample_width = audio_format - multiplier = 1.0 / (1 << (8 * sample_width - 1)) - + outlier_batch_size=1, + exception_box=None): def generate_values(): frames = read_frames_from_file(audio_path) segments = vad_split(frames, aggressiveness=aggressiveness) for segment in segments: segment_buffer, time_start, time_end = segment - samples = np.frombuffer(segment_buffer, dtype=np.int16) - samples = samples * multiplier - samples = np.expand_dims(samples, axis=1) + samples = pcm_to_np(audio_format, segment_buffer) yield time_start, time_end, samples def to_mfccs(time_start, time_end, samples): - features, features_len = samples_to_mfccs(samples, sample_rate) + features, features_len = samples_to_mfccs(samples, audio_format[0]) return time_start, time_end, features, features_len def create_batch_set(bs, criteria): return (tf.data.Dataset - .from_generator(generate_values, output_types=(tf.int32, tf.int32, tf.float32)) + .from_generator(remember_exception(generate_values, exception_box), + output_types=(tf.int32, tf.int32, tf.float32)) .map(to_mfccs, num_parallel_calls=tf.data.experimental.AUTOTUNE) .filter(criteria) .padded_batch(bs, padded_shapes=([], [], [None, Config.n_input], []))) @@ -192,9 +185,3 @@ def split_audio_file(audio_path, dataset = nds.concatenate(ods) dataset = dataset.prefetch(len(Config.available_devices)) return dataset - - -def secs_to_hours(secs): - hours, remainder = divmod(secs, 3600) - minutes, seconds = divmod(remainder, 60) - return '%d:%02d:%02d' % (hours, minutes, seconds) diff --git a/util/flags.py b/util/flags.py index 5057d76c..274d04ce 100644 --- a/util/flags.py +++ b/util/flags.py @@ -15,6 +15,7 @@ def create_flags(): f.DEFINE_string('dev_files', '', 'comma separated list of files specifying the dataset used for validation. Multiple files will get merged. If empty, validation will not be run.') f.DEFINE_string('test_files', '', 'comma separated list of files specifying the dataset used for testing. Multiple files will get merged. If empty, the model will not be tested.') + f.DEFINE_string('read_buffer', '1MB', 'buffer-size for reading samples from datasets (supports file-size suffixes KB, MB, GB, TB)') f.DEFINE_string('feature_cache', '', 'cache MFCC features to disk to speed up future training runs ont he same data. This flag specifies the path where cached features extracted from --train_files will be saved. If empty, or if online augmentation flags are enabled, caching will be disabled.') f.DEFINE_integer('feature_win_len', 32, 'feature extraction audio window length in milliseconds') diff --git a/util/helpers.py b/util/helpers.py index cd4f4b03..73cd9b74 100644 --- a/util/helpers.py +++ b/util/helpers.py @@ -1,10 +1,32 @@ import os -import semver import sys +import time +import heapq +import semver + +from multiprocessing.dummy import Pool as ThreadPool + +KILO = 1024 +KILOBYTE = 1 * KILO +MEGABYTE = KILO * KILOBYTE +GIGABYTE = KILO * MEGABYTE +TERABYTE = KILO * GIGABYTE +SIZE_PREFIX_LOOKUP = {'k': KILOBYTE, 'm': MEGABYTE, 'g': GIGABYTE, 't': TERABYTE} + + +def parse_file_size(file_size): + file_size = file_size.lower().strip() + if len(file_size) == 0: + return 0 + n = int(keep_only_digits(file_size)) + if file_size[-1] == 'b': + file_size = file_size[:-1] + e = file_size[-1] + return SIZE_PREFIX_LOOKUP[e] * n if e in SIZE_PREFIX_LOOKUP else n def keep_only_digits(txt): - return ''.join(filter(lambda c: c.isdigit(), txt)) + return ''.join(filter(str.isdigit, txt)) def secs_to_hours(secs): @@ -21,7 +43,8 @@ def check_ctcdecoder_version(): from ds_ctcdecoder import __version__ as decoder_version except ImportError as e: if e.msg.find('__version__') > 0: - print("DeepSpeech version ({ds_version}) requires CTC decoder to expose __version__. Please upgrade the ds_ctcdecoder package to version {ds_version}".format(ds_version=ds_version_s)) + print("DeepSpeech version ({ds_version}) requires CTC decoder to expose __version__. " + "Please upgrade the ds_ctcdecoder package to version {ds_version}".format(ds_version=ds_version_s)) sys.exit(1) raise e @@ -29,7 +52,79 @@ def check_ctcdecoder_version(): rv = semver.compare(ds_version_s, decoder_version_s) if rv != 0: - print("DeepSpeech version ({}) and CTC decoder version ({}) do not match. Please ensure matching versions are in use.".format(ds_version_s, decoder_version_s)) + print("DeepSpeech version ({}) and CTC decoder version ({}) do not match. " + "Please ensure matching versions are in use.".format(ds_version_s, decoder_version_s)) sys.exit(1) return rv + + +class Interleaved: + """Collection that lazily combines sorted collections in an interleaving fashion. + During iteration the next smallest element from all the sorted collections is always picked. + The collections must support iter() and len().""" + def __init__(self, *iterables, key=lambda obj: obj): + self.iterables = iterables + self.key = key + self.len = sum(map(len, iterables)) + + def __iter__(self): + return heapq.merge(*self.iterables, key=self.key) + + def __len__(self): + return self.len + + +class LimitingPool: + """Limits unbound ahead-processing of multiprocessing.Pool's imap method + before items get consumed by the iteration caller. + This prevents OOM issues in situations where items represent larger memory allocations.""" + def __init__(self, processes=None, process_ahead=None, sleeping_for=0.1): + self.process_ahead = os.cpu_count() if process_ahead is None else process_ahead + self.sleeping_for = sleeping_for + self.processed = 0 + self.pool = ThreadPool(processes=processes) + + def __enter__(self): + return self + + def _limit(self, it): + for obj in it: + while self.processed >= self.process_ahead: + time.sleep(self.sleeping_for) + self.processed += 1 + yield obj + + def imap(self, fun, it): + for obj in self.pool.imap(fun, self._limit(it)): + self.processed -= 1 + yield obj + + def __exit__(self, exc_type, exc_value, traceback): + self.pool.close() + + +class ExceptionBox: + """Helper class for passing-back and re-raising an exception from inside a TensorFlow dataset generator. + Used in conjunction with `remember_exception`.""" + def __init__(self): + self.exception = None + + def raise_if_set(self): + if self.exception is not None: + exception = self.exception + self.exception = None + raise exception # pylint: disable = raising-bad-type + + +def remember_exception(iterable, exception_box=None): + """Wraps a TensorFlow dataset generator for catching its actual exceptions + that would otherwise just interrupt iteration w/o bubbling up.""" + def do_iterate(): + try: + yield from iterable() + except StopIteration: + return + except Exception as ex: # pylint: disable = broad-except + exception_box.exception = ex + return iterable if exception_box is None else do_iterate diff --git a/util/sample_collections.py b/util/sample_collections.py new file mode 100644 index 00000000..c1e99dc1 --- /dev/null +++ b/util/sample_collections.py @@ -0,0 +1,262 @@ +# -*- coding: utf-8 -*- +import os +import csv +import json + +from pathlib import Path +from functools import partial +from util.helpers import MEGABYTE, GIGABYTE, Interleaved +from util.audio import Sample, DEFAULT_FORMAT, AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS, SERIALIZABLE_AUDIO_TYPES + +BIG_ENDIAN = 'big' +INT_SIZE = 4 +BIGINT_SIZE = 2 * INT_SIZE +MAGIC = b'SAMPLEDB' + +BUFFER_SIZE = 1 * MEGABYTE +CACHE_SIZE = 1 * GIGABYTE + +SCHEMA_KEY = 'schema' +CONTENT_KEY = 'content' +MIME_TYPE_KEY = 'mime-type' +MIME_TYPE_TEXT = 'text/plain' +CONTENT_TYPE_SPEECH = 'speech' +CONTENT_TYPE_TRANSCRIPT = 'transcript' + + +class LabeledSample(Sample): + """In-memory labeled audio sample representing an utterance. + Derived from util.audio.Sample and used by sample collection readers and writers.""" + def __init__(self, audio_type, raw_data, transcript, audio_format=DEFAULT_FORMAT, sample_id=None): + """ + Creates an in-memory speech sample together with a transcript of the utterance (label). + :param audio_type: See util.audio.Sample.__init__ . + :param raw_data: See util.audio.Sample.__init__ . + :param transcript: Transcript of the sample's utterance + :param audio_format: See util.audio.Sample.__init__ . + :param sample_id: Tracking ID - typically assigned by collection readers + """ + super().__init__(audio_type, raw_data, audio_format=audio_format) + self.sample_id = sample_id + self.transcript = transcript + + +class DirectSDBWriter: + """Sample collection writer for creating a Sample DB (SDB) file""" + def __init__(self, sdb_filename, buffering=BUFFER_SIZE, audio_type=AUDIO_TYPE_OPUS, id_prefix=None): + self.sdb_filename = sdb_filename + self.id_prefix = sdb_filename if id_prefix is None else id_prefix + if audio_type not in SERIALIZABLE_AUDIO_TYPES: + raise ValueError('Audio type "{}" not supported'.format(audio_type)) + self.audio_type = audio_type + self.sdb_file = open(sdb_filename, 'wb', buffering=buffering) + self.offsets = [] + self.num_samples = 0 + + self.sdb_file.write(MAGIC) + + meta_data = { + SCHEMA_KEY: [ + {CONTENT_KEY: CONTENT_TYPE_SPEECH, MIME_TYPE_KEY: audio_type}, + {CONTENT_KEY: CONTENT_TYPE_TRANSCRIPT, MIME_TYPE_KEY: MIME_TYPE_TEXT} + ] + } + meta_data = json.dumps(meta_data).encode() + self.write_big_int(len(meta_data)) + self.sdb_file.write(meta_data) + + self.offset_samples = self.sdb_file.tell() + self.sdb_file.seek(2 * BIGINT_SIZE, 1) + + def write_int(self, n): + return self.sdb_file.write(n.to_bytes(INT_SIZE, BIG_ENDIAN)) + + def write_big_int(self, n): + return self.sdb_file.write(n.to_bytes(BIGINT_SIZE, BIG_ENDIAN)) + + def __enter__(self): + return self + + def add(self, sample): + def to_bytes(n): + return n.to_bytes(INT_SIZE, BIG_ENDIAN) + sample.change_audio_type(self.audio_type) + opus = sample.audio.getbuffer() + opus_len = to_bytes(len(opus)) + transcript = sample.transcript.encode() + transcript_len = to_bytes(len(transcript)) + entry_len = to_bytes(len(opus_len) + len(opus) + len(transcript_len) + len(transcript)) + buffer = b''.join([entry_len, opus_len, opus, transcript_len, transcript]) + self.offsets.append(self.sdb_file.tell()) + self.sdb_file.write(buffer) + sample.sample_id = '{}:{}'.format(self.id_prefix, self.num_samples) + self.num_samples += 1 + return sample.sample_id + + def close(self): + if self.sdb_file is None: + return + offset_index = self.sdb_file.tell() + self.sdb_file.seek(self.offset_samples) + self.write_big_int(offset_index - self.offset_samples - BIGINT_SIZE) + self.write_big_int(self.num_samples) + + self.sdb_file.seek(offset_index + BIGINT_SIZE) + self.write_big_int(self.num_samples) + for offset in self.offsets: + self.write_big_int(offset) + offset_end = self.sdb_file.tell() + self.sdb_file.seek(offset_index) + self.write_big_int(offset_end - offset_index - BIGINT_SIZE) + self.sdb_file.close() + self.sdb_file = None + + def __len__(self): + return len(self.offsets) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + +class SDB: # pylint: disable=too-many-instance-attributes + """Sample collection reader for reading a Sample DB (SDB) file""" + def __init__(self, sdb_filename, buffering=BUFFER_SIZE, id_prefix=None): + self.sdb_filename = sdb_filename + self.id_prefix = sdb_filename if id_prefix is None else id_prefix + self.sdb_file = open(sdb_filename, 'rb', buffering=buffering) + self.offsets = [] + if self.sdb_file.read(len(MAGIC)) != MAGIC: + raise RuntimeError('No Sample Database') + meta_chunk_len = self.read_big_int() + self.meta = json.loads(self.sdb_file.read(meta_chunk_len).decode()) + if SCHEMA_KEY not in self.meta: + raise RuntimeError('Missing schema') + self.schema = self.meta[SCHEMA_KEY] + + speech_columns = self.find_columns(content=CONTENT_TYPE_SPEECH, mime_type=SERIALIZABLE_AUDIO_TYPES) + if not speech_columns: + raise RuntimeError('No speech data (missing in schema)') + self.speech_index = speech_columns[0] + self.audio_type = self.schema[self.speech_index][MIME_TYPE_KEY] + + transcript_columns = self.find_columns(content=CONTENT_TYPE_TRANSCRIPT, mime_type=MIME_TYPE_TEXT) + if not transcript_columns: + raise RuntimeError('No transcript data (missing in schema)') + self.transcript_index = transcript_columns[0] + + sample_chunk_len = self.read_big_int() + self.sdb_file.seek(sample_chunk_len + BIGINT_SIZE, 1) + num_samples = self.read_big_int() + for _ in range(num_samples): + self.offsets.append(self.read_big_int()) + + def read_int(self): + return int.from_bytes(self.sdb_file.read(INT_SIZE), BIG_ENDIAN) + + def read_big_int(self): + return int.from_bytes(self.sdb_file.read(BIGINT_SIZE), BIG_ENDIAN) + + def find_columns(self, content=None, mime_type=None): + criteria = [] + if content is not None: + criteria.append((CONTENT_KEY, content)) + if mime_type is not None: + criteria.append((MIME_TYPE_KEY, mime_type)) + if len(criteria) == 0: + raise ValueError('At least one of "content" or "mime-type" has to be provided') + matches = [] + for index, column in enumerate(self.schema): + matched = 0 + for field, value in criteria: + if column[field] == value or (isinstance(value, list) and column[field] in value): + matched += 1 + if matched == len(criteria): + matches.append(index) + return matches + + def read_row(self, row_index, *columns): + columns = list(columns) + column_data = [None] * len(columns) + found = 0 + if not 0 <= row_index < len(self.offsets): + raise ValueError('Wrong sample index: {} - has to be between 0 and {}' + .format(row_index, len(self.offsets) - 1)) + self.sdb_file.seek(self.offsets[row_index] + INT_SIZE) + for index in range(len(self.schema)): + chunk_len = self.read_int() + if index in columns: + column_data[columns.index(index)] = self.sdb_file.read(chunk_len) + found += 1 + if found == len(columns): + return tuple(column_data) + else: + self.sdb_file.seek(chunk_len, 1) + return tuple(column_data) + + def __getitem__(self, i): + audio_data, transcript = self.read_row(i, self.speech_index, self.transcript_index) + transcript = transcript.decode() + sample_id = '{}:{}'.format(self.id_prefix, i) + return LabeledSample(self.audio_type, audio_data, transcript, sample_id=sample_id) + + def __iter__(self): + for i in range(len(self.offsets)): + yield self[i] + + def __len__(self): + return len(self.offsets) + + def close(self): + if self.sdb_file is not None: + self.sdb_file.close() + + def __del__(self): + self.close() + + +class CSV: + """Sample collection reader for reading a DeepSpeech CSV file""" + def __init__(self, csv_filename): + self.csv_filename = csv_filename + self.rows = [] + csv_dir = Path(csv_filename).parent + with open(csv_filename, 'r', encoding='utf8') as csv_file: + reader = csv.DictReader(csv_file) + for row in reader: + wav_filename = Path(row['wav_filename']) + if not wav_filename.is_absolute(): + wav_filename = csv_dir / wav_filename + self.rows.append((str(wav_filename), int(row['wav_filesize']), row['transcript'])) + self.rows.sort(key=lambda r: r[1]) + + def __getitem__(self, i): + wav_filename, _, transcript = self.rows[i] + with open(wav_filename, 'rb') as wav_file: + return LabeledSample(AUDIO_TYPE_WAV, wav_file.read(), transcript, sample_id=wav_filename) + + def __iter__(self): + for i in range(len(self.rows)): + yield self[i] + + def __len__(self): + return len(self.rows) + + +def samples_from_file(filename, buffering=BUFFER_SIZE): + """Returns an iterable of LabeledSample objects loaded from a file.""" + ext = os.path.splitext(filename)[1].lower() + if ext == '.sdb': + return SDB(filename, buffering=buffering) + if ext == '.csv': + return CSV(filename) + raise ValueError('Unknown file type: "{}"'.format(ext)) + + +def samples_from_files(filenames, buffering=BUFFER_SIZE): + """Returns an iterable of LabeledSample objects from a list of files.""" + if len(filenames) == 0: + raise ValueError('No files') + if len(filenames) == 1: + return samples_from_file(filenames[0], buffering=buffering) + cols = list(map(partial(samples_from_file, buffering=buffering), filenames)) + return Interleaved(*cols, key=lambda s: s.duration) diff --git a/util/text.py b/util/text.py index d9a67b96..af958191 100644 --- a/util/text.py +++ b/util/text.py @@ -4,7 +4,6 @@ import numpy as np import re import struct -from util.flags import FLAGS from six.moves import range class Alphabet(object): @@ -120,19 +119,22 @@ class UTF8Alphabet(object): return '' -def text_to_char_array(series, alphabet): +def text_to_char_array(transcript, alphabet, context=''): r""" - Given a Pandas Series containing transcript string, map characters to + Given a transcript string, map characters to integers and return a numpy array representing the processed string. + Use a string in `context` for adding text to raised exceptions. """ try: - transcript = np.asarray(alphabet.encode(series['transcript'])) + transcript = alphabet.encode(transcript) if len(transcript) == 0: - raise ValueError('While processing: {}\nFound an empty transcript! You must include a transcript for all training data.'.format(series['wav_filename'])) + raise ValueError('While processing {}: Found an empty transcript! ' + 'You must include a transcript for all training data.' + .format(context)) return transcript except KeyError as e: # Provide the row context (especially wav_filename) for alphabet errors - raise ValueError('While processing: {}\n{}'.format(series['wav_filename'], e)) + raise ValueError('While processing: {}\n{}'.format(context, e)) # The following code is from: http://hetland.org/coding/python/levenshtein.py