[transcribe] Fix multiprocessing hangs, clean-up target collection

This commit is contained in:
Reuben Morais 2021-12-01 14:24:42 +01:00
parent 5cefd7069c
commit d90bb60506
4 changed files with 176 additions and 126 deletions

View File

@ -80,12 +80,11 @@ time python -m coqui_stt_training.transcribe \
--n_hidden 100 \ --n_hidden 100 \
--scorer_path "data/smoke_test/pruned_lm.scorer" --scorer_path "data/smoke_test/pruned_lm.scorer"
#TODO: investigate why this is hanging in CI mkdir /tmp/transcribe_dir
#mkdir /tmp/transcribe_dir cp data/smoke_test/LDC93S1.wav /tmp/transcribe_dir
#cp data/smoke_test/LDC93S1.wav /tmp/transcribe_dir time python -m coqui_stt_training.transcribe \
#time python -m coqui_stt_training.transcribe \ --src "/tmp/transcribe_dir/" \
# --src "/tmp/transcribe_dir/" \ --n_hidden 100 \
# --n_hidden 100 \ --scorer_path "data/smoke_test/pruned_lm.scorer"
# --scorer_path "data/smoke_test/pruned_lm.scorer"
# for i in /tmp/transcribe_dir/*.tlog; do echo $i; cat $i; echo; done
#for i in data/smoke_test/*.tlog; do echo $i; cat $i; echo; done

View File

