Merge pull request #2032 from coqui-ai/transcription-scripts-docs

[transcribe] Fix multiprocessing hangs, clean-up target collection, write docs
This commit is contained in:
Reuben Morais 2021-12-03 16:46:48 +01:00 committed by GitHub
commit dbd38c3a89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 359 additions and 253 deletions

View File

@ -80,12 +80,11 @@ time python -m coqui_stt_training.transcribe \
--n_hidden 100 \ --n_hidden 100 \
--scorer_path "data/smoke_test/pruned_lm.scorer" --scorer_path "data/smoke_test/pruned_lm.scorer"
#TODO: investigate why this is hanging in CI mkdir /tmp/transcribe_dir
#mkdir /tmp/transcribe_dir cp data/smoke_test/LDC93S1.wav /tmp/transcribe_dir
#cp data/smoke_test/LDC93S1.wav /tmp/transcribe_dir time python -m coqui_stt_training.transcribe \
#time python -m coqui_stt_training.transcribe \ --src "/tmp/transcribe_dir/" \
# --src "/tmp/transcribe_dir/" \ --n_hidden 100 \
# --n_hidden 100 \ --scorer_path "data/smoke_test/pruned_lm.scorer"
# --scorer_path "data/smoke_test/pruned_lm.scorer"
# for i in /tmp/transcribe_dir/*.tlog; do echo $i; cat $i; echo; done
#for i in data/smoke_test/*.tlog; do echo $i; cat $i; echo; done

View File

@ -0,0 +1,79 @@
.. _checkpoint-inference:
Inference tools in the training package
=======================================
The standard deployment options for 🐸STT use highly optimized packages for deployment in real time, single-stream, low latency use cases. They take as input exported models which are also optimized, leading to further space and runtime gains. On the other hand, for the development of new features, it might be easier to use the training code for prototyping, which will allow you to test your changes without needing to recompile source code.
The training package contains options for performing inference directly from a checkpoint (and optionally a scorer), without needing to export a model. They are documented below, and all require a working :ref:`training environment <intro-training-docs>` before they can be used. Additionally, they require the Python ``webrtcvad`` package to be installed. This can either be done by specifying the "transcribe" extra when installing the training package, or by installing it manually in your training environment:
.. code-block:: bash
$ python -m pip install webrtcvad
Note that if your goal is to evaluate a trained model and obtain accuracy metrics, you should use the evaluation module: ``python -m coqui_stt_training.evaluate``, which calculates character and word error rates, from a properly formatted CSV file (specified with the ``--test_files`` flag. See the :ref:`training docs <intro-training-docs>` for more information).
Single file (aka one-shot) inference
------------------------------------
This is the simplest way to perform inference from a checkpoint. It takes a single WAV file as input with the ``--one_shot_infer`` flag, and outputs the predicted transcription for that file.
.. code-block:: bash
$ python -m coqui_stt_training.training_graph_inference --checkpoint_dir coqui-stt-1.0.0-checkpoint --scorer_path huge-vocabulary.scorer --n_hidden 2048 --one_shot_infer audio/2830-3980-0043.wav
I --alphabet_config_path not specified, but found an alphabet file alongside specified checkpoint (coqui-stt-1.0.0-checkpoint/alphabet.txt). Will use this alphabet file for this run.
I Loading best validating checkpoint from coqui-stt-1.0.0-checkpoint/best_dev-3663881
I Loading variable from checkpoint: cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/bias
I Loading variable from checkpoint: cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/kernel
I Loading variable from checkpoint: layer_1/bias
I Loading variable from checkpoint: layer_1/weights
I Loading variable from checkpoint: layer_2/bias
I Loading variable from checkpoint: layer_2/weights
I Loading variable from checkpoint: layer_3/bias
I Loading variable from checkpoint: layer_3/weights
I Loading variable from checkpoint: layer_5/bias
I Loading variable from checkpoint: layer_5/weights
I Loading variable from checkpoint: layer_6/bias
I Loading variable from checkpoint: layer_6/weights
experience proves this
Transcription of longer audio files
-----------------------------------
If you have longer audio files to transcribe, we offer a script which uses Voice Activity Detection (VAD) to split audio files in chunks and perform batched inference on said files. This can speed-up the transcription time significantly. The transcription script will also output the results in JSON format, allowing for easier programmatic usage of the outputs.
There are two main usage modes: transcribing a single file, or scanning a directory for audio files and transcribing all of them.
Transcribing a single file
^^^^^^^^^^^^^^^^^^^^^^^^^^
For a single audio file, you can specify it directly in the ``--src`` flag of the ``python -m coqui_stt_training.transcribe`` script:
.. code-block:: bash
$ python -m coqui_stt_training.transcribe --checkpoint_dir coqui-stt-1.0.0-checkpoint --n_hidden 2048 --scorer_path huge-vocabulary.scorer --vad_aggressiveness 0 --src audio/2830-3980-0043.wav
[1]: "audio/2830-3980-0043.wav" -> "audio/2830-3980-0043.tlog"
Transcribing files: 100%|███████████████████████████████████| 1/1 [00:05<00:00, 5.40s/it]
$ cat audio/2830-3980-0043.tlog
[{"start": 150, "end": 1950, "transcript": "experience proves this"}]
Note the use of the ``--vad_aggressiveness`` flag above to control the behavior of the VAD process used to find silent sections of the audio file for splitting into chunks. You can run ``python -m coqui_stt_training.transcribe --help`` to see the full listing of options, the last ones are specific to the transcribe module.
By default the transcription results are put in a ``.tlog`` file next to the audio file that was transcribed, but you can specify a different location with the ``--dst path/to/some/file.tlog`` flag. This only works when trancribing a single file.
Scanning a directory for audio files
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Alternatively you can also specify a directory in the ``--src`` flag, in which case the directory will be scanned for any WAV files to be transcribed. If you specify ``--recursive true``, it'll scan the directory recursively, going into any subdirectories as well. Transcription results will be placed in a ``.tlog`` file alongside every audio file that was found by the process.
Multiple processes will be used to distribute the transcription work among available CPUs.
.. code-block:: bash
$ python -m coqui_stt_training.transcribe --checkpoint_dir coqui-stt-1.0.0-checkpoint --n_hidden 2048 --scorer_path huge-vocabulary.scorer --vad_aggressiveness 0 --src audio/ --recursive true
Transcribing all files in --src directory audio
Transcribing files: 0%| | 0/3 [00:00<?, ?it/s]
[3]: "audio/8455-210777-0068.wav" -> "audio/8455-210777-0068.tlog"
[1]: "audio/2830-3980-0043.wav" -> "audio/2830-3980-0043.tlog"
[2]: "audio/4507-16021-0012.wav" -> "audio/4507-16021-0012.tlog"
Transcribing files: 100%|███████████████████████████████████| 3/3 [00:07<00:00, 2.50s/it]

