Revive transcribe.py

Update to use Coqpit based config handling, fix multiprocesing setup, and add CI coverage.
This commit is contained in:
Reuben Morais 2021-11-18 12:37:12 +01:00
parent 419b15b72a
commit efdaa61e2c
10 changed files with 372 additions and 259 deletions

View File

@ -808,7 +808,7 @@ jobs:
- run: |
mkdir -p ${CI_ARTIFACTS_DIR} || true
- run: |
sudo apt-get install -y --no-install-recommends libopus0
sudo apt-get install -y --no-install-recommends libopus0 sox
- name: Run extra training tests
run: |
python -m pip install coqui_stt_ctcdecoder-*.whl

View File

@ -8,14 +8,14 @@ from coqui_stt_training.evaluate import test
# only one GPU for only one training sample
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
download_ldc("data/ldc93s1")
download_ldc("data/smoke_test")
initialize_globals_from_args(
load_train="init",
alphabet_config_path="data/alphabet.txt",
train_files=["data/ldc93s1/ldc93s1.csv"],
dev_files=["data/ldc93s1/ldc93s1.csv"],
test_files=["data/ldc93s1/ldc93s1.csv"],
train_files=["data/smoke_test/ldc93s1.csv"],
dev_files=["data/smoke_test/ldc93s1.csv"],
test_files=["data/smoke_test/ldc93s1.csv"],
augment=["time_mask"],
n_hidden=100,
epochs=200,

View File

@ -5,9 +5,9 @@ if [ ! -f train.py ]; then
exit 1
fi;
if [ ! -f "data/ldc93s1/ldc93s1.csv" ]; then
echo "Downloading and preprocessing LDC93S1 example data, saving in ./data/ldc93s1."
python -u bin/import_ldc93s1.py ./data/ldc93s1
if [ ! -f "data/smoke_test/ldc93s1.csv" ]; then
echo "Downloading and preprocessing LDC93S1 example data, saving in ./data/smoke_test."
python -u bin/import_ldc93s1.py ./data/smoke_test
fi;
if [ -d "${COMPUTE_KEEP_DIR}" ]; then
@ -23,8 +23,8 @@ export CUDA_VISIBLE_DEVICES=0
python -m coqui_stt_training.train \
--alphabet_config_path "data/alphabet.txt" \
--show_progressbar false \
--train_files data/ldc93s1/ldc93s1.csv \
--test_files data/ldc93s1/ldc93s1.csv \
--train_files data/smoke_test/ldc93s1.csv \
--test_files data/smoke_test/ldc93s1.csv \
--train_batch_size 1 \
--test_batch_size 1 \
--n_hidden 100 \

View File

@ -16,7 +16,7 @@ mkdir -p /tmp/train_tflite || true
set -o pipefail
python -m pip install --upgrade pip setuptools wheel | cat
python -m pip install --upgrade . | cat
python -m pip install --upgrade ".[transcribe]" | cat
set +o pipefail
# Prepare correct arguments for training
@ -72,3 +72,20 @@ time python ./bin/run-ldc93s1.py
# Training graph inference
time ./bin/run-ci-ldc93s1_singleshotinference.sh
# transcribe module
time python -m coqui_stt_training.transcribe \
--src "data/smoke_test/LDC93S1.wav" \
--dst ${CI_ARTIFACTS_DIR}/transcribe.log \
--n_hidden 100 \
--scorer_path "data/smoke_test/pruned_lm.scorer"
#TODO: investigate why this is hanging in CI
#mkdir /tmp/transcribe_dir
#cp data/smoke_test/LDC93S1.wav /tmp/transcribe_dir
#time python -m coqui_stt_training.transcribe \
# --src "/tmp/transcribe_dir/" \
# --n_hidden 100 \
# --scorer_path "data/smoke_test/pruned_lm.scorer"
#
#for i in data/smoke_test/*.tlog; do echo $i; cat $i; echo; done

View File

@ -78,8 +78,8 @@
"def download_sample_data():\n",
" data_dir=\"english/\"\n",
" # Download data + alphabet\n",
" audio_file = maybe_download(\"LDC93S1.wav\", data_dir, \"https://catalog.ldc.upenn.edu/desc/addenda/LDC93S1.wav\")\n",
" transcript_file = maybe_download(\"LDC93S1.txt\", data_dir, \"https://catalog.ldc.upenn.edu/desc/addenda/LDC93S1.txt\")\n",
" audio_file = maybe_download(\"LDC93S1.wav\", data_dir, \"https://raw.githubusercontent.com/coqui-ai/STT/main/data/smoke_test/LDC93S1.wav\")\n",
" transcript_file = maybe_download(\"LDC93S1.txt\", data_dir, \"https://raw.githubusercontent.com/coqui-ai/STT/main/data/smoke_test/LDC93S1.txt\")\n",
" alphabet = maybe_download(\"alphabet.txt\", data_dir, \"https://raw.githubusercontent.com/coqui-ai/STT/main/data/alphabet.txt\")\n",
" # Format data\n",
" with open(transcript_file, \"r\") as fin:\n",

View File

@ -69,6 +69,9 @@ def main():
python_requires=">=3.5, <4",
install_requires=install_requires,
include_package_data=True,
extras_require={
"transcribe": ["webrtcvad"],
},
)

