diff --git a/training/coqui_stt_training/transcribe.py b/training/coqui_stt_training/transcribe.py index 9bb0b252..5d57b03d 100755 --- a/training/coqui_stt_training/transcribe.py +++ b/training/coqui_stt_training/transcribe.py @@ -1,12 +1,5 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- - -# This script is structured in a weird way, with delayed imports. This is due -# to the use of multiprocessing. TensorFlow cannot handle forking, and even with -# the spawn strategy set to "spawn" it still leads to weird problems, so we -# restructure the code so that TensorFlow is only imported inside the child -# processes. - import glob import itertools import json @@ -18,28 +11,31 @@ from multiprocessing import Pool, Lock, cpu_count from pathlib import Path from typing import Optional, List, Tuple +LOG_LEVEL_INDEX = sys.argv.index("--log_level") + 1 if "--log_level" in sys.argv else 0 +DESIRED_LOG_LEVEL = ( + sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else "3" +) +os.environ["TF_CPP_MIN_LOG_LEVEL"] = DESIRED_LOG_LEVEL +# Hide GPUs to prevent issues with child processes trying to use the same GPU +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" +import tensorflow as tf +import tensorflow.compat.v1 as tfv1 + from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch +from coqui_stt_training.train import create_model +from coqui_stt_training.util.audio import AudioFile +from coqui_stt_training.util.checkpoints import load_graph_for_evaluation +from coqui_stt_training.util.config import ( + BaseSttConfig, + Config, + initialize_globals_from_instance, +) +from coqui_stt_training.util.feeding import split_audio_file +from coqui_stt_training.util.helpers import check_ctcdecoder_version from tqdm import tqdm def transcribe_file(audio_path: Path, tlog_path: Path): - log_level_index = ( - sys.argv.index("--log_level") + 1 if "--log_level" in sys.argv else 0 - ) - desired_log_level = ( - sys.argv[log_level_index] if 0 < log_level_index < len(sys.argv) else "3" - ) - os.environ["TF_CPP_MIN_LOG_LEVEL"] = desired_log_level - - import tensorflow as tf - import tensorflow.compat.v1 as tfv1 - - from coqui_stt_training.train import create_model - from coqui_stt_training.util.audio import AudioFile - from coqui_stt_training.util.checkpoints import load_graph_for_evaluation - from coqui_stt_training.util.config import Config - from coqui_stt_training.util.feeding import split_audio_file - initialize_transcribe_config() scorer = None @@ -113,8 +109,6 @@ def step_function(job): def transcribe_many(src_paths, dst_paths): - from coqui_stt_training.util.config import Config, log_progress - # Create list of items to be processed: [(i, src_path[i], dst_paths[i])] jobs = zip(itertools.count(), src_paths, dst_paths) @@ -134,11 +128,17 @@ def transcribe_many(src_paths, dst_paths): cwd = Path.cwd() for result in process_iterable: idx, src, dst = result - # Revert to relative to avoid spamming logs - if not src.is_absolute(): + # Revert to relative if possible to make logs more concise + # if path is not relative to cwd, use the absolute path + # (Path.is_relative_to is only available in Python >=3.9) + try: src = src.relative_to(cwd) - if not dst.is_absolute(): + except ValueError: + pass + try: dst = dst.relative_to(cwd) + except ValueError: + pass tqdm.write(f'[{idx+1}]: "{src}" -> "{dst}"') @@ -187,14 +187,12 @@ def get_tasks_from_dir(src_dir: Path, recursive: bool) -> Tuple[List[Path], List transcription results. """ glob_method = src_dir.rglob if recursive else src_dir.glob - src_paths = list(itertools.chain(glob_method("*.wav"), glob_method("*.opus"))) + src_paths = list(glob_method("*.wav")) dst_paths = [path.with_suffix(".tlog") for path in src_paths] return src_paths, dst_paths def transcribe(): - from coqui_stt_training.util.config import Config - initialize_transcribe_config() src_path = Path(Config.src).resolve() @@ -236,89 +234,85 @@ def transcribe(): transcribe_many(src_paths, dst_paths) -def initialize_transcribe_config(): - from coqui_stt_training.util.config import ( - BaseSttConfig, - initialize_globals_from_instance, +@dataclass +class TranscribeConfig(BaseSttConfig): + src: str = field( + default="", + metadata=dict( + help="Source path to an audio file or directory or catalog file. " + "Catalog files should be formatted from DSAlign. A directory " + "will be recursively searched for audio. If --dst not set, " + "transcription logs (.tlog) will be written in-place using the " + 'source filenames with suffix ".tlog" instead of the original.' + ), ) - @dataclass - class TranscribeConfig(BaseSttConfig): - src: str = field( - default="", - metadata=dict( - help="Source path to an audio file or directory or catalog file. " - "Catalog files should be formatted from DSAlign. A directory " - "will be recursively searched for audio. If --dst not set, " - "transcription logs (.tlog) will be written in-place using the " - 'source filenames with suffix ".tlog" instead of the original.' - ), - ) + dst: str = field( + default="", + metadata=dict( + help="path for writing the transcription log or logs (.tlog). " + "If --src is a directory, this one also has to be a directory " + "and the required sub-dir tree of --src will get replicated." + ), + ) - dst: str = field( - default="", - metadata=dict( - help="path for writing the transcription log or logs (.tlog). " - "If --src is a directory, this one also has to be a directory " - "and the required sub-dir tree of --src will get replicated." - ), - ) + recursive: bool = field( + default=False, + metadata=dict(help="scan source directory recursively for audio"), + ) - recursive: bool = field( - default=False, - metadata=dict(help="scan source directory recursively for audio"), - ) + force: bool = field( + default=False, + metadata=dict( + help="Forces re-transcribing and overwriting of already existing " + "transcription logs (.tlog)" + ), + ) - force: bool = field( - default=False, - metadata=dict( - help="Forces re-transcribing and overwriting of already existing " - "transcription logs (.tlog)" - ), - ) + vad_aggressiveness: int = field( + default=3, + metadata=dict(help="VAD aggressiveness setting (0=lowest, 3=highest)"), + ) - vad_aggressiveness: int = field( - default=3, - metadata=dict(help="VAD aggressiveness setting (0=lowest, 3=highest)"), - ) + batch_size: int = field( + default=40, + metadata=dict(help="Default batch size"), + ) - batch_size: int = field( - default=40, - metadata=dict(help="Default batch size"), - ) + outlier_duration_ms: int = field( + default=10000, + metadata=dict( + help="Duration in ms after which samples are considered outliers" + ), + ) - outlier_duration_ms: int = field( - default=10000, - metadata=dict( - help="Duration in ms after which samples are considered outliers" - ), - ) + outlier_batch_size: int = field( + default=1, + metadata=dict(help="Batch size for duration outliers (defaults to 1)"), + ) - outlier_batch_size: int = field( - default=1, - metadata=dict(help="Batch size for duration outliers (defaults to 1)"), - ) + def __post_init__(self): + if os.path.isfile(self.src) and self.src.endswith(".catalog") and self.dst: + raise RuntimeError( + "Parameter --dst not supported if --src points to a catalog" + ) - def __post_init__(self): - if os.path.isfile(self.src) and self.src.endswith(".catalog") and self.dst: + if os.path.isdir(self.src): + if self.dst: raise RuntimeError( - "Parameter --dst not supported if --src points to a catalog" + "Destination path not supported for batch decoding jobs." ) - if os.path.isdir(self.src): - if self.dst: - raise RuntimeError( - "Destination path not supported for batch decoding jobs." - ) + super().__post_init__() - super().__post_init__() +def initialize_transcribe_config(): config = TranscribeConfig.init_from_argparse(arg_prefix="") initialize_globals_from_instance(config) def main(): - from coqui_stt_training.util.helpers import check_ctcdecoder_version + assert not tf.test.is_gpu_available() # Set start method to spawn on all platforms to avoid issues with TensorFlow multiprocessing.set_start_method("spawn")