View File

@ -18,6 +18,8 @@ You can deploy 🐸STT models either via a command-line client or a language bin
* :ref:`The command-line client <cli-usage>` * :ref:`The command-line client <cli-usage>`
* :ref:`The C API <c-usage>` * :ref:`The C API <c-usage>`
In some use cases, you might want to use the inference facilities built into the training code, for example for faster prototyping of new features. They are not production-ready, but because it's all Python code you won't need to recompile in order to test code changes, which can be much faster. See :ref:`checkpoint-inference` for more details.
.. _download-models: .. _download-models:
Download trained Coqui STT models Download trained Coqui STT models
@ -103,14 +105,6 @@ The following command assumes you :ref:`downloaded the pre-trained models <downl
See :ref:`the Python client <py-api-example>` for an example of how to use the package programatically. See :ref:`the Python client <py-api-example>` for an example of how to use the package programatically.
*GPUs will soon be supported:* If you have a supported NVIDIA GPU on Linux, you can install the GPU specific package as follows:
.. code-block::
(coqui-stt-venv)$ python -m pip install -U pip && python -m pip install stt-gpu
See the `release notes <https://github.com/coqui-ai/STT/releases>`_ to find which GPUs are supported. Please ensure you have the required `CUDA dependency <#cuda-dependency>`_.
.. _nodejs-usage: .. _nodejs-usage:
Using the Node.JS / Electron.JS package Using the Node.JS / Electron.JS package
@ -132,14 +126,6 @@ Please note that as of now, we support:
TypeScript support is also provided. TypeScript support is also provided.
If you're using Linux and have a supported NVIDIA GPU, you can install the GPU specific package as follows:
.. code-block:: bash
npm install stt-gpu
See the `release notes <https://github.com/coqui-ai/STT/releases>`_ to find which GPUs are supported. Please ensure you have the required `CUDA dependency <#cuda-dependency>`_.
See the :ref:`TypeScript client <js-api-example>` for an example of how to use the bindings programatically. See the :ref:`TypeScript client <js-api-example>` for an example of how to use the bindings programatically.
.. _android-usage: .. _android-usage:
@ -232,11 +218,6 @@ Running ``stt`` may require runtime dependencies. Please refer to your system's
* ``libpthread`` - Reported dependency on Linux. On Ubuntu, ``libpthread`` is part of the ``libpthread-stubs0-dev`` package * ``libpthread`` - Reported dependency on Linux. On Ubuntu, ``libpthread`` is part of the ``libpthread-stubs0-dev`` package
* ``Redistribuable Visual C++ 2015 Update 3 (64-bits)`` - Reported dependency on Windows. Please `download from Microsoft <https://www.microsoft.com/download/details.aspx?id=53587>`_ * ``Redistribuable Visual C++ 2015 Update 3 (64-bits)`` - Reported dependency on Windows. Please `download from Microsoft <https://www.microsoft.com/download/details.aspx?id=53587>`_
CUDA Dependency
^^^^^^^^^^^^^^^
The GPU capable builds (Python, NodeJS, C++, etc) depend on CUDA 10.1 and CuDNN v7.6.
.. toctree:: .. toctree::
:maxdepth: 1 :maxdepth: 1