View File

@ -0,0 +1,315 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# This script is structured in a weird way, with delayed imports. This is due
# to the use of multiprocessing. TensorFlow cannot handle forking, and even with
# the spawn strategy set to "spawn" it still leads to weird problems, so we
# restructure the code so that TensorFlow is only imported inside the child
# processes.
import os
import sys
import glob
import itertools
import json
import multiprocessing
from multiprocessing import Pool, cpu_count
from dataclasses import dataclass, field
from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch
from tqdm import tqdm
def fail(message, code=1):
print(f"E {message}")
sys.exit(code)
def transcribe_file(audio_path, tlog_path):
log_level_index = (
sys.argv.index("--log_level") + 1 if "--log_level" in sys.argv else 0
)
desired_log_level = (
sys.argv[log_level_index] if 0 < log_level_index < len(sys.argv) else "3"
)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = desired_log_level
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
from coqui_stt_training.train import create_model
from coqui_stt_training.util.audio import AudioFile
from coqui_stt_training.util.checkpoints import load_graph_for_evaluation
from coqui_stt_training.util.config import Config
from coqui_stt_training.util.feeding import split_audio_file
initialize_transcribe_config()
scorer = None
if Config.scorer_path:
scorer = Scorer(
Config.lm_alpha, Config.lm_beta, Config.scorer_path, Config.alphabet
)
try:
num_processes = cpu_count()
except NotImplementedError:
num_processes = 1
with AudioFile(audio_path, as_path=True) as wav_path:
data_set = split_audio_file(
wav_path,
batch_size=Config.batch_size,
aggressiveness=Config.vad_aggressiveness,
outlier_duration_ms=Config.outlier_duration_ms,
outlier_batch_size=Config.outlier_batch_size,
)
iterator = tfv1.data.make_one_shot_iterator(data_set)
batch_time_start, batch_time_end, batch_x, batch_x_len = iterator.get_next()
no_dropout = [None] * 6
logits, _ = create_model(
batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout
)
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
tf.train.get_or_create_global_step()
with tf.Session(config=Config.session_config) as session:
load_graph_for_evaluation(session)
transcripts = []
while True:
try:
starts, ends, batch_logits, batch_lengths = session.run(
[batch_time_start, batch_time_end, transposed, batch_x_len]
)
except tf.errors.OutOfRangeError:
break
decoded = ctc_beam_search_decoder_batch(
batch_logits,
batch_lengths,
Config.alphabet,
Config.beam_width,
num_processes=num_processes,
scorer=scorer,
)
decoded = list(d[0][1] for d in decoded)
transcripts.extend(zip(starts, ends, decoded))
transcripts.sort(key=lambda t: t[0])
transcripts = [
{"start": int(start), "end": int(end), "transcript": transcript}
for start, end, transcript in transcripts
]
with open(tlog_path, "w") as tlog_file:
json.dump(transcripts, tlog_file, default=float)
def step_function(job):
""" Wrap transcribe_file to unpack arguments from a single tuple """
idx, src, dst = job
transcribe_file(src, dst)
return idx, src, dst
def transcribe_many(src_paths, dst_paths):
from coqui_stt_training.util.config import Config, log_progress
pool = Pool(processes=min(cpu_count(), len(src_paths)))
# Create list of items to be processed: [(i, src_path[i], dst_paths[i])]
jobs = zip(itertools.count(), src_paths, dst_paths)
process_iterable = tqdm(
pool.imap_unordered(step_function, jobs),
desc="Transcribing files",
total=len(src_paths),
disable=not Config.show_progressbar,
)
for result in process_iterable:
idx, src, dst = result
log_progress(
f'Transcribed file {idx+1} of {len(src_paths)} from "{src}" to "{dst}"'
)
def transcribe_one(src_path, dst_path):
transcribe_file(src_path, dst_path)
print(f'I Transcribed file "{src_path}" to "{dst_path}"')
def resolve(base_path, spec_path):
if spec_path is None:
return None
if not os.path.isabs(spec_path):
spec_path = os.path.join(base_path, spec_path)
return spec_path
def transcribe():
from coqui_stt_training.util.config import Config
initialize_transcribe_config()
if not Config.src or not os.path.exists(Config.src):
# path not given or non-existant
fail(
"You have to specify which file or catalog to transcribe via the --src flag."
)
else:
# path given and exists
src_path = os.path.abspath(Config.src)
if os.path.isfile(src_path):
if src_path.endswith(".catalog"):
# Transcribe batch of files via ".catalog" file (from DSAlign)
catalog_dir = os.path.dirname(src_path)
with open(src_path, "r") as catalog_file:
catalog_entries = json.load(catalog_file)
catalog_entries = [
(resolve(catalog_dir, e["audio"]), resolve(catalog_dir, e["tlog"]))
for e in catalog_entries
]
if any(map(lambda e: not os.path.isfile(e[0]), catalog_entries)):
fail("Missing source file(s) in catalog")
if not Config.force and any(
map(lambda e: os.path.isfile(e[1]), catalog_entries)
):
fail(
"Destination file(s) from catalog already existing, use --force for overwriting"
)
if any(
map(
lambda e: not os.path.isdir(os.path.dirname(e[1])),
catalog_entries,
)
):
fail("Missing destination directory for at least one catalog entry")
src_paths, dst_paths = zip(*paths)
transcribe_many(src_paths, dst_paths)
else:
# Transcribe one file
dst_path = (
os.path.abspath(Config.dst)
if Config.dst
else os.path.splitext(src_path)[0] + ".tlog"
)
if os.path.isfile(dst_path):
if Config.force:
transcribe_one(src_path, dst_path)
else:
fail(
'Destination file "{}" already existing - use --force for overwriting'.format(
dst_path
),
code=0,
)
elif os.path.isdir(os.path.dirname(dst_path)):
transcribe_one(src_path, dst_path)
else:
fail("Missing destination directory")
elif os.path.isdir(src_path):
# Transcribe all files in dir
print("Transcribing all WAV files in --src")
if Config.recursive:
wav_paths = glob.glob(os.path.join(src_path, "**", "*.wav"))
else:
wav_paths = glob.glob(os.path.join(src_path, "*.wav"))
dst_paths = [path.replace(".wav", ".tlog") for path in wav_paths]
transcribe_many(wav_paths, dst_paths)
def initialize_transcribe_config():
from coqui_stt_training.util.config import (
BaseSttConfig,
initialize_globals_from_instance,
)
@dataclass
class TranscribeConfig(BaseSttConfig):
src: str = field(
default="",
metadata=dict(
help="Source path to an audio file or directory or catalog file. "
"Catalog files should be formatted from DSAlign. A directory "
"will be recursively searched for audio. If --dst not set, "
"transcription logs (.tlog) will be written in-place using the "
'source filenames with suffix ".tlog" instead of ".wav".'
),
)
dst: str = field(
default="",
metadata=dict(
help="path for writing the transcription log or logs (.tlog). "
"If --src is a directory, this one also has to be a directory "
"and the required sub-dir tree of --src will get replicated."
),
)
recursive: bool = field(
default=False,
metadata=dict(help="scan source directory recursively for audio"),
)
force: bool = field(
default=False,
metadata=dict(
help="Forces re-transcribing and overwriting of already existing "
"transcription logs (.tlog)"
),
)
vad_aggressiveness: int = field(
default=3,
metadata=dict(help="VAD aggressiveness setting (0=lowest, 3=highest)"),
)
batch_size: int = field(
default=40,
metadata=dict(help="Default batch size"),
)
outlier_duration_ms: int = field(
default=10000,
metadata=dict(
help="Duration in ms after which samples are considered outliers"
),
)
outlier_batch_size: int = field(
default=1,
metadata=dict(help="Batch size for duration outliers (defaults to 1)"),
)
def __post_init__(self):
if os.path.isfile(self.src) and self.src.endswith(".catalog") and self.dst:
raise RuntimeError(
"Parameter --dst not supported if --src points to a catalog"
)
if os.path.isdir(self.src):
if self.dst:
raise RuntimeError(
"Destination path not supported for batch decoding jobs."
)
super().__post_init__()
config = TranscribeConfig.init_from_argparse(arg_prefix="")
initialize_globals_from_instance(config)
def main():
from coqui_stt_training.util.helpers import check_ctcdecoder_version
try:
import webrtcvad
except ImportError:
print(
"E transcribe module requires webrtcvad, which cannot be imported. Install with pip install webrtcvad"
)
sys.exit(1)
check_ctcdecoder_version()
transcribe()
if __name__ == "__main__":
main()

