Revive transcribe.py
Update to use Coqpit based config handling, fix multiprocesing setup, and add CI coverage.
This commit is contained in:
parent
419b15b72a
commit
efdaa61e2c
2
.github/workflows/build-and-test.yml
vendored
2
.github/workflows/build-and-test.yml
vendored
@ -808,7 +808,7 @@ jobs:
|
||||
- run: |
|
||||
mkdir -p ${CI_ARTIFACTS_DIR} || true
|
||||
- run: |
|
||||
sudo apt-get install -y --no-install-recommends libopus0
|
||||
sudo apt-get install -y --no-install-recommends libopus0 sox
|
||||
- name: Run extra training tests
|
||||
run: |
|
||||
python -m pip install coqui_stt_ctcdecoder-*.whl
|
||||
|
@ -8,14 +8,14 @@ from coqui_stt_training.evaluate import test
|
||||
# only one GPU for only one training sample
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
|
||||
download_ldc("data/ldc93s1")
|
||||
download_ldc("data/smoke_test")
|
||||
|
||||
initialize_globals_from_args(
|
||||
load_train="init",
|
||||
alphabet_config_path="data/alphabet.txt",
|
||||
train_files=["data/ldc93s1/ldc93s1.csv"],
|
||||
dev_files=["data/ldc93s1/ldc93s1.csv"],
|
||||
test_files=["data/ldc93s1/ldc93s1.csv"],
|
||||
train_files=["data/smoke_test/ldc93s1.csv"],
|
||||
dev_files=["data/smoke_test/ldc93s1.csv"],
|
||||
test_files=["data/smoke_test/ldc93s1.csv"],
|
||||
augment=["time_mask"],
|
||||
n_hidden=100,
|
||||
epochs=200,
|
||||
|
@ -5,9 +5,9 @@ if [ ! -f train.py ]; then
|
||||
exit 1
|
||||
fi;
|
||||
|
||||
if [ ! -f "data/ldc93s1/ldc93s1.csv" ]; then
|
||||
echo "Downloading and preprocessing LDC93S1 example data, saving in ./data/ldc93s1."
|
||||
python -u bin/import_ldc93s1.py ./data/ldc93s1
|
||||
if [ ! -f "data/smoke_test/ldc93s1.csv" ]; then
|
||||
echo "Downloading and preprocessing LDC93S1 example data, saving in ./data/smoke_test."
|
||||
python -u bin/import_ldc93s1.py ./data/smoke_test
|
||||
fi;
|
||||
|
||||
if [ -d "${COMPUTE_KEEP_DIR}" ]; then
|
||||
@ -23,8 +23,8 @@ export CUDA_VISIBLE_DEVICES=0
|
||||
python -m coqui_stt_training.train \
|
||||
--alphabet_config_path "data/alphabet.txt" \
|
||||
--show_progressbar false \
|
||||
--train_files data/ldc93s1/ldc93s1.csv \
|
||||
--test_files data/ldc93s1/ldc93s1.csv \
|
||||
--train_files data/smoke_test/ldc93s1.csv \
|
||||
--test_files data/smoke_test/ldc93s1.csv \
|
||||
--train_batch_size 1 \
|
||||
--test_batch_size 1 \
|
||||
--n_hidden 100 \
|
||||
|
@ -16,7 +16,7 @@ mkdir -p /tmp/train_tflite || true
|
||||
|
||||
set -o pipefail
|
||||
python -m pip install --upgrade pip setuptools wheel | cat
|
||||
python -m pip install --upgrade . | cat
|
||||
python -m pip install --upgrade ".[transcribe]" | cat
|
||||
set +o pipefail
|
||||
|
||||
# Prepare correct arguments for training
|
||||
@ -72,3 +72,20 @@ time python ./bin/run-ldc93s1.py
|
||||
|
||||
# Training graph inference
|
||||
time ./bin/run-ci-ldc93s1_singleshotinference.sh
|
||||
|
||||
# transcribe module
|
||||
time python -m coqui_stt_training.transcribe \
|
||||
--src "data/smoke_test/LDC93S1.wav" \
|
||||
--dst ${CI_ARTIFACTS_DIR}/transcribe.log \
|
||||
--n_hidden 100 \
|
||||
--scorer_path "data/smoke_test/pruned_lm.scorer"
|
||||
|
||||
#TODO: investigate why this is hanging in CI
|
||||
#mkdir /tmp/transcribe_dir
|
||||
#cp data/smoke_test/LDC93S1.wav /tmp/transcribe_dir
|
||||
#time python -m coqui_stt_training.transcribe \
|
||||
# --src "/tmp/transcribe_dir/" \
|
||||
# --n_hidden 100 \
|
||||
# --scorer_path "data/smoke_test/pruned_lm.scorer"
|
||||
#
|
||||
#for i in data/smoke_test/*.tlog; do echo $i; cat $i; echo; done
|
||||
|
@ -78,8 +78,8 @@
|
||||
"def download_sample_data():\n",
|
||||
" data_dir=\"english/\"\n",
|
||||
" # Download data + alphabet\n",
|
||||
" audio_file = maybe_download(\"LDC93S1.wav\", data_dir, \"https://catalog.ldc.upenn.edu/desc/addenda/LDC93S1.wav\")\n",
|
||||
" transcript_file = maybe_download(\"LDC93S1.txt\", data_dir, \"https://catalog.ldc.upenn.edu/desc/addenda/LDC93S1.txt\")\n",
|
||||
" audio_file = maybe_download(\"LDC93S1.wav\", data_dir, \"https://raw.githubusercontent.com/coqui-ai/STT/main/data/smoke_test/LDC93S1.wav\")\n",
|
||||
" transcript_file = maybe_download(\"LDC93S1.txt\", data_dir, \"https://raw.githubusercontent.com/coqui-ai/STT/main/data/smoke_test/LDC93S1.txt\")\n",
|
||||
" alphabet = maybe_download(\"alphabet.txt\", data_dir, \"https://raw.githubusercontent.com/coqui-ai/STT/main/data/alphabet.txt\")\n",
|
||||
" # Format data\n",
|
||||
" with open(transcript_file, \"r\") as fin:\n",
|
||||
|
3
setup.py
3
setup.py
@ -69,6 +69,9 @@ def main():
|
||||
python_requires=">=3.5, <4",
|
||||
install_requires=install_requires,
|
||||
include_package_data=True,
|
||||
extras_require={
|
||||
"transcribe": ["webrtcvad"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
315
training/coqui_stt_training/transcribe.py
Executable file
315
training/coqui_stt_training/transcribe.py
Executable file
@ -0,0 +1,315 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# This script is structured in a weird way, with delayed imports. This is due
|
||||
# to the use of multiprocessing. TensorFlow cannot handle forking, and even with
|
||||
# the spawn strategy set to "spawn" it still leads to weird problems, so we
|
||||
# restructure the code so that TensorFlow is only imported inside the child
|
||||
# processes.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import glob
|
||||
import itertools
|
||||
import json
|
||||
import multiprocessing
|
||||
from multiprocessing import Pool, cpu_count
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def fail(message, code=1):
|
||||
print(f"E {message}")
|
||||
sys.exit(code)
|
||||
|
||||
|
||||
def transcribe_file(audio_path, tlog_path):
|
||||
log_level_index = (
|
||||
sys.argv.index("--log_level") + 1 if "--log_level" in sys.argv else 0
|
||||
)
|
||||
desired_log_level = (
|
||||
sys.argv[log_level_index] if 0 < log_level_index < len(sys.argv) else "3"
|
||||
)
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = desired_log_level
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
|
||||
from coqui_stt_training.train import create_model
|
||||
from coqui_stt_training.util.audio import AudioFile
|
||||
from coqui_stt_training.util.checkpoints import load_graph_for_evaluation
|
||||
from coqui_stt_training.util.config import Config
|
||||
from coqui_stt_training.util.feeding import split_audio_file
|
||||
|
||||
initialize_transcribe_config()
|
||||
|
||||
scorer = None
|
||||
if Config.scorer_path:
|
||||
scorer = Scorer(
|
||||
Config.lm_alpha, Config.lm_beta, Config.scorer_path, Config.alphabet
|
||||
)
|
||||
|
||||
try:
|
||||
num_processes = cpu_count()
|
||||
except NotImplementedError:
|
||||
num_processes = 1
|
||||
|
||||
with AudioFile(audio_path, as_path=True) as wav_path:
|
||||
data_set = split_audio_file(
|
||||
wav_path,
|
||||
batch_size=Config.batch_size,
|
||||
aggressiveness=Config.vad_aggressiveness,
|
||||
outlier_duration_ms=Config.outlier_duration_ms,
|
||||
outlier_batch_size=Config.outlier_batch_size,
|
||||
)
|
||||
iterator = tfv1.data.make_one_shot_iterator(data_set)
|
||||
batch_time_start, batch_time_end, batch_x, batch_x_len = iterator.get_next()
|
||||
no_dropout = [None] * 6
|
||||
logits, _ = create_model(
|
||||
batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout
|
||||
)
|
||||
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
|
||||
tf.train.get_or_create_global_step()
|
||||
with tf.Session(config=Config.session_config) as session:
|
||||
load_graph_for_evaluation(session)
|
||||
transcripts = []
|
||||
while True:
|
||||
try:
|
||||
starts, ends, batch_logits, batch_lengths = session.run(
|
||||
[batch_time_start, batch_time_end, transposed, batch_x_len]
|
||||
)
|
||||
except tf.errors.OutOfRangeError:
|
||||
break
|
||||
decoded = ctc_beam_search_decoder_batch(
|
||||
batch_logits,
|
||||
batch_lengths,
|
||||
Config.alphabet,
|
||||
Config.beam_width,
|
||||
num_processes=num_processes,
|
||||
scorer=scorer,
|
||||
)
|
||||
decoded = list(d[0][1] for d in decoded)
|
||||
transcripts.extend(zip(starts, ends, decoded))
|
||||
transcripts.sort(key=lambda t: t[0])
|
||||
transcripts = [
|
||||
{"start": int(start), "end": int(end), "transcript": transcript}
|
||||
for start, end, transcript in transcripts
|
||||
]
|
||||
with open(tlog_path, "w") as tlog_file:
|
||||
json.dump(transcripts, tlog_file, default=float)
|
||||
|
||||
|
||||
def step_function(job):
|
||||
""" Wrap transcribe_file to unpack arguments from a single tuple """
|
||||
idx, src, dst = job
|
||||
transcribe_file(src, dst)
|
||||
return idx, src, dst
|
||||
|
||||
|
||||
def transcribe_many(src_paths, dst_paths):
|
||||
from coqui_stt_training.util.config import Config, log_progress
|
||||
|
||||
pool = Pool(processes=min(cpu_count(), len(src_paths)))
|
||||
|
||||
# Create list of items to be processed: [(i, src_path[i], dst_paths[i])]
|
||||
jobs = zip(itertools.count(), src_paths, dst_paths)
|
||||
|
||||
process_iterable = tqdm(
|
||||
pool.imap_unordered(step_function, jobs),
|
||||
desc="Transcribing files",
|
||||
total=len(src_paths),
|
||||
disable=not Config.show_progressbar,
|
||||
)
|
||||
|
||||
for result in process_iterable:
|
||||
idx, src, dst = result
|
||||
log_progress(
|
||||
f'Transcribed file {idx+1} of {len(src_paths)} from "{src}" to "{dst}"'
|
||||
)
|
||||
|
||||
|
||||
def transcribe_one(src_path, dst_path):
|
||||
transcribe_file(src_path, dst_path)
|
||||
print(f'I Transcribed file "{src_path}" to "{dst_path}"')
|
||||
|
||||
|
||||
def resolve(base_path, spec_path):
|
||||
if spec_path is None:
|
||||
return None
|
||||
if not os.path.isabs(spec_path):
|
||||
spec_path = os.path.join(base_path, spec_path)
|
||||
return spec_path
|
||||
|
||||
|
||||
def transcribe():
|
||||
from coqui_stt_training.util.config import Config
|
||||
|
||||
initialize_transcribe_config()
|
||||
|
||||
if not Config.src or not os.path.exists(Config.src):
|
||||
# path not given or non-existant
|
||||
fail(
|
||||
"You have to specify which file or catalog to transcribe via the --src flag."
|
||||
)
|
||||
else:
|
||||
# path given and exists
|
||||
src_path = os.path.abspath(Config.src)
|
||||
if os.path.isfile(src_path):
|
||||
if src_path.endswith(".catalog"):
|
||||
# Transcribe batch of files via ".catalog" file (from DSAlign)
|
||||
catalog_dir = os.path.dirname(src_path)
|
||||
with open(src_path, "r") as catalog_file:
|
||||
catalog_entries = json.load(catalog_file)
|
||||
catalog_entries = [
|
||||
(resolve(catalog_dir, e["audio"]), resolve(catalog_dir, e["tlog"]))
|
||||
for e in catalog_entries
|
||||
]
|
||||
if any(map(lambda e: not os.path.isfile(e[0]), catalog_entries)):
|
||||
fail("Missing source file(s) in catalog")
|
||||
if not Config.force and any(
|
||||
map(lambda e: os.path.isfile(e[1]), catalog_entries)
|
||||
):
|
||||
fail(
|
||||
"Destination file(s) from catalog already existing, use --force for overwriting"
|
||||
)
|
||||
if any(
|
||||
map(
|
||||
lambda e: not os.path.isdir(os.path.dirname(e[1])),
|
||||
catalog_entries,
|
||||
)
|
||||
):
|
||||
fail("Missing destination directory for at least one catalog entry")
|
||||
src_paths, dst_paths = zip(*paths)
|
||||
transcribe_many(src_paths, dst_paths)
|
||||
else:
|
||||
# Transcribe one file
|
||||
dst_path = (
|
||||
os.path.abspath(Config.dst)
|
||||
if Config.dst
|
||||
else os.path.splitext(src_path)[0] + ".tlog"
|
||||
)
|
||||
if os.path.isfile(dst_path):
|
||||
if Config.force:
|
||||
transcribe_one(src_path, dst_path)
|
||||
else:
|
||||
fail(
|
||||
'Destination file "{}" already existing - use --force for overwriting'.format(
|
||||
dst_path
|
||||
),
|
||||
code=0,
|
||||
)
|
||||
elif os.path.isdir(os.path.dirname(dst_path)):
|
||||
transcribe_one(src_path, dst_path)
|
||||
else:
|
||||
fail("Missing destination directory")
|
||||
elif os.path.isdir(src_path):
|
||||
# Transcribe all files in dir
|
||||
print("Transcribing all WAV files in --src")
|
||||
if Config.recursive:
|
||||
wav_paths = glob.glob(os.path.join(src_path, "**", "*.wav"))
|
||||
else:
|
||||
wav_paths = glob.glob(os.path.join(src_path, "*.wav"))
|
||||
dst_paths = [path.replace(".wav", ".tlog") for path in wav_paths]
|
||||
transcribe_many(wav_paths, dst_paths)
|
||||
|
||||
|
||||
def initialize_transcribe_config():
|
||||
from coqui_stt_training.util.config import (
|
||||
BaseSttConfig,
|
||||
initialize_globals_from_instance,
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class TranscribeConfig(BaseSttConfig):
|
||||
src: str = field(
|
||||
default="",
|
||||
metadata=dict(
|
||||
help="Source path to an audio file or directory or catalog file. "
|
||||
"Catalog files should be formatted from DSAlign. A directory "
|
||||
"will be recursively searched for audio. If --dst not set, "
|
||||
"transcription logs (.tlog) will be written in-place using the "
|
||||
'source filenames with suffix ".tlog" instead of ".wav".'
|
||||
),
|
||||
)
|
||||
|
||||
dst: str = field(
|
||||
default="",
|
||||
metadata=dict(
|
||||
help="path for writing the transcription log or logs (.tlog). "
|
||||
"If --src is a directory, this one also has to be a directory "
|
||||
"and the required sub-dir tree of --src will get replicated."
|
||||
),
|
||||
)
|
||||
|
||||
recursive: bool = field(
|
||||
default=False,
|
||||
metadata=dict(help="scan source directory recursively for audio"),
|
||||
)
|
||||
|
||||
force: bool = field(
|
||||
default=False,
|
||||
metadata=dict(
|
||||
help="Forces re-transcribing and overwriting of already existing "
|
||||
"transcription logs (.tlog)"
|
||||
),
|
||||
)
|
||||
|
||||
vad_aggressiveness: int = field(
|
||||
default=3,
|
||||
metadata=dict(help="VAD aggressiveness setting (0=lowest, 3=highest)"),
|
||||
)
|
||||
|
||||
batch_size: int = field(
|
||||
default=40,
|
||||
metadata=dict(help="Default batch size"),
|
||||
)
|
||||
|
||||
outlier_duration_ms: int = field(
|
||||
default=10000,
|
||||
metadata=dict(
|
||||
help="Duration in ms after which samples are considered outliers"
|
||||
),
|
||||
)
|
||||
|
||||
outlier_batch_size: int = field(
|
||||
default=1,
|
||||
metadata=dict(help="Batch size for duration outliers (defaults to 1)"),
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if os.path.isfile(self.src) and self.src.endswith(".catalog") and self.dst:
|
||||
raise RuntimeError(
|
||||
"Parameter --dst not supported if --src points to a catalog"
|
||||
)
|
||||
|
||||
if os.path.isdir(self.src):
|
||||
if self.dst:
|
||||
raise RuntimeError(
|
||||
"Destination path not supported for batch decoding jobs."
|
||||
)
|
||||
|
||||
super().__post_init__()
|
||||
|
||||
config = TranscribeConfig.init_from_argparse(arg_prefix="")
|
||||
initialize_globals_from_instance(config)
|
||||
|
||||
|
||||
def main():
|
||||
from coqui_stt_training.util.helpers import check_ctcdecoder_version
|
||||
|
||||
try:
|
||||
import webrtcvad
|
||||
except ImportError:
|
||||
print(
|
||||
"E transcribe module requires webrtcvad, which cannot be imported. Install with pip install webrtcvad"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
check_ctcdecoder_version()
|
||||
transcribe()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -75,9 +75,12 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
|
||||
init_vars.add(v)
|
||||
load_vars -= init_vars
|
||||
|
||||
log_info(f"Vars to load: {list(sorted(v.op.name for v in load_vars))}")
|
||||
for v in sorted(load_vars, key=lambda v: v.op.name):
|
||||
log_info("Loading variable from checkpoint: %s" % (v.op.name))
|
||||
v.load(ckpt.get_tensor(v.op.name), session=session)
|
||||
log_info(f"Getting tensor from variable: {v.op.name}")
|
||||
tensor = ckpt.get_tensor(v.op.name)
|
||||
log_info(f"Loading tensor from checkpoint: {v.op.name}")
|
||||
v.load(tensor, session=session)
|
||||
|
||||
for v in sorted(init_vars, key=lambda v: v.op.name):
|
||||
log_info("Initializing variable: %s" % (v.op.name))
|
||||
|
@ -37,7 +37,7 @@ Config = _ConfigSingleton() # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class _SttConfig(Coqpit):
|
||||
class BaseSttConfig(Coqpit):
|
||||
def __post_init__(self):
|
||||
# Augmentations
|
||||
self.augmentations = parse_augmentations(self.augment)
|
||||
@ -835,16 +835,22 @@ class _SttConfig(Coqpit):
|
||||
|
||||
|
||||
def initialize_globals_from_cli():
|
||||
c = _SttConfig.init_from_argparse(arg_prefix="")
|
||||
c = BaseSttConfig.init_from_argparse(arg_prefix="")
|
||||
_ConfigSingleton._config = c # pylint: disable=protected-access
|
||||
|
||||
|
||||
def initialize_globals_from_args(**override_args):
|
||||
# Update Config with new args
|
||||
c = _SttConfig(**override_args)
|
||||
c = BaseSttConfig(**override_args)
|
||||
_ConfigSingleton._config = c # pylint: disable=protected-access
|
||||
|
||||
|
||||
def initialize_globals_from_instance(config):
|
||||
""" Initialize Config singleton from an existing Config instance (or subclass) """
|
||||
assert isinstance(config, BaseSttConfig)
|
||||
_ConfigSingleton._config = config # pylint: disable=protected-access
|
||||
|
||||
|
||||
# Logging functions
|
||||
# =================
|
||||
|
||||
|
251
transcribe.py
251
transcribe.py
@ -2,246 +2,15 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
import tensorflow.compat.v1.logging as tflogging
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
tflogging.set_verbosity(tflogging.ERROR)
|
||||
import logging
|
||||
|
||||
logging.getLogger("sox").setLevel(logging.ERROR)
|
||||
import glob
|
||||
from multiprocessing import Process, cpu_count
|
||||
|
||||
from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch
|
||||
from coqui_stt_training.util.audio import AudioFile
|
||||
from coqui_stt_training.util.config import Config, initialize_globals_from_cli
|
||||
from coqui_stt_training.util.feeding import split_audio_file
|
||||
from coqui_stt_training.util.flags import FLAGS, create_flags
|
||||
from coqui_stt_training.util.logging import (
|
||||
create_progressbar,
|
||||
log_error,
|
||||
log_info,
|
||||
log_progress,
|
||||
)
|
||||
|
||||
|
||||
def fail(message, code=1):
|
||||
log_error(message)
|
||||
sys.exit(code)
|
||||
|
||||
|
||||
def transcribe_file(audio_path, tlog_path):
|
||||
from coqui_stt_training.train import ( # pylint: disable=cyclic-import,import-outside-toplevel
|
||||
create_model,
|
||||
)
|
||||
from coqui_stt_training.util.checkpoints import load_graph_for_evaluation
|
||||
|
||||
initialize_globals_from_cli()
|
||||
|
||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet)
|
||||
try:
|
||||
num_processes = cpu_count()
|
||||
except NotImplementedError:
|
||||
num_processes = 1
|
||||
with AudioFile(audio_path, as_path=True) as wav_path:
|
||||
data_set = split_audio_file(
|
||||
wav_path,
|
||||
batch_size=FLAGS.batch_size,
|
||||
aggressiveness=FLAGS.vad_aggressiveness,
|
||||
outlier_duration_ms=FLAGS.outlier_duration_ms,
|
||||
outlier_batch_size=FLAGS.outlier_batch_size,
|
||||
)
|
||||
iterator = tf.data.Iterator.from_structure(
|
||||
data_set.output_types,
|
||||
data_set.output_shapes,
|
||||
output_classes=data_set.output_classes,
|
||||
)
|
||||
batch_time_start, batch_time_end, batch_x, batch_x_len = iterator.get_next()
|
||||
no_dropout = [None] * 6
|
||||
logits, _ = create_model(
|
||||
batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout
|
||||
)
|
||||
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
|
||||
tf.train.get_or_create_global_step()
|
||||
with tf.Session(config=Config.session_config) as session:
|
||||
load_graph_for_evaluation(session)
|
||||
session.run(iterator.make_initializer(data_set))
|
||||
transcripts = []
|
||||
while True:
|
||||
try:
|
||||
starts, ends, batch_logits, batch_lengths = session.run(
|
||||
[batch_time_start, batch_time_end, transposed, batch_x_len]
|
||||
)
|
||||
except tf.errors.OutOfRangeError:
|
||||
break
|
||||
decoded = ctc_beam_search_decoder_batch(
|
||||
batch_logits,
|
||||
batch_lengths,
|
||||
Config.alphabet,
|
||||
FLAGS.beam_width,
|
||||
num_processes=num_processes,
|
||||
scorer=scorer,
|
||||
)
|
||||
decoded = list(d[0][1] for d in decoded)
|
||||
transcripts.extend(zip(starts, ends, decoded))
|
||||
transcripts.sort(key=lambda t: t[0])
|
||||
transcripts = [
|
||||
{"start": int(start), "end": int(end), "transcript": transcript}
|
||||
for start, end, transcript in transcripts
|
||||
]
|
||||
with open(tlog_path, "w") as tlog_file:
|
||||
json.dump(transcripts, tlog_file, default=float)
|
||||
|
||||
|
||||
def transcribe_many(src_paths, dst_paths):
|
||||
pbar = create_progressbar(
|
||||
prefix="Transcribing files | ", max_value=len(src_paths)
|
||||
).start()
|
||||
for i in range(len(src_paths)):
|
||||
p = Process(target=transcribe_file, args=(src_paths[i], dst_paths[i]))
|
||||
p.start()
|
||||
p.join()
|
||||
log_progress(
|
||||
'Transcribed file {} of {} from "{}" to "{}"'.format(
|
||||
i + 1, len(src_paths), src_paths[i], dst_paths[i]
|
||||
)
|
||||
)
|
||||
pbar.update(i)
|
||||
pbar.finish()
|
||||
|
||||
|
||||
def transcribe_one(src_path, dst_path):
|
||||
transcribe_file(src_path, dst_path)
|
||||
log_info('Transcribed file "{}" to "{}"'.format(src_path, dst_path))
|
||||
|
||||
|
||||
def resolve(base_path, spec_path):
|
||||
if spec_path is None:
|
||||
return None
|
||||
if not os.path.isabs(spec_path):
|
||||
spec_path = os.path.join(base_path, spec_path)
|
||||
return spec_path
|
||||
|
||||
|
||||
def main(_):
|
||||
if not FLAGS.src or not os.path.exists(FLAGS.src):
|
||||
# path not given or non-existant
|
||||
fail(
|
||||
"You have to specify which file or catalog to transcribe via the --src flag."
|
||||
)
|
||||
else:
|
||||
# path given and exists
|
||||
src_path = os.path.abspath(FLAGS.src)
|
||||
if os.path.isfile(src_path):
|
||||
if src_path.endswith(".catalog"):
|
||||
# Transcribe batch of files via ".catalog" file (from DSAlign)
|
||||
if FLAGS.dst:
|
||||
fail("Parameter --dst not supported if --src points to a catalog")
|
||||
catalog_dir = os.path.dirname(src_path)
|
||||
with open(src_path, "r") as catalog_file:
|
||||
catalog_entries = json.load(catalog_file)
|
||||
catalog_entries = [
|
||||
(resolve(catalog_dir, e["audio"]), resolve(catalog_dir, e["tlog"]))
|
||||
for e in catalog_entries
|
||||
]
|
||||
if any(map(lambda e: not os.path.isfile(e[0]), catalog_entries)):
|
||||
fail("Missing source file(s) in catalog")
|
||||
if not FLAGS.force and any(
|
||||
map(lambda e: os.path.isfile(e[1]), catalog_entries)
|
||||
):
|
||||
fail(
|
||||
"Destination file(s) from catalog already existing, use --force for overwriting"
|
||||
)
|
||||
if any(
|
||||
map(
|
||||
lambda e: not os.path.isdir(os.path.dirname(e[1])),
|
||||
catalog_entries,
|
||||
)
|
||||
):
|
||||
fail("Missing destination directory for at least one catalog entry")
|
||||
src_paths, dst_paths = zip(*paths)
|
||||
transcribe_many(src_paths, dst_paths)
|
||||
else:
|
||||
# Transcribe one file
|
||||
dst_path = (
|
||||
os.path.abspath(FLAGS.dst)
|
||||
if FLAGS.dst
|
||||
else os.path.splitext(src_path)[0] + ".tlog"
|
||||
)
|
||||
if os.path.isfile(dst_path):
|
||||
if FLAGS.force:
|
||||
transcribe_one(src_path, dst_path)
|
||||
else:
|
||||
fail(
|
||||
'Destination file "{}" already existing - use --force for overwriting'.format(
|
||||
dst_path
|
||||
),
|
||||
code=0,
|
||||
)
|
||||
elif os.path.isdir(os.path.dirname(dst_path)):
|
||||
transcribe_one(src_path, dst_path)
|
||||
else:
|
||||
fail("Missing destination directory")
|
||||
elif os.path.isdir(src_path):
|
||||
# Transcribe all files in dir
|
||||
print("Transcribing all WAV files in --src")
|
||||
if FLAGS.dst:
|
||||
fail("Destination file not supported for batch decoding jobs.")
|
||||
else:
|
||||
if not FLAGS.recursive:
|
||||
print(
|
||||
"If you wish to recursively scan --src, then you must use --recursive"
|
||||
)
|
||||
wav_paths = glob.glob(src_path + "/*.wav")
|
||||
else:
|
||||
wav_paths = glob.glob(src_path + "/**/*.wav")
|
||||
dst_paths = [path.replace(".wav", ".tlog") for path in wav_paths]
|
||||
transcribe_many(wav_paths, dst_paths)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
create_flags()
|
||||
tf.app.flags.DEFINE_string(
|
||||
"src",
|
||||
"",
|
||||
"Source path to an audio file or directory or catalog file."
|
||||
"Catalog files should be formatted from DSAlign. A directory will"
|
||||
"be recursively searched for audio. If --dst not set, transcription logs (.tlog) will be "
|
||||
"written in-place using the source filenames with "
|
||||
'suffix ".tlog" instead of ".wav".',
|
||||
print(
|
||||
"Using the top level transcribe.py script is deprecated and will be removed "
|
||||
"in a future release. Instead use: python -m coqui_stt_training.transcribe"
|
||||
)
|
||||
tf.app.flags.DEFINE_string(
|
||||
"dst",
|
||||
"",
|
||||
"path for writing the transcription log or logs (.tlog). "
|
||||
"If --src is a directory, this one also has to be a directory "
|
||||
"and the required sub-dir tree of --src will get replicated.",
|
||||
)
|
||||
tf.app.flags.DEFINE_boolean("recursive", False, "scan dir of audio recursively")
|
||||
tf.app.flags.DEFINE_boolean(
|
||||
"force",
|
||||
False,
|
||||
"Forces re-transcribing and overwriting of already existing "
|
||||
"transcription logs (.tlog)",
|
||||
)
|
||||
tf.app.flags.DEFINE_integer(
|
||||
"vad_aggressiveness",
|
||||
3,
|
||||
"How aggressive (0=lowest, 3=highest) the VAD should " "split audio",
|
||||
)
|
||||
tf.app.flags.DEFINE_integer("batch_size", 40, "Default batch size")
|
||||
tf.app.flags.DEFINE_float(
|
||||
"outlier_duration_ms",
|
||||
10000,
|
||||
"Duration in ms after which samples are considered outliers",
|
||||
)
|
||||
tf.app.flags.DEFINE_integer(
|
||||
"outlier_batch_size", 1, "Batch size for duration outliers (defaults to 1)"
|
||||
)
|
||||
tf.app.run(main)
|
||||
try:
|
||||
from coqui_stt_training import transcribe as stt_transcribe
|
||||
except ImportError:
|
||||
print("Training package is not installed. See training documentation.")
|
||||
raise
|
||||
|
||||
stt_transcribe.main()
|
||||
|
Loading…
Reference in New Issue
Block a user