Merge pull request #2032 from coqui-ai/transcription-scripts-docs
[transcribe] Fix multiprocessing hangs, clean-up target collection, write docs
This commit is contained in:
commit
dbd38c3a89
@ -80,12 +80,11 @@ time python -m coqui_stt_training.transcribe \
|
||||
--n_hidden 100 \
|
||||
--scorer_path "data/smoke_test/pruned_lm.scorer"
|
||||
|
||||
#TODO: investigate why this is hanging in CI
|
||||
#mkdir /tmp/transcribe_dir
|
||||
#cp data/smoke_test/LDC93S1.wav /tmp/transcribe_dir
|
||||
#time python -m coqui_stt_training.transcribe \
|
||||
# --src "/tmp/transcribe_dir/" \
|
||||
# --n_hidden 100 \
|
||||
# --scorer_path "data/smoke_test/pruned_lm.scorer"
|
||||
#
|
||||
#for i in data/smoke_test/*.tlog; do echo $i; cat $i; echo; done
|
||||
mkdir /tmp/transcribe_dir
|
||||
cp data/smoke_test/LDC93S1.wav /tmp/transcribe_dir
|
||||
time python -m coqui_stt_training.transcribe \
|
||||
--src "/tmp/transcribe_dir/" \
|
||||
--n_hidden 100 \
|
||||
--scorer_path "data/smoke_test/pruned_lm.scorer"
|
||||
|
||||
for i in /tmp/transcribe_dir/*.tlog; do echo $i; cat $i; echo; done
|
||||
|
79
doc/Checkpoint-Inference.rst
Normal file
79
doc/Checkpoint-Inference.rst
Normal file
@ -0,0 +1,79 @@
|
||||
.. _checkpoint-inference:
|
||||
|
||||
Inference tools in the training package
|
||||
=======================================
|
||||
|
||||
The standard deployment options for 🐸STT use highly optimized packages for deployment in real time, single-stream, low latency use cases. They take as input exported models which are also optimized, leading to further space and runtime gains. On the other hand, for the development of new features, it might be easier to use the training code for prototyping, which will allow you to test your changes without needing to recompile source code.
|
||||
|
||||
The training package contains options for performing inference directly from a checkpoint (and optionally a scorer), without needing to export a model. They are documented below, and all require a working :ref:`training environment <intro-training-docs>` before they can be used. Additionally, they require the Python ``webrtcvad`` package to be installed. This can either be done by specifying the "transcribe" extra when installing the training package, or by installing it manually in your training environment:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ python -m pip install webrtcvad
|
||||
|
||||
Note that if your goal is to evaluate a trained model and obtain accuracy metrics, you should use the evaluation module: ``python -m coqui_stt_training.evaluate``, which calculates character and word error rates, from a properly formatted CSV file (specified with the ``--test_files`` flag. See the :ref:`training docs <intro-training-docs>` for more information).
|
||||
|
||||
Single file (aka one-shot) inference
|
||||
------------------------------------
|
||||
|
||||
This is the simplest way to perform inference from a checkpoint. It takes a single WAV file as input with the ``--one_shot_infer`` flag, and outputs the predicted transcription for that file.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ python -m coqui_stt_training.training_graph_inference --checkpoint_dir coqui-stt-1.0.0-checkpoint --scorer_path huge-vocabulary.scorer --n_hidden 2048 --one_shot_infer audio/2830-3980-0043.wav
|
||||
I --alphabet_config_path not specified, but found an alphabet file alongside specified checkpoint (coqui-stt-1.0.0-checkpoint/alphabet.txt). Will use this alphabet file for this run.
|
||||
I Loading best validating checkpoint from coqui-stt-1.0.0-checkpoint/best_dev-3663881
|
||||
I Loading variable from checkpoint: cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/bias
|
||||
I Loading variable from checkpoint: cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/kernel
|
||||
I Loading variable from checkpoint: layer_1/bias
|
||||
I Loading variable from checkpoint: layer_1/weights
|
||||
I Loading variable from checkpoint: layer_2/bias
|
||||
I Loading variable from checkpoint: layer_2/weights
|
||||
I Loading variable from checkpoint: layer_3/bias
|
||||
I Loading variable from checkpoint: layer_3/weights
|
||||
I Loading variable from checkpoint: layer_5/bias
|
||||
I Loading variable from checkpoint: layer_5/weights
|
||||
I Loading variable from checkpoint: layer_6/bias
|
||||
I Loading variable from checkpoint: layer_6/weights
|
||||
experience proves this
|
||||
|
||||
Transcription of longer audio files
|
||||
-----------------------------------
|
||||
|
||||
If you have longer audio files to transcribe, we offer a script which uses Voice Activity Detection (VAD) to split audio files in chunks and perform batched inference on said files. This can speed-up the transcription time significantly. The transcription script will also output the results in JSON format, allowing for easier programmatic usage of the outputs.
|
||||
|
||||
There are two main usage modes: transcribing a single file, or scanning a directory for audio files and transcribing all of them.
|
||||
|
||||
Transcribing a single file
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
For a single audio file, you can specify it directly in the ``--src`` flag of the ``python -m coqui_stt_training.transcribe`` script:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ python -m coqui_stt_training.transcribe --checkpoint_dir coqui-stt-1.0.0-checkpoint --n_hidden 2048 --scorer_path huge-vocabulary.scorer --vad_aggressiveness 0 --src audio/2830-3980-0043.wav
|
||||
[1]: "audio/2830-3980-0043.wav" -> "audio/2830-3980-0043.tlog"
|
||||
Transcribing files: 100%|███████████████████████████████████| 1/1 [00:05<00:00, 5.40s/it]
|
||||
$ cat audio/2830-3980-0043.tlog
|
||||
[{"start": 150, "end": 1950, "transcript": "experience proves this"}]
|
||||
|
||||
Note the use of the ``--vad_aggressiveness`` flag above to control the behavior of the VAD process used to find silent sections of the audio file for splitting into chunks. You can run ``python -m coqui_stt_training.transcribe --help`` to see the full listing of options, the last ones are specific to the transcribe module.
|
||||
|
||||
By default the transcription results are put in a ``.tlog`` file next to the audio file that was transcribed, but you can specify a different location with the ``--dst path/to/some/file.tlog`` flag. This only works when trancribing a single file.
|
||||
|
||||
Scanning a directory for audio files
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Alternatively you can also specify a directory in the ``--src`` flag, in which case the directory will be scanned for any WAV files to be transcribed. If you specify ``--recursive true``, it'll scan the directory recursively, going into any subdirectories as well. Transcription results will be placed in a ``.tlog`` file alongside every audio file that was found by the process.
|
||||
|
||||
Multiple processes will be used to distribute the transcription work among available CPUs.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ python -m coqui_stt_training.transcribe --checkpoint_dir coqui-stt-1.0.0-checkpoint --n_hidden 2048 --scorer_path huge-vocabulary.scorer --vad_aggressiveness 0 --src audio/ --recursive true
|
||||
Transcribing all files in --src directory audio
|
||||
Transcribing files: 0%| | 0/3 [00:00<?, ?it/s]
|
||||
[3]: "audio/8455-210777-0068.wav" -> "audio/8455-210777-0068.tlog"
|
||||
[1]: "audio/2830-3980-0043.wav" -> "audio/2830-3980-0043.tlog"
|
||||
[2]: "audio/4507-16021-0012.wav" -> "audio/4507-16021-0012.tlog"
|
||||
Transcribing files: 100%|███████████████████████████████████| 3/3 [00:07<00:00, 2.50s/it]
|
@ -18,6 +18,8 @@ You can deploy 🐸STT models either via a command-line client or a language bin
|
||||
* :ref:`The command-line client <cli-usage>`
|
||||
* :ref:`The C API <c-usage>`
|
||||
|
||||
In some use cases, you might want to use the inference facilities built into the training code, for example for faster prototyping of new features. They are not production-ready, but because it's all Python code you won't need to recompile in order to test code changes, which can be much faster. See :ref:`checkpoint-inference` for more details.
|
||||
|
||||
.. _download-models:
|
||||
|
||||
Download trained Coqui STT models
|
||||
@ -103,14 +105,6 @@ The following command assumes you :ref:`downloaded the pre-trained models <downl
|
||||
|
||||
See :ref:`the Python client <py-api-example>` for an example of how to use the package programatically.
|
||||
|
||||
*GPUs will soon be supported:* If you have a supported NVIDIA GPU on Linux, you can install the GPU specific package as follows:
|
||||
|
||||
.. code-block::
|
||||
|
||||
(coqui-stt-venv)$ python -m pip install -U pip && python -m pip install stt-gpu
|
||||
|
||||
See the `release notes <https://github.com/coqui-ai/STT/releases>`_ to find which GPUs are supported. Please ensure you have the required `CUDA dependency <#cuda-dependency>`_.
|
||||
|
||||
.. _nodejs-usage:
|
||||
|
||||
Using the Node.JS / Electron.JS package
|
||||
@ -132,14 +126,6 @@ Please note that as of now, we support:
|
||||
|
||||
TypeScript support is also provided.
|
||||
|
||||
If you're using Linux and have a supported NVIDIA GPU, you can install the GPU specific package as follows:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
npm install stt-gpu
|
||||
|
||||
See the `release notes <https://github.com/coqui-ai/STT/releases>`_ to find which GPUs are supported. Please ensure you have the required `CUDA dependency <#cuda-dependency>`_.
|
||||
|
||||
See the :ref:`TypeScript client <js-api-example>` for an example of how to use the bindings programatically.
|
||||
|
||||
.. _android-usage:
|
||||
@ -232,11 +218,6 @@ Running ``stt`` may require runtime dependencies. Please refer to your system's
|
||||
* ``libpthread`` - Reported dependency on Linux. On Ubuntu, ``libpthread`` is part of the ``libpthread-stubs0-dev`` package
|
||||
* ``Redistribuable Visual C++ 2015 Update 3 (64-bits)`` - Reported dependency on Windows. Please `download from Microsoft <https://www.microsoft.com/download/details.aspx?id=53587>`_
|
||||
|
||||
CUDA Dependency
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
The GPU capable builds (Python, NodeJS, C++, etc) depend on CUDA 10.1 and CuDNN v7.6.
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
|
@ -27,3 +27,5 @@ This document contains more advanced topics with regard to training models with
|
||||
PARALLLEL_OPTIMIZATION
|
||||
|
||||
DATASET_IMPORTERS
|
||||
|
||||
Checkpoint-Inference
|
||||
|
2
doc/static/custom.css
vendored
2
doc/static/custom.css
vendored
@ -1,3 +1,3 @@
|
||||
#flags pre {
|
||||
#flags pre, #inference-tools-in-the-training-package pre {
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
|
2
setup.py
2
setup.py
@ -66,7 +66,7 @@ def main():
|
||||
],
|
||||
package_dir={"": "training"},
|
||||
packages=find_packages(where="training"),
|
||||
python_requires=">=3.5, <4",
|
||||
python_requires=">=3.5, <3.8",
|
||||
install_requires=install_requires,
|
||||
include_package_data=True,
|
||||
extras_require={
|
||||
|
@ -1,48 +1,41 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# This script is structured in a weird way, with delayed imports. This is due
|
||||
# to the use of multiprocessing. TensorFlow cannot handle forking, and even with
|
||||
# the spawn strategy set to "spawn" it still leads to weird problems, so we
|
||||
# restructure the code so that TensorFlow is only imported inside the child
|
||||
# processes.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import glob
|
||||
import itertools
|
||||
import json
|
||||
import multiprocessing
|
||||
from multiprocessing import Pool, cpu_count
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from multiprocessing import Pool, Lock, cpu_count
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
LOG_LEVEL_INDEX = sys.argv.index("--log_level") + 1 if "--log_level" in sys.argv else 0
|
||||
DESIRED_LOG_LEVEL = (
|
||||
sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else "3"
|
||||
)
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = DESIRED_LOG_LEVEL
|
||||
# Hide GPUs to prevent issues with child processes trying to use the same GPU
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
|
||||
from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch
|
||||
from coqui_stt_training.train import create_model
|
||||
from coqui_stt_training.util.audio import AudioFile
|
||||
from coqui_stt_training.util.checkpoints import load_graph_for_evaluation
|
||||
from coqui_stt_training.util.config import (
|
||||
BaseSttConfig,
|
||||
Config,
|
||||
initialize_globals_from_instance,
|
||||
)
|
||||
from coqui_stt_training.util.feeding import split_audio_file
|
||||
from coqui_stt_training.util.helpers import check_ctcdecoder_version
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def fail(message, code=1):
|
||||
print(f"E {message}")
|
||||
sys.exit(code)
|
||||
|
||||
|
||||
def transcribe_file(audio_path, tlog_path):
|
||||
log_level_index = (
|
||||
sys.argv.index("--log_level") + 1 if "--log_level" in sys.argv else 0
|
||||
)
|
||||
desired_log_level = (
|
||||
sys.argv[log_level_index] if 0 < log_level_index < len(sys.argv) else "3"
|
||||
)
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = desired_log_level
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
|
||||
from coqui_stt_training.train import create_model
|
||||
from coqui_stt_training.util.audio import AudioFile
|
||||
from coqui_stt_training.util.checkpoints import load_graph_for_evaluation
|
||||
from coqui_stt_training.util.config import Config
|
||||
from coqui_stt_training.util.feeding import split_audio_file
|
||||
|
||||
def transcribe_file(audio_path: Path, tlog_path: Path):
|
||||
initialize_transcribe_config()
|
||||
|
||||
scorer = None
|
||||
@ -56,7 +49,7 @@ def transcribe_file(audio_path, tlog_path):
|
||||
except NotImplementedError:
|
||||
num_processes = 1
|
||||
|
||||
with AudioFile(audio_path, as_path=True) as wav_path:
|
||||
with AudioFile(str(audio_path), as_path=True) as wav_path:
|
||||
data_set = split_audio_file(
|
||||
wav_path,
|
||||
batch_size=Config.batch_size,
|
||||
@ -73,7 +66,9 @@ def transcribe_file(audio_path, tlog_path):
|
||||
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
|
||||
tf.train.get_or_create_global_step()
|
||||
with tf.Session(config=Config.session_config) as session:
|
||||
load_graph_for_evaluation(session)
|
||||
# Load checkpoint in a mutex way to avoid hangs in TensorFlow code
|
||||
with lock:
|
||||
load_graph_for_evaluation(session, silent=True)
|
||||
transcripts = []
|
||||
while True:
|
||||
try:
|
||||
@ -101,21 +96,28 @@ def transcribe_file(audio_path, tlog_path):
|
||||
json.dump(transcripts, tlog_file, default=float)
|
||||
|
||||
|
||||
def init_fn(l):
|
||||
global lock
|
||||
lock = l
|
||||
|
||||
|
||||
def step_function(job):
|
||||
""" Wrap transcribe_file to unpack arguments from a single tuple """
|
||||
"""Wrap transcribe_file to unpack arguments from a single tuple"""
|
||||
idx, src, dst = job
|
||||
transcribe_file(src, dst)
|
||||
return idx, src, dst
|
||||
|
||||
|
||||
def transcribe_many(src_paths, dst_paths):
|
||||
from coqui_stt_training.util.config import Config, log_progress
|
||||
|
||||
pool = Pool(processes=min(cpu_count(), len(src_paths)))
|
||||
|
||||
# Create list of items to be processed: [(i, src_path[i], dst_paths[i])]
|
||||
jobs = zip(itertools.count(), src_paths, dst_paths)
|
||||
|
||||
lock = Lock()
|
||||
with Pool(
|
||||
processes=min(cpu_count(), len(src_paths)),
|
||||
initializer=init_fn,
|
||||
initargs=(lock,),
|
||||
) as pool:
|
||||
process_iterable = tqdm(
|
||||
pool.imap_unordered(step_function, jobs),
|
||||
desc="Transcribing files",
|
||||
@ -123,106 +125,117 @@ def transcribe_many(src_paths, dst_paths):
|
||||
disable=not Config.show_progressbar,
|
||||
)
|
||||
|
||||
cwd = Path.cwd()
|
||||
for result in process_iterable:
|
||||
idx, src, dst = result
|
||||
log_progress(
|
||||
f'Transcribed file {idx+1} of {len(src_paths)} from "{src}" to "{dst}"'
|
||||
)
|
||||
# Revert to relative if possible to make logs more concise
|
||||
# if path is not relative to cwd, use the absolute path
|
||||
# (Path.is_relative_to is only available in Python >=3.9)
|
||||
try:
|
||||
src = src.relative_to(cwd)
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
dst = dst.relative_to(cwd)
|
||||
except ValueError:
|
||||
pass
|
||||
tqdm.write(f'[{idx+1}]: "{src}" -> "{dst}"')
|
||||
|
||||
|
||||
def transcribe_one(src_path, dst_path):
|
||||
transcribe_file(src_path, dst_path)
|
||||
print(f'I Transcribed file "{src_path}" to "{dst_path}"')
|
||||
def get_tasks_from_catalog(catalog_file_path: Path) -> Tuple[List[Path], List[Path]]:
|
||||
"""Given a `catalog_file_path` pointing to a .catalog file (from DSAlign),
|
||||
extract transcription tasks, ie. (src_path, dest_path) pairs corresponding to
|
||||
a path to an audio file to be transcribed, and a path to a JSON file to place
|
||||
transcription results. For .catalog file inputs, these are taken from the
|
||||
"audio" and "tlog" properties of the entries in the catalog, with any relative
|
||||
paths being absolutized relative to the directory containing the .catalog file.
|
||||
"""
|
||||
assert catalog_file_path.suffix == ".catalog"
|
||||
|
||||
catalog_dir = catalog_file_path.parent
|
||||
with open(catalog_file_path, "r") as catalog_file:
|
||||
catalog_entries = json.load(catalog_file)
|
||||
|
||||
def resolve(base_path, spec_path):
|
||||
def resolve(spec_path: Optional[Path]):
|
||||
if spec_path is None:
|
||||
return None
|
||||
if not os.path.isabs(spec_path):
|
||||
spec_path = os.path.join(base_path, spec_path)
|
||||
if not spec_path.is_absolute():
|
||||
spec_path = catalog_dir / spec_path
|
||||
return spec_path
|
||||
|
||||
catalog_entries = [
|
||||
(resolve(Path(e["audio"])), resolve(Path(e["tlog"]))) for e in catalog_entries
|
||||
]
|
||||
|
||||
for src, dst in catalog_entries:
|
||||
if not Config.force and dst.is_file():
|
||||
raise RuntimeError(
|
||||
f"Destination file already exists: {dst}. Use --force for overwriting."
|
||||
)
|
||||
|
||||
if not dst.parent.is_dir():
|
||||
dst.parent.mkdir(parents=True)
|
||||
|
||||
src_paths, dst_paths = zip(*catalog_entries)
|
||||
return src_paths, dst_paths
|
||||
|
||||
|
||||
def get_tasks_from_dir(src_dir: Path, recursive: bool) -> Tuple[List[Path], List[Path]]:
|
||||
"""Given a directory `src_dir` containing audio files, scan it for audio files
|
||||
and return transcription tasks, ie. (src_path, dest_path) pairs corresponding to
|
||||
a path to an audio file to be transcribed, and a path to a JSON file to place
|
||||
transcription results.
|
||||
"""
|
||||
glob_method = src_dir.rglob if recursive else src_dir.glob
|
||||
src_paths = list(glob_method("*.wav"))
|
||||
dst_paths = [path.with_suffix(".tlog") for path in src_paths]
|
||||
return src_paths, dst_paths
|
||||
|
||||
|
||||
def transcribe():
|
||||
from coqui_stt_training.util.config import Config
|
||||
|
||||
initialize_transcribe_config()
|
||||
|
||||
if not Config.src or not os.path.exists(Config.src):
|
||||
src_path = Path(Config.src).resolve()
|
||||
if not Config.src or not src_path.exists():
|
||||
# path not given or non-existant
|
||||
fail(
|
||||
"You have to specify which file or catalog to transcribe via the --src flag."
|
||||
raise RuntimeError(
|
||||
"You have to specify which audio file, catalog file or directory to "
|
||||
"transcribe with the --src flag."
|
||||
)
|
||||
else:
|
||||
# path given and exists
|
||||
src_path = os.path.abspath(Config.src)
|
||||
if os.path.isfile(src_path):
|
||||
if src_path.endswith(".catalog"):
|
||||
# Transcribe batch of files via ".catalog" file (from DSAlign)
|
||||
catalog_dir = os.path.dirname(src_path)
|
||||
with open(src_path, "r") as catalog_file:
|
||||
catalog_entries = json.load(catalog_file)
|
||||
catalog_entries = [
|
||||
(resolve(catalog_dir, e["audio"]), resolve(catalog_dir, e["tlog"]))
|
||||
for e in catalog_entries
|
||||
]
|
||||
if any(map(lambda e: not os.path.isfile(e[0]), catalog_entries)):
|
||||
fail("Missing source file(s) in catalog")
|
||||
if not Config.force and any(
|
||||
map(lambda e: os.path.isfile(e[1]), catalog_entries)
|
||||
):
|
||||
fail(
|
||||
"Destination file(s) from catalog already existing, use --force for overwriting"
|
||||
)
|
||||
if any(
|
||||
map(
|
||||
lambda e: not os.path.isdir(os.path.dirname(e[1])),
|
||||
catalog_entries,
|
||||
)
|
||||
):
|
||||
fail("Missing destination directory for at least one catalog entry")
|
||||
src_paths, dst_paths = zip(*paths)
|
||||
transcribe_many(src_paths, dst_paths)
|
||||
else:
|
||||
if src_path.is_file():
|
||||
if src_path.suffix != ".catalog":
|
||||
# Transcribe one file
|
||||
dst_path = (
|
||||
os.path.abspath(Config.dst)
|
||||
Path(Config.dst).resolve()
|
||||
if Config.dst
|
||||
else os.path.splitext(src_path)[0] + ".tlog"
|
||||
)
|
||||
if os.path.isfile(dst_path):
|
||||
if Config.force:
|
||||
transcribe_one(src_path, dst_path)
|
||||
else:
|
||||
fail(
|
||||
'Destination file "{}" already existing - use --force for overwriting'.format(
|
||||
dst_path
|
||||
),
|
||||
code=0,
|
||||
)
|
||||
elif os.path.isdir(os.path.dirname(dst_path)):
|
||||
transcribe_one(src_path, dst_path)
|
||||
else:
|
||||
fail("Missing destination directory")
|
||||
elif os.path.isdir(src_path):
|
||||
# Transcribe all files in dir
|
||||
print("Transcribing all WAV files in --src")
|
||||
if Config.recursive:
|
||||
wav_paths = glob.glob(os.path.join(src_path, "**", "*.wav"))
|
||||
else:
|
||||
wav_paths = glob.glob(os.path.join(src_path, "*.wav"))
|
||||
dst_paths = [path.replace(".wav", ".tlog") for path in wav_paths]
|
||||
transcribe_many(wav_paths, dst_paths)
|
||||
|
||||
|
||||
def initialize_transcribe_config():
|
||||
from coqui_stt_training.util.config import (
|
||||
BaseSttConfig,
|
||||
initialize_globals_from_instance,
|
||||
else src_path.with_suffix(".tlog")
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class TranscribeConfig(BaseSttConfig):
|
||||
if dst_path.is_file() and not Config.force:
|
||||
raise RuntimeError(
|
||||
f'Destination file "{dst_path}" already exists - use '
|
||||
"--force for overwriting."
|
||||
)
|
||||
|
||||
if not dst_path.parent.is_dir():
|
||||
raise RuntimeError("Missing destination directory")
|
||||
|
||||
transcribe_many([src_path], [dst_path])
|
||||
else:
|
||||
# Transcribe from .catalog input
|
||||
src_paths, dst_paths = get_tasks_from_catalog(src_path)
|
||||
transcribe_many(src_paths, dst_paths)
|
||||
elif src_path.is_dir():
|
||||
# Transcribe from dir input
|
||||
print(f"Transcribing all files in --src directory {src_path}")
|
||||
src_paths, dst_paths = get_tasks_from_dir(src_path, Config.recursive)
|
||||
transcribe_many(src_paths, dst_paths)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscribeConfig(BaseSttConfig):
|
||||
src: str = field(
|
||||
default="",
|
||||
metadata=dict(
|
||||
@ -230,7 +243,7 @@ def initialize_transcribe_config():
|
||||
"Catalog files should be formatted from DSAlign. A directory "
|
||||
"will be recursively searched for audio. If --dst not set, "
|
||||
"transcription logs (.tlog) will be written in-place using the "
|
||||
'source filenames with suffix ".tlog" instead of ".wav".'
|
||||
'source filenames with suffix ".tlog" instead of the original.'
|
||||
),
|
||||
)
|
||||
|
||||
@ -292,12 +305,17 @@ def initialize_transcribe_config():
|
||||
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
def initialize_transcribe_config():
|
||||
config = TranscribeConfig.init_from_argparse(arg_prefix="")
|
||||
initialize_globals_from_instance(config)
|
||||
|
||||
|
||||
def main():
|
||||
from coqui_stt_training.util.helpers import check_ctcdecoder_version
|
||||
assert not tf.test.is_gpu_available()
|
||||
|
||||
# Set start method to spawn on all platforms to avoid issues with TensorFlow
|
||||
multiprocessing.set_start_method("spawn")
|
||||
|
||||
try:
|
||||
import webrtcvad
|
||||
|
@ -7,7 +7,13 @@ import tensorflow as tf
|
||||
from .config import Config, log_error, log_info, log_warn
|
||||
|
||||
|
||||
def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=True):
|
||||
def _load_checkpoint(
|
||||
session,
|
||||
checkpoint_path,
|
||||
allow_drop_layers,
|
||||
allow_lr_init=True,
|
||||
silent: bool = False,
|
||||
):
|
||||
# Load the checkpoint and put all variables into loading list
|
||||
# we will exclude variables we do not wish to load and then
|
||||
# we will initialize them instead
|
||||
@ -75,15 +81,16 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
|
||||
init_vars.add(v)
|
||||
load_vars -= init_vars
|
||||
|
||||
log_info(f"Vars to load: {list(sorted(v.op.name for v in load_vars))}")
|
||||
def maybe_log_info(*args, **kwargs):
|
||||
if not silent:
|
||||
log_info(*args, **kwargs)
|
||||
|
||||
for v in sorted(load_vars, key=lambda v: v.op.name):
|
||||
log_info(f"Getting tensor from variable: {v.op.name}")
|
||||
tensor = ckpt.get_tensor(v.op.name)
|
||||
log_info(f"Loading tensor from checkpoint: {v.op.name}")
|
||||
v.load(tensor, session=session)
|
||||
maybe_log_info(f"Loading variable from checkpoint: {v.op.name}")
|
||||
v.load(ckpt.get_tensor(v.op.name), session=session)
|
||||
|
||||
for v in sorted(init_vars, key=lambda v: v.op.name):
|
||||
log_info("Initializing variable: %s" % (v.op.name))
|
||||
maybe_log_info("Initializing variable: %s" % (v.op.name))
|
||||
session.run(v.initializer)
|
||||
|
||||
|
||||
@ -102,31 +109,49 @@ def _initialize_all_variables(session):
|
||||
session.run(v.initializer)
|
||||
|
||||
|
||||
def _load_or_init_impl(session, method_order, allow_drop_layers, allow_lr_init=True):
|
||||
def _load_or_init_impl(
|
||||
session, method_order, allow_drop_layers, allow_lr_init=True, silent: bool = False
|
||||
):
|
||||
def maybe_log_info(*args, **kwargs):
|
||||
if not silent:
|
||||
log_info(*args, **kwargs)
|
||||
|
||||
for method in method_order:
|
||||
# Load best validating checkpoint, saved in checkpoint file 'best_dev_checkpoint'
|
||||
if method == "best":
|
||||
ckpt_path = _checkpoint_path_or_none("best_dev_checkpoint")
|
||||
if ckpt_path:
|
||||
log_info("Loading best validating checkpoint from {}".format(ckpt_path))
|
||||
return _load_checkpoint(
|
||||
session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init
|
||||
maybe_log_info(
|
||||
"Loading best validating checkpoint from {}".format(ckpt_path)
|
||||
)
|
||||
log_info("Could not find best validating checkpoint.")
|
||||
return _load_checkpoint(
|
||||
session,
|
||||
ckpt_path,
|
||||
allow_drop_layers,
|
||||
allow_lr_init=allow_lr_init,
|
||||
silent=silent,
|
||||
)
|
||||
maybe_log_info("Could not find best validating checkpoint.")
|
||||
|
||||
# Load most recent checkpoint, saved in checkpoint file 'checkpoint'
|
||||
elif method == "last":
|
||||
ckpt_path = _checkpoint_path_or_none("checkpoint")
|
||||
if ckpt_path:
|
||||
log_info("Loading most recent checkpoint from {}".format(ckpt_path))
|
||||
return _load_checkpoint(
|
||||
session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init
|
||||
maybe_log_info(
|
||||
"Loading most recent checkpoint from {}".format(ckpt_path)
|
||||
)
|
||||
log_info("Could not find most recent checkpoint.")
|
||||
return _load_checkpoint(
|
||||
session,
|
||||
ckpt_path,
|
||||
allow_drop_layers,
|
||||
allow_lr_init=allow_lr_init,
|
||||
silent=silent,
|
||||
)
|
||||
maybe_log_info("Could not find most recent checkpoint.")
|
||||
|
||||
# Initialize all variables
|
||||
elif method == "init":
|
||||
log_info("Initializing all variables.")
|
||||
maybe_log_info("Initializing all variables.")
|
||||
return _initialize_all_variables(session)
|
||||
|
||||
else:
|
||||
@ -141,7 +166,7 @@ def reload_best_checkpoint(session):
|
||||
_load_or_init_impl(session, ["best"], allow_drop_layers=False, allow_lr_init=False)
|
||||
|
||||
|
||||
def load_or_init_graph_for_training(session):
|
||||
def load_or_init_graph_for_training(session, silent: bool = False):
|
||||
"""
|
||||
Load variables from checkpoint or initialize variables. By default this will
|
||||
try to load the best validating checkpoint, then try the last checkpoint,
|
||||
@ -152,10 +177,10 @@ def load_or_init_graph_for_training(session):
|
||||
methods = ["best", "last", "init"]
|
||||
else:
|
||||
methods = [Config.load_train]
|
||||
_load_or_init_impl(session, methods, allow_drop_layers=True)
|
||||
_load_or_init_impl(session, methods, allow_drop_layers=True, silent=silent)
|
||||
|
||||
|
||||
def load_graph_for_evaluation(session):
|
||||
def load_graph_for_evaluation(session, silent: bool = False):
|
||||
"""
|
||||
Load variables from checkpoint. Initialization is not allowed. By default
|
||||
this will try to load the best validating checkpoint, then try the last
|
||||
@ -166,4 +191,4 @@ def load_graph_for_evaluation(session):
|
||||
methods = ["best", "last"]
|
||||
else:
|
||||
methods = [Config.load_evaluate]
|
||||
_load_or_init_impl(session, methods, allow_drop_layers=False)
|
||||
_load_or_init_impl(session, methods, allow_drop_layers=False, silent=silent)
|
||||
|
@ -217,10 +217,12 @@ class BaseSttConfig(Coqpit):
|
||||
if not is_remote_path(self.save_checkpoint_dir):
|
||||
os.makedirs(self.save_checkpoint_dir, exist_ok=True)
|
||||
flags_file = os.path.join(self.save_checkpoint_dir, "flags.txt")
|
||||
if not os.path.exists(flags_file):
|
||||
with open_remote(flags_file, "w") as fout:
|
||||
json.dump(self.serialize(), fout, indent=2)
|
||||
|
||||
# Serialize alphabet alongside checkpoint
|
||||
if not os.path.exists(saved_checkpoint_alphabet_file):
|
||||
with open_remote(saved_checkpoint_alphabet_file, "wb") as fout:
|
||||
fout.write(self.alphabet.SerializeText())
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user