From d90bb60506ffd11dba3074f4b129f7b060354e06 Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Wed, 1 Dec 2021 14:24:42 +0100 Subject: [PATCH] [transcribe] Fix multiprocessing hangs, clean-up target collection --- ci_scripts/train-extra-tests.sh | 17 +- training/coqui_stt_training/transcribe.py | 208 ++++++++++-------- .../coqui_stt_training/util/checkpoints.py | 67 ++++-- training/coqui_stt_training/util/config.py | 10 +- 4 files changed, 176 insertions(+), 126 deletions(-) diff --git a/ci_scripts/train-extra-tests.sh b/ci_scripts/train-extra-tests.sh index 82987474..1f18ab12 100755 --- a/ci_scripts/train-extra-tests.sh +++ b/ci_scripts/train-extra-tests.sh @@ -80,12 +80,11 @@ time python -m coqui_stt_training.transcribe \ --n_hidden 100 \ --scorer_path "data/smoke_test/pruned_lm.scorer" -#TODO: investigate why this is hanging in CI -#mkdir /tmp/transcribe_dir -#cp data/smoke_test/LDC93S1.wav /tmp/transcribe_dir -#time python -m coqui_stt_training.transcribe \ -# --src "/tmp/transcribe_dir/" \ -# --n_hidden 100 \ -# --scorer_path "data/smoke_test/pruned_lm.scorer" -# -#for i in data/smoke_test/*.tlog; do echo $i; cat $i; echo; done +mkdir /tmp/transcribe_dir +cp data/smoke_test/LDC93S1.wav /tmp/transcribe_dir +time python -m coqui_stt_training.transcribe \ + --src "/tmp/transcribe_dir/" \ + --n_hidden 100 \ + --scorer_path "data/smoke_test/pruned_lm.scorer" + +for i in /tmp/transcribe_dir/*.tlog; do echo $i; cat $i; echo; done diff --git a/training/coqui_stt_training/transcribe.py b/training/coqui_stt_training/transcribe.py index a761d516..9bb0b252 100755 --- a/training/coqui_stt_training/transcribe.py +++ b/training/coqui_stt_training/transcribe.py @@ -7,25 +7,22 @@ # restructure the code so that TensorFlow is only imported inside the child # processes. -import os -import sys import glob import itertools import json import multiprocessing -from multiprocessing import Pool, cpu_count +import os +import sys from dataclasses import dataclass, field +from multiprocessing import Pool, Lock, cpu_count +from pathlib import Path +from typing import Optional, List, Tuple from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch from tqdm import tqdm -def fail(message, code=1): - print(f"E {message}") - sys.exit(code) - - -def transcribe_file(audio_path, tlog_path): +def transcribe_file(audio_path: Path, tlog_path: Path): log_level_index = ( sys.argv.index("--log_level") + 1 if "--log_level" in sys.argv else 0 ) @@ -56,7 +53,7 @@ def transcribe_file(audio_path, tlog_path): except NotImplementedError: num_processes = 1 - with AudioFile(audio_path, as_path=True) as wav_path: + with AudioFile(str(audio_path), as_path=True) as wav_path: data_set = split_audio_file( wav_path, batch_size=Config.batch_size, @@ -73,7 +70,9 @@ def transcribe_file(audio_path, tlog_path): transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2])) tf.train.get_or_create_global_step() with tf.Session(config=Config.session_config) as session: - load_graph_for_evaluation(session) + # Load checkpoint in a mutex way to avoid hangs in TensorFlow code + with lock: + load_graph_for_evaluation(session, silent=True) transcripts = [] while True: try: @@ -101,8 +100,13 @@ def transcribe_file(audio_path, tlog_path): json.dump(transcripts, tlog_file, default=float) +def init_fn(l): + global lock + lock = l + + def step_function(job): - """ Wrap transcribe_file to unpack arguments from a single tuple """ + """Wrap transcribe_file to unpack arguments from a single tuple""" idx, src, dst = job transcribe_file(src, dst) return idx, src, dst @@ -111,36 +115,81 @@ def step_function(job): def transcribe_many(src_paths, dst_paths): from coqui_stt_training.util.config import Config, log_progress - pool = Pool(processes=min(cpu_count(), len(src_paths))) - # Create list of items to be processed: [(i, src_path[i], dst_paths[i])] jobs = zip(itertools.count(), src_paths, dst_paths) - process_iterable = tqdm( - pool.imap_unordered(step_function, jobs), - desc="Transcribing files", - total=len(src_paths), - disable=not Config.show_progressbar, - ) - - for result in process_iterable: - idx, src, dst = result - log_progress( - f'Transcribed file {idx+1} of {len(src_paths)} from "{src}" to "{dst}"' + lock = Lock() + with Pool( + processes=min(cpu_count(), len(src_paths)), + initializer=init_fn, + initargs=(lock,), + ) as pool: + process_iterable = tqdm( + pool.imap_unordered(step_function, jobs), + desc="Transcribing files", + total=len(src_paths), + disable=not Config.show_progressbar, ) - -def transcribe_one(src_path, dst_path): - transcribe_file(src_path, dst_path) - print(f'I Transcribed file "{src_path}" to "{dst_path}"') + cwd = Path.cwd() + for result in process_iterable: + idx, src, dst = result + # Revert to relative to avoid spamming logs + if not src.is_absolute(): + src = src.relative_to(cwd) + if not dst.is_absolute(): + dst = dst.relative_to(cwd) + tqdm.write(f'[{idx+1}]: "{src}" -> "{dst}"') -def resolve(base_path, spec_path): - if spec_path is None: - return None - if not os.path.isabs(spec_path): - spec_path = os.path.join(base_path, spec_path) - return spec_path +def get_tasks_from_catalog(catalog_file_path: Path) -> Tuple[List[Path], List[Path]]: + """Given a `catalog_file_path` pointing to a .catalog file (from DSAlign), + extract transcription tasks, ie. (src_path, dest_path) pairs corresponding to + a path to an audio file to be transcribed, and a path to a JSON file to place + transcription results. For .catalog file inputs, these are taken from the + "audio" and "tlog" properties of the entries in the catalog, with any relative + paths being absolutized relative to the directory containing the .catalog file. + """ + assert catalog_file_path.suffix == ".catalog" + + catalog_dir = catalog_file_path.parent + with open(catalog_file_path, "r") as catalog_file: + catalog_entries = json.load(catalog_file) + + def resolve(spec_path: Optional[Path]): + if spec_path is None: + return None + if not spec_path.is_absolute(): + spec_path = catalog_dir / spec_path + return spec_path + + catalog_entries = [ + (resolve(Path(e["audio"])), resolve(Path(e["tlog"]))) for e in catalog_entries + ] + + for src, dst in catalog_entries: + if not Config.force and dst.is_file(): + raise RuntimeError( + f"Destination file already exists: {dst}. Use --force for overwriting." + ) + + if not dst.parent.is_dir(): + dst.parent.mkdir(parents=True) + + src_paths, dst_paths = zip(*catalog_entries) + return src_paths, dst_paths + + +def get_tasks_from_dir(src_dir: Path, recursive: bool) -> Tuple[List[Path], List[Path]]: + """Given a directory `src_dir` containing audio files, scan it for audio files + and return transcription tasks, ie. (src_path, dest_path) pairs corresponding to + a path to an audio file to be transcribed, and a path to a JSON file to place + transcription results. + """ + glob_method = src_dir.rglob if recursive else src_dir.glob + src_paths = list(itertools.chain(glob_method("*.wav"), glob_method("*.opus"))) + dst_paths = [path.with_suffix(".tlog") for path in src_paths] + return src_paths, dst_paths def transcribe(): @@ -148,71 +197,43 @@ def transcribe(): initialize_transcribe_config() - if not Config.src or not os.path.exists(Config.src): + src_path = Path(Config.src).resolve() + if not Config.src or not src_path.exists(): # path not given or non-existant - fail( - "You have to specify which file or catalog to transcribe via the --src flag." + raise RuntimeError( + "You have to specify which audio file, catalog file or directory to " + "transcribe with the --src flag." ) else: # path given and exists - src_path = os.path.abspath(Config.src) - if os.path.isfile(src_path): - if src_path.endswith(".catalog"): - # Transcribe batch of files via ".catalog" file (from DSAlign) - catalog_dir = os.path.dirname(src_path) - with open(src_path, "r") as catalog_file: - catalog_entries = json.load(catalog_file) - catalog_entries = [ - (resolve(catalog_dir, e["audio"]), resolve(catalog_dir, e["tlog"])) - for e in catalog_entries - ] - if any(map(lambda e: not os.path.isfile(e[0]), catalog_entries)): - fail("Missing source file(s) in catalog") - if not Config.force and any( - map(lambda e: os.path.isfile(e[1]), catalog_entries) - ): - fail( - "Destination file(s) from catalog already existing, use --force for overwriting" - ) - if any( - map( - lambda e: not os.path.isdir(os.path.dirname(e[1])), - catalog_entries, - ) - ): - fail("Missing destination directory for at least one catalog entry") - src_paths, dst_paths = zip(*paths) - transcribe_many(src_paths, dst_paths) - else: + if src_path.is_file(): + if src_path.suffix != ".catalog": # Transcribe one file dst_path = ( - os.path.abspath(Config.dst) + Path(Config.dst).resolve() if Config.dst - else os.path.splitext(src_path)[0] + ".tlog" + else src_path.with_suffix(".tlog") ) - if os.path.isfile(dst_path): - if Config.force: - transcribe_one(src_path, dst_path) - else: - fail( - 'Destination file "{}" already existing - use --force for overwriting'.format( - dst_path - ), - code=0, - ) - elif os.path.isdir(os.path.dirname(dst_path)): - transcribe_one(src_path, dst_path) - else: - fail("Missing destination directory") - elif os.path.isdir(src_path): - # Transcribe all files in dir - print("Transcribing all WAV files in --src") - if Config.recursive: - wav_paths = glob.glob(os.path.join(src_path, "**", "*.wav")) + + if dst_path.is_file() and not Config.force: + raise RuntimeError( + f'Destination file "{dst_path}" already exists - use ' + "--force for overwriting." + ) + + if not dst_path.parent.is_dir(): + raise RuntimeError("Missing destination directory") + + transcribe_many([src_path], [dst_path]) else: - wav_paths = glob.glob(os.path.join(src_path, "*.wav")) - dst_paths = [path.replace(".wav", ".tlog") for path in wav_paths] - transcribe_many(wav_paths, dst_paths) + # Transcribe from .catalog input + src_paths, dst_paths = get_tasks_from_catalog(src_path) + transcribe_many(src_paths, dst_paths) + elif src_path.is_dir(): + # Transcribe from dir input + print(f"Transcribing all files in --src directory {src_path}") + src_paths, dst_paths = get_tasks_from_dir(src_path, Config.recursive) + transcribe_many(src_paths, dst_paths) def initialize_transcribe_config(): @@ -230,7 +251,7 @@ def initialize_transcribe_config(): "Catalog files should be formatted from DSAlign. A directory " "will be recursively searched for audio. If --dst not set, " "transcription logs (.tlog) will be written in-place using the " - 'source filenames with suffix ".tlog" instead of ".wav".' + 'source filenames with suffix ".tlog" instead of the original.' ), ) @@ -299,6 +320,9 @@ def initialize_transcribe_config(): def main(): from coqui_stt_training.util.helpers import check_ctcdecoder_version + # Set start method to spawn on all platforms to avoid issues with TensorFlow + multiprocessing.set_start_method("spawn") + try: import webrtcvad except ImportError: diff --git a/training/coqui_stt_training/util/checkpoints.py b/training/coqui_stt_training/util/checkpoints.py index 9d5452ae..638eadfe 100644 --- a/training/coqui_stt_training/util/checkpoints.py +++ b/training/coqui_stt_training/util/checkpoints.py @@ -7,7 +7,13 @@ import tensorflow as tf from .config import Config, log_error, log_info, log_warn -def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=True): +def _load_checkpoint( + session, + checkpoint_path, + allow_drop_layers, + allow_lr_init=True, + silent: bool = False, +): # Load the checkpoint and put all variables into loading list # we will exclude variables we do not wish to load and then # we will initialize them instead @@ -75,15 +81,16 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init= init_vars.add(v) load_vars -= init_vars - log_info(f"Vars to load: {list(sorted(v.op.name for v in load_vars))}") + def maybe_log_info(*args, **kwargs): + if not silent: + log_info(*args, **kwargs) + for v in sorted(load_vars, key=lambda v: v.op.name): - log_info(f"Getting tensor from variable: {v.op.name}") - tensor = ckpt.get_tensor(v.op.name) - log_info(f"Loading tensor from checkpoint: {v.op.name}") - v.load(tensor, session=session) + maybe_log_info(f"Loading variable from checkpoint: {v.op.name}") + v.load(ckpt.get_tensor(v.op.name), session=session) for v in sorted(init_vars, key=lambda v: v.op.name): - log_info("Initializing variable: %s" % (v.op.name)) + maybe_log_info("Initializing variable: %s" % (v.op.name)) session.run(v.initializer) @@ -102,31 +109,49 @@ def _initialize_all_variables(session): session.run(v.initializer) -def _load_or_init_impl(session, method_order, allow_drop_layers, allow_lr_init=True): +def _load_or_init_impl( + session, method_order, allow_drop_layers, allow_lr_init=True, silent: bool = False +): + def maybe_log_info(*args, **kwargs): + if not silent: + log_info(*args, **kwargs) + for method in method_order: # Load best validating checkpoint, saved in checkpoint file 'best_dev_checkpoint' if method == "best": ckpt_path = _checkpoint_path_or_none("best_dev_checkpoint") if ckpt_path: - log_info("Loading best validating checkpoint from {}".format(ckpt_path)) - return _load_checkpoint( - session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init + maybe_log_info( + "Loading best validating checkpoint from {}".format(ckpt_path) ) - log_info("Could not find best validating checkpoint.") + return _load_checkpoint( + session, + ckpt_path, + allow_drop_layers, + allow_lr_init=allow_lr_init, + silent=silent, + ) + maybe_log_info("Could not find best validating checkpoint.") # Load most recent checkpoint, saved in checkpoint file 'checkpoint' elif method == "last": ckpt_path = _checkpoint_path_or_none("checkpoint") if ckpt_path: - log_info("Loading most recent checkpoint from {}".format(ckpt_path)) - return _load_checkpoint( - session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init + maybe_log_info( + "Loading most recent checkpoint from {}".format(ckpt_path) ) - log_info("Could not find most recent checkpoint.") + return _load_checkpoint( + session, + ckpt_path, + allow_drop_layers, + allow_lr_init=allow_lr_init, + silent=silent, + ) + maybe_log_info("Could not find most recent checkpoint.") # Initialize all variables elif method == "init": - log_info("Initializing all variables.") + maybe_log_info("Initializing all variables.") return _initialize_all_variables(session) else: @@ -141,7 +166,7 @@ def reload_best_checkpoint(session): _load_or_init_impl(session, ["best"], allow_drop_layers=False, allow_lr_init=False) -def load_or_init_graph_for_training(session): +def load_or_init_graph_for_training(session, silent: bool = False): """ Load variables from checkpoint or initialize variables. By default this will try to load the best validating checkpoint, then try the last checkpoint, @@ -152,10 +177,10 @@ def load_or_init_graph_for_training(session): methods = ["best", "last", "init"] else: methods = [Config.load_train] - _load_or_init_impl(session, methods, allow_drop_layers=True) + _load_or_init_impl(session, methods, allow_drop_layers=True, silent=silent) -def load_graph_for_evaluation(session): +def load_graph_for_evaluation(session, silent: bool = False): """ Load variables from checkpoint. Initialization is not allowed. By default this will try to load the best validating checkpoint, then try the last @@ -166,4 +191,4 @@ def load_graph_for_evaluation(session): methods = ["best", "last"] else: methods = [Config.load_evaluate] - _load_or_init_impl(session, methods, allow_drop_layers=False) + _load_or_init_impl(session, methods, allow_drop_layers=False, silent=silent) diff --git a/training/coqui_stt_training/util/config.py b/training/coqui_stt_training/util/config.py index 5c9a189b..2bd88dbe 100644 --- a/training/coqui_stt_training/util/config.py +++ b/training/coqui_stt_training/util/config.py @@ -217,12 +217,14 @@ class BaseSttConfig(Coqpit): if not is_remote_path(self.save_checkpoint_dir): os.makedirs(self.save_checkpoint_dir, exist_ok=True) flags_file = os.path.join(self.save_checkpoint_dir, "flags.txt") - with open_remote(flags_file, "w") as fout: - json.dump(self.serialize(), fout, indent=2) + if not os.path.exists(flags_file): + with open_remote(flags_file, "w") as fout: + json.dump(self.serialize(), fout, indent=2) # Serialize alphabet alongside checkpoint - with open_remote(saved_checkpoint_alphabet_file, "wb") as fout: - fout.write(self.alphabet.SerializeText()) + if not os.path.exists(saved_checkpoint_alphabet_file): + with open_remote(saved_checkpoint_alphabet_file, "wb") as fout: + fout.write(self.alphabet.SerializeText()) # Geometric Constants # ===================