@ -7,25 +7,22 @@
# restructure the code so that TensorFlow is only imported inside the child # restructure the code so that TensorFlow is only imported inside the child
# processes. # processes.
import os
import sys
import glob import glob
import itertools import itertools
import json import json
import multiprocessing import multiprocessing
from multiprocessing import Pool, cpu_count import os
import sys
from dataclasses import dataclass, field from dataclasses import dataclass, field
from multiprocessing import Pool, Lock, cpu_count
from pathlib import Path
from typing import Optional, List, Tuple
from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch
from tqdm import tqdm from tqdm import tqdm
def fail(message, code=1): def transcribe_file(audio_path: Path, tlog_path: Path):
print(f"E {message}")
sys.exit(code)
def transcribe_file(audio_path, tlog_path):
log_level_index = ( log_level_index = (
sys.argv.index("--log_level") + 1 if "--log_level" in sys.argv else 0 sys.argv.index("--log_level") + 1 if "--log_level" in sys.argv else 0
) )
@ -56,7 +53,7 @@ def transcribe_file(audio_path, tlog_path):
except NotImplementedError: except NotImplementedError:
num_processes = 1 num_processes = 1
with AudioFile(audio_path, as_path=True) as wav_path: with AudioFile(str(audio_path), as_path=True) as wav_path:
data_set = split_audio_file( data_set = split_audio_file(
wav_path, wav_path,
batch_size=Config.batch_size, batch_size=Config.batch_size,
@ -73,7 +70,9 @@ def transcribe_file(audio_path, tlog_path):
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2])) transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
tf.train.get_or_create_global_step() tf.train.get_or_create_global_step()
with tf.Session(config=Config.session_config) as session: with tf.Session(config=Config.session_config) as session:
load_graph_for_evaluation(session) # Load checkpoint in a mutex way to avoid hangs in TensorFlow code
with lock:
load_graph_for_evaluation(session, silent=True)
transcripts = [] transcripts = []
while True: while True:
try: try:
@ -101,8 +100,13 @@ def transcribe_file(audio_path, tlog_path):
json.dump(transcripts, tlog_file, default=float) json.dump(transcripts, tlog_file, default=float)
def init_fn(l):
global lock
lock = l
def step_function(job): def step_function(job):
""" Wrap transcribe_file to unpack arguments from a single tuple """ """Wrap transcribe_file to unpack arguments from a single tuple"""
idx, src, dst = job idx, src, dst = job
transcribe_file(src, dst) transcribe_file(src, dst)
return idx, src, dst return idx, src, dst
@ -111,11 +115,15 @@ def step_function(job):
def transcribe_many(src_paths, dst_paths): def transcribe_many(src_paths, dst_paths):
from coqui_stt_training.util.config import Config, log_progress from coqui_stt_training.util.config import Config, log_progress
pool = Pool(processes=min(cpu_count(), len(src_paths)))
# Create list of items to be processed: [(i, src_path[i], dst_paths[i])] # Create list of items to be processed: [(i, src_path[i], dst_paths[i])]
jobs = zip(itertools.count(), src_paths, dst_paths) jobs = zip(itertools.count(), src_paths, dst_paths)
lock = Lock()
with Pool(
processes=min(cpu_count(), len(src_paths)),
initializer=init_fn,
initargs=(lock,),
) as pool:
process_iterable = tqdm( process_iterable = tqdm(
pool.imap_unordered(step_function, jobs), pool.imap_unordered(step_function, jobs),
desc="Transcribing files", desc="Transcribing files",
@ -123,96 +131,109 @@ def transcribe_many(src_paths, dst_paths):
disable=not Config.show_progressbar, disable=not Config.show_progressbar,
) )
cwd = Path.cwd()
for result in process_iterable: for result in process_iterable:
idx, src, dst = result idx, src, dst = result
log_progress( # Revert to relative to avoid spamming logs
f'Transcribed file {idx+1} of {len(src_paths)} from "{src}" to "{dst}"' if not src.is_absolute():
) src = src.relative_to(cwd)
if not dst.is_absolute():
dst = dst.relative_to(cwd)
tqdm.write(f'[{idx+1}]: "{src}" -> "{dst}"')
def transcribe_one(src_path, dst_path): def get_tasks_from_catalog(catalog_file_path: Path) -> Tuple[List[Path], List[Path]]:
transcribe_file(src_path, dst_path) """Given a `catalog_file_path` pointing to a .catalog file (from DSAlign),
print(f'I Transcribed file "{src_path}" to "{dst_path}"') extract transcription tasks, ie. (src_path, dest_path) pairs corresponding to
a path to an audio file to be transcribed, and a path to a JSON file to place
transcription results. For .catalog file inputs, these are taken from the
"audio" and "tlog" properties of the entries in the catalog, with any relative
paths being absolutized relative to the directory containing the .catalog file.
"""
assert catalog_file_path.suffix == ".catalog"
catalog_dir = catalog_file_path.parent
with open(catalog_file_path, "r") as catalog_file:
catalog_entries = json.load(catalog_file)
def resolve(base_path, spec_path): def resolve(spec_path: Optional[Path]):
if spec_path is None: if spec_path is None:
return None return None
if not os.path.isabs(spec_path): if not spec_path.is_absolute():
spec_path = os.path.join(base_path, spec_path) spec_path = catalog_dir / spec_path
return spec_path return spec_path
catalog_entries = [
(resolve(Path(e["audio"])), resolve(Path(e["tlog"]))) for e in catalog_entries
]
for src, dst in catalog_entries:
if not Config.force and dst.is_file():
raise RuntimeError(
f"Destination file already exists: {dst}. Use --force for overwriting."
)
if not dst.parent.is_dir():
dst.parent.mkdir(parents=True)
src_paths, dst_paths = zip(*catalog_entries)
return src_paths, dst_paths
def get_tasks_from_dir(src_dir: Path, recursive: bool) -> Tuple[List[Path], List[Path]]:
"""Given a directory `src_dir` containing audio files, scan it for audio files
and return transcription tasks, ie. (src_path, dest_path) pairs corresponding to
a path to an audio file to be transcribed, and a path to a JSON file to place
transcription results.
"""
glob_method = src_dir.rglob if recursive else src_dir.glob
src_paths = list(itertools.chain(glob_method("*.wav"), glob_method("*.opus")))
dst_paths = [path.with_suffix(".tlog") for path in src_paths]
return src_paths, dst_paths
def transcribe(): def transcribe():
from coqui_stt_training.util.config import Config from coqui_stt_training.util.config import Config
initialize_transcribe_config() initialize_transcribe_config()
if not Config.src or not os.path.exists(Config.src): src_path = Path(Config.src).resolve()
if not Config.src or not src_path.exists():
# path not given or non-existant # path not given or non-existant
fail( raise RuntimeError(
"You have to specify which file or catalog to transcribe via the --src flag." "You have to specify which audio file, catalog file or directory to "
"transcribe with the --src flag."
) )
else: else:
# path given and exists # path given and exists
src_path = os.path.abspath(Config.src) if src_path.is_file():
if os.path.isfile(src_path): if src_path.suffix != ".catalog":
if src_path.endswith(".catalog"):
# Transcribe batch of files via ".catalog" file (from DSAlign)
catalog_dir = os.path.dirname(src_path)
with open(src_path, "r") as catalog_file:
catalog_entries = json.load(catalog_file)
catalog_entries = [
(resolve(catalog_dir, e["audio"]), resolve(catalog_dir, e["tlog"]))
for e in catalog_entries
]
if any(map(lambda e: not os.path.isfile(e[0]), catalog_entries)):
fail("Missing source file(s) in catalog")
if not Config.force and any(
map(lambda e: os.path.isfile(e[1]), catalog_entries)
):
fail(
"Destination file(s) from catalog already existing, use --force for overwriting"
)
if any(
map(
lambda e: not os.path.isdir(os.path.dirname(e[1])),
catalog_entries,
)
):
fail("Missing destination directory for at least one catalog entry")
src_paths, dst_paths = zip(*paths)
transcribe_many(src_paths, dst_paths)
else:
# Transcribe one file # Transcribe one file
dst_path = ( dst_path = (
os.path.abspath(Config.dst) Path(Config.dst).resolve()
if Config.dst if Config.dst
else os.path.splitext(src_path)[0] + ".tlog" else src_path.with_suffix(".tlog")
) )
if os.path.isfile(dst_path):
if Config.force: if dst_path.is_file() and not Config.force:
transcribe_one(src_path, dst_path) raise RuntimeError(
else: f'Destination file "{dst_path}" already exists - use '
fail( "--force for overwriting."
'Destination file "{}" already existing - use --force for overwriting'.format(
dst_path
),
code=0,
) )
elif os.path.isdir(os.path.dirname(dst_path)):
transcribe_one(src_path, dst_path) if not dst_path.parent.is_dir():
raise RuntimeError("Missing destination directory")
transcribe_many([src_path], [dst_path])
else: else:
fail("Missing destination directory") # Transcribe from .catalog input
elif os.path.isdir(src_path): src_paths, dst_paths = get_tasks_from_catalog(src_path)
# Transcribe all files in dir transcribe_many(src_paths, dst_paths)
print("Transcribing all WAV files in --src") elif src_path.is_dir():
if Config.recursive: # Transcribe from dir input
wav_paths = glob.glob(os.path.join(src_path, "**", "*.wav")) print(f"Transcribing all files in --src directory {src_path}")
else: src_paths, dst_paths = get_tasks_from_dir(src_path, Config.recursive)
wav_paths = glob.glob(os.path.join(src_path, "*.wav")) transcribe_many(src_paths, dst_paths)
dst_paths = [path.replace(".wav", ".tlog") for path in wav_paths]
transcribe_many(wav_paths, dst_paths)
def initialize_transcribe_config(): def initialize_transcribe_config():
@ -230,7 +251,7 @@ def initialize_transcribe_config():
"Catalog files should be formatted from DSAlign. A directory " "Catalog files should be formatted from DSAlign. A directory "
"will be recursively searched for audio. If --dst not set, " "will be recursively searched for audio. If --dst not set, "
"transcription logs (.tlog) will be written in-place using the " "transcription logs (.tlog) will be written in-place using the "
'source filenames with suffix ".tlog" instead of ".wav".' 'source filenames with suffix ".tlog" instead of the original.'
), ),
) )
@ -299,6 +320,9 @@ def initialize_transcribe_config():
def main(): def main():
from coqui_stt_training.util.helpers import check_ctcdecoder_version from coqui_stt_training.util.helpers import check_ctcdecoder_version
# Set start method to spawn on all platforms to avoid issues with TensorFlow
multiprocessing.set_start_method("spawn")
try: try:
import webrtcvad import webrtcvad
except ImportError: except ImportError:

