[transcribe] Fix multiprocessing hangs, clean-up target collection

This commit is contained in:
Reuben Morais 2021-12-01 14:24:42 +01:00
parent 5cefd7069c
commit d90bb60506
4 changed files with 176 additions and 126 deletions

View File

@ -80,12 +80,11 @@ time python -m coqui_stt_training.transcribe \
--n_hidden 100 \ --n_hidden 100 \
--scorer_path "data/smoke_test/pruned_lm.scorer" --scorer_path "data/smoke_test/pruned_lm.scorer"
#TODO: investigate why this is hanging in CI mkdir /tmp/transcribe_dir
#mkdir /tmp/transcribe_dir cp data/smoke_test/LDC93S1.wav /tmp/transcribe_dir
#cp data/smoke_test/LDC93S1.wav /tmp/transcribe_dir time python -m coqui_stt_training.transcribe \
#time python -m coqui_stt_training.transcribe \ --src "/tmp/transcribe_dir/" \
# --src "/tmp/transcribe_dir/" \ --n_hidden 100 \
# --n_hidden 100 \ --scorer_path "data/smoke_test/pruned_lm.scorer"
# --scorer_path "data/smoke_test/pruned_lm.scorer"
# for i in /tmp/transcribe_dir/*.tlog; do echo $i; cat $i; echo; done
#for i in data/smoke_test/*.tlog; do echo $i; cat $i; echo; done

View File

