Revive transcribe.py
Update to use Coqpit based config handling, fix multiprocesing setup, and add CI coverage.
This commit is contained in:
parent
419b15b72a
commit
efdaa61e2c
2
.github/workflows/build-and-test.yml
vendored
2
.github/workflows/build-and-test.yml
vendored
@ -808,7 +808,7 @@ jobs:
|
|||||||
- run: |
|
- run: |
|
||||||
mkdir -p ${CI_ARTIFACTS_DIR} || true
|
mkdir -p ${CI_ARTIFACTS_DIR} || true
|
||||||
- run: |
|
- run: |
|
||||||
sudo apt-get install -y --no-install-recommends libopus0
|
sudo apt-get install -y --no-install-recommends libopus0 sox
|
||||||
- name: Run extra training tests
|
- name: Run extra training tests
|
||||||
run: |
|
run: |
|
||||||
python -m pip install coqui_stt_ctcdecoder-*.whl
|
python -m pip install coqui_stt_ctcdecoder-*.whl
|
||||||
|
@ -8,14 +8,14 @@ from coqui_stt_training.evaluate import test
|
|||||||
# only one GPU for only one training sample
|
# only one GPU for only one training sample
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||||
|
|
||||||
download_ldc("data/ldc93s1")
|
download_ldc("data/smoke_test")
|
||||||
|
|
||||||
initialize_globals_from_args(
|
initialize_globals_from_args(
|
||||||
load_train="init",
|
load_train="init",
|
||||||
alphabet_config_path="data/alphabet.txt",
|
alphabet_config_path="data/alphabet.txt",
|
||||||
train_files=["data/ldc93s1/ldc93s1.csv"],
|
train_files=["data/smoke_test/ldc93s1.csv"],
|
||||||
dev_files=["data/ldc93s1/ldc93s1.csv"],
|
dev_files=["data/smoke_test/ldc93s1.csv"],
|
||||||
test_files=["data/ldc93s1/ldc93s1.csv"],
|
test_files=["data/smoke_test/ldc93s1.csv"],
|
||||||
augment=["time_mask"],
|
augment=["time_mask"],
|
||||||
n_hidden=100,
|
n_hidden=100,
|
||||||
epochs=200,
|
epochs=200,
|
||||||
|
@ -5,9 +5,9 @@ if [ ! -f train.py ]; then
|
|||||||
exit 1
|
exit 1
|
||||||
fi;
|
fi;
|
||||||
|
|
||||||
if [ ! -f "data/ldc93s1/ldc93s1.csv" ]; then
|
if [ ! -f "data/smoke_test/ldc93s1.csv" ]; then
|
||||||
echo "Downloading and preprocessing LDC93S1 example data, saving in ./data/ldc93s1."
|
echo "Downloading and preprocessing LDC93S1 example data, saving in ./data/smoke_test."
|
||||||
python -u bin/import_ldc93s1.py ./data/ldc93s1
|
python -u bin/import_ldc93s1.py ./data/smoke_test
|
||||||
fi;
|
fi;
|
||||||
|
|
||||||
if [ -d "${COMPUTE_KEEP_DIR}" ]; then
|
if [ -d "${COMPUTE_KEEP_DIR}" ]; then
|
||||||
@ -23,8 +23,8 @@ export CUDA_VISIBLE_DEVICES=0
|
|||||||
python -m coqui_stt_training.train \
|
python -m coqui_stt_training.train \
|
||||||
--alphabet_config_path "data/alphabet.txt" \
|
--alphabet_config_path "data/alphabet.txt" \
|
||||||
--show_progressbar false \
|
--show_progressbar false \
|
||||||
--train_files data/ldc93s1/ldc93s1.csv \
|
--train_files data/smoke_test/ldc93s1.csv \
|
||||||
--test_files data/ldc93s1/ldc93s1.csv \
|
--test_files data/smoke_test/ldc93s1.csv \
|
||||||
--train_batch_size 1 \
|
--train_batch_size 1 \
|
||||||
--test_batch_size 1 \
|
--test_batch_size 1 \
|
||||||
--n_hidden 100 \
|
--n_hidden 100 \
|
||||||
|
@ -16,7 +16,7 @@ mkdir -p /tmp/train_tflite || true
|
|||||||
|
|
||||||
set -o pipefail
|
set -o pipefail
|
||||||
python -m pip install --upgrade pip setuptools wheel | cat
|
python -m pip install --upgrade pip setuptools wheel | cat
|
||||||
python -m pip install --upgrade . | cat
|
python -m pip install --upgrade ".[transcribe]" | cat
|
||||||
set +o pipefail
|
set +o pipefail
|
||||||
|
|
||||||
# Prepare correct arguments for training
|
# Prepare correct arguments for training
|
||||||
@ -72,3 +72,20 @@ time python ./bin/run-ldc93s1.py
|
|||||||
|
|
||||||
# Training graph inference
|
# Training graph inference
|
||||||
time ./bin/run-ci-ldc93s1_singleshotinference.sh
|
time ./bin/run-ci-ldc93s1_singleshotinference.sh
|
||||||
|
|
||||||
|
# transcribe module
|
||||||
|
time python -m coqui_stt_training.transcribe \
|
||||||
|
--src "data/smoke_test/LDC93S1.wav" \
|
||||||
|
--dst ${CI_ARTIFACTS_DIR}/transcribe.log \
|
||||||
|
--n_hidden 100 \
|
||||||
|
--scorer_path "data/smoke_test/pruned_lm.scorer"
|
||||||
|
|
||||||
|
#TODO: investigate why this is hanging in CI
|
||||||
|
#mkdir /tmp/transcribe_dir
|
||||||
|
#cp data/smoke_test/LDC93S1.wav /tmp/transcribe_dir
|
||||||
|
#time python -m coqui_stt_training.transcribe \
|
||||||
|
# --src "/tmp/transcribe_dir/" \
|
||||||
|
# --n_hidden 100 \
|
||||||
|
# --scorer_path "data/smoke_test/pruned_lm.scorer"
|
||||||
|
#
|
||||||
|
#for i in data/smoke_test/*.tlog; do echo $i; cat $i; echo; done
|
||||||
|
@ -78,8 +78,8 @@
|
|||||||
"def download_sample_data():\n",
|
"def download_sample_data():\n",
|
||||||
" data_dir=\"english/\"\n",
|
" data_dir=\"english/\"\n",
|
||||||
" # Download data + alphabet\n",
|
" # Download data + alphabet\n",
|
||||||
" audio_file = maybe_download(\"LDC93S1.wav\", data_dir, \"https://catalog.ldc.upenn.edu/desc/addenda/LDC93S1.wav\")\n",
|
" audio_file = maybe_download(\"LDC93S1.wav\", data_dir, \"https://raw.githubusercontent.com/coqui-ai/STT/main/data/smoke_test/LDC93S1.wav\")\n",
|
||||||
" transcript_file = maybe_download(\"LDC93S1.txt\", data_dir, \"https://catalog.ldc.upenn.edu/desc/addenda/LDC93S1.txt\")\n",
|
" transcript_file = maybe_download(\"LDC93S1.txt\", data_dir, \"https://raw.githubusercontent.com/coqui-ai/STT/main/data/smoke_test/LDC93S1.txt\")\n",
|
||||||
" alphabet = maybe_download(\"alphabet.txt\", data_dir, \"https://raw.githubusercontent.com/coqui-ai/STT/main/data/alphabet.txt\")\n",
|
" alphabet = maybe_download(\"alphabet.txt\", data_dir, \"https://raw.githubusercontent.com/coqui-ai/STT/main/data/alphabet.txt\")\n",
|
||||||
" # Format data\n",
|
" # Format data\n",
|
||||||
" with open(transcript_file, \"r\") as fin:\n",
|
" with open(transcript_file, \"r\") as fin:\n",
|
||||||
|
3
setup.py
3
setup.py
@ -69,6 +69,9 @@ def main():
|
|||||||
python_requires=">=3.5, <4",
|
python_requires=">=3.5, <4",
|
||||||
install_requires=install_requires,
|
install_requires=install_requires,
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
|
extras_require={
|
||||||
|
"transcribe": ["webrtcvad"],
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
315
training/coqui_stt_training/transcribe.py
Executable file
315
training/coqui_stt_training/transcribe.py
Executable file
@ -0,0 +1,315 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
# This script is structured in a weird way, with delayed imports. This is due
|
||||||
|
# to the use of multiprocessing. TensorFlow cannot handle forking, and even with
|
||||||
|
# the spawn strategy set to "spawn" it still leads to weird problems, so we
|
||||||
|
# restructure the code so that TensorFlow is only imported inside the child
|
||||||
|
# processes.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import glob
|
||||||
|
import itertools
|
||||||
|
import json
|
||||||
|
import multiprocessing
|
||||||
|
from multiprocessing import Pool, cpu_count
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def fail(message, code=1):
|
||||||
|
print(f"E {message}")
|
||||||
|
sys.exit(code)
|
||||||
|
|
||||||
|
|
||||||
|
def transcribe_file(audio_path, tlog_path):
|
||||||
|
log_level_index = (
|
||||||
|
sys.argv.index("--log_level") + 1 if "--log_level" in sys.argv else 0
|
||||||
|
)
|
||||||
|
desired_log_level = (
|
||||||
|
sys.argv[log_level_index] if 0 < log_level_index < len(sys.argv) else "3"
|
||||||
|
)
|
||||||
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = desired_log_level
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
import tensorflow.compat.v1 as tfv1
|
||||||
|
|
||||||
|
from coqui_stt_training.train import create_model
|
||||||
|
from coqui_stt_training.util.audio import AudioFile
|
||||||
|
from coqui_stt_training.util.checkpoints import load_graph_for_evaluation
|
||||||
|
from coqui_stt_training.util.config import Config
|
||||||
|
from coqui_stt_training.util.feeding import split_audio_file
|
||||||
|
|
||||||
|
initialize_transcribe_config()
|
||||||
|
|
||||||
|
scorer = None
|
||||||
|
if Config.scorer_path:
|
||||||
|
scorer = Scorer(
|
||||||
|
Config.lm_alpha, Config.lm_beta, Config.scorer_path, Config.alphabet
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
num_processes = cpu_count()
|
||||||
|
except NotImplementedError:
|
||||||
|
num_processes = 1
|
||||||
|
|
||||||
|
with AudioFile(audio_path, as_path=True) as wav_path:
|
||||||
|
data_set = split_audio_file(
|
||||||
|
wav_path,
|
||||||
|
batch_size=Config.batch_size,
|
||||||
|
aggressiveness=Config.vad_aggressiveness,
|
||||||
|
outlier_duration_ms=Config.outlier_duration_ms,
|
||||||
|
outlier_batch_size=Config.outlier_batch_size,
|
||||||
|
)
|
||||||
|
iterator = tfv1.data.make_one_shot_iterator(data_set)
|
||||||
|
batch_time_start, batch_time_end, batch_x, batch_x_len = iterator.get_next()
|
||||||
|
no_dropout = [None] * 6
|
||||||
|
logits, _ = create_model(
|
||||||
|
batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout
|
||||||
|
)
|
||||||
|
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
|
||||||
|
tf.train.get_or_create_global_step()
|
||||||
|
with tf.Session(config=Config.session_config) as session:
|
||||||
|
load_graph_for_evaluation(session)
|
||||||
|
transcripts = []
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
starts, ends, batch_logits, batch_lengths = session.run(
|
||||||
|
[batch_time_start, batch_time_end, transposed, batch_x_len]
|
||||||
|
)
|
||||||
|
except tf.errors.OutOfRangeError:
|
||||||
|
break
|
||||||
|
decoded = ctc_beam_search_decoder_batch(
|
||||||
|
batch_logits,
|
||||||
|
batch_lengths,
|
||||||
|
Config.alphabet,
|
||||||
|
Config.beam_width,
|
||||||
|
num_processes=num_processes,
|
||||||
|
scorer=scorer,
|
||||||
|
)
|
||||||
|
decoded = list(d[0][1] for d in decoded)
|
||||||
|
transcripts.extend(zip(starts, ends, decoded))
|
||||||
|
transcripts.sort(key=lambda t: t[0])
|
||||||
|
transcripts = [
|
||||||
|
{"start": int(start), "end": int(end), "transcript": transcript}
|
||||||
|
for start, end, transcript in transcripts
|
||||||
|
]
|
||||||
|
with open(tlog_path, "w") as tlog_file:
|
||||||
|
json.dump(transcripts, tlog_file, default=float)
|
||||||
|
|
||||||
|
|
||||||
|
def step_function(job):
|
||||||
|
""" Wrap transcribe_file to unpack arguments from a single tuple """
|
||||||
|
idx, src, dst = job
|
||||||
|
transcribe_file(src, dst)
|
||||||
|
return idx, src, dst
|
||||||
|
|
||||||
|
|
||||||
|
def transcribe_many(src_paths, dst_paths):
|
||||||
|
from coqui_stt_training.util.config import Config, log_progress
|
||||||
|
|
||||||
|
pool = Pool(processes=min(cpu_count(), len(src_paths)))
|
||||||
|
|
||||||
|
# Create list of items to be processed: [(i, src_path[i], dst_paths[i])]
|
||||||
|
jobs = zip(itertools.count(), src_paths, dst_paths)
|
||||||
|
|
||||||
|
process_iterable = tqdm(
|
||||||
|
pool.imap_unordered(step_function, jobs),
|
||||||
|
desc="Transcribing files",
|
||||||
|
total=len(src_paths),
|
||||||
|
disable=not Config.show_progressbar,
|
||||||
|
)
|
||||||
|
|
||||||
|
for result in process_iterable:
|
||||||
|
idx, src, dst = result
|
||||||
|
log_progress(
|
||||||
|
f'Transcribed file {idx+1} of {len(src_paths)} from "{src}" to "{dst}"'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def transcribe_one(src_path, dst_path):
|
||||||
|
transcribe_file(src_path, dst_path)
|
||||||
|
print(f'I Transcribed file "{src_path}" to "{dst_path}"')
|
||||||
|
|
||||||
|
|
||||||
|
def resolve(base_path, spec_path):
|
||||||
|
if spec_path is None:
|
||||||
|
return None
|
||||||
|
if not os.path.isabs(spec_path):
|
||||||
|
spec_path = os.path.join(base_path, spec_path)
|
||||||
|
return spec_path
|
||||||
|
|
||||||
|
|
||||||
|
def transcribe():
|
||||||
|
from coqui_stt_training.util.config import Config
|
||||||
|
|
||||||
|
initialize_transcribe_config()
|
||||||
|
|
||||||
|
if not Config.src or not os.path.exists(Config.src):
|
||||||
|
# path not given or non-existant
|
||||||
|
fail(
|
||||||
|
"You have to specify which file or catalog to transcribe via the --src flag."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# path given and exists
|
||||||
|
src_path = os.path.abspath(Config.src)
|
||||||
|
if os.path.isfile(src_path):
|
||||||
|
if src_path.endswith(".catalog"):
|
||||||
|
# Transcribe batch of files via ".catalog" file (from DSAlign)
|
||||||
|
catalog_dir = os.path.dirname(src_path)
|
||||||
|
with open(src_path, "r") as catalog_file:
|
||||||
|
catalog_entries = json.load(catalog_file)
|
||||||
|
catalog_entries = [
|
||||||
|
(resolve(catalog_dir, e["audio"]), resolve(catalog_dir, e["tlog"]))
|
||||||
|
for e in catalog_entries
|
||||||
|
]
|
||||||
|
if any(map(lambda e: not os.path.isfile(e[0]), catalog_entries)):
|
||||||
|
fail("Missing source file(s) in catalog")
|
||||||
|
if not Config.force and any(
|
||||||
|
map(lambda e: os.path.isfile(e[1]), catalog_entries)
|
||||||
|
):
|
||||||
|
fail(
|
||||||
|
"Destination file(s) from catalog already existing, use --force for overwriting"
|
||||||
|
)
|
||||||
|
if any(
|
||||||
|
map(
|
||||||
|
lambda e: not os.path.isdir(os.path.dirname(e[1])),
|
||||||
|
catalog_entries,
|
||||||
|
)
|
||||||
|
):
|
||||||
|
fail("Missing destination directory for at least one catalog entry")
|
||||||
|
src_paths, dst_paths = zip(*paths)
|
||||||
|
transcribe_many(src_paths, dst_paths)
|
||||||
|
else:
|
||||||
|
# Transcribe one file
|
||||||
|
dst_path = (
|
||||||
|
os.path.abspath(Config.dst)
|
||||||
|
if Config.dst
|
||||||
|
else os.path.splitext(src_path)[0] + ".tlog"
|
||||||
|
)
|
||||||
|
if os.path.isfile(dst_path):
|
||||||
|
if Config.force:
|
||||||
|
transcribe_one(src_path, dst_path)
|
||||||
|
else:
|
||||||
|
fail(
|
||||||
|
'Destination file "{}" already existing - use --force for overwriting'.format(
|
||||||
|
dst_path
|
||||||
|
),
|
||||||
|
code=0,
|
||||||
|
)
|
||||||
|
elif os.path.isdir(os.path.dirname(dst_path)):
|
||||||
|
transcribe_one(src_path, dst_path)
|
||||||
|
else:
|
||||||
|
fail("Missing destination directory")
|
||||||
|
elif os.path.isdir(src_path):
|
||||||
|
# Transcribe all files in dir
|
||||||
|
print("Transcribing all WAV files in --src")
|
||||||
|
if Config.recursive:
|
||||||
|
wav_paths = glob.glob(os.path.join(src_path, "**", "*.wav"))
|
||||||
|
else:
|
||||||
|
wav_paths = glob.glob(os.path.join(src_path, "*.wav"))
|
||||||
|
dst_paths = [path.replace(".wav", ".tlog") for path in wav_paths]
|
||||||
|
transcribe_many(wav_paths, dst_paths)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_transcribe_config():
|
||||||
|
from coqui_stt_training.util.config import (
|
||||||
|
BaseSttConfig,
|
||||||
|
initialize_globals_from_instance,
|
||||||
|
)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TranscribeConfig(BaseSttConfig):
|
||||||
|
src: str = field(
|
||||||
|
default="",
|
||||||
|
metadata=dict(
|
||||||
|
help="Source path to an audio file or directory or catalog file. "
|
||||||
|
"Catalog files should be formatted from DSAlign. A directory "
|
||||||
|
"will be recursively searched for audio. If --dst not set, "
|
||||||
|
"transcription logs (.tlog) will be written in-place using the "
|
||||||
|
'source filenames with suffix ".tlog" instead of ".wav".'
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
dst: str = field(
|
||||||
|
default="",
|
||||||
|
metadata=dict(
|
||||||
|
help="path for writing the transcription log or logs (.tlog). "
|
||||||
|
"If --src is a directory, this one also has to be a directory "
|
||||||
|
"and the required sub-dir tree of --src will get replicated."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
recursive: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata=dict(help="scan source directory recursively for audio"),
|
||||||
|
)
|
||||||
|
|
||||||
|
force: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata=dict(
|
||||||
|
help="Forces re-transcribing and overwriting of already existing "
|
||||||
|
"transcription logs (.tlog)"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
vad_aggressiveness: int = field(
|
||||||
|
default=3,
|
||||||
|
metadata=dict(help="VAD aggressiveness setting (0=lowest, 3=highest)"),
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_size: int = field(
|
||||||
|
default=40,
|
||||||
|
metadata=dict(help="Default batch size"),
|
||||||
|
)
|
||||||
|
|
||||||
|
outlier_duration_ms: int = field(
|
||||||
|
default=10000,
|
||||||
|
metadata=dict(
|
||||||
|
help="Duration in ms after which samples are considered outliers"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
outlier_batch_size: int = field(
|
||||||
|
default=1,
|
||||||
|
metadata=dict(help="Batch size for duration outliers (defaults to 1)"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if os.path.isfile(self.src) and self.src.endswith(".catalog") and self.dst:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Parameter --dst not supported if --src points to a catalog"
|
||||||
|
)
|
||||||
|
|
||||||
|
if os.path.isdir(self.src):
|
||||||
|
if self.dst:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Destination path not supported for batch decoding jobs."
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__post_init__()
|
||||||
|
|
||||||
|
config = TranscribeConfig.init_from_argparse(arg_prefix="")
|
||||||
|
initialize_globals_from_instance(config)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
from coqui_stt_training.util.helpers import check_ctcdecoder_version
|
||||||
|
|
||||||
|
try:
|
||||||
|
import webrtcvad
|
||||||
|
except ImportError:
|
||||||
|
print(
|
||||||
|
"E transcribe module requires webrtcvad, which cannot be imported. Install with pip install webrtcvad"
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
check_ctcdecoder_version()
|
||||||
|
transcribe()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -75,9 +75,12 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
|
|||||||
init_vars.add(v)
|
init_vars.add(v)
|
||||||
load_vars -= init_vars
|
load_vars -= init_vars
|
||||||
|
|
||||||
|
log_info(f"Vars to load: {list(sorted(v.op.name for v in load_vars))}")
|
||||||
for v in sorted(load_vars, key=lambda v: v.op.name):
|
for v in sorted(load_vars, key=lambda v: v.op.name):
|
||||||
log_info("Loading variable from checkpoint: %s" % (v.op.name))
|
log_info(f"Getting tensor from variable: {v.op.name}")
|
||||||
v.load(ckpt.get_tensor(v.op.name), session=session)
|
tensor = ckpt.get_tensor(v.op.name)
|
||||||
|
log_info(f"Loading tensor from checkpoint: {v.op.name}")
|
||||||
|
v.load(tensor, session=session)
|
||||||
|
|
||||||
for v in sorted(init_vars, key=lambda v: v.op.name):
|
for v in sorted(init_vars, key=lambda v: v.op.name):
|
||||||
log_info("Initializing variable: %s" % (v.op.name))
|
log_info("Initializing variable: %s" % (v.op.name))
|
||||||
|
@ -37,7 +37,7 @@ Config = _ConfigSingleton() # pylint: disable=invalid-name
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class _SttConfig(Coqpit):
|
class BaseSttConfig(Coqpit):
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Augmentations
|
# Augmentations
|
||||||
self.augmentations = parse_augmentations(self.augment)
|
self.augmentations = parse_augmentations(self.augment)
|
||||||
@ -835,16 +835,22 @@ class _SttConfig(Coqpit):
|
|||||||
|
|
||||||
|
|
||||||
def initialize_globals_from_cli():
|
def initialize_globals_from_cli():
|
||||||
c = _SttConfig.init_from_argparse(arg_prefix="")
|
c = BaseSttConfig.init_from_argparse(arg_prefix="")
|
||||||
_ConfigSingleton._config = c # pylint: disable=protected-access
|
_ConfigSingleton._config = c # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
def initialize_globals_from_args(**override_args):
|
def initialize_globals_from_args(**override_args):
|
||||||
# Update Config with new args
|
# Update Config with new args
|
||||||
c = _SttConfig(**override_args)
|
c = BaseSttConfig(**override_args)
|
||||||
_ConfigSingleton._config = c # pylint: disable=protected-access
|
_ConfigSingleton._config = c # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_globals_from_instance(config):
|
||||||
|
""" Initialize Config singleton from an existing Config instance (or subclass) """
|
||||||
|
assert isinstance(config, BaseSttConfig)
|
||||||
|
_ConfigSingleton._config = config # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
# Logging functions
|
# Logging functions
|
||||||
# =================
|
# =================
|
||||||
|
|
||||||
|
251
transcribe.py
251
transcribe.py
@ -2,246 +2,15 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
||||||
import tensorflow.compat.v1.logging as tflogging
|
|
||||||
|
|
||||||
import tensorflow as tf
|
|
||||||
|
|
||||||
tflogging.set_verbosity(tflogging.ERROR)
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logging.getLogger("sox").setLevel(logging.ERROR)
|
|
||||||
import glob
|
|
||||||
from multiprocessing import Process, cpu_count
|
|
||||||
|
|
||||||
from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch
|
|
||||||
from coqui_stt_training.util.audio import AudioFile
|
|
||||||
from coqui_stt_training.util.config import Config, initialize_globals_from_cli
|
|
||||||
from coqui_stt_training.util.feeding import split_audio_file
|
|
||||||
from coqui_stt_training.util.flags import FLAGS, create_flags
|
|
||||||
from coqui_stt_training.util.logging import (
|
|
||||||
create_progressbar,
|
|
||||||
log_error,
|
|
||||||
log_info,
|
|
||||||
log_progress,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def fail(message, code=1):
|
|
||||||
log_error(message)
|
|
||||||
sys.exit(code)
|
|
||||||
|
|
||||||
|
|
||||||
def transcribe_file(audio_path, tlog_path):
|
|
||||||
from coqui_stt_training.train import ( # pylint: disable=cyclic-import,import-outside-toplevel
|
|
||||||
create_model,
|
|
||||||
)
|
|
||||||
from coqui_stt_training.util.checkpoints import load_graph_for_evaluation
|
|
||||||
|
|
||||||
initialize_globals_from_cli()
|
|
||||||
|
|
||||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet)
|
|
||||||
try:
|
|
||||||
num_processes = cpu_count()
|
|
||||||
except NotImplementedError:
|
|
||||||
num_processes = 1
|
|
||||||
with AudioFile(audio_path, as_path=True) as wav_path:
|
|
||||||
data_set = split_audio_file(
|
|
||||||
wav_path,
|
|
||||||
batch_size=FLAGS.batch_size,
|
|
||||||
aggressiveness=FLAGS.vad_aggressiveness,
|
|
||||||
outlier_duration_ms=FLAGS.outlier_duration_ms,
|
|
||||||
outlier_batch_size=FLAGS.outlier_batch_size,
|
|
||||||
)
|
|
||||||
iterator = tf.data.Iterator.from_structure(
|
|
||||||
data_set.output_types,
|
|
||||||
data_set.output_shapes,
|
|
||||||
output_classes=data_set.output_classes,
|
|
||||||
)
|
|
||||||
batch_time_start, batch_time_end, batch_x, batch_x_len = iterator.get_next()
|
|
||||||
no_dropout = [None] * 6
|
|
||||||
logits, _ = create_model(
|
|
||||||
batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout
|
|
||||||
)
|
|
||||||
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
|
|
||||||
tf.train.get_or_create_global_step()
|
|
||||||
with tf.Session(config=Config.session_config) as session:
|
|
||||||
load_graph_for_evaluation(session)
|
|
||||||
session.run(iterator.make_initializer(data_set))
|
|
||||||
transcripts = []
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
starts, ends, batch_logits, batch_lengths = session.run(
|
|
||||||
[batch_time_start, batch_time_end, transposed, batch_x_len]
|
|
||||||
)
|
|
||||||
except tf.errors.OutOfRangeError:
|
|
||||||
break
|
|
||||||
decoded = ctc_beam_search_decoder_batch(
|
|
||||||
batch_logits,
|
|
||||||
batch_lengths,
|
|
||||||
Config.alphabet,
|
|
||||||
FLAGS.beam_width,
|
|
||||||
num_processes=num_processes,
|
|
||||||
scorer=scorer,
|
|
||||||
)
|
|
||||||
decoded = list(d[0][1] for d in decoded)
|
|
||||||
transcripts.extend(zip(starts, ends, decoded))
|
|
||||||
transcripts.sort(key=lambda t: t[0])
|
|
||||||
transcripts = [
|
|
||||||
{"start": int(start), "end": int(end), "transcript": transcript}
|
|
||||||
for start, end, transcript in transcripts
|
|
||||||
]
|
|
||||||
with open(tlog_path, "w") as tlog_file:
|
|
||||||
json.dump(transcripts, tlog_file, default=float)
|
|
||||||
|
|
||||||
|
|
||||||
def transcribe_many(src_paths, dst_paths):
|
|
||||||
pbar = create_progressbar(
|
|
||||||
prefix="Transcribing files | ", max_value=len(src_paths)
|
|
||||||
).start()
|
|
||||||
for i in range(len(src_paths)):
|
|
||||||
p = Process(target=transcribe_file, args=(src_paths[i], dst_paths[i]))
|
|
||||||
p.start()
|
|
||||||
p.join()
|
|
||||||
log_progress(
|
|
||||||
'Transcribed file {} of {} from "{}" to "{}"'.format(
|
|
||||||
i + 1, len(src_paths), src_paths[i], dst_paths[i]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
pbar.update(i)
|
|
||||||
pbar.finish()
|
|
||||||
|
|
||||||
|
|
||||||
def transcribe_one(src_path, dst_path):
|
|
||||||
transcribe_file(src_path, dst_path)
|
|
||||||
log_info('Transcribed file "{}" to "{}"'.format(src_path, dst_path))
|
|
||||||
|
|
||||||
|
|
||||||
def resolve(base_path, spec_path):
|
|
||||||
if spec_path is None:
|
|
||||||
return None
|
|
||||||
if not os.path.isabs(spec_path):
|
|
||||||
spec_path = os.path.join(base_path, spec_path)
|
|
||||||
return spec_path
|
|
||||||
|
|
||||||
|
|
||||||
def main(_):
|
|
||||||
if not FLAGS.src or not os.path.exists(FLAGS.src):
|
|
||||||
# path not given or non-existant
|
|
||||||
fail(
|
|
||||||
"You have to specify which file or catalog to transcribe via the --src flag."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# path given and exists
|
|
||||||
src_path = os.path.abspath(FLAGS.src)
|
|
||||||
if os.path.isfile(src_path):
|
|
||||||
if src_path.endswith(".catalog"):
|
|
||||||
# Transcribe batch of files via ".catalog" file (from DSAlign)
|
|
||||||
if FLAGS.dst:
|
|
||||||
fail("Parameter --dst not supported if --src points to a catalog")
|
|
||||||
catalog_dir = os.path.dirname(src_path)
|
|
||||||
with open(src_path, "r") as catalog_file:
|
|
||||||
catalog_entries = json.load(catalog_file)
|
|
||||||
catalog_entries = [
|
|
||||||
(resolve(catalog_dir, e["audio"]), resolve(catalog_dir, e["tlog"]))
|
|
||||||
for e in catalog_entries
|
|
||||||
]
|
|
||||||
if any(map(lambda e: not os.path.isfile(e[0]), catalog_entries)):
|
|
||||||
fail("Missing source file(s) in catalog")
|
|
||||||
if not FLAGS.force and any(
|
|
||||||
map(lambda e: os.path.isfile(e[1]), catalog_entries)
|
|
||||||
):
|
|
||||||
fail(
|
|
||||||
"Destination file(s) from catalog already existing, use --force for overwriting"
|
|
||||||
)
|
|
||||||
if any(
|
|
||||||
map(
|
|
||||||
lambda e: not os.path.isdir(os.path.dirname(e[1])),
|
|
||||||
catalog_entries,
|
|
||||||
)
|
|
||||||
):
|
|
||||||
fail("Missing destination directory for at least one catalog entry")
|
|
||||||
src_paths, dst_paths = zip(*paths)
|
|
||||||
transcribe_many(src_paths, dst_paths)
|
|
||||||
else:
|
|
||||||
# Transcribe one file
|
|
||||||
dst_path = (
|
|
||||||
os.path.abspath(FLAGS.dst)
|
|
||||||
if FLAGS.dst
|
|
||||||
else os.path.splitext(src_path)[0] + ".tlog"
|
|
||||||
)
|
|
||||||
if os.path.isfile(dst_path):
|
|
||||||
if FLAGS.force:
|
|
||||||
transcribe_one(src_path, dst_path)
|
|
||||||
else:
|
|
||||||
fail(
|
|
||||||
'Destination file "{}" already existing - use --force for overwriting'.format(
|
|
||||||
dst_path
|
|
||||||
),
|
|
||||||
code=0,
|
|
||||||
)
|
|
||||||
elif os.path.isdir(os.path.dirname(dst_path)):
|
|
||||||
transcribe_one(src_path, dst_path)
|
|
||||||
else:
|
|
||||||
fail("Missing destination directory")
|
|
||||||
elif os.path.isdir(src_path):
|
|
||||||
# Transcribe all files in dir
|
|
||||||
print("Transcribing all WAV files in --src")
|
|
||||||
if FLAGS.dst:
|
|
||||||
fail("Destination file not supported for batch decoding jobs.")
|
|
||||||
else:
|
|
||||||
if not FLAGS.recursive:
|
|
||||||
print(
|
|
||||||
"If you wish to recursively scan --src, then you must use --recursive"
|
|
||||||
)
|
|
||||||
wav_paths = glob.glob(src_path + "/*.wav")
|
|
||||||
else:
|
|
||||||
wav_paths = glob.glob(src_path + "/**/*.wav")
|
|
||||||
dst_paths = [path.replace(".wav", ".tlog") for path in wav_paths]
|
|
||||||
transcribe_many(wav_paths, dst_paths)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
create_flags()
|
print(
|
||||||
tf.app.flags.DEFINE_string(
|
"Using the top level transcribe.py script is deprecated and will be removed "
|
||||||
"src",
|
"in a future release. Instead use: python -m coqui_stt_training.transcribe"
|
||||||
"",
|
|
||||||
"Source path to an audio file or directory or catalog file."
|
|
||||||
"Catalog files should be formatted from DSAlign. A directory will"
|
|
||||||
"be recursively searched for audio. If --dst not set, transcription logs (.tlog) will be "
|
|
||||||
"written in-place using the source filenames with "
|
|
||||||
'suffix ".tlog" instead of ".wav".',
|
|
||||||
)
|
)
|
||||||
tf.app.flags.DEFINE_string(
|
try:
|
||||||
"dst",
|
from coqui_stt_training import transcribe as stt_transcribe
|
||||||
"",
|
except ImportError:
|
||||||
"path for writing the transcription log or logs (.tlog). "
|
print("Training package is not installed. See training documentation.")
|
||||||
"If --src is a directory, this one also has to be a directory "
|
raise
|
||||||
"and the required sub-dir tree of --src will get replicated.",
|
|
||||||
)
|
stt_transcribe.main()
|
||||||
tf.app.flags.DEFINE_boolean("recursive", False, "scan dir of audio recursively")
|
|
||||||
tf.app.flags.DEFINE_boolean(
|
|
||||||
"force",
|
|
||||||
False,
|
|
||||||
"Forces re-transcribing and overwriting of already existing "
|
|
||||||
"transcription logs (.tlog)",
|
|
||||||
)
|
|
||||||
tf.app.flags.DEFINE_integer(
|
|
||||||
"vad_aggressiveness",
|
|
||||||
3,
|
|
||||||
"How aggressive (0=lowest, 3=highest) the VAD should " "split audio",
|
|
||||||
)
|
|
||||||
tf.app.flags.DEFINE_integer("batch_size", 40, "Default batch size")
|
|
||||||
tf.app.flags.DEFINE_float(
|
|
||||||
"outlier_duration_ms",
|
|
||||||
10000,
|
|
||||||
"Duration in ms after which samples are considered outliers",
|
|
||||||
)
|
|
||||||
tf.app.flags.DEFINE_integer(
|
|
||||||
"outlier_batch_size", 1, "Batch size for duration outliers (defaults to 1)"
|
|
||||||
)
|
|
||||||
tf.app.run(main)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user