View File

@ -75,9 +75,12 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
init_vars.add(v)
load_vars -= init_vars
log_info(f"Vars to load: {list(sorted(v.op.name for v in load_vars))}")
for v in sorted(load_vars, key=lambda v: v.op.name):
log_info("Loading variable from checkpoint: %s" % (v.op.name))
v.load(ckpt.get_tensor(v.op.name), session=session)
log_info(f"Getting tensor from variable: {v.op.name}")
tensor = ckpt.get_tensor(v.op.name)
log_info(f"Loading tensor from checkpoint: {v.op.name}")
v.load(tensor, session=session)
for v in sorted(init_vars, key=lambda v: v.op.name):
log_info("Initializing variable: %s" % (v.op.name))

View File

@ -37,7 +37,7 @@ Config = _ConfigSingleton() # pylint: disable=invalid-name
@dataclass
class _SttConfig(Coqpit):
class BaseSttConfig(Coqpit):
def __post_init__(self):
# Augmentations
self.augmentations = parse_augmentations(self.augment)
@ -835,16 +835,22 @@ class _SttConfig(Coqpit):
def initialize_globals_from_cli():
c = _SttConfig.init_from_argparse(arg_prefix="")
c = BaseSttConfig.init_from_argparse(arg_prefix="")
_ConfigSingleton._config = c # pylint: disable=protected-access
def initialize_globals_from_args(**override_args):
# Update Config with new args
c = _SttConfig(**override_args)
c = BaseSttConfig(**override_args)
_ConfigSingleton._config = c # pylint: disable=protected-access
def initialize_globals_from_instance(config):
""" Initialize Config singleton from an existing Config instance (or subclass) """
assert isinstance(config, BaseSttConfig)
_ConfigSingleton._config = config # pylint: disable=protected-access
# Logging functions
# =================

