334 lines
11 KiB
Python
Executable File
334 lines
11 KiB
Python
Executable File
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
import glob
|
|
import itertools
|
|
import json
|
|
import multiprocessing
|
|
import os
|
|
import sys
|
|
from dataclasses import dataclass, field
|
|
from multiprocessing import Pool, Lock, cpu_count
|
|
from pathlib import Path
|
|
from typing import Optional, List, Tuple
|
|
|
|
LOG_LEVEL_INDEX = sys.argv.index("--log_level") + 1 if "--log_level" in sys.argv else 0
|
|
DESIRED_LOG_LEVEL = (
|
|
sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else "3"
|
|
)
|
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = DESIRED_LOG_LEVEL
|
|
# Hide GPUs to prevent issues with child processes trying to use the same GPU
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
|
import tensorflow as tf
|
|
import tensorflow.compat.v1 as tfv1
|
|
|
|
from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch
|
|
from coqui_stt_training.train import create_model
|
|
from coqui_stt_training.util.audio import AudioFile
|
|
from coqui_stt_training.util.checkpoints import load_graph_for_evaluation
|
|
from coqui_stt_training.util.config import (
|
|
BaseSttConfig,
|
|
Config,
|
|
initialize_globals_from_instance,
|
|
)
|
|
from coqui_stt_training.util.feeding import split_audio_file
|
|
from coqui_stt_training.util.helpers import check_ctcdecoder_version
|
|
from tqdm import tqdm
|
|
|
|
|
|
def transcribe_file(audio_path: Path, tlog_path: Path):
|
|
initialize_transcribe_config()
|
|
|
|
scorer = None
|
|
if Config.scorer_path:
|
|
scorer = Scorer(
|
|
Config.lm_alpha, Config.lm_beta, Config.scorer_path, Config.alphabet
|
|
)
|
|
|
|
try:
|
|
num_processes = cpu_count()
|
|
except NotImplementedError:
|
|
num_processes = 1
|
|
|
|
with AudioFile(str(audio_path), as_path=True) as wav_path:
|
|
data_set = split_audio_file(
|
|
wav_path,
|
|
batch_size=Config.batch_size,
|
|
aggressiveness=Config.vad_aggressiveness,
|
|
outlier_duration_ms=Config.outlier_duration_ms,
|
|
outlier_batch_size=Config.outlier_batch_size,
|
|
)
|
|
iterator = tfv1.data.make_one_shot_iterator(data_set)
|
|
batch_time_start, batch_time_end, batch_x, batch_x_len = iterator.get_next()
|
|
no_dropout = [None] * 6
|
|
logits, _ = create_model(
|
|
batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout
|
|
)
|
|
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
|
|
tf.train.get_or_create_global_step()
|
|
with tf.Session(config=Config.session_config) as session:
|
|
# Load checkpoint in a mutex way to avoid hangs in TensorFlow code
|
|
with lock:
|
|
load_graph_for_evaluation(session, silent=True)
|
|
transcripts = []
|
|
while True:
|
|
try:
|
|
starts, ends, batch_logits, batch_lengths = session.run(
|
|
[batch_time_start, batch_time_end, transposed, batch_x_len]
|
|
)
|
|
except tf.errors.OutOfRangeError:
|
|
break
|
|
decoded = ctc_beam_search_decoder_batch(
|
|
batch_logits,
|
|
batch_lengths,
|
|
Config.alphabet,
|
|
Config.beam_width,
|
|
num_processes=num_processes,
|
|
scorer=scorer,
|
|
)
|
|
decoded = list(d[0][1] for d in decoded)
|
|
transcripts.extend(zip(starts, ends, decoded))
|
|
transcripts.sort(key=lambda t: t[0])
|
|
transcripts = [
|
|
{"start": int(start), "end": int(end), "transcript": transcript}
|
|
for start, end, transcript in transcripts
|
|
]
|
|
with open(tlog_path, "w") as tlog_file:
|
|
json.dump(transcripts, tlog_file, default=float)
|
|
|
|
|
|
def init_fn(l):
|
|
global lock
|
|
lock = l
|
|
|
|
|
|
def step_function(job):
|
|
"""Wrap transcribe_file to unpack arguments from a single tuple"""
|
|
idx, src, dst = job
|
|
transcribe_file(src, dst)
|
|
return idx, src, dst
|
|
|
|
|
|
def transcribe_many(src_paths, dst_paths):
|
|
# Create list of items to be processed: [(i, src_path[i], dst_paths[i])]
|
|
jobs = zip(itertools.count(), src_paths, dst_paths)
|
|
|
|
lock = Lock()
|
|
with Pool(
|
|
processes=min(cpu_count(), len(src_paths)),
|
|
initializer=init_fn,
|
|
initargs=(lock,),
|
|
) as pool:
|
|
process_iterable = tqdm(
|
|
pool.imap_unordered(step_function, jobs),
|
|
desc="Transcribing files",
|
|
total=len(src_paths),
|
|
disable=not Config.show_progressbar,
|
|
)
|
|
|
|
cwd = Path.cwd()
|
|
for result in process_iterable:
|
|
idx, src, dst = result
|
|
# Revert to relative if possible to make logs more concise
|
|
# if path is not relative to cwd, use the absolute path
|
|
# (Path.is_relative_to is only available in Python >=3.9)
|
|
try:
|
|
src = src.relative_to(cwd)
|
|
except ValueError:
|
|
pass
|
|
try:
|
|
dst = dst.relative_to(cwd)
|
|
except ValueError:
|
|
pass
|
|
tqdm.write(f'[{idx+1}]: "{src}" -> "{dst}"')
|
|
|
|
|
|
def get_tasks_from_catalog(catalog_file_path: Path) -> Tuple[List[Path], List[Path]]:
|
|
"""Given a `catalog_file_path` pointing to a .catalog file (from DSAlign),
|
|
extract transcription tasks, ie. (src_path, dest_path) pairs corresponding to
|
|
a path to an audio file to be transcribed, and a path to a JSON file to place
|
|
transcription results. For .catalog file inputs, these are taken from the
|
|
"audio" and "tlog" properties of the entries in the catalog, with any relative
|
|
paths being absolutized relative to the directory containing the .catalog file.
|
|
"""
|
|
assert catalog_file_path.suffix == ".catalog"
|
|
|
|
catalog_dir = catalog_file_path.parent
|
|
with open(catalog_file_path, "r") as catalog_file:
|
|
catalog_entries = json.load(catalog_file)
|
|
|
|
def resolve(spec_path: Optional[Path]):
|
|
if spec_path is None:
|
|
return None
|
|
if not spec_path.is_absolute():
|
|
spec_path = catalog_dir / spec_path
|
|
return spec_path
|
|
|
|
catalog_entries = [
|
|
(resolve(Path(e["audio"])), resolve(Path(e["tlog"]))) for e in catalog_entries
|
|
]
|
|
|
|
for src, dst in catalog_entries:
|
|
if not Config.force and dst.is_file():
|
|
raise RuntimeError(
|
|
f"Destination file already exists: {dst}. Use --force for overwriting."
|
|
)
|
|
|
|
if not dst.parent.is_dir():
|
|
dst.parent.mkdir(parents=True)
|
|
|
|
src_paths, dst_paths = zip(*catalog_entries)
|
|
return src_paths, dst_paths
|
|
|
|
|
|
def get_tasks_from_dir(src_dir: Path, recursive: bool) -> Tuple[List[Path], List[Path]]:
|
|
"""Given a directory `src_dir` containing audio files, scan it for audio files
|
|
and return transcription tasks, ie. (src_path, dest_path) pairs corresponding to
|
|
a path to an audio file to be transcribed, and a path to a JSON file to place
|
|
transcription results.
|
|
"""
|
|
glob_method = src_dir.rglob if recursive else src_dir.glob
|
|
src_paths = list(glob_method("*.wav"))
|
|
dst_paths = [path.with_suffix(".tlog") for path in src_paths]
|
|
return src_paths, dst_paths
|
|
|
|
|
|
def transcribe():
|
|
initialize_transcribe_config()
|
|
|
|
src_path = Path(Config.src).resolve()
|
|
if not Config.src or not src_path.exists():
|
|
# path not given or non-existant
|
|
raise RuntimeError(
|
|
"You have to specify which audio file, catalog file or directory to "
|
|
"transcribe with the --src flag."
|
|
)
|
|
else:
|
|
# path given and exists
|
|
if src_path.is_file():
|
|
if src_path.suffix != ".catalog":
|
|
# Transcribe one file
|
|
dst_path = (
|
|
Path(Config.dst).resolve()
|
|
if Config.dst
|
|
else src_path.with_suffix(".tlog")
|
|
)
|
|
|
|
if dst_path.is_file() and not Config.force:
|
|
raise RuntimeError(
|
|
f'Destination file "{dst_path}" already exists - use '
|
|
"--force for overwriting."
|
|
)
|
|
|
|
if not dst_path.parent.is_dir():
|
|
raise RuntimeError("Missing destination directory")
|
|
|
|
transcribe_many([src_path], [dst_path])
|
|
else:
|
|
# Transcribe from .catalog input
|
|
src_paths, dst_paths = get_tasks_from_catalog(src_path)
|
|
transcribe_many(src_paths, dst_paths)
|
|
elif src_path.is_dir():
|
|
# Transcribe from dir input
|
|
print(f"Transcribing all files in --src directory {src_path}")
|
|
src_paths, dst_paths = get_tasks_from_dir(src_path, Config.recursive)
|
|
transcribe_many(src_paths, dst_paths)
|
|
|
|
|
|
@dataclass
|
|
class TranscribeConfig(BaseSttConfig):
|
|
src: str = field(
|
|
default="",
|
|
metadata=dict(
|
|
help="Source path to an audio file or directory or catalog file. "
|
|
"Catalog files should be formatted from DSAlign. A directory "
|
|
"will be recursively searched for audio. If --dst not set, "
|
|
"transcription logs (.tlog) will be written in-place using the "
|
|
'source filenames with suffix ".tlog" instead of the original.'
|
|
),
|
|
)
|
|
|
|
dst: str = field(
|
|
default="",
|
|
metadata=dict(
|
|
help="path for writing the transcription log or logs (.tlog). "
|
|
"If --src is a directory, this one also has to be a directory "
|
|
"and the required sub-dir tree of --src will get replicated."
|
|
),
|
|
)
|
|
|
|
recursive: bool = field(
|
|
default=False,
|
|
metadata=dict(help="scan source directory recursively for audio"),
|
|
)
|
|
|
|
force: bool = field(
|
|
default=False,
|
|
metadata=dict(
|
|
help="Forces re-transcribing and overwriting of already existing "
|
|
"transcription logs (.tlog)"
|
|
),
|
|
)
|
|
|
|
vad_aggressiveness: int = field(
|
|
default=3,
|
|
metadata=dict(help="VAD aggressiveness setting (0=lowest, 3=highest)"),
|
|
)
|
|
|
|
batch_size: int = field(
|
|
default=40,
|
|
metadata=dict(help="Default batch size"),
|
|
)
|
|
|
|
outlier_duration_ms: int = field(
|
|
default=10000,
|
|
metadata=dict(
|
|
help="Duration in ms after which samples are considered outliers"
|
|
),
|
|
)
|
|
|
|
outlier_batch_size: int = field(
|
|
default=1,
|
|
metadata=dict(help="Batch size for duration outliers (defaults to 1)"),
|
|
)
|
|
|
|
def __post_init__(self):
|
|
if os.path.isfile(self.src) and self.src.endswith(".catalog") and self.dst:
|
|
raise RuntimeError(
|
|
"Parameter --dst not supported if --src points to a catalog"
|
|
)
|
|
|
|
if os.path.isdir(self.src):
|
|
if self.dst:
|
|
raise RuntimeError(
|
|
"Destination path not supported for batch decoding jobs."
|
|
)
|
|
|
|
super().__post_init__()
|
|
|
|
|
|
def initialize_transcribe_config():
|
|
config = TranscribeConfig.init_from_argparse(arg_prefix="")
|
|
initialize_globals_from_instance(config)
|
|
|
|
|
|
def main():
|
|
assert not tf.test.is_gpu_available()
|
|
|
|
# Set start method to spawn on all platforms to avoid issues with TensorFlow
|
|
multiprocessing.set_start_method("spawn")
|
|
|
|
try:
|
|
import webrtcvad
|
|
except ImportError:
|
|
print(
|
|
"E transcribe module requires webrtcvad, which cannot be imported. Install with pip install webrtcvad"
|
|
)
|
|
sys.exit(1)
|
|
|
|
check_ctcdecoder_version()
|
|
transcribe()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|