Undo late-imports

This commit is contained in:
Reuben Morais 2021-12-03 15:43:42 +01:00
parent 479d963155
commit ff24a8b917
1 changed files with 89 additions and 95 deletions

View File

@ -1,12 +1,5 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# This script is structured in a weird way, with delayed imports. This is due
# to the use of multiprocessing. TensorFlow cannot handle forking, and even with
# the spawn strategy set to "spawn" it still leads to weird problems, so we
# restructure the code so that TensorFlow is only imported inside the child
# processes.
import glob import glob
import itertools import itertools
import json import json
@ -18,28 +11,31 @@ from multiprocessing import Pool, Lock, cpu_count
from pathlib import Path from pathlib import Path
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
LOG_LEVEL_INDEX = sys.argv.index("--log_level") + 1 if "--log_level" in sys.argv else 0
DESIRED_LOG_LEVEL = (
sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else "3"
)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = DESIRED_LOG_LEVEL
# Hide GPUs to prevent issues with child processes trying to use the same GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch
from coqui_stt_training.train import create_model
from coqui_stt_training.util.audio import AudioFile
from coqui_stt_training.util.checkpoints import load_graph_for_evaluation
from coqui_stt_training.util.config import (
BaseSttConfig,
Config,
initialize_globals_from_instance,
)
from coqui_stt_training.util.feeding import split_audio_file
from coqui_stt_training.util.helpers import check_ctcdecoder_version
from tqdm import tqdm from tqdm import tqdm
def transcribe_file(audio_path: Path, tlog_path: Path): def transcribe_file(audio_path: Path, tlog_path: Path):
log_level_index = (
sys.argv.index("--log_level") + 1 if "--log_level" in sys.argv else 0
)
desired_log_level = (
sys.argv[log_level_index] if 0 < log_level_index < len(sys.argv) else "3"
)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = desired_log_level
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
from coqui_stt_training.train import create_model
from coqui_stt_training.util.audio import AudioFile
from coqui_stt_training.util.checkpoints import load_graph_for_evaluation
from coqui_stt_training.util.config import Config
from coqui_stt_training.util.feeding import split_audio_file
initialize_transcribe_config() initialize_transcribe_config()
scorer = None scorer = None
@ -113,8 +109,6 @@ def step_function(job):
def transcribe_many(src_paths, dst_paths): def transcribe_many(src_paths, dst_paths):
from coqui_stt_training.util.config import Config, log_progress
# Create list of items to be processed: [(i, src_path[i], dst_paths[i])] # Create list of items to be processed: [(i, src_path[i], dst_paths[i])]
jobs = zip(itertools.count(), src_paths, dst_paths) jobs = zip(itertools.count(), src_paths, dst_paths)
@ -134,11 +128,17 @@ def transcribe_many(src_paths, dst_paths):
cwd = Path.cwd() cwd = Path.cwd()
for result in process_iterable: for result in process_iterable:
idx, src, dst = result idx, src, dst = result
# Revert to relative to avoid spamming logs # Revert to relative if possible to make logs more concise
if not src.is_absolute(): # if path is not relative to cwd, use the absolute path
# (Path.is_relative_to is only available in Python >=3.9)
try:
src = src.relative_to(cwd) src = src.relative_to(cwd)
if not dst.is_absolute(): except ValueError:
pass
try:
dst = dst.relative_to(cwd) dst = dst.relative_to(cwd)
except ValueError:
pass
tqdm.write(f'[{idx+1}]: "{src}" -> "{dst}"') tqdm.write(f'[{idx+1}]: "{src}" -> "{dst}"')
@ -187,14 +187,12 @@ def get_tasks_from_dir(src_dir: Path, recursive: bool) -> Tuple[List[Path], List
transcription results. transcription results.
""" """
glob_method = src_dir.rglob if recursive else src_dir.glob glob_method = src_dir.rglob if recursive else src_dir.glob
src_paths = list(itertools.chain(glob_method("*.wav"), glob_method("*.opus"))) src_paths = list(glob_method("*.wav"))
dst_paths = [path.with_suffix(".tlog") for path in src_paths] dst_paths = [path.with_suffix(".tlog") for path in src_paths]
return src_paths, dst_paths return src_paths, dst_paths
def transcribe(): def transcribe():
from coqui_stt_training.util.config import Config
initialize_transcribe_config() initialize_transcribe_config()
src_path = Path(Config.src).resolve() src_path = Path(Config.src).resolve()
@ -236,12 +234,6 @@ def transcribe():
transcribe_many(src_paths, dst_paths) transcribe_many(src_paths, dst_paths)
def initialize_transcribe_config():
from coqui_stt_training.util.config import (
BaseSttConfig,
initialize_globals_from_instance,
)
@dataclass @dataclass
class TranscribeConfig(BaseSttConfig): class TranscribeConfig(BaseSttConfig):
src: str = field( src: str = field(
@ -313,12 +305,14 @@ def initialize_transcribe_config():
super().__post_init__() super().__post_init__()
def initialize_transcribe_config():
config = TranscribeConfig.init_from_argparse(arg_prefix="") config = TranscribeConfig.init_from_argparse(arg_prefix="")
initialize_globals_from_instance(config) initialize_globals_from_instance(config)
def main(): def main():
from coqui_stt_training.util.helpers import check_ctcdecoder_version assert not tf.test.is_gpu_available()
# Set start method to spawn on all platforms to avoid issues with TensorFlow # Set start method to spawn on all platforms to avoid issues with TensorFlow
multiprocessing.set_start_method("spawn") multiprocessing.set_start_method("spawn")