View File

@ -7,7 +7,13 @@ import tensorflow as tf
from .config import Config, log_error, log_info, log_warn from .config import Config, log_error, log_info, log_warn
def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=True): def _load_checkpoint(
session,
checkpoint_path,
allow_drop_layers,
allow_lr_init=True,
silent: bool = False,
):
# Load the checkpoint and put all variables into loading list # Load the checkpoint and put all variables into loading list
# we will exclude variables we do not wish to load and then # we will exclude variables we do not wish to load and then
# we will initialize them instead # we will initialize them instead
@ -75,15 +81,16 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
init_vars.add(v) init_vars.add(v)
load_vars -= init_vars load_vars -= init_vars
log_info(f"Vars to load: {list(sorted(v.op.name for v in load_vars))}") def maybe_log_info(*args, **kwargs):
if not silent:
log_info(*args, **kwargs)
for v in sorted(load_vars, key=lambda v: v.op.name): for v in sorted(load_vars, key=lambda v: v.op.name):
log_info(f"Getting tensor from variable: {v.op.name}") maybe_log_info(f"Loading variable from checkpoint: {v.op.name}")
tensor = ckpt.get_tensor(v.op.name) v.load(ckpt.get_tensor(v.op.name), session=session)
log_info(f"Loading tensor from checkpoint: {v.op.name}")
v.load(tensor, session=session)
for v in sorted(init_vars, key=lambda v: v.op.name): for v in sorted(init_vars, key=lambda v: v.op.name):
log_info("Initializing variable: %s" % (v.op.name)) maybe_log_info("Initializing variable: %s" % (v.op.name))
session.run(v.initializer) session.run(v.initializer)
@ -102,31 +109,49 @@ def _initialize_all_variables(session):
session.run(v.initializer) session.run(v.initializer)
def _load_or_init_impl(session, method_order, allow_drop_layers, allow_lr_init=True): def _load_or_init_impl(
session, method_order, allow_drop_layers, allow_lr_init=True, silent: bool = False
):
def maybe_log_info(*args, **kwargs):
if not silent:
log_info(*args, **kwargs)
for method in method_order: for method in method_order:
# Load best validating checkpoint, saved in checkpoint file 'best_dev_checkpoint' # Load best validating checkpoint, saved in checkpoint file 'best_dev_checkpoint'
if method == "best": if method == "best":
ckpt_path = _checkpoint_path_or_none("best_dev_checkpoint") ckpt_path = _checkpoint_path_or_none("best_dev_checkpoint")
if ckpt_path: if ckpt_path:
log_info("Loading best validating checkpoint from {}".format(ckpt_path)) maybe_log_info(
return _load_checkpoint( "Loading best validating checkpoint from {}".format(ckpt_path)
session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init
) )
log_info("Could not find best validating checkpoint.") return _load_checkpoint(
session,
ckpt_path,
allow_drop_layers,
allow_lr_init=allow_lr_init,
silent=silent,
)
maybe_log_info("Could not find best validating checkpoint.")
# Load most recent checkpoint, saved in checkpoint file 'checkpoint' # Load most recent checkpoint, saved in checkpoint file 'checkpoint'
elif method == "last": elif method == "last":
ckpt_path = _checkpoint_path_or_none("checkpoint") ckpt_path = _checkpoint_path_or_none("checkpoint")
if ckpt_path: if ckpt_path:
log_info("Loading most recent checkpoint from {}".format(ckpt_path)) maybe_log_info(
return _load_checkpoint( "Loading most recent checkpoint from {}".format(ckpt_path)
session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init
) )
log_info("Could not find most recent checkpoint.") return _load_checkpoint(
session,
ckpt_path,
allow_drop_layers,
allow_lr_init=allow_lr_init,
silent=silent,
)
maybe_log_info("Could not find most recent checkpoint.")
# Initialize all variables # Initialize all variables
elif method == "init": elif method == "init":
log_info("Initializing all variables.") maybe_log_info("Initializing all variables.")
return _initialize_all_variables(session) return _initialize_all_variables(session)
else: else:
@ -141,7 +166,7 @@ def reload_best_checkpoint(session):
_load_or_init_impl(session, ["best"], allow_drop_layers=False, allow_lr_init=False) _load_or_init_impl(session, ["best"], allow_drop_layers=False, allow_lr_init=False)
def load_or_init_graph_for_training(session): def load_or_init_graph_for_training(session, silent: bool = False):
""" """
Load variables from checkpoint or initialize variables. By default this will Load variables from checkpoint or initialize variables. By default this will
try to load the best validating checkpoint, then try the last checkpoint, try to load the best validating checkpoint, then try the last checkpoint,
@ -152,10 +177,10 @@ def load_or_init_graph_for_training(session):
methods = ["best", "last", "init"] methods = ["best", "last", "init"]
else: else:
methods = [Config.load_train] methods = [Config.load_train]
_load_or_init_impl(session, methods, allow_drop_layers=True) _load_or_init_impl(session, methods, allow_drop_layers=True, silent=silent)
def load_graph_for_evaluation(session): def load_graph_for_evaluation(session, silent: bool = False):
""" """
Load variables from checkpoint. Initialization is not allowed. By default Load variables from checkpoint. Initialization is not allowed. By default
this will try to load the best validating checkpoint, then try the last this will try to load the best validating checkpoint, then try the last
@ -166,4 +191,4 @@ def load_graph_for_evaluation(session):
methods = ["best", "last"] methods = ["best", "last"]
else: else:
methods = [Config.load_evaluate] methods = [Config.load_evaluate]
_load_or_init_impl(session, methods, allow_drop_layers=False) _load_or_init_impl(session, methods, allow_drop_layers=False, silent=silent)

View File

@ -217,10 +217,12 @@ class BaseSttConfig(Coqpit):
if not is_remote_path(self.save_checkpoint_dir): if not is_remote_path(self.save_checkpoint_dir):
os.makedirs(self.save_checkpoint_dir, exist_ok=True) os.makedirs(self.save_checkpoint_dir, exist_ok=True)
flags_file = os.path.join(self.save_checkpoint_dir, "flags.txt") flags_file = os.path.join(self.save_checkpoint_dir, "flags.txt")
if not os.path.exists(flags_file):
with open_remote(flags_file, "w") as fout: with open_remote(flags_file, "w") as fout:
json.dump(self.serialize(), fout, indent=2) json.dump(self.serialize(), fout, indent=2)
# Serialize alphabet alongside checkpoint # Serialize alphabet alongside checkpoint
if not os.path.exists(saved_checkpoint_alphabet_file):
with open_remote(saved_checkpoint_alphabet_file, "wb") as fout: with open_remote(saved_checkpoint_alphabet_file, "wb") as fout:
fout.write(self.alphabet.SerializeText()) fout.write(self.alphabet.SerializeText())