Undo late-imports
This commit is contained in:
parent
479d963155
commit
ff24a8b917
|
@ -1,12 +1,5 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
# This script is structured in a weird way, with delayed imports. This is due
|
|
||||||
# to the use of multiprocessing. TensorFlow cannot handle forking, and even with
|
|
||||||
# the spawn strategy set to "spawn" it still leads to weird problems, so we
|
|
||||||
# restructure the code so that TensorFlow is only imported inside the child
|
|
||||||
# processes.
|
|
||||||
|
|
||||||
import glob
|
import glob
|
||||||
import itertools
|
import itertools
|
||||||
import json
|
import json
|
||||||
|
@ -18,28 +11,31 @@ from multiprocessing import Pool, Lock, cpu_count
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
|
LOG_LEVEL_INDEX = sys.argv.index("--log_level") + 1 if "--log_level" in sys.argv else 0
|
||||||
|
DESIRED_LOG_LEVEL = (
|
||||||
|
sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else "3"
|
||||||
|
)
|
||||||
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = DESIRED_LOG_LEVEL
|
||||||
|
# Hide GPUs to prevent issues with child processes trying to use the same GPU
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
||||||
|
import tensorflow as tf
|
||||||
|
import tensorflow.compat.v1 as tfv1
|
||||||
|
|
||||||
from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch
|
from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch
|
||||||
|
from coqui_stt_training.train import create_model
|
||||||
|
from coqui_stt_training.util.audio import AudioFile
|
||||||
|
from coqui_stt_training.util.checkpoints import load_graph_for_evaluation
|
||||||
|
from coqui_stt_training.util.config import (
|
||||||
|
BaseSttConfig,
|
||||||
|
Config,
|
||||||
|
initialize_globals_from_instance,
|
||||||
|
)
|
||||||
|
from coqui_stt_training.util.feeding import split_audio_file
|
||||||
|
from coqui_stt_training.util.helpers import check_ctcdecoder_version
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
def transcribe_file(audio_path: Path, tlog_path: Path):
|
def transcribe_file(audio_path: Path, tlog_path: Path):
|
||||||
log_level_index = (
|
|
||||||
sys.argv.index("--log_level") + 1 if "--log_level" in sys.argv else 0
|
|
||||||
)
|
|
||||||
desired_log_level = (
|
|
||||||
sys.argv[log_level_index] if 0 < log_level_index < len(sys.argv) else "3"
|
|
||||||
)
|
|
||||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = desired_log_level
|
|
||||||
|
|
||||||
import tensorflow as tf
|
|
||||||
import tensorflow.compat.v1 as tfv1
|
|
||||||
|
|
||||||
from coqui_stt_training.train import create_model
|
|
||||||
from coqui_stt_training.util.audio import AudioFile
|
|
||||||
from coqui_stt_training.util.checkpoints import load_graph_for_evaluation
|
|
||||||
from coqui_stt_training.util.config import Config
|
|
||||||
from coqui_stt_training.util.feeding import split_audio_file
|
|
||||||
|
|
||||||
initialize_transcribe_config()
|
initialize_transcribe_config()
|
||||||
|
|
||||||
scorer = None
|
scorer = None
|
||||||
|
@ -113,8 +109,6 @@ def step_function(job):
|
||||||
|
|
||||||
|
|
||||||
def transcribe_many(src_paths, dst_paths):
|
def transcribe_many(src_paths, dst_paths):
|
||||||
from coqui_stt_training.util.config import Config, log_progress
|
|
||||||
|
|
||||||
# Create list of items to be processed: [(i, src_path[i], dst_paths[i])]
|
# Create list of items to be processed: [(i, src_path[i], dst_paths[i])]
|
||||||
jobs = zip(itertools.count(), src_paths, dst_paths)
|
jobs = zip(itertools.count(), src_paths, dst_paths)
|
||||||
|
|
||||||
|
@ -134,11 +128,17 @@ def transcribe_many(src_paths, dst_paths):
|
||||||
cwd = Path.cwd()
|
cwd = Path.cwd()
|
||||||
for result in process_iterable:
|
for result in process_iterable:
|
||||||
idx, src, dst = result
|
idx, src, dst = result
|
||||||
# Revert to relative to avoid spamming logs
|
# Revert to relative if possible to make logs more concise
|
||||||
if not src.is_absolute():
|
# if path is not relative to cwd, use the absolute path
|
||||||
|
# (Path.is_relative_to is only available in Python >=3.9)
|
||||||
|
try:
|
||||||
src = src.relative_to(cwd)
|
src = src.relative_to(cwd)
|
||||||
if not dst.is_absolute():
|
except ValueError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
dst = dst.relative_to(cwd)
|
dst = dst.relative_to(cwd)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
tqdm.write(f'[{idx+1}]: "{src}" -> "{dst}"')
|
tqdm.write(f'[{idx+1}]: "{src}" -> "{dst}"')
|
||||||
|
|
||||||
|
|
||||||
|
@ -187,14 +187,12 @@ def get_tasks_from_dir(src_dir: Path, recursive: bool) -> Tuple[List[Path], List
|
||||||
transcription results.
|
transcription results.
|
||||||
"""
|
"""
|
||||||
glob_method = src_dir.rglob if recursive else src_dir.glob
|
glob_method = src_dir.rglob if recursive else src_dir.glob
|
||||||
src_paths = list(itertools.chain(glob_method("*.wav"), glob_method("*.opus")))
|
src_paths = list(glob_method("*.wav"))
|
||||||
dst_paths = [path.with_suffix(".tlog") for path in src_paths]
|
dst_paths = [path.with_suffix(".tlog") for path in src_paths]
|
||||||
return src_paths, dst_paths
|
return src_paths, dst_paths
|
||||||
|
|
||||||
|
|
||||||
def transcribe():
|
def transcribe():
|
||||||
from coqui_stt_training.util.config import Config
|
|
||||||
|
|
||||||
initialize_transcribe_config()
|
initialize_transcribe_config()
|
||||||
|
|
||||||
src_path = Path(Config.src).resolve()
|
src_path = Path(Config.src).resolve()
|
||||||
|
@ -236,89 +234,85 @@ def transcribe():
|
||||||
transcribe_many(src_paths, dst_paths)
|
transcribe_many(src_paths, dst_paths)
|
||||||
|
|
||||||
|
|
||||||
def initialize_transcribe_config():
|
@dataclass
|
||||||
from coqui_stt_training.util.config import (
|
class TranscribeConfig(BaseSttConfig):
|
||||||
BaseSttConfig,
|
src: str = field(
|
||||||
initialize_globals_from_instance,
|
default="",
|
||||||
|
metadata=dict(
|
||||||
|
help="Source path to an audio file or directory or catalog file. "
|
||||||
|
"Catalog files should be formatted from DSAlign. A directory "
|
||||||
|
"will be recursively searched for audio. If --dst not set, "
|
||||||
|
"transcription logs (.tlog) will be written in-place using the "
|
||||||
|
'source filenames with suffix ".tlog" instead of the original.'
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@dataclass
|
dst: str = field(
|
||||||
class TranscribeConfig(BaseSttConfig):
|
default="",
|
||||||
src: str = field(
|
metadata=dict(
|
||||||
default="",
|
help="path for writing the transcription log or logs (.tlog). "
|
||||||
metadata=dict(
|
"If --src is a directory, this one also has to be a directory "
|
||||||
help="Source path to an audio file or directory or catalog file. "
|
"and the required sub-dir tree of --src will get replicated."
|
||||||
"Catalog files should be formatted from DSAlign. A directory "
|
),
|
||||||
"will be recursively searched for audio. If --dst not set, "
|
)
|
||||||
"transcription logs (.tlog) will be written in-place using the "
|
|
||||||
'source filenames with suffix ".tlog" instead of the original.'
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
dst: str = field(
|
recursive: bool = field(
|
||||||
default="",
|
default=False,
|
||||||
metadata=dict(
|
metadata=dict(help="scan source directory recursively for audio"),
|
||||||
help="path for writing the transcription log or logs (.tlog). "
|
)
|
||||||
"If --src is a directory, this one also has to be a directory "
|
|
||||||
"and the required sub-dir tree of --src will get replicated."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
recursive: bool = field(
|
force: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata=dict(help="scan source directory recursively for audio"),
|
metadata=dict(
|
||||||
)
|
help="Forces re-transcribing and overwriting of already existing "
|
||||||
|
"transcription logs (.tlog)"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
force: bool = field(
|
vad_aggressiveness: int = field(
|
||||||
default=False,
|
default=3,
|
||||||
metadata=dict(
|
metadata=dict(help="VAD aggressiveness setting (0=lowest, 3=highest)"),
|
||||||
help="Forces re-transcribing and overwriting of already existing "
|
)
|
||||||
"transcription logs (.tlog)"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
vad_aggressiveness: int = field(
|
batch_size: int = field(
|
||||||
default=3,
|
default=40,
|
||||||
metadata=dict(help="VAD aggressiveness setting (0=lowest, 3=highest)"),
|
metadata=dict(help="Default batch size"),
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_size: int = field(
|
outlier_duration_ms: int = field(
|
||||||
default=40,
|
default=10000,
|
||||||
metadata=dict(help="Default batch size"),
|
metadata=dict(
|
||||||
)
|
help="Duration in ms after which samples are considered outliers"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
outlier_duration_ms: int = field(
|
outlier_batch_size: int = field(
|
||||||
default=10000,
|
default=1,
|
||||||
metadata=dict(
|
metadata=dict(help="Batch size for duration outliers (defaults to 1)"),
|
||||||
help="Duration in ms after which samples are considered outliers"
|
)
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
outlier_batch_size: int = field(
|
def __post_init__(self):
|
||||||
default=1,
|
if os.path.isfile(self.src) and self.src.endswith(".catalog") and self.dst:
|
||||||
metadata=dict(help="Batch size for duration outliers (defaults to 1)"),
|
raise RuntimeError(
|
||||||
)
|
"Parameter --dst not supported if --src points to a catalog"
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
if os.path.isdir(self.src):
|
||||||
if os.path.isfile(self.src) and self.src.endswith(".catalog") and self.dst:
|
if self.dst:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Parameter --dst not supported if --src points to a catalog"
|
"Destination path not supported for batch decoding jobs."
|
||||||
)
|
)
|
||||||
|
|
||||||
if os.path.isdir(self.src):
|
super().__post_init__()
|
||||||
if self.dst:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Destination path not supported for batch decoding jobs."
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__post_init__()
|
|
||||||
|
|
||||||
|
def initialize_transcribe_config():
|
||||||
config = TranscribeConfig.init_from_argparse(arg_prefix="")
|
config = TranscribeConfig.init_from_argparse(arg_prefix="")
|
||||||
initialize_globals_from_instance(config)
|
initialize_globals_from_instance(config)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
from coqui_stt_training.util.helpers import check_ctcdecoder_version
|
assert not tf.test.is_gpu_available()
|
||||||
|
|
||||||
# Set start method to spawn on all platforms to avoid issues with TensorFlow
|
# Set start method to spawn on all platforms to avoid issues with TensorFlow
|
||||||
multiprocessing.set_start_method("spawn")
|
multiprocessing.set_start_method("spawn")
|
||||||
|
|
Loading…
Reference in New Issue