View File

@ -2,246 +2,15 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
import json
import os
import sys
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow.compat.v1.logging as tflogging
import tensorflow as tf
tflogging.set_verbosity(tflogging.ERROR)
import logging
logging.getLogger("sox").setLevel(logging.ERROR)
import glob
from multiprocessing import Process, cpu_count
from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch
from coqui_stt_training.util.audio import AudioFile
from coqui_stt_training.util.config import Config, initialize_globals_from_cli
from coqui_stt_training.util.feeding import split_audio_file
from coqui_stt_training.util.flags import FLAGS, create_flags
from coqui_stt_training.util.logging import (
create_progressbar,
log_error,
log_info,
log_progress,
)
def fail(message, code=1):
log_error(message)
sys.exit(code)
def transcribe_file(audio_path, tlog_path):
from coqui_stt_training.train import ( # pylint: disable=cyclic-import,import-outside-toplevel
create_model,
)
from coqui_stt_training.util.checkpoints import load_graph_for_evaluation
initialize_globals_from_cli()
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet)
try:
num_processes = cpu_count()
except NotImplementedError:
num_processes = 1
with AudioFile(audio_path, as_path=True) as wav_path:
data_set = split_audio_file(
wav_path,
batch_size=FLAGS.batch_size,
aggressiveness=FLAGS.vad_aggressiveness,
outlier_duration_ms=FLAGS.outlier_duration_ms,
outlier_batch_size=FLAGS.outlier_batch_size,
)
iterator = tf.data.Iterator.from_structure(
data_set.output_types,
data_set.output_shapes,
output_classes=data_set.output_classes,
)
batch_time_start, batch_time_end, batch_x, batch_x_len = iterator.get_next()
no_dropout = [None] * 6
logits, _ = create_model(
batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout
)
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
tf.train.get_or_create_global_step()
with tf.Session(config=Config.session_config) as session:
load_graph_for_evaluation(session)
session.run(iterator.make_initializer(data_set))
transcripts = []
while True:
try:
starts, ends, batch_logits, batch_lengths = session.run(
[batch_time_start, batch_time_end, transposed, batch_x_len]
)
except tf.errors.OutOfRangeError:
break
decoded = ctc_beam_search_decoder_batch(
batch_logits,
batch_lengths,
Config.alphabet,
FLAGS.beam_width,
num_processes=num_processes,
scorer=scorer,
)
decoded = list(d[0][1] for d in decoded)
transcripts.extend(zip(starts, ends, decoded))
transcripts.sort(key=lambda t: t[0])
transcripts = [
{"start": int(start), "end": int(end), "transcript": transcript}
for start, end, transcript in transcripts
]
with open(tlog_path, "w") as tlog_file:
json.dump(transcripts, tlog_file, default=float)
def transcribe_many(src_paths, dst_paths):
pbar = create_progressbar(
prefix="Transcribing files | ", max_value=len(src_paths)
).start()
for i in range(len(src_paths)):
p = Process(target=transcribe_file, args=(src_paths[i], dst_paths[i]))
p.start()
p.join()
log_progress(
'Transcribed file {} of {} from "{}" to "{}"'.format(
i + 1, len(src_paths), src_paths[i], dst_paths[i]
)
)
pbar.update(i)
pbar.finish()
def transcribe_one(src_path, dst_path):
transcribe_file(src_path, dst_path)
log_info('Transcribed file "{}" to "{}"'.format(src_path, dst_path))
def resolve(base_path, spec_path):
if spec_path is None:
return None
if not os.path.isabs(spec_path):
spec_path = os.path.join(base_path, spec_path)
return spec_path
def main(_):
if not FLAGS.src or not os.path.exists(FLAGS.src):
# path not given or non-existant
fail(
"You have to specify which file or catalog to transcribe via the --src flag."
)
else:
# path given and exists
src_path = os.path.abspath(FLAGS.src)
if os.path.isfile(src_path):
if src_path.endswith(".catalog"):
# Transcribe batch of files via ".catalog" file (from DSAlign)
if FLAGS.dst:
fail("Parameter --dst not supported if --src points to a catalog")
catalog_dir = os.path.dirname(src_path)
with open(src_path, "r") as catalog_file:
catalog_entries = json.load(catalog_file)
catalog_entries = [
(resolve(catalog_dir, e["audio"]), resolve(catalog_dir, e["tlog"]))
for e in catalog_entries
]
if any(map(lambda e: not os.path.isfile(e[0]), catalog_entries)):
fail("Missing source file(s) in catalog")
if not FLAGS.force and any(
map(lambda e: os.path.isfile(e[1]), catalog_entries)
):
fail(
"Destination file(s) from catalog already existing, use --force for overwriting"
)
if any(
map(
lambda e: not os.path.isdir(os.path.dirname(e[1])),
catalog_entries,
)
):
fail("Missing destination directory for at least one catalog entry")
src_paths, dst_paths = zip(*paths)
transcribe_many(src_paths, dst_paths)
else:
# Transcribe one file
dst_path = (
os.path.abspath(FLAGS.dst)
if FLAGS.dst
else os.path.splitext(src_path)[0] + ".tlog"
)
if os.path.isfile(dst_path):
if FLAGS.force:
transcribe_one(src_path, dst_path)
else:
fail(
'Destination file "{}" already existing - use --force for overwriting'.format(
dst_path
),
code=0,
)
elif os.path.isdir(os.path.dirname(dst_path)):
transcribe_one(src_path, dst_path)
else:
fail("Missing destination directory")
elif os.path.isdir(src_path):
# Transcribe all files in dir
print("Transcribing all WAV files in --src")
if FLAGS.dst:
fail("Destination file not supported for batch decoding jobs.")
else:
if not FLAGS.recursive:
print(
"If you wish to recursively scan --src, then you must use --recursive"
)
wav_paths = glob.glob(src_path + "/*.wav")
else:
wav_paths = glob.glob(src_path + "/**/*.wav")
dst_paths = [path.replace(".wav", ".tlog") for path in wav_paths]
transcribe_many(wav_paths, dst_paths)
if __name__ == "__main__":
create_flags()
tf.app.flags.DEFINE_string(
"src",
"",
"Source path to an audio file or directory or catalog file."
"Catalog files should be formatted from DSAlign. A directory will"
"be recursively searched for audio. If --dst not set, transcription logs (.tlog) will be "
"written in-place using the source filenames with "
'suffix ".tlog" instead of ".wav".',
print(
"Using the top level transcribe.py script is deprecated and will be removed "
"in a future release. Instead use: python -m coqui_stt_training.transcribe"
)
tf.app.flags.DEFINE_string(
"dst",
"",
"path for writing the transcription log or logs (.tlog). "
"If --src is a directory, this one also has to be a directory "
"and the required sub-dir tree of --src will get replicated.",
)
tf.app.flags.DEFINE_boolean("recursive", False, "scan dir of audio recursively")
tf.app.flags.DEFINE_boolean(
"force",
False,
"Forces re-transcribing and overwriting of already existing "
"transcription logs (.tlog)",
)
tf.app.flags.DEFINE_integer(
"vad_aggressiveness",
3,
"How aggressive (0=lowest, 3=highest) the VAD should " "split audio",
)
tf.app.flags.DEFINE_integer("batch_size", 40, "Default batch size")
tf.app.flags.DEFINE_float(
"outlier_duration_ms",
10000,
"Duration in ms after which samples are considered outliers",
)
tf.app.flags.DEFINE_integer(
"outlier_batch_size", 1, "Batch size for duration outliers (defaults to 1)"
)
tf.app.run(main)
try:
from coqui_stt_training import transcribe as stt_transcribe
except ImportError:
print("Training package is not installed. See training documentation.")
raise
stt_transcribe.main()