[transcribe] Fix multiprocessing hangs, clean-up target collection
This commit is contained in:
parent
5cefd7069c
commit
d90bb60506
@ -80,12 +80,11 @@ time python -m coqui_stt_training.transcribe \
|
||||
--n_hidden 100 \
|
||||
--scorer_path "data/smoke_test/pruned_lm.scorer"
|
||||
|
||||
#TODO: investigate why this is hanging in CI
|
||||
#mkdir /tmp/transcribe_dir
|
||||
#cp data/smoke_test/LDC93S1.wav /tmp/transcribe_dir
|
||||
#time python -m coqui_stt_training.transcribe \
|
||||
# --src "/tmp/transcribe_dir/" \
|
||||
# --n_hidden 100 \
|
||||
# --scorer_path "data/smoke_test/pruned_lm.scorer"
|
||||
#
|
||||
#for i in data/smoke_test/*.tlog; do echo $i; cat $i; echo; done
|
||||
mkdir /tmp/transcribe_dir
|
||||
cp data/smoke_test/LDC93S1.wav /tmp/transcribe_dir
|
||||
time python -m coqui_stt_training.transcribe \
|
||||
--src "/tmp/transcribe_dir/" \
|
||||
--n_hidden 100 \
|
||||
--scorer_path "data/smoke_test/pruned_lm.scorer"
|
||||
|
||||
for i in /tmp/transcribe_dir/*.tlog; do echo $i; cat $i; echo; done
|
||||
|
@ -7,25 +7,22 @@
|
||||
# restructure the code so that TensorFlow is only imported inside the child
|
||||
# processes.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import glob
|
||||
import itertools
|
||||
import json
|
||||
import multiprocessing
|
||||
from multiprocessing import Pool, cpu_count
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from multiprocessing import Pool, Lock, cpu_count
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def fail(message, code=1):
|
||||
print(f"E {message}")
|
||||
sys.exit(code)
|
||||
|
||||
|
||||
def transcribe_file(audio_path, tlog_path):
|
||||
def transcribe_file(audio_path: Path, tlog_path: Path):
|
||||
log_level_index = (
|
||||
sys.argv.index("--log_level") + 1 if "--log_level" in sys.argv else 0
|
||||
)
|
||||
@ -56,7 +53,7 @@ def transcribe_file(audio_path, tlog_path):
|
||||
except NotImplementedError:
|
||||
num_processes = 1
|
||||
|
||||
with AudioFile(audio_path, as_path=True) as wav_path:
|
||||
with AudioFile(str(audio_path), as_path=True) as wav_path:
|
||||
data_set = split_audio_file(
|
||||
wav_path,
|
||||
batch_size=Config.batch_size,
|
||||
@ -73,7 +70,9 @@ def transcribe_file(audio_path, tlog_path):
|
||||
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
|
||||
tf.train.get_or_create_global_step()
|
||||
with tf.Session(config=Config.session_config) as session:
|
||||
load_graph_for_evaluation(session)
|
||||
# Load checkpoint in a mutex way to avoid hangs in TensorFlow code
|
||||
with lock:
|
||||
load_graph_for_evaluation(session, silent=True)
|
||||
transcripts = []
|
||||
while True:
|
||||
try:
|
||||
@ -101,8 +100,13 @@ def transcribe_file(audio_path, tlog_path):
|
||||
json.dump(transcripts, tlog_file, default=float)
|
||||
|
||||
|
||||
def init_fn(l):
|
||||
global lock
|
||||
lock = l
|
||||
|
||||
|
||||
def step_function(job):
|
||||
""" Wrap transcribe_file to unpack arguments from a single tuple """
|
||||
"""Wrap transcribe_file to unpack arguments from a single tuple"""
|
||||
idx, src, dst = job
|
||||
transcribe_file(src, dst)
|
||||
return idx, src, dst
|
||||
@ -111,36 +115,81 @@ def step_function(job):
|
||||
def transcribe_many(src_paths, dst_paths):
|
||||
from coqui_stt_training.util.config import Config, log_progress
|
||||
|
||||
pool = Pool(processes=min(cpu_count(), len(src_paths)))
|
||||
|
||||
# Create list of items to be processed: [(i, src_path[i], dst_paths[i])]
|
||||
jobs = zip(itertools.count(), src_paths, dst_paths)
|
||||
|
||||
process_iterable = tqdm(
|
||||
pool.imap_unordered(step_function, jobs),
|
||||
desc="Transcribing files",
|
||||
total=len(src_paths),
|
||||
disable=not Config.show_progressbar,
|
||||
)
|
||||
|
||||
for result in process_iterable:
|
||||
idx, src, dst = result
|
||||
log_progress(
|
||||
f'Transcribed file {idx+1} of {len(src_paths)} from "{src}" to "{dst}"'
|
||||
lock = Lock()
|
||||
with Pool(
|
||||
processes=min(cpu_count(), len(src_paths)),
|
||||
initializer=init_fn,
|
||||
initargs=(lock,),
|
||||
) as pool:
|
||||
process_iterable = tqdm(
|
||||
pool.imap_unordered(step_function, jobs),
|
||||
desc="Transcribing files",
|
||||
total=len(src_paths),
|
||||
disable=not Config.show_progressbar,
|
||||
)
|
||||
|
||||
|
||||
def transcribe_one(src_path, dst_path):
|
||||
transcribe_file(src_path, dst_path)
|
||||
print(f'I Transcribed file "{src_path}" to "{dst_path}"')
|
||||
cwd = Path.cwd()
|
||||
for result in process_iterable:
|
||||
idx, src, dst = result
|
||||
# Revert to relative to avoid spamming logs
|
||||
if not src.is_absolute():
|
||||
src = src.relative_to(cwd)
|
||||
if not dst.is_absolute():
|
||||
dst = dst.relative_to(cwd)
|
||||
tqdm.write(f'[{idx+1}]: "{src}" -> "{dst}"')
|
||||
|
||||
|
||||
def resolve(base_path, spec_path):
|
||||
if spec_path is None:
|
||||
return None
|
||||
if not os.path.isabs(spec_path):
|
||||
spec_path = os.path.join(base_path, spec_path)
|
||||
return spec_path
|
||||
def get_tasks_from_catalog(catalog_file_path: Path) -> Tuple[List[Path], List[Path]]:
|
||||
"""Given a `catalog_file_path` pointing to a .catalog file (from DSAlign),
|
||||
extract transcription tasks, ie. (src_path, dest_path) pairs corresponding to
|
||||
a path to an audio file to be transcribed, and a path to a JSON file to place
|
||||
transcription results. For .catalog file inputs, these are taken from the
|
||||
"audio" and "tlog" properties of the entries in the catalog, with any relative
|
||||
paths being absolutized relative to the directory containing the .catalog file.
|
||||
"""
|
||||
assert catalog_file_path.suffix == ".catalog"
|
||||
|
||||
catalog_dir = catalog_file_path.parent
|
||||
with open(catalog_file_path, "r") as catalog_file:
|
||||
catalog_entries = json.load(catalog_file)
|
||||
|
||||
def resolve(spec_path: Optional[Path]):
|
||||
if spec_path is None:
|
||||
return None
|
||||
if not spec_path.is_absolute():
|
||||
spec_path = catalog_dir / spec_path
|
||||
return spec_path
|
||||
|
||||
catalog_entries = [
|
||||
(resolve(Path(e["audio"])), resolve(Path(e["tlog"]))) for e in catalog_entries
|
||||
]
|
||||
|
||||
for src, dst in catalog_entries:
|
||||
if not Config.force and dst.is_file():
|
||||
raise RuntimeError(
|
||||
f"Destination file already exists: {dst}. Use --force for overwriting."
|
||||
)
|
||||
|
||||
if not dst.parent.is_dir():
|
||||
dst.parent.mkdir(parents=True)
|
||||
|
||||
src_paths, dst_paths = zip(*catalog_entries)
|
||||
return src_paths, dst_paths
|
||||
|
||||
|
||||
def get_tasks_from_dir(src_dir: Path, recursive: bool) -> Tuple[List[Path], List[Path]]:
|
||||
"""Given a directory `src_dir` containing audio files, scan it for audio files
|
||||
and return transcription tasks, ie. (src_path, dest_path) pairs corresponding to
|
||||
a path to an audio file to be transcribed, and a path to a JSON file to place
|
||||
transcription results.
|
||||
"""
|
||||
glob_method = src_dir.rglob if recursive else src_dir.glob
|
||||
src_paths = list(itertools.chain(glob_method("*.wav"), glob_method("*.opus")))
|
||||
dst_paths = [path.with_suffix(".tlog") for path in src_paths]
|
||||
return src_paths, dst_paths
|
||||
|
||||
|
||||
def transcribe():
|
||||
@ -148,71 +197,43 @@ def transcribe():
|
||||
|
||||
initialize_transcribe_config()
|
||||
|
||||
if not Config.src or not os.path.exists(Config.src):
|
||||
src_path = Path(Config.src).resolve()
|
||||
if not Config.src or not src_path.exists():
|
||||
# path not given or non-existant
|
||||
fail(
|
||||
"You have to specify which file or catalog to transcribe via the --src flag."
|
||||
raise RuntimeError(
|
||||
"You have to specify which audio file, catalog file or directory to "
|
||||
"transcribe with the --src flag."
|
||||
)
|
||||
else:
|
||||
# path given and exists
|
||||
src_path = os.path.abspath(Config.src)
|
||||
if os.path.isfile(src_path):
|
||||
if src_path.endswith(".catalog"):
|
||||
# Transcribe batch of files via ".catalog" file (from DSAlign)
|
||||
catalog_dir = os.path.dirname(src_path)
|
||||
with open(src_path, "r") as catalog_file:
|
||||
catalog_entries = json.load(catalog_file)
|
||||
catalog_entries = [
|
||||
(resolve(catalog_dir, e["audio"]), resolve(catalog_dir, e["tlog"]))
|
||||
for e in catalog_entries
|
||||
]
|
||||
if any(map(lambda e: not os.path.isfile(e[0]), catalog_entries)):
|
||||
fail("Missing source file(s) in catalog")
|
||||
if not Config.force and any(
|
||||
map(lambda e: os.path.isfile(e[1]), catalog_entries)
|
||||
):
|
||||
fail(
|
||||
"Destination file(s) from catalog already existing, use --force for overwriting"
|
||||
)
|
||||
if any(
|
||||
map(
|
||||
lambda e: not os.path.isdir(os.path.dirname(e[1])),
|
||||
catalog_entries,
|
||||
)
|
||||
):
|
||||
fail("Missing destination directory for at least one catalog entry")
|
||||
src_paths, dst_paths = zip(*paths)
|
||||
transcribe_many(src_paths, dst_paths)
|
||||
else:
|
||||
if src_path.is_file():
|
||||
if src_path.suffix != ".catalog":
|
||||
# Transcribe one file
|
||||
dst_path = (
|
||||
os.path.abspath(Config.dst)
|
||||
Path(Config.dst).resolve()
|
||||
if Config.dst
|
||||
else os.path.splitext(src_path)[0] + ".tlog"
|
||||
else src_path.with_suffix(".tlog")
|
||||
)
|
||||
if os.path.isfile(dst_path):
|
||||
if Config.force:
|
||||
transcribe_one(src_path, dst_path)
|
||||
else:
|
||||
fail(
|
||||
'Destination file "{}" already existing - use --force for overwriting'.format(
|
||||
dst_path
|
||||
),
|
||||
code=0,
|
||||
)
|
||||
elif os.path.isdir(os.path.dirname(dst_path)):
|
||||
transcribe_one(src_path, dst_path)
|
||||
else:
|
||||
fail("Missing destination directory")
|
||||
elif os.path.isdir(src_path):
|
||||
# Transcribe all files in dir
|
||||
print("Transcribing all WAV files in --src")
|
||||
if Config.recursive:
|
||||
wav_paths = glob.glob(os.path.join(src_path, "**", "*.wav"))
|
||||
|
||||
if dst_path.is_file() and not Config.force:
|
||||
raise RuntimeError(
|
||||
f'Destination file "{dst_path}" already exists - use '
|
||||
"--force for overwriting."
|
||||
)
|
||||
|
||||
if not dst_path.parent.is_dir():
|
||||
raise RuntimeError("Missing destination directory")
|
||||
|
||||
transcribe_many([src_path], [dst_path])
|
||||
else:
|
||||
wav_paths = glob.glob(os.path.join(src_path, "*.wav"))
|
||||
dst_paths = [path.replace(".wav", ".tlog") for path in wav_paths]
|
||||
transcribe_many(wav_paths, dst_paths)
|
||||
# Transcribe from .catalog input
|
||||
src_paths, dst_paths = get_tasks_from_catalog(src_path)
|
||||
transcribe_many(src_paths, dst_paths)
|
||||
elif src_path.is_dir():
|
||||
# Transcribe from dir input
|
||||
print(f"Transcribing all files in --src directory {src_path}")
|
||||
src_paths, dst_paths = get_tasks_from_dir(src_path, Config.recursive)
|
||||
transcribe_many(src_paths, dst_paths)
|
||||
|
||||
|
||||
def initialize_transcribe_config():
|
||||
@ -230,7 +251,7 @@ def initialize_transcribe_config():
|
||||
"Catalog files should be formatted from DSAlign. A directory "
|
||||
"will be recursively searched for audio. If --dst not set, "
|
||||
"transcription logs (.tlog) will be written in-place using the "
|
||||
'source filenames with suffix ".tlog" instead of ".wav".'
|
||||
'source filenames with suffix ".tlog" instead of the original.'
|
||||
),
|
||||
)
|
||||
|
||||
@ -299,6 +320,9 @@ def initialize_transcribe_config():
|
||||
def main():
|
||||
from coqui_stt_training.util.helpers import check_ctcdecoder_version
|
||||
|
||||
# Set start method to spawn on all platforms to avoid issues with TensorFlow
|
||||
multiprocessing.set_start_method("spawn")
|
||||
|
||||
try:
|
||||
import webrtcvad
|
||||
except ImportError:
|
||||
|
@ -7,7 +7,13 @@ import tensorflow as tf
|
||||
from .config import Config, log_error, log_info, log_warn
|
||||
|
||||
|
||||
def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=True):
|
||||
def _load_checkpoint(
|
||||
session,
|
||||
checkpoint_path,
|
||||
allow_drop_layers,
|
||||
allow_lr_init=True,
|
||||
silent: bool = False,
|
||||
):
|
||||
# Load the checkpoint and put all variables into loading list
|
||||
# we will exclude variables we do not wish to load and then
|
||||
# we will initialize them instead
|
||||
@ -75,15 +81,16 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
|
||||
init_vars.add(v)
|
||||
load_vars -= init_vars
|
||||
|
||||
log_info(f"Vars to load: {list(sorted(v.op.name for v in load_vars))}")
|
||||
def maybe_log_info(*args, **kwargs):
|
||||
if not silent:
|
||||
log_info(*args, **kwargs)
|
||||
|
||||
for v in sorted(load_vars, key=lambda v: v.op.name):
|
||||
log_info(f"Getting tensor from variable: {v.op.name}")
|
||||
tensor = ckpt.get_tensor(v.op.name)
|
||||
log_info(f"Loading tensor from checkpoint: {v.op.name}")
|
||||
v.load(tensor, session=session)
|
||||
maybe_log_info(f"Loading variable from checkpoint: {v.op.name}")
|
||||
v.load(ckpt.get_tensor(v.op.name), session=session)
|
||||
|
||||
for v in sorted(init_vars, key=lambda v: v.op.name):
|
||||
log_info("Initializing variable: %s" % (v.op.name))
|
||||
maybe_log_info("Initializing variable: %s" % (v.op.name))
|
||||
session.run(v.initializer)
|
||||
|
||||
|
||||
@ -102,31 +109,49 @@ def _initialize_all_variables(session):
|
||||
session.run(v.initializer)
|
||||
|
||||
|
||||
def _load_or_init_impl(session, method_order, allow_drop_layers, allow_lr_init=True):
|
||||
def _load_or_init_impl(
|
||||
session, method_order, allow_drop_layers, allow_lr_init=True, silent: bool = False
|
||||
):
|
||||
def maybe_log_info(*args, **kwargs):
|
||||
if not silent:
|
||||
log_info(*args, **kwargs)
|
||||
|
||||
for method in method_order:
|
||||
# Load best validating checkpoint, saved in checkpoint file 'best_dev_checkpoint'
|
||||
if method == "best":
|
||||
ckpt_path = _checkpoint_path_or_none("best_dev_checkpoint")
|
||||
if ckpt_path:
|
||||
log_info("Loading best validating checkpoint from {}".format(ckpt_path))
|
||||
return _load_checkpoint(
|
||||
session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init
|
||||
maybe_log_info(
|
||||
"Loading best validating checkpoint from {}".format(ckpt_path)
|
||||
)
|
||||
log_info("Could not find best validating checkpoint.")
|
||||
return _load_checkpoint(
|
||||
session,
|
||||
ckpt_path,
|
||||
allow_drop_layers,
|
||||
allow_lr_init=allow_lr_init,
|
||||
silent=silent,
|
||||
)
|
||||
maybe_log_info("Could not find best validating checkpoint.")
|
||||
|
||||
# Load most recent checkpoint, saved in checkpoint file 'checkpoint'
|
||||
elif method == "last":
|
||||
ckpt_path = _checkpoint_path_or_none("checkpoint")
|
||||
if ckpt_path:
|
||||
log_info("Loading most recent checkpoint from {}".format(ckpt_path))
|
||||
return _load_checkpoint(
|
||||
session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init
|
||||
maybe_log_info(
|
||||
"Loading most recent checkpoint from {}".format(ckpt_path)
|
||||
)
|
||||
log_info("Could not find most recent checkpoint.")
|
||||
return _load_checkpoint(
|
||||
session,
|
||||
ckpt_path,
|
||||
allow_drop_layers,
|
||||
allow_lr_init=allow_lr_init,
|
||||
silent=silent,
|
||||
)
|
||||
maybe_log_info("Could not find most recent checkpoint.")
|
||||
|
||||
# Initialize all variables
|
||||
elif method == "init":
|
||||
log_info("Initializing all variables.")
|
||||
maybe_log_info("Initializing all variables.")
|
||||
return _initialize_all_variables(session)
|
||||
|
||||
else:
|
||||
@ -141,7 +166,7 @@ def reload_best_checkpoint(session):
|
||||
_load_or_init_impl(session, ["best"], allow_drop_layers=False, allow_lr_init=False)
|
||||
|
||||
|
||||
def load_or_init_graph_for_training(session):
|
||||
def load_or_init_graph_for_training(session, silent: bool = False):
|
||||
"""
|
||||
Load variables from checkpoint or initialize variables. By default this will
|
||||
try to load the best validating checkpoint, then try the last checkpoint,
|
||||
@ -152,10 +177,10 @@ def load_or_init_graph_for_training(session):
|
||||
methods = ["best", "last", "init"]
|
||||
else:
|
||||
methods = [Config.load_train]
|
||||
_load_or_init_impl(session, methods, allow_drop_layers=True)
|
||||
_load_or_init_impl(session, methods, allow_drop_layers=True, silent=silent)
|
||||
|
||||
|
||||
def load_graph_for_evaluation(session):
|
||||
def load_graph_for_evaluation(session, silent: bool = False):
|
||||
"""
|
||||
Load variables from checkpoint. Initialization is not allowed. By default
|
||||
this will try to load the best validating checkpoint, then try the last
|
||||
@ -166,4 +191,4 @@ def load_graph_for_evaluation(session):
|
||||
methods = ["best", "last"]
|
||||
else:
|
||||
methods = [Config.load_evaluate]
|
||||
_load_or_init_impl(session, methods, allow_drop_layers=False)
|
||||
_load_or_init_impl(session, methods, allow_drop_layers=False, silent=silent)
|
||||
|
@ -217,12 +217,14 @@ class BaseSttConfig(Coqpit):
|
||||
if not is_remote_path(self.save_checkpoint_dir):
|
||||
os.makedirs(self.save_checkpoint_dir, exist_ok=True)
|
||||
flags_file = os.path.join(self.save_checkpoint_dir, "flags.txt")
|
||||
with open_remote(flags_file, "w") as fout:
|
||||
json.dump(self.serialize(), fout, indent=2)
|
||||
if not os.path.exists(flags_file):
|
||||
with open_remote(flags_file, "w") as fout:
|
||||
json.dump(self.serialize(), fout, indent=2)
|
||||
|
||||
# Serialize alphabet alongside checkpoint
|
||||
with open_remote(saved_checkpoint_alphabet_file, "wb") as fout:
|
||||
fout.write(self.alphabet.SerializeText())
|
||||
if not os.path.exists(saved_checkpoint_alphabet_file):
|
||||
with open_remote(saved_checkpoint_alphabet_file, "wb") as fout:
|
||||
fout.write(self.alphabet.SerializeText())
|
||||
|
||||
# Geometric Constants
|
||||
# ===================
|
||||
|
Loading…
Reference in New Issue
Block a user