View File

@ -27,3 +27,5 @@ This document contains more advanced topics with regard to training models with
PARALLLEL_OPTIMIZATION PARALLLEL_OPTIMIZATION
DATASET_IMPORTERS DATASET_IMPORTERS
Checkpoint-Inference

View File

@ -1,3 +1,3 @@
#flags pre { #flags pre, #inference-tools-in-the-training-package pre {
white-space: pre-wrap; white-space: pre-wrap;
} }

View File

@ -66,7 +66,7 @@ def main():
], ],
package_dir={"": "training"}, package_dir={"": "training"},
packages=find_packages(where="training"), packages=find_packages(where="training"),
python_requires=">=3.5, <4", python_requires=">=3.5, <3.8",
install_requires=install_requires, install_requires=install_requires,
include_package_data=True, include_package_data=True,
extras_require={ extras_require={

View File

@ -1,48 +1,41 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# This script is structured in a weird way, with delayed imports. This is due
# to the use of multiprocessing. TensorFlow cannot handle forking, and even with
# the spawn strategy set to "spawn" it still leads to weird problems, so we
# restructure the code so that TensorFlow is only imported inside the child
# processes.
import os
import sys
import glob import glob
import itertools import itertools
import json import json
import multiprocessing import multiprocessing
from multiprocessing import Pool, cpu_count import os
import sys
from dataclasses import dataclass, field from dataclasses import dataclass, field
from multiprocessing import Pool, Lock, cpu_count
from pathlib import Path
from typing import Optional, List, Tuple
LOG_LEVEL_INDEX = sys.argv.index("--log_level") + 1 if "--log_level" in sys.argv else 0
DESIRED_LOG_LEVEL = (
sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else "3"
)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = DESIRED_LOG_LEVEL
# Hide GPUs to prevent issues with child processes trying to use the same GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch
from coqui_stt_training.train import create_model
from coqui_stt_training.util.audio import AudioFile
from coqui_stt_training.util.checkpoints import load_graph_for_evaluation
from coqui_stt_training.util.config import (
BaseSttConfig,
Config,
initialize_globals_from_instance,
)
from coqui_stt_training.util.feeding import split_audio_file
from coqui_stt_training.util.helpers import check_ctcdecoder_version
from tqdm import tqdm from tqdm import tqdm
def fail(message, code=1): def transcribe_file(audio_path: Path, tlog_path: Path):
print(f"E {message}")
sys.exit(code)
def transcribe_file(audio_path, tlog_path):
log_level_index = (
sys.argv.index("--log_level") + 1 if "--log_level" in sys.argv else 0
)
desired_log_level = (
sys.argv[log_level_index] if 0 < log_level_index < len(sys.argv) else "3"
)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = desired_log_level
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
from coqui_stt_training.train import create_model
from coqui_stt_training.util.audio import AudioFile
from coqui_stt_training.util.checkpoints import load_graph_for_evaluation
from coqui_stt_training.util.config import Config
from coqui_stt_training.util.feeding import split_audio_file
initialize_transcribe_config() initialize_transcribe_config()
scorer = None scorer = None
@ -56,7 +49,7 @@ def transcribe_file(audio_path, tlog_path):
except NotImplementedError: except NotImplementedError:
num_processes = 1 num_processes = 1
with AudioFile(audio_path, as_path=True) as wav_path: with AudioFile(str(audio_path), as_path=True) as wav_path:
data_set = split_audio_file( data_set = split_audio_file(
wav_path, wav_path,
batch_size=Config.batch_size, batch_size=Config.batch_size,
@ -73,7 +66,9 @@ def transcribe_file(audio_path, tlog_path):
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2])) transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
tf.train.get_or_create_global_step() tf.train.get_or_create_global_step()
with tf.Session(config=Config.session_config) as session: with tf.Session(config=Config.session_config) as session:
load_graph_for_evaluation(session) # Load checkpoint in a mutex way to avoid hangs in TensorFlow code
with lock:
load_graph_for_evaluation(session, silent=True)
transcripts = [] transcripts = []
while True: while True:
try: try:
@ -101,21 +96,28 @@ def transcribe_file(audio_path, tlog_path):
json.dump(transcripts, tlog_file, default=float) json.dump(transcripts, tlog_file, default=float)
def init_fn(l):
global lock
lock = l
def step_function(job): def step_function(job):
""" Wrap transcribe_file to unpack arguments from a single tuple """ """Wrap transcribe_file to unpack arguments from a single tuple"""
idx, src, dst = job idx, src, dst = job
transcribe_file(src, dst) transcribe_file(src, dst)
return idx, src, dst return idx, src, dst
def transcribe_many(src_paths, dst_paths): def transcribe_many(src_paths, dst_paths):
from coqui_stt_training.util.config import Config, log_progress
pool = Pool(processes=min(cpu_count(), len(src_paths)))
# Create list of items to be processed: [(i, src_path[i], dst_paths[i])] # Create list of items to be processed: [(i, src_path[i], dst_paths[i])]
jobs = zip(itertools.count(), src_paths, dst_paths) jobs = zip(itertools.count(), src_paths, dst_paths)
lock = Lock()
with Pool(
processes=min(cpu_count(), len(src_paths)),
initializer=init_fn,
initargs=(lock,),
) as pool:
process_iterable = tqdm( process_iterable = tqdm(
pool.imap_unordered(step_function, jobs), pool.imap_unordered(step_function, jobs),
desc="Transcribing files", desc="Transcribing files",
@ -123,106 +125,117 @@ def transcribe_many(src_paths, dst_paths):
disable=not Config.show_progressbar, disable=not Config.show_progressbar,
) )
cwd = Path.cwd()
for result in process_iterable: for result in process_iterable:
idx, src, dst = result idx, src, dst = result
log_progress( # Revert to relative if possible to make logs more concise
f'Transcribed file {idx+1} of {len(src_paths)} from "{src}" to "{dst}"' # if path is not relative to cwd, use the absolute path
) # (Path.is_relative_to is only available in Python >=3.9)
try:
src = src.relative_to(cwd)
except ValueError:
pass
try:
dst = dst.relative_to(cwd)
except ValueError:
pass
tqdm.write(f'[{idx+1}]: "{src}" -> "{dst}"')
def transcribe_one(src_path, dst_path): def get_tasks_from_catalog(catalog_file_path: Path) -> Tuple[List[Path], List[Path]]:
transcribe_file(src_path, dst_path) """Given a `catalog_file_path` pointing to a .catalog file (from DSAlign),
print(f'I Transcribed file "{src_path}" to "{dst_path}"') extract transcription tasks, ie. (src_path, dest_path) pairs corresponding to
a path to an audio file to be transcribed, and a path to a JSON file to place
transcription results. For .catalog file inputs, these are taken from the
"audio" and "tlog" properties of the entries in the catalog, with any relative
paths being absolutized relative to the directory containing the .catalog file.
"""
assert catalog_file_path.suffix == ".catalog"
catalog_dir = catalog_file_path.parent
with open(catalog_file_path, "r") as catalog_file:
catalog_entries = json.load(catalog_file)
def resolve(base_path, spec_path): def resolve(spec_path: Optional[Path]):
if spec_path is None: if spec_path is None:
return None return None
if not os.path.isabs(spec_path): if not spec_path.is_absolute():
spec_path = os.path.join(base_path, spec_path) spec_path = catalog_dir / spec_path
return spec_path return spec_path
catalog_entries = [
(resolve(Path(e["audio"])), resolve(Path(e["tlog"]))) for e in catalog_entries
]
for src, dst in catalog_entries:
if not Config.force and dst.is_file():
raise RuntimeError(
f"Destination file already exists: {dst}. Use --force for overwriting."
)
if not dst.parent.is_dir():
dst.parent.mkdir(parents=True)
src_paths, dst_paths = zip(*catalog_entries)
return src_paths, dst_paths
def get_tasks_from_dir(src_dir: Path, recursive: bool) -> Tuple[List[Path], List[Path]]:
"""Given a directory `src_dir` containing audio files, scan it for audio files
and return transcription tasks, ie. (src_path, dest_path) pairs corresponding to
a path to an audio file to be transcribed, and a path to a JSON file to place
transcription results.
"""
glob_method = src_dir.rglob if recursive else src_dir.glob
src_paths = list(glob_method("*.wav"))
dst_paths = [path.with_suffix(".tlog") for path in src_paths]
return src_paths, dst_paths
def transcribe(): def transcribe():
from coqui_stt_training.util.config import Config
initialize_transcribe_config() initialize_transcribe_config()
if not Config.src or not os.path.exists(Config.src): src_path = Path(Config.src).resolve()
if not Config.src or not src_path.exists():
# path not given or non-existant # path not given or non-existant
fail( raise RuntimeError(
"You have to specify which file or catalog to transcribe via the --src flag." "You have to specify which audio file, catalog file or directory to "
"transcribe with the --src flag."
) )
else: else:
# path given and exists # path given and exists
src_path = os.path.abspath(Config.src) if src_path.is_file():
if os.path.isfile(src_path): if src_path.suffix != ".catalog":
if src_path.endswith(".catalog"):
# Transcribe batch of files via ".catalog" file (from DSAlign)
catalog_dir = os.path.dirname(src_path)
with open(src_path, "r") as catalog_file:
catalog_entries = json.load(catalog_file)
catalog_entries = [
(resolve(catalog_dir, e["audio"]), resolve(catalog_dir, e["tlog"]))
for e in catalog_entries
]
if any(map(lambda e: not os.path.isfile(e[0]), catalog_entries)):
fail("Missing source file(s) in catalog")
if not Config.force and any(
map(lambda e: os.path.isfile(e[1]), catalog_entries)
):
fail(
"Destination file(s) from catalog already existing, use --force for overwriting"
)
if any(
map(
lambda e: not os.path.isdir(os.path.dirname(e[1])),
catalog_entries,
)
):
fail("Missing destination directory for at least one catalog entry")
src_paths, dst_paths = zip(*paths)
transcribe_many(src_paths, dst_paths)
else:
# Transcribe one file # Transcribe one file
dst_path = ( dst_path = (
os.path.abspath(Config.dst) Path(Config.dst).resolve()
if Config.dst if Config.dst
else os.path.splitext(src_path)[0] + ".tlog" else src_path.with_suffix(".tlog")
)
if os.path.isfile(dst_path):
if Config.force:
transcribe_one(src_path, dst_path)
else:
fail(
'Destination file "{}" already existing - use --force for overwriting'.format(
dst_path
),
code=0,
)
elif os.path.isdir(os.path.dirname(dst_path)):
transcribe_one(src_path, dst_path)
else:
fail("Missing destination directory")
elif os.path.isdir(src_path):
# Transcribe all files in dir
print("Transcribing all WAV files in --src")
if Config.recursive:
wav_paths = glob.glob(os.path.join(src_path, "**", "*.wav"))
else:
wav_paths = glob.glob(os.path.join(src_path, "*.wav"))
dst_paths = [path.replace(".wav", ".tlog") for path in wav_paths]
transcribe_many(wav_paths, dst_paths)
def initialize_transcribe_config():
from coqui_stt_training.util.config import (
BaseSttConfig,
initialize_globals_from_instance,
) )
@dataclass if dst_path.is_file() and not Config.force:
class TranscribeConfig(BaseSttConfig): raise RuntimeError(
f'Destination file "{dst_path}" already exists - use '
"--force for overwriting."
)
if not dst_path.parent.is_dir():
raise RuntimeError("Missing destination directory")
transcribe_many([src_path], [dst_path])
else:
# Transcribe from .catalog input
src_paths, dst_paths = get_tasks_from_catalog(src_path)
transcribe_many(src_paths, dst_paths)
elif src_path.is_dir():
# Transcribe from dir input
print(f"Transcribing all files in --src directory {src_path}")
src_paths, dst_paths = get_tasks_from_dir(src_path, Config.recursive)
transcribe_many(src_paths, dst_paths)
@dataclass
class TranscribeConfig(BaseSttConfig):
src: str = field( src: str = field(
default="", default="",
metadata=dict( metadata=dict(
@ -230,7 +243,7 @@ def initialize_transcribe_config():
"Catalog files should be formatted from DSAlign. A directory " "Catalog files should be formatted from DSAlign. A directory "
"will be recursively searched for audio. If --dst not set, " "will be recursively searched for audio. If --dst not set, "
"transcription logs (.tlog) will be written in-place using the " "transcription logs (.tlog) will be written in-place using the "
'source filenames with suffix ".tlog" instead of ".wav".' 'source filenames with suffix ".tlog" instead of the original.'
), ),
) )
@ -292,12 +305,17 @@ def initialize_transcribe_config():
super().__post_init__() super().__post_init__()
def initialize_transcribe_config():
config = TranscribeConfig.init_from_argparse(arg_prefix="") config = TranscribeConfig.init_from_argparse(arg_prefix="")
initialize_globals_from_instance(config) initialize_globals_from_instance(config)
def main(): def main():
from coqui_stt_training.util.helpers import check_ctcdecoder_version assert not tf.test.is_gpu_available()
# Set start method to spawn on all platforms to avoid issues with TensorFlow
multiprocessing.set_start_method("spawn")
try: try:
import webrtcvad import webrtcvad

View File

@ -7,7 +7,13 @@ import tensorflow as tf
from .config import Config, log_error, log_info, log_warn from .config import Config, log_error, log_info, log_warn
def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=True): def _load_checkpoint(
session,
checkpoint_path,
allow_drop_layers,
allow_lr_init=True,
silent: bool = False,
):
# Load the checkpoint and put all variables into loading list # Load the checkpoint and put all variables into loading list
# we will exclude variables we do not wish to load and then # we will exclude variables we do not wish to load and then
# we will initialize them instead # we will initialize them instead
@ -75,15 +81,16 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
init_vars.add(v) init_vars.add(v)
load_vars -= init_vars load_vars -= init_vars
log_info(f"Vars to load: {list(sorted(v.op.name for v in load_vars))}") def maybe_log_info(*args, **kwargs):
if not silent:
log_info(*args, **kwargs)
for v in sorted(load_vars, key=lambda v: v.op.name): for v in sorted(load_vars, key=lambda v: v.op.name):
log_info(f"Getting tensor from variable: {v.op.name}") maybe_log_info(f"Loading variable from checkpoint: {v.op.name}")
tensor = ckpt.get_tensor(v.op.name) v.load(ckpt.get_tensor(v.op.name), session=session)
log_info(f"Loading tensor from checkpoint: {v.op.name}")
v.load(tensor, session=session)
for v in sorted(init_vars, key=lambda v: v.op.name): for v in sorted(init_vars, key=lambda v: v.op.name):
log_info("Initializing variable: %s" % (v.op.name)) maybe_log_info("Initializing variable: %s" % (v.op.name))
session.run(v.initializer) session.run(v.initializer)
@ -102,31 +109,49 @@ def _initialize_all_variables(session):
session.run(v.initializer) session.run(v.initializer)
def _load_or_init_impl(session, method_order, allow_drop_layers, allow_lr_init=True): def _load_or_init_impl(
session, method_order, allow_drop_layers, allow_lr_init=True, silent: bool = False
):
def maybe_log_info(*args, **kwargs):
if not silent:
log_info(*args, **kwargs)
for method in method_order: for method in method_order:
# Load best validating checkpoint, saved in checkpoint file 'best_dev_checkpoint' # Load best validating checkpoint, saved in checkpoint file 'best_dev_checkpoint'
if method == "best": if method == "best":
ckpt_path = _checkpoint_path_or_none("best_dev_checkpoint") ckpt_path = _checkpoint_path_or_none("best_dev_checkpoint")
if ckpt_path: if ckpt_path:
log_info("Loading best validating checkpoint from {}".format(ckpt_path)) maybe_log_info(
return _load_checkpoint( "Loading best validating checkpoint from {}".format(ckpt_path)
session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init
) )
log_info("Could not find best validating checkpoint.") return _load_checkpoint(
session,
ckpt_path,
allow_drop_layers,
allow_lr_init=allow_lr_init,
silent=silent,
)
maybe_log_info("Could not find best validating checkpoint.")
# Load most recent checkpoint, saved in checkpoint file 'checkpoint' # Load most recent checkpoint, saved in checkpoint file 'checkpoint'
elif method == "last": elif method == "last":
ckpt_path = _checkpoint_path_or_none("checkpoint") ckpt_path = _checkpoint_path_or_none("checkpoint")
if ckpt_path: if ckpt_path:
log_info("Loading most recent checkpoint from {}".format(ckpt_path)) maybe_log_info(
return _load_checkpoint( "Loading most recent checkpoint from {}".format(ckpt_path)
session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init
) )
log_info("Could not find most recent checkpoint.") return _load_checkpoint(
session,
ckpt_path,
allow_drop_layers,
allow_lr_init=allow_lr_init,
silent=silent,
)
maybe_log_info("Could not find most recent checkpoint.")
# Initialize all variables # Initialize all variables
elif method == "init": elif method == "init":
log_info("Initializing all variables.") maybe_log_info("Initializing all variables.")
return _initialize_all_variables(session) return _initialize_all_variables(session)
else: else:
@ -141,7 +166,7 @@ def reload_best_checkpoint(session):
_load_or_init_impl(session, ["best"], allow_drop_layers=False, allow_lr_init=False) _load_or_init_impl(session, ["best"], allow_drop_layers=False, allow_lr_init=False)
def load_or_init_graph_for_training(session): def load_or_init_graph_for_training(session, silent: bool = False):
""" """
Load variables from checkpoint or initialize variables. By default this will Load variables from checkpoint or initialize variables. By default this will
try to load the best validating checkpoint, then try the last checkpoint, try to load the best validating checkpoint, then try the last checkpoint,
@ -152,10 +177,10 @@ def load_or_init_graph_for_training(session):
methods = ["best", "last", "init"] methods = ["best", "last", "init"]
else: else:
methods = [Config.load_train] methods = [Config.load_train]
_load_or_init_impl(session, methods, allow_drop_layers=True) _load_or_init_impl(session, methods, allow_drop_layers=True, silent=silent)
def load_graph_for_evaluation(session): def load_graph_for_evaluation(session, silent: bool = False):
""" """
Load variables from checkpoint. Initialization is not allowed. By default Load variables from checkpoint. Initialization is not allowed. By default
this will try to load the best validating checkpoint, then try the last this will try to load the best validating checkpoint, then try the last
@ -166,4 +191,4 @@ def load_graph_for_evaluation(session):
methods = ["best", "last"] methods = ["best", "last"]
else: else:
methods = [Config.load_evaluate] methods = [Config.load_evaluate]
_load_or_init_impl(session, methods, allow_drop_layers=False) _load_or_init_impl(session, methods, allow_drop_layers=False, silent=silent)

View File

@ -217,10 +217,12 @@ class BaseSttConfig(Coqpit):
if not is_remote_path(self.save_checkpoint_dir): if not is_remote_path(self.save_checkpoint_dir):
os.makedirs(self.save_checkpoint_dir, exist_ok=True) os.makedirs(self.save_checkpoint_dir, exist_ok=True)
flags_file = os.path.join(self.save_checkpoint_dir, "flags.txt") flags_file = os.path.join(self.save_checkpoint_dir, "flags.txt")
if not os.path.exists(flags_file):
with open_remote(flags_file, "w") as fout: with open_remote(flags_file, "w") as fout:
json.dump(self.serialize(), fout, indent=2) json.dump(self.serialize(), fout, indent=2)
# Serialize alphabet alongside checkpoint # Serialize alphabet alongside checkpoint
if not os.path.exists(saved_checkpoint_alphabet_file):
with open_remote(saved_checkpoint_alphabet_file, "wb") as fout: with open_remote(saved_checkpoint_alphabet_file, "wb") as fout:
fout.write(self.alphabet.SerializeText()) fout.write(self.alphabet.SerializeText())