[transcribe] Fix multiprocessing hangs, clean-up target collection

This commit is contained in:
Reuben Morais 2021-12-01 14:24:42 +01:00
parent 5cefd7069c
commit d90bb60506
4 changed files with 176 additions and 126 deletions

View File

@ -80,12 +80,11 @@ time python -m coqui_stt_training.transcribe \
--n_hidden 100 \
--scorer_path "data/smoke_test/pruned_lm.scorer"
#TODO: investigate why this is hanging in CI
#mkdir /tmp/transcribe_dir
#cp data/smoke_test/LDC93S1.wav /tmp/transcribe_dir
#time python -m coqui_stt_training.transcribe \
# --src "/tmp/transcribe_dir/" \
# --n_hidden 100 \
# --scorer_path "data/smoke_test/pruned_lm.scorer"
#
#for i in data/smoke_test/*.tlog; do echo $i; cat $i; echo; done
mkdir /tmp/transcribe_dir
cp data/smoke_test/LDC93S1.wav /tmp/transcribe_dir
time python -m coqui_stt_training.transcribe \
--src "/tmp/transcribe_dir/" \
--n_hidden 100 \
--scorer_path "data/smoke_test/pruned_lm.scorer"
for i in /tmp/transcribe_dir/*.tlog; do echo $i; cat $i; echo; done

View File

@ -7,25 +7,22 @@
# restructure the code so that TensorFlow is only imported inside the child
# processes.
import os
import sys
import glob
import itertools
import json
import multiprocessing
from multiprocessing import Pool, cpu_count
import os
import sys
from dataclasses import dataclass, field
from multiprocessing import Pool, Lock, cpu_count
from pathlib import Path
from typing import Optional, List, Tuple
from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch
from tqdm import tqdm
def fail(message, code=1):
print(f"E {message}")
sys.exit(code)
def transcribe_file(audio_path, tlog_path):
def transcribe_file(audio_path: Path, tlog_path: Path):
log_level_index = (
sys.argv.index("--log_level") + 1 if "--log_level" in sys.argv else 0
)
@ -56,7 +53,7 @@ def transcribe_file(audio_path, tlog_path):
except NotImplementedError:
num_processes = 1
with AudioFile(audio_path, as_path=True) as wav_path:
with AudioFile(str(audio_path), as_path=True) as wav_path:
data_set = split_audio_file(
wav_path,
batch_size=Config.batch_size,
@ -73,7 +70,9 @@ def transcribe_file(audio_path, tlog_path):
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
tf.train.get_or_create_global_step()
with tf.Session(config=Config.session_config) as session:
load_graph_for_evaluation(session)
# Load checkpoint in a mutex way to avoid hangs in TensorFlow code
with lock:
load_graph_for_evaluation(session, silent=True)
transcripts = []
while True:
try:
@ -101,8 +100,13 @@ def transcribe_file(audio_path, tlog_path):
json.dump(transcripts, tlog_file, default=float)
def init_fn(l):
global lock
lock = l
def step_function(job):
""" Wrap transcribe_file to unpack arguments from a single tuple """
"""Wrap transcribe_file to unpack arguments from a single tuple"""
idx, src, dst = job
transcribe_file(src, dst)
return idx, src, dst
@ -111,36 +115,81 @@ def step_function(job):
def transcribe_many(src_paths, dst_paths):
from coqui_stt_training.util.config import Config, log_progress
pool = Pool(processes=min(cpu_count(), len(src_paths)))
# Create list of items to be processed: [(i, src_path[i], dst_paths[i])]
jobs = zip(itertools.count(), src_paths, dst_paths)
process_iterable = tqdm(
pool.imap_unordered(step_function, jobs),
desc="Transcribing files",
total=len(src_paths),
disable=not Config.show_progressbar,
)
for result in process_iterable:
idx, src, dst = result
log_progress(
f'Transcribed file {idx+1} of {len(src_paths)} from "{src}" to "{dst}"'
lock = Lock()
with Pool(
processes=min(cpu_count(), len(src_paths)),
initializer=init_fn,
initargs=(lock,),
) as pool:
process_iterable = tqdm(
pool.imap_unordered(step_function, jobs),
desc="Transcribing files",
total=len(src_paths),
disable=not Config.show_progressbar,
)
def transcribe_one(src_path, dst_path):
transcribe_file(src_path, dst_path)
print(f'I Transcribed file "{src_path}" to "{dst_path}"')
cwd = Path.cwd()
for result in process_iterable:
idx, src, dst = result
# Revert to relative to avoid spamming logs
if not src.is_absolute():
src = src.relative_to(cwd)
if not dst.is_absolute():
dst = dst.relative_to(cwd)
tqdm.write(f'[{idx+1}]: "{src}" -> "{dst}"')
def resolve(base_path, spec_path):
if spec_path is None:
return None
if not os.path.isabs(spec_path):
spec_path = os.path.join(base_path, spec_path)
return spec_path
def get_tasks_from_catalog(catalog_file_path: Path) -> Tuple[List[Path], List[Path]]:
"""Given a `catalog_file_path` pointing to a .catalog file (from DSAlign),
extract transcription tasks, ie. (src_path, dest_path) pairs corresponding to
a path to an audio file to be transcribed, and a path to a JSON file to place
transcription results. For .catalog file inputs, these are taken from the
"audio" and "tlog" properties of the entries in the catalog, with any relative
paths being absolutized relative to the directory containing the .catalog file.
"""
assert catalog_file_path.suffix == ".catalog"
catalog_dir = catalog_file_path.parent
with open(catalog_file_path, "r") as catalog_file:
catalog_entries = json.load(catalog_file)
def resolve(spec_path: Optional[Path]):
if spec_path is None:
return None
if not spec_path.is_absolute():
spec_path = catalog_dir / spec_path
return spec_path
catalog_entries = [
(resolve(Path(e["audio"])), resolve(Path(e["tlog"]))) for e in catalog_entries
]
for src, dst in catalog_entries:
if not Config.force and dst.is_file():
raise RuntimeError(
f"Destination file already exists: {dst}. Use --force for overwriting."
)
if not dst.parent.is_dir():
dst.parent.mkdir(parents=True)
src_paths, dst_paths = zip(*catalog_entries)
return src_paths, dst_paths
def get_tasks_from_dir(src_dir: Path, recursive: bool) -> Tuple[List[Path], List[Path]]:
"""Given a directory `src_dir` containing audio files, scan it for audio files
and return transcription tasks, ie. (src_path, dest_path) pairs corresponding to
a path to an audio file to be transcribed, and a path to a JSON file to place
transcription results.
"""
glob_method = src_dir.rglob if recursive else src_dir.glob
src_paths = list(itertools.chain(glob_method("*.wav"), glob_method("*.opus")))
dst_paths = [path.with_suffix(".tlog") for path in src_paths]
return src_paths, dst_paths
def transcribe():
@ -148,71 +197,43 @@ def transcribe():
initialize_transcribe_config()
if not Config.src or not os.path.exists(Config.src):
src_path = Path(Config.src).resolve()
if not Config.src or not src_path.exists():
# path not given or non-existant
fail(
"You have to specify which file or catalog to transcribe via the --src flag."
raise RuntimeError(
"You have to specify which audio file, catalog file or directory to "
"transcribe with the --src flag."
)
else:
# path given and exists
src_path = os.path.abspath(Config.src)
if os.path.isfile(src_path):
if src_path.endswith(".catalog"):
# Transcribe batch of files via ".catalog" file (from DSAlign)
catalog_dir = os.path.dirname(src_path)
with open(src_path, "r") as catalog_file:
catalog_entries = json.load(catalog_file)
catalog_entries = [
(resolve(catalog_dir, e["audio"]), resolve(catalog_dir, e["tlog"]))
for e in catalog_entries
]
if any(map(lambda e: not os.path.isfile(e[0]), catalog_entries)):
fail("Missing source file(s) in catalog")
if not Config.force and any(
map(lambda e: os.path.isfile(e[1]), catalog_entries)
):
fail(
"Destination file(s) from catalog already existing, use --force for overwriting"
)
if any(
map(
lambda e: not os.path.isdir(os.path.dirname(e[1])),
catalog_entries,
)
):
fail("Missing destination directory for at least one catalog entry")
src_paths, dst_paths = zip(*paths)
transcribe_many(src_paths, dst_paths)
else:
if src_path.is_file():
if src_path.suffix != ".catalog":
# Transcribe one file
dst_path = (
os.path.abspath(Config.dst)
Path(Config.dst).resolve()
if Config.dst
else os.path.splitext(src_path)[0] + ".tlog"
else src_path.with_suffix(".tlog")
)
if os.path.isfile(dst_path):
if Config.force:
transcribe_one(src_path, dst_path)
else:
fail(
'Destination file "{}" already existing - use --force for overwriting'.format(
dst_path
),
code=0,
)
elif os.path.isdir(os.path.dirname(dst_path)):
transcribe_one(src_path, dst_path)
else:
fail("Missing destination directory")
elif os.path.isdir(src_path):
# Transcribe all files in dir
print("Transcribing all WAV files in --src")
if Config.recursive:
wav_paths = glob.glob(os.path.join(src_path, "**", "*.wav"))
if dst_path.is_file() and not Config.force:
raise RuntimeError(
f'Destination file "{dst_path}" already exists - use '
"--force for overwriting."
)
if not dst_path.parent.is_dir():
raise RuntimeError("Missing destination directory")
transcribe_many([src_path], [dst_path])
else:
wav_paths = glob.glob(os.path.join(src_path, "*.wav"))
dst_paths = [path.replace(".wav", ".tlog") for path in wav_paths]
transcribe_many(wav_paths, dst_paths)
# Transcribe from .catalog input
src_paths, dst_paths = get_tasks_from_catalog(src_path)
transcribe_many(src_paths, dst_paths)
elif src_path.is_dir():
# Transcribe from dir input
print(f"Transcribing all files in --src directory {src_path}")
src_paths, dst_paths = get_tasks_from_dir(src_path, Config.recursive)
transcribe_many(src_paths, dst_paths)
def initialize_transcribe_config():
@ -230,7 +251,7 @@ def initialize_transcribe_config():
"Catalog files should be formatted from DSAlign. A directory "
"will be recursively searched for audio. If --dst not set, "
"transcription logs (.tlog) will be written in-place using the "
'source filenames with suffix ".tlog" instead of ".wav".'
'source filenames with suffix ".tlog" instead of the original.'
),
)
@ -299,6 +320,9 @@ def initialize_transcribe_config():
def main():
from coqui_stt_training.util.helpers import check_ctcdecoder_version
# Set start method to spawn on all platforms to avoid issues with TensorFlow
multiprocessing.set_start_method("spawn")
try:
import webrtcvad
except ImportError:

View File

@ -7,7 +7,13 @@ import tensorflow as tf
from .config import Config, log_error, log_info, log_warn
def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=True):
def _load_checkpoint(
session,
checkpoint_path,
allow_drop_layers,
allow_lr_init=True,
silent: bool = False,
):
# Load the checkpoint and put all variables into loading list
# we will exclude variables we do not wish to load and then
# we will initialize them instead
@ -75,15 +81,16 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
init_vars.add(v)
load_vars -= init_vars
log_info(f"Vars to load: {list(sorted(v.op.name for v in load_vars))}")
def maybe_log_info(*args, **kwargs):
if not silent:
log_info(*args, **kwargs)
for v in sorted(load_vars, key=lambda v: v.op.name):
log_info(f"Getting tensor from variable: {v.op.name}")
tensor = ckpt.get_tensor(v.op.name)
log_info(f"Loading tensor from checkpoint: {v.op.name}")
v.load(tensor, session=session)
maybe_log_info(f"Loading variable from checkpoint: {v.op.name}")
v.load(ckpt.get_tensor(v.op.name), session=session)
for v in sorted(init_vars, key=lambda v: v.op.name):
log_info("Initializing variable: %s" % (v.op.name))
maybe_log_info("Initializing variable: %s" % (v.op.name))
session.run(v.initializer)
@ -102,31 +109,49 @@ def _initialize_all_variables(session):
session.run(v.initializer)
def _load_or_init_impl(session, method_order, allow_drop_layers, allow_lr_init=True):
def _load_or_init_impl(
session, method_order, allow_drop_layers, allow_lr_init=True, silent: bool = False
):
def maybe_log_info(*args, **kwargs):
if not silent:
log_info(*args, **kwargs)
for method in method_order:
# Load best validating checkpoint, saved in checkpoint file 'best_dev_checkpoint'
if method == "best":
ckpt_path = _checkpoint_path_or_none("best_dev_checkpoint")
if ckpt_path:
log_info("Loading best validating checkpoint from {}".format(ckpt_path))
return _load_checkpoint(
session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init
maybe_log_info(
"Loading best validating checkpoint from {}".format(ckpt_path)
)
log_info("Could not find best validating checkpoint.")
return _load_checkpoint(
session,
ckpt_path,
allow_drop_layers,
allow_lr_init=allow_lr_init,
silent=silent,
)
maybe_log_info("Could not find best validating checkpoint.")
# Load most recent checkpoint, saved in checkpoint file 'checkpoint'
elif method == "last":
ckpt_path = _checkpoint_path_or_none("checkpoint")
if ckpt_path:
log_info("Loading most recent checkpoint from {}".format(ckpt_path))
return _load_checkpoint(
session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init
maybe_log_info(
"Loading most recent checkpoint from {}".format(ckpt_path)
)
log_info("Could not find most recent checkpoint.")
return _load_checkpoint(
session,
ckpt_path,
allow_drop_layers,
allow_lr_init=allow_lr_init,
silent=silent,
)
maybe_log_info("Could not find most recent checkpoint.")
# Initialize all variables
elif method == "init":
log_info("Initializing all variables.")
maybe_log_info("Initializing all variables.")
return _initialize_all_variables(session)
else:
@ -141,7 +166,7 @@ def reload_best_checkpoint(session):
_load_or_init_impl(session, ["best"], allow_drop_layers=False, allow_lr_init=False)
def load_or_init_graph_for_training(session):
def load_or_init_graph_for_training(session, silent: bool = False):
"""
Load variables from checkpoint or initialize variables. By default this will
try to load the best validating checkpoint, then try the last checkpoint,
@ -152,10 +177,10 @@ def load_or_init_graph_for_training(session):
methods = ["best", "last", "init"]
else:
methods = [Config.load_train]
_load_or_init_impl(session, methods, allow_drop_layers=True)
_load_or_init_impl(session, methods, allow_drop_layers=True, silent=silent)
def load_graph_for_evaluation(session):
def load_graph_for_evaluation(session, silent: bool = False):
"""
Load variables from checkpoint. Initialization is not allowed. By default
this will try to load the best validating checkpoint, then try the last
@ -166,4 +191,4 @@ def load_graph_for_evaluation(session):
methods = ["best", "last"]
else:
methods = [Config.load_evaluate]
_load_or_init_impl(session, methods, allow_drop_layers=False)
_load_or_init_impl(session, methods, allow_drop_layers=False, silent=silent)

View File

@ -217,12 +217,14 @@ class BaseSttConfig(Coqpit):
if not is_remote_path(self.save_checkpoint_dir):
os.makedirs(self.save_checkpoint_dir, exist_ok=True)
flags_file = os.path.join(self.save_checkpoint_dir, "flags.txt")
with open_remote(flags_file, "w") as fout:
json.dump(self.serialize(), fout, indent=2)
if not os.path.exists(flags_file):
with open_remote(flags_file, "w") as fout:
json.dump(self.serialize(), fout, indent=2)
# Serialize alphabet alongside checkpoint
with open_remote(saved_checkpoint_alphabet_file, "wb") as fout:
fout.write(self.alphabet.SerializeText())
if not os.path.exists(saved_checkpoint_alphabet_file):
with open_remote(saved_checkpoint_alphabet_file, "wb") as fout:
fout.write(self.alphabet.SerializeText())
# Geometric Constants
# ===================