@ -7,25 +7,22 @@
# restructure the code so that TensorFlow is only imported inside the child # restructure the code so that TensorFlow is only imported inside the child
# processes. # processes.
import os
import sys
import glob import glob
import itertools import itertools
import json import json
import multiprocessing import multiprocessing
from multiprocessing import Pool, cpu_count import os
import sys
from dataclasses import dataclass, field from dataclasses import dataclass, field
from multiprocessing import Pool, Lock, cpu_count
from pathlib import Path
from typing import Optional, List, Tuple
from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch
from tqdm import tqdm from tqdm import tqdm
def fail(message, code=1): def transcribe_file(audio_path: Path, tlog_path: Path):
print(f"E {message}")
sys.exit(code)
def transcribe_file(audio_path, tlog_path):
log_level_index = ( log_level_index = (
sys.argv.index("--log_level") + 1 if "--log_level" in sys.argv else 0 sys.argv.index("--log_level") + 1 if "--log_level" in sys.argv else 0
) )
@ -56,7 +53,7 @@ def transcribe_file(audio_path, tlog_path):
except NotImplementedError: except NotImplementedError:
num_processes = 1 num_processes = 1
with AudioFile(audio_path, as_path=True) as wav_path: with AudioFile(str(audio_path), as_path=True) as wav_path:
data_set = split_audio_file( data_set = split_audio_file(
wav_path, wav_path,
batch_size=Config.batch_size, batch_size=Config.batch_size,
@ -73,7 +70,9 @@ def transcribe_file(audio_path, tlog_path):
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2])) transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
tf.train.get_or_create_global_step() tf.train.get_or_create_global_step()
with tf.Session(config=Config.session_config) as session: with tf.Session(config=Config.session_config) as session:
load_graph_for_evaluation(session) # Load checkpoint in a mutex way to avoid hangs in TensorFlow code
with lock:
load_graph_for_evaluation(session, silent=True)
transcripts = [] transcripts = []
while True: while True:
try: try:
@ -101,8 +100,13 @@ def transcribe_file(audio_path, tlog_path):
json.dump(transcripts, tlog_file, default=float) json.dump(transcripts, tlog_file, default=float)
def init_fn(l):
global lock
lock = l
def step_function(job): def step_function(job):
""" Wrap transcribe_file to unpack arguments from a single tuple """ """Wrap transcribe_file to unpack arguments from a single tuple"""
idx, src, dst = job idx, src, dst = job
transcribe_file(src, dst) transcribe_file(src, dst)
return idx, src, dst return idx, src, dst
@ -111,36 +115,81 @@ def step_function(job):
def transcribe_many(src_paths, dst_paths): def transcribe_many(src_paths, dst_paths):
from coqui_stt_training.util.config import Config, log_progress from coqui_stt_training.util.config import Config, log_progress
pool = Pool(processes=min(cpu_count(), len(src_paths)))
# Create list of items to be processed: [(i, src_path[i], dst_paths[i])] # Create list of items to be processed: [(i, src_path[i], dst_paths[i])]
jobs = zip(itertools.count(), src_paths, dst_paths) jobs = zip(itertools.count(), src_paths, dst_paths)
process_iterable = tqdm( lock = Lock()
pool.imap_unordered(step_function, jobs), with Pool(
desc="Transcribing files", processes=min(cpu_count(), len(src_paths)),
total=len(src_paths), initializer=init_fn,
disable=not Config.show_progressbar, initargs=(lock,),
) ) as pool:
process_iterable = tqdm(
for result in process_iterable: pool.imap_unordered(step_function, jobs),
idx, src, dst = result desc="Transcribing files",
log_progress( total=len(src_paths),
f'Transcribed file {idx+1} of {len(src_paths)} from "{src}" to "{dst}"' disable=not Config.show_progressbar,
) )
cwd = Path.cwd()
def transcribe_one(src_path, dst_path): for result in process_iterable:
transcribe_file(src_path, dst_path) idx, src, dst = result
print(f'I Transcribed file "{src_path}" to "{dst_path}"') # Revert to relative to avoid spamming logs
if not src.is_absolute():
src = src.relative_to(cwd)
if not dst.is_absolute():
dst = dst.relative_to(cwd)
tqdm.write(f'[{idx+1}]: "{src}" -> "{dst}"')
def resolve(base_path, spec_path): def get_tasks_from_catalog(catalog_file_path: Path) -> Tuple[List[Path], List[Path]]:
if spec_path is None: """Given a `catalog_file_path` pointing to a .catalog file (from DSAlign),
return None extract transcription tasks, ie. (src_path, dest_path) pairs corresponding to
if not os.path.isabs(spec_path): a path to an audio file to be transcribed, and a path to a JSON file to place
spec_path = os.path.join(base_path, spec_path) transcription results. For .catalog file inputs, these are taken from the
return spec_path "audio" and "tlog" properties of the entries in the catalog, with any relative
paths being absolutized relative to the directory containing the .catalog file.
"""
assert catalog_file_path.suffix == ".catalog"
catalog_dir = catalog_file_path.parent
with open(catalog_file_path, "r") as catalog_file:
catalog_entries = json.load(catalog_file)
def resolve(spec_path: Optional[Path]):
if spec_path is None:
return None
if not spec_path.is_absolute():
spec_path = catalog_dir / spec_path
return spec_path
catalog_entries = [
(resolve(Path(e["audio"])), resolve(Path(e["tlog"]))) for e in catalog_entries
]
for src, dst in catalog_entries:
if not Config.force and dst.is_file():
raise RuntimeError(
f"Destination file already exists: {dst}. Use --force for overwriting."
)
if not dst.parent.is_dir():
dst.parent.mkdir(parents=True)
src_paths, dst_paths = zip(*catalog_entries)
return src_paths, dst_paths
def get_tasks_from_dir(src_dir: Path, recursive: bool) -> Tuple[List[Path], List[Path]]:
"""Given a directory `src_dir` containing audio files, scan it for audio files
and return transcription tasks, ie. (src_path, dest_path) pairs corresponding to
a path to an audio file to be transcribed, and a path to a JSON file to place
transcription results.
"""
glob_method = src_dir.rglob if recursive else src_dir.glob
src_paths = list(itertools.chain(glob_method("*.wav"), glob_method("*.opus")))
dst_paths = [path.with_suffix(".tlog") for path in src_paths]
return src_paths, dst_paths
def transcribe(): def transcribe():
@ -148,71 +197,43 @@ def transcribe():
initialize_transcribe_config() initialize_transcribe_config()
if not Config.src or not os.path.exists(Config.src): src_path = Path(Config.src).resolve()
if not Config.src or not src_path.exists():
# path not given or non-existant # path not given or non-existant
fail( raise RuntimeError(
"You have to specify which file or catalog to transcribe via the --src flag." "You have to specify which audio file, catalog file or directory to "
"transcribe with the --src flag."
) )
else: else:
# path given and exists # path given and exists
src_path = os.path.abspath(Config.src) if src_path.is_file():
if os.path.isfile(src_path): if src_path.suffix != ".catalog":
if src_path.endswith(".catalog"):
# Transcribe batch of files via ".catalog" file (from DSAlign)
catalog_dir = os.path.dirname(src_path)
with open(src_path, "r") as catalog_file:
catalog_entries = json.load(catalog_file)
catalog_entries = [
(resolve(catalog_dir, e["audio"]), resolve(catalog_dir, e["tlog"]))
for e in catalog_entries
]
if any(map(lambda e: not os.path.isfile(e[0]), catalog_entries)):
fail("Missing source file(s) in catalog")
if not Config.force and any(
map(lambda e: os.path.isfile(e[1]), catalog_entries)
):
fail(
"Destination file(s) from catalog already existing, use --force for overwriting"
)
if any(
map(
lambda e: not os.path.isdir(os.path.dirname(e[1])),
catalog_entries,
)
):
fail("Missing destination directory for at least one catalog entry")
src_paths, dst_paths = zip(*paths)
transcribe_many(src_paths, dst_paths)
else:
# Transcribe one file # Transcribe one file
dst_path = ( dst_path = (
os.path.abspath(Config.dst) Path(Config.dst).resolve()
if Config.dst if Config.dst
else os.path.splitext(src_path)[0] + ".tlog" else src_path.with_suffix(".tlog")
) )
if os.path.isfile(dst_path):
if Config.force: if dst_path.is_file() and not Config.force:
transcribe_one(src_path, dst_path) raise RuntimeError(
else: f'Destination file "{dst_path}" already exists - use '
fail( "--force for overwriting."
'Destination file "{}" already existing - use --force for overwriting'.format( )
dst_path
), if not dst_path.parent.is_dir():
code=0, raise RuntimeError("Missing destination directory")
)
elif os.path.isdir(os.path.dirname(dst_path)): transcribe_many([src_path], [dst_path])
transcribe_one(src_path, dst_path)
else:
fail("Missing destination directory")
elif os.path.isdir(src_path):
# Transcribe all files in dir
print("Transcribing all WAV files in --src")
if Config.recursive:
wav_paths = glob.glob(os.path.join(src_path, "**", "*.wav"))
else: else:
wav_paths = glob.glob(os.path.join(src_path, "*.wav")) # Transcribe from .catalog input
dst_paths = [path.replace(".wav", ".tlog") for path in wav_paths] src_paths, dst_paths = get_tasks_from_catalog(src_path)
transcribe_many(wav_paths, dst_paths) transcribe_many(src_paths, dst_paths)
elif src_path.is_dir():
# Transcribe from dir input
print(f"Transcribing all files in --src directory {src_path}")
src_paths, dst_paths = get_tasks_from_dir(src_path, Config.recursive)
transcribe_many(src_paths, dst_paths)
def initialize_transcribe_config(): def initialize_transcribe_config():
@ -230,7 +251,7 @@ def initialize_transcribe_config():
"Catalog files should be formatted from DSAlign. A directory " "Catalog files should be formatted from DSAlign. A directory "
"will be recursively searched for audio. If --dst not set, " "will be recursively searched for audio. If --dst not set, "
"transcription logs (.tlog) will be written in-place using the " "transcription logs (.tlog) will be written in-place using the "
'source filenames with suffix ".tlog" instead of ".wav".' 'source filenames with suffix ".tlog" instead of the original.'
), ),
) )
@ -299,6 +320,9 @@ def initialize_transcribe_config():
def main(): def main():
from coqui_stt_training.util.helpers import check_ctcdecoder_version from coqui_stt_training.util.helpers import check_ctcdecoder_version
# Set start method to spawn on all platforms to avoid issues with TensorFlow
multiprocessing.set_start_method("spawn")
try: try:
import webrtcvad import webrtcvad
except ImportError: except ImportError:

View File

@ -7,7 +7,13 @@ import tensorflow as tf
from .config import Config, log_error, log_info, log_warn from .config import Config, log_error, log_info, log_warn
def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=True): def _load_checkpoint(
session,
checkpoint_path,
allow_drop_layers,
allow_lr_init=True,
silent: bool = False,
):
# Load the checkpoint and put all variables into loading list # Load the checkpoint and put all variables into loading list
# we will exclude variables we do not wish to load and then # we will exclude variables we do not wish to load and then
# we will initialize them instead # we will initialize them instead
@ -75,15 +81,16 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
init_vars.add(v) init_vars.add(v)
load_vars -= init_vars load_vars -= init_vars
log_info(f"Vars to load: {list(sorted(v.op.name for v in load_vars))}") def maybe_log_info(*args, **kwargs):
if not silent:
log_info(*args, **kwargs)
for v in sorted(load_vars, key=lambda v: v.op.name): for v in sorted(load_vars, key=lambda v: v.op.name):
log_info(f"Getting tensor from variable: {v.op.name}") maybe_log_info(f"Loading variable from checkpoint: {v.op.name}")
tensor = ckpt.get_tensor(v.op.name) v.load(ckpt.get_tensor(v.op.name), session=session)
log_info(f"Loading tensor from checkpoint: {v.op.name}")
v.load(tensor, session=session)
for v in sorted(init_vars, key=lambda v: v.op.name): for v in sorted(init_vars, key=lambda v: v.op.name):
log_info("Initializing variable: %s" % (v.op.name)) maybe_log_info("Initializing variable: %s" % (v.op.name))
session.run(v.initializer) session.run(v.initializer)
@ -102,31 +109,49 @@ def _initialize_all_variables(session):
session.run(v.initializer) session.run(v.initializer)
def _load_or_init_impl(session, method_order, allow_drop_layers, allow_lr_init=True): def _load_or_init_impl(
session, method_order, allow_drop_layers, allow_lr_init=True, silent: bool = False
):
def maybe_log_info(*args, **kwargs):
if not silent:
log_info(*args, **kwargs)
for method in method_order: for method in method_order:
# Load best validating checkpoint, saved in checkpoint file 'best_dev_checkpoint' # Load best validating checkpoint, saved in checkpoint file 'best_dev_checkpoint'
if method == "best": if method == "best":
ckpt_path = _checkpoint_path_or_none("best_dev_checkpoint") ckpt_path = _checkpoint_path_or_none("best_dev_checkpoint")
if ckpt_path: if ckpt_path:
log_info("Loading best validating checkpoint from {}".format(ckpt_path)) maybe_log_info(
return _load_checkpoint( "Loading best validating checkpoint from {}".format(ckpt_path)
session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init
) )
log_info("Could not find best validating checkpoint.") return _load_checkpoint(
session,
ckpt_path,
allow_drop_layers,
allow_lr_init=allow_lr_init,
silent=silent,
)
maybe_log_info("Could not find best validating checkpoint.")
# Load most recent checkpoint, saved in checkpoint file 'checkpoint' # Load most recent checkpoint, saved in checkpoint file 'checkpoint'
elif method == "last": elif method == "last":
ckpt_path = _checkpoint_path_or_none("checkpoint") ckpt_path = _checkpoint_path_or_none("checkpoint")
if ckpt_path: if ckpt_path:
log_info("Loading most recent checkpoint from {}".format(ckpt_path)) maybe_log_info(
return _load_checkpoint( "Loading most recent checkpoint from {}".format(ckpt_path)
session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init
) )
log_info("Could not find most recent checkpoint.") return _load_checkpoint(
session,
ckpt_path,
allow_drop_layers,
allow_lr_init=allow_lr_init,
silent=silent,
)
maybe_log_info("Could not find most recent checkpoint.")
# Initialize all variables # Initialize all variables
elif method == "init": elif method == "init":
log_info("Initializing all variables.") maybe_log_info("Initializing all variables.")
return _initialize_all_variables(session) return _initialize_all_variables(session)
else: else:
@ -141,7 +166,7 @@ def reload_best_checkpoint(session):
_load_or_init_impl(session, ["best"], allow_drop_layers=False, allow_lr_init=False) _load_or_init_impl(session, ["best"], allow_drop_layers=False, allow_lr_init=False)
def load_or_init_graph_for_training(session): def load_or_init_graph_for_training(session, silent: bool = False):
""" """
Load variables from checkpoint or initialize variables. By default this will Load variables from checkpoint or initialize variables. By default this will
try to load the best validating checkpoint, then try the last checkpoint, try to load the best validating checkpoint, then try the last checkpoint,
@ -152,10 +177,10 @@ def load_or_init_graph_for_training(session):
methods = ["best", "last", "init"] methods = ["best", "last", "init"]
else: else:
methods = [Config.load_train] methods = [Config.load_train]
_load_or_init_impl(session, methods, allow_drop_layers=True) _load_or_init_impl(session, methods, allow_drop_layers=True, silent=silent)
def load_graph_for_evaluation(session): def load_graph_for_evaluation(session, silent: bool = False):
""" """
Load variables from checkpoint. Initialization is not allowed. By default Load variables from checkpoint. Initialization is not allowed. By default
this will try to load the best validating checkpoint, then try the last this will try to load the best validating checkpoint, then try the last
@ -166,4 +191,4 @@ def load_graph_for_evaluation(session):
methods = ["best", "last"] methods = ["best", "last"]
else: else:
methods = [Config.load_evaluate] methods = [Config.load_evaluate]
_load_or_init_impl(session, methods, allow_drop_layers=False) _load_or_init_impl(session, methods, allow_drop_layers=False, silent=silent)

View File

@ -217,12 +217,14 @@ class BaseSttConfig(Coqpit):
if not is_remote_path(self.save_checkpoint_dir): if not is_remote_path(self.save_checkpoint_dir):
os.makedirs(self.save_checkpoint_dir, exist_ok=True) os.makedirs(self.save_checkpoint_dir, exist_ok=True)
flags_file = os.path.join(self.save_checkpoint_dir, "flags.txt") flags_file = os.path.join(self.save_checkpoint_dir, "flags.txt")
with open_remote(flags_file, "w") as fout: if not os.path.exists(flags_file):
json.dump(self.serialize(), fout, indent=2) with open_remote(flags_file, "w") as fout:
json.dump(self.serialize(), fout, indent=2)
# Serialize alphabet alongside checkpoint # Serialize alphabet alongside checkpoint
with open_remote(saved_checkpoint_alphabet_file, "wb") as fout: if not os.path.exists(saved_checkpoint_alphabet_file):
fout.write(self.alphabet.SerializeText()) with open_remote(saved_checkpoint_alphabet_file, "wb") as fout:
fout.write(self.alphabet.SerializeText())
# Geometric Constants # Geometric Constants
# =================== # ===================