Merge pull request #2032 from coqui-ai/transcription-scripts-docs

[transcribe] Fix multiprocessing hangs, clean-up target collection, write docs
This commit is contained in:
Reuben Morais 2021-12-03 16:46:48 +01:00 committed by GitHub
commit dbd38c3a89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 359 additions and 253 deletions

View File

@ -80,12 +80,11 @@ time python -m coqui_stt_training.transcribe \
--n_hidden 100 \
--scorer_path "data/smoke_test/pruned_lm.scorer"
#TODO: investigate why this is hanging in CI
#mkdir /tmp/transcribe_dir
#cp data/smoke_test/LDC93S1.wav /tmp/transcribe_dir
#time python -m coqui_stt_training.transcribe \
# --src "/tmp/transcribe_dir/" \
# --n_hidden 100 \
# --scorer_path "data/smoke_test/pruned_lm.scorer"
#
#for i in data/smoke_test/*.tlog; do echo $i; cat $i; echo; done
mkdir /tmp/transcribe_dir
cp data/smoke_test/LDC93S1.wav /tmp/transcribe_dir
time python -m coqui_stt_training.transcribe \
--src "/tmp/transcribe_dir/" \
--n_hidden 100 \
--scorer_path "data/smoke_test/pruned_lm.scorer"
for i in /tmp/transcribe_dir/*.tlog; do echo $i; cat $i; echo; done

View File

@ -0,0 +1,79 @@
.. _checkpoint-inference:
Inference tools in the training package
=======================================
The standard deployment options for 🐸STT use highly optimized packages for deployment in real time, single-stream, low latency use cases. They take as input exported models which are also optimized, leading to further space and runtime gains. On the other hand, for the development of new features, it might be easier to use the training code for prototyping, which will allow you to test your changes without needing to recompile source code.
The training package contains options for performing inference directly from a checkpoint (and optionally a scorer), without needing to export a model. They are documented below, and all require a working :ref:`training environment <intro-training-docs>` before they can be used. Additionally, they require the Python ``webrtcvad`` package to be installed. This can either be done by specifying the "transcribe" extra when installing the training package, or by installing it manually in your training environment:
.. code-block:: bash
$ python -m pip install webrtcvad
Note that if your goal is to evaluate a trained model and obtain accuracy metrics, you should use the evaluation module: ``python -m coqui_stt_training.evaluate``, which calculates character and word error rates, from a properly formatted CSV file (specified with the ``--test_files`` flag. See the :ref:`training docs <intro-training-docs>` for more information).
Single file (aka one-shot) inference
------------------------------------
This is the simplest way to perform inference from a checkpoint. It takes a single WAV file as input with the ``--one_shot_infer`` flag, and outputs the predicted transcription for that file.
.. code-block:: bash
$ python -m coqui_stt_training.training_graph_inference --checkpoint_dir coqui-stt-1.0.0-checkpoint --scorer_path huge-vocabulary.scorer --n_hidden 2048 --one_shot_infer audio/2830-3980-0043.wav
I --alphabet_config_path not specified, but found an alphabet file alongside specified checkpoint (coqui-stt-1.0.0-checkpoint/alphabet.txt). Will use this alphabet file for this run.
I Loading best validating checkpoint from coqui-stt-1.0.0-checkpoint/best_dev-3663881
I Loading variable from checkpoint: cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/bias
I Loading variable from checkpoint: cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/kernel
I Loading variable from checkpoint: layer_1/bias
I Loading variable from checkpoint: layer_1/weights
I Loading variable from checkpoint: layer_2/bias
I Loading variable from checkpoint: layer_2/weights
I Loading variable from checkpoint: layer_3/bias
I Loading variable from checkpoint: layer_3/weights
I Loading variable from checkpoint: layer_5/bias
I Loading variable from checkpoint: layer_5/weights
I Loading variable from checkpoint: layer_6/bias
I Loading variable from checkpoint: layer_6/weights
experience proves this
Transcription of longer audio files
-----------------------------------
If you have longer audio files to transcribe, we offer a script which uses Voice Activity Detection (VAD) to split audio files in chunks and perform batched inference on said files. This can speed-up the transcription time significantly. The transcription script will also output the results in JSON format, allowing for easier programmatic usage of the outputs.
There are two main usage modes: transcribing a single file, or scanning a directory for audio files and transcribing all of them.
Transcribing a single file
^^^^^^^^^^^^^^^^^^^^^^^^^^
For a single audio file, you can specify it directly in the ``--src`` flag of the ``python -m coqui_stt_training.transcribe`` script:
.. code-block:: bash
$ python -m coqui_stt_training.transcribe --checkpoint_dir coqui-stt-1.0.0-checkpoint --n_hidden 2048 --scorer_path huge-vocabulary.scorer --vad_aggressiveness 0 --src audio/2830-3980-0043.wav
[1]: "audio/2830-3980-0043.wav" -> "audio/2830-3980-0043.tlog"
Transcribing files: 100%|███████████████████████████████████| 1/1 [00:05<00:00, 5.40s/it]
$ cat audio/2830-3980-0043.tlog
[{"start": 150, "end": 1950, "transcript": "experience proves this"}]
Note the use of the ``--vad_aggressiveness`` flag above to control the behavior of the VAD process used to find silent sections of the audio file for splitting into chunks. You can run ``python -m coqui_stt_training.transcribe --help`` to see the full listing of options, the last ones are specific to the transcribe module.
By default the transcription results are put in a ``.tlog`` file next to the audio file that was transcribed, but you can specify a different location with the ``--dst path/to/some/file.tlog`` flag. This only works when trancribing a single file.
Scanning a directory for audio files
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Alternatively you can also specify a directory in the ``--src`` flag, in which case the directory will be scanned for any WAV files to be transcribed. If you specify ``--recursive true``, it'll scan the directory recursively, going into any subdirectories as well. Transcription results will be placed in a ``.tlog`` file alongside every audio file that was found by the process.
Multiple processes will be used to distribute the transcription work among available CPUs.
.. code-block:: bash
$ python -m coqui_stt_training.transcribe --checkpoint_dir coqui-stt-1.0.0-checkpoint --n_hidden 2048 --scorer_path huge-vocabulary.scorer --vad_aggressiveness 0 --src audio/ --recursive true
Transcribing all files in --src directory audio
Transcribing files: 0%| | 0/3 [00:00<?, ?it/s]
[3]: "audio/8455-210777-0068.wav" -> "audio/8455-210777-0068.tlog"
[1]: "audio/2830-3980-0043.wav" -> "audio/2830-3980-0043.tlog"
[2]: "audio/4507-16021-0012.wav" -> "audio/4507-16021-0012.tlog"
Transcribing files: 100%|███████████████████████████████████| 3/3 [00:07<00:00, 2.50s/it]

View File

@ -18,6 +18,8 @@ You can deploy 🐸STT models either via a command-line client or a language bin
* :ref:`The command-line client <cli-usage>`
* :ref:`The C API <c-usage>`
In some use cases, you might want to use the inference facilities built into the training code, for example for faster prototyping of new features. They are not production-ready, but because it's all Python code you won't need to recompile in order to test code changes, which can be much faster. See :ref:`checkpoint-inference` for more details.
.. _download-models:
Download trained Coqui STT models
@ -103,14 +105,6 @@ The following command assumes you :ref:`downloaded the pre-trained models <downl
See :ref:`the Python client <py-api-example>` for an example of how to use the package programatically.
*GPUs will soon be supported:* If you have a supported NVIDIA GPU on Linux, you can install the GPU specific package as follows:
.. code-block::
(coqui-stt-venv)$ python -m pip install -U pip && python -m pip install stt-gpu
See the `release notes <https://github.com/coqui-ai/STT/releases>`_ to find which GPUs are supported. Please ensure you have the required `CUDA dependency <#cuda-dependency>`_.
.. _nodejs-usage:
Using the Node.JS / Electron.JS package
@ -132,14 +126,6 @@ Please note that as of now, we support:
TypeScript support is also provided.
If you're using Linux and have a supported NVIDIA GPU, you can install the GPU specific package as follows:
.. code-block:: bash
npm install stt-gpu
See the `release notes <https://github.com/coqui-ai/STT/releases>`_ to find which GPUs are supported. Please ensure you have the required `CUDA dependency <#cuda-dependency>`_.
See the :ref:`TypeScript client <js-api-example>` for an example of how to use the bindings programatically.
.. _android-usage:
@ -232,11 +218,6 @@ Running ``stt`` may require runtime dependencies. Please refer to your system's
* ``libpthread`` - Reported dependency on Linux. On Ubuntu, ``libpthread`` is part of the ``libpthread-stubs0-dev`` package
* ``Redistribuable Visual C++ 2015 Update 3 (64-bits)`` - Reported dependency on Windows. Please `download from Microsoft <https://www.microsoft.com/download/details.aspx?id=53587>`_
CUDA Dependency
^^^^^^^^^^^^^^^
The GPU capable builds (Python, NodeJS, C++, etc) depend on CUDA 10.1 and CuDNN v7.6.
.. toctree::
:maxdepth: 1

View File

@ -27,3 +27,5 @@ This document contains more advanced topics with regard to training models with
PARALLLEL_OPTIMIZATION
DATASET_IMPORTERS
Checkpoint-Inference

View File

@ -1,3 +1,3 @@
#flags pre {
#flags pre, #inference-tools-in-the-training-package pre {
white-space: pre-wrap;
}

View File

@ -66,7 +66,7 @@ def main():
],
package_dir={"": "training"},
packages=find_packages(where="training"),
python_requires=">=3.5, <4",
python_requires=">=3.5, <3.8",
install_requires=install_requires,
include_package_data=True,
extras_require={

View File

@ -1,48 +1,41 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# This script is structured in a weird way, with delayed imports. This is due
# to the use of multiprocessing. TensorFlow cannot handle forking, and even with
# the spawn strategy set to "spawn" it still leads to weird problems, so we
# restructure the code so that TensorFlow is only imported inside the child
# processes.
import os
import sys
import glob
import itertools
import json
import multiprocessing
from multiprocessing import Pool, cpu_count
import os
import sys
from dataclasses import dataclass, field
from multiprocessing import Pool, Lock, cpu_count
from pathlib import Path
from typing import Optional, List, Tuple
LOG_LEVEL_INDEX = sys.argv.index("--log_level") + 1 if "--log_level" in sys.argv else 0
DESIRED_LOG_LEVEL = (
sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else "3"
)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = DESIRED_LOG_LEVEL
# Hide GPUs to prevent issues with child processes trying to use the same GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch
from coqui_stt_training.train import create_model
from coqui_stt_training.util.audio import AudioFile
from coqui_stt_training.util.checkpoints import load_graph_for_evaluation
from coqui_stt_training.util.config import (
BaseSttConfig,
Config,
initialize_globals_from_instance,
)
from coqui_stt_training.util.feeding import split_audio_file
from coqui_stt_training.util.helpers import check_ctcdecoder_version
from tqdm import tqdm
def fail(message, code=1):
print(f"E {message}")
sys.exit(code)
def transcribe_file(audio_path, tlog_path):
log_level_index = (
sys.argv.index("--log_level") + 1 if "--log_level" in sys.argv else 0
)
desired_log_level = (
sys.argv[log_level_index] if 0 < log_level_index < len(sys.argv) else "3"
)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = desired_log_level
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
from coqui_stt_training.train import create_model
from coqui_stt_training.util.audio import AudioFile
from coqui_stt_training.util.checkpoints import load_graph_for_evaluation
from coqui_stt_training.util.config import Config
from coqui_stt_training.util.feeding import split_audio_file
def transcribe_file(audio_path: Path, tlog_path: Path):
initialize_transcribe_config()
scorer = None
@ -56,7 +49,7 @@ def transcribe_file(audio_path, tlog_path):
except NotImplementedError:
num_processes = 1
with AudioFile(audio_path, as_path=True) as wav_path:
with AudioFile(str(audio_path), as_path=True) as wav_path:
data_set = split_audio_file(
wav_path,
batch_size=Config.batch_size,
@ -73,7 +66,9 @@ def transcribe_file(audio_path, tlog_path):
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
tf.train.get_or_create_global_step()
with tf.Session(config=Config.session_config) as session:
load_graph_for_evaluation(session)
# Load checkpoint in a mutex way to avoid hangs in TensorFlow code
with lock:
load_graph_for_evaluation(session, silent=True)
transcripts = []
while True:
try:
@ -101,203 +96,226 @@ def transcribe_file(audio_path, tlog_path):
json.dump(transcripts, tlog_file, default=float)
def init_fn(l):
global lock
lock = l
def step_function(job):
""" Wrap transcribe_file to unpack arguments from a single tuple """
"""Wrap transcribe_file to unpack arguments from a single tuple"""
idx, src, dst = job
transcribe_file(src, dst)
return idx, src, dst
def transcribe_many(src_paths, dst_paths):
from coqui_stt_training.util.config import Config, log_progress
pool = Pool(processes=min(cpu_count(), len(src_paths)))
# Create list of items to be processed: [(i, src_path[i], dst_paths[i])]
jobs = zip(itertools.count(), src_paths, dst_paths)
process_iterable = tqdm(
pool.imap_unordered(step_function, jobs),
desc="Transcribing files",
total=len(src_paths),
disable=not Config.show_progressbar,
)
for result in process_iterable:
idx, src, dst = result
log_progress(
f'Transcribed file {idx+1} of {len(src_paths)} from "{src}" to "{dst}"'
lock = Lock()
with Pool(
processes=min(cpu_count(), len(src_paths)),
initializer=init_fn,
initargs=(lock,),
) as pool:
process_iterable = tqdm(
pool.imap_unordered(step_function, jobs),
desc="Transcribing files",
total=len(src_paths),
disable=not Config.show_progressbar,
)
def transcribe_one(src_path, dst_path):
transcribe_file(src_path, dst_path)
print(f'I Transcribed file "{src_path}" to "{dst_path}"')
cwd = Path.cwd()
for result in process_iterable:
idx, src, dst = result
# Revert to relative if possible to make logs more concise
# if path is not relative to cwd, use the absolute path
# (Path.is_relative_to is only available in Python >=3.9)
try:
src = src.relative_to(cwd)
except ValueError:
pass
try:
dst = dst.relative_to(cwd)
except ValueError:
pass
tqdm.write(f'[{idx+1}]: "{src}" -> "{dst}"')
def resolve(base_path, spec_path):
if spec_path is None:
return None
if not os.path.isabs(spec_path):
spec_path = os.path.join(base_path, spec_path)
return spec_path
def get_tasks_from_catalog(catalog_file_path: Path) -> Tuple[List[Path], List[Path]]:
"""Given a `catalog_file_path` pointing to a .catalog file (from DSAlign),
extract transcription tasks, ie. (src_path, dest_path) pairs corresponding to
a path to an audio file to be transcribed, and a path to a JSON file to place
transcription results. For .catalog file inputs, these are taken from the
"audio" and "tlog" properties of the entries in the catalog, with any relative
paths being absolutized relative to the directory containing the .catalog file.
"""
assert catalog_file_path.suffix == ".catalog"
catalog_dir = catalog_file_path.parent
with open(catalog_file_path, "r") as catalog_file:
catalog_entries = json.load(catalog_file)
def resolve(spec_path: Optional[Path]):
if spec_path is None:
return None
if not spec_path.is_absolute():
spec_path = catalog_dir / spec_path
return spec_path
catalog_entries = [
(resolve(Path(e["audio"])), resolve(Path(e["tlog"]))) for e in catalog_entries
]
for src, dst in catalog_entries:
if not Config.force and dst.is_file():
raise RuntimeError(
f"Destination file already exists: {dst}. Use --force for overwriting."
)
if not dst.parent.is_dir():
dst.parent.mkdir(parents=True)
src_paths, dst_paths = zip(*catalog_entries)
return src_paths, dst_paths
def get_tasks_from_dir(src_dir: Path, recursive: bool) -> Tuple[List[Path], List[Path]]:
"""Given a directory `src_dir` containing audio files, scan it for audio files
and return transcription tasks, ie. (src_path, dest_path) pairs corresponding to
a path to an audio file to be transcribed, and a path to a JSON file to place
transcription results.
"""
glob_method = src_dir.rglob if recursive else src_dir.glob
src_paths = list(glob_method("*.wav"))
dst_paths = [path.with_suffix(".tlog") for path in src_paths]
return src_paths, dst_paths
def transcribe():
from coqui_stt_training.util.config import Config
initialize_transcribe_config()
if not Config.src or not os.path.exists(Config.src):
src_path = Path(Config.src).resolve()
if not Config.src or not src_path.exists():
# path not given or non-existant
fail(
"You have to specify which file or catalog to transcribe via the --src flag."
raise RuntimeError(
"You have to specify which audio file, catalog file or directory to "
"transcribe with the --src flag."
)
else:
# path given and exists
src_path = os.path.abspath(Config.src)
if os.path.isfile(src_path):
if src_path.endswith(".catalog"):
# Transcribe batch of files via ".catalog" file (from DSAlign)
catalog_dir = os.path.dirname(src_path)
with open(src_path, "r") as catalog_file:
catalog_entries = json.load(catalog_file)
catalog_entries = [
(resolve(catalog_dir, e["audio"]), resolve(catalog_dir, e["tlog"]))
for e in catalog_entries
]
if any(map(lambda e: not os.path.isfile(e[0]), catalog_entries)):
fail("Missing source file(s) in catalog")
if not Config.force and any(
map(lambda e: os.path.isfile(e[1]), catalog_entries)
):
fail(
"Destination file(s) from catalog already existing, use --force for overwriting"
)
if any(
map(
lambda e: not os.path.isdir(os.path.dirname(e[1])),
catalog_entries,
)
):
fail("Missing destination directory for at least one catalog entry")
src_paths, dst_paths = zip(*paths)
transcribe_many(src_paths, dst_paths)
else:
if src_path.is_file():
if src_path.suffix != ".catalog":
# Transcribe one file
dst_path = (
os.path.abspath(Config.dst)
Path(Config.dst).resolve()
if Config.dst
else os.path.splitext(src_path)[0] + ".tlog"
else src_path.with_suffix(".tlog")
)
if os.path.isfile(dst_path):
if Config.force:
transcribe_one(src_path, dst_path)
else:
fail(
'Destination file "{}" already existing - use --force for overwriting'.format(
dst_path
),
code=0,
)
elif os.path.isdir(os.path.dirname(dst_path)):
transcribe_one(src_path, dst_path)
else:
fail("Missing destination directory")
elif os.path.isdir(src_path):
# Transcribe all files in dir
print("Transcribing all WAV files in --src")
if Config.recursive:
wav_paths = glob.glob(os.path.join(src_path, "**", "*.wav"))
if dst_path.is_file() and not Config.force:
raise RuntimeError(
f'Destination file "{dst_path}" already exists - use '
"--force for overwriting."
)
if not dst_path.parent.is_dir():
raise RuntimeError("Missing destination directory")
transcribe_many([src_path], [dst_path])
else:
wav_paths = glob.glob(os.path.join(src_path, "*.wav"))
dst_paths = [path.replace(".wav", ".tlog") for path in wav_paths]
transcribe_many(wav_paths, dst_paths)
# Transcribe from .catalog input
src_paths, dst_paths = get_tasks_from_catalog(src_path)
transcribe_many(src_paths, dst_paths)
elif src_path.is_dir():
# Transcribe from dir input
print(f"Transcribing all files in --src directory {src_path}")
src_paths, dst_paths = get_tasks_from_dir(src_path, Config.recursive)
transcribe_many(src_paths, dst_paths)
@dataclass
class TranscribeConfig(BaseSttConfig):
src: str = field(
default="",
metadata=dict(
help="Source path to an audio file or directory or catalog file. "
"Catalog files should be formatted from DSAlign. A directory "
"will be recursively searched for audio. If --dst not set, "
"transcription logs (.tlog) will be written in-place using the "
'source filenames with suffix ".tlog" instead of the original.'
),
)
dst: str = field(
default="",
metadata=dict(
help="path for writing the transcription log or logs (.tlog). "
"If --src is a directory, this one also has to be a directory "
"and the required sub-dir tree of --src will get replicated."
),
)
recursive: bool = field(
default=False,
metadata=dict(help="scan source directory recursively for audio"),
)
force: bool = field(
default=False,
metadata=dict(
help="Forces re-transcribing and overwriting of already existing "
"transcription logs (.tlog)"
),
)
vad_aggressiveness: int = field(
default=3,
metadata=dict(help="VAD aggressiveness setting (0=lowest, 3=highest)"),
)
batch_size: int = field(
default=40,
metadata=dict(help="Default batch size"),
)
outlier_duration_ms: int = field(
default=10000,
metadata=dict(
help="Duration in ms after which samples are considered outliers"
),
)
outlier_batch_size: int = field(
default=1,
metadata=dict(help="Batch size for duration outliers (defaults to 1)"),
)
def __post_init__(self):
if os.path.isfile(self.src) and self.src.endswith(".catalog") and self.dst:
raise RuntimeError(
"Parameter --dst not supported if --src points to a catalog"
)
if os.path.isdir(self.src):
if self.dst:
raise RuntimeError(
"Destination path not supported for batch decoding jobs."
)
super().__post_init__()
def initialize_transcribe_config():
from coqui_stt_training.util.config import (
BaseSttConfig,
initialize_globals_from_instance,
)
@dataclass
class TranscribeConfig(BaseSttConfig):
src: str = field(
default="",
metadata=dict(
help="Source path to an audio file or directory or catalog file. "
"Catalog files should be formatted from DSAlign. A directory "
"will be recursively searched for audio. If --dst not set, "
"transcription logs (.tlog) will be written in-place using the "
'source filenames with suffix ".tlog" instead of ".wav".'
),
)
dst: str = field(
default="",
metadata=dict(
help="path for writing the transcription log or logs (.tlog). "
"If --src is a directory, this one also has to be a directory "
"and the required sub-dir tree of --src will get replicated."
),
)
recursive: bool = field(
default=False,
metadata=dict(help="scan source directory recursively for audio"),
)
force: bool = field(
default=False,
metadata=dict(
help="Forces re-transcribing and overwriting of already existing "
"transcription logs (.tlog)"
),
)
vad_aggressiveness: int = field(
default=3,
metadata=dict(help="VAD aggressiveness setting (0=lowest, 3=highest)"),
)
batch_size: int = field(
default=40,
metadata=dict(help="Default batch size"),
)
outlier_duration_ms: int = field(
default=10000,
metadata=dict(
help="Duration in ms after which samples are considered outliers"
),
)
outlier_batch_size: int = field(
default=1,
metadata=dict(help="Batch size for duration outliers (defaults to 1)"),
)
def __post_init__(self):
if os.path.isfile(self.src) and self.src.endswith(".catalog") and self.dst:
raise RuntimeError(
"Parameter --dst not supported if --src points to a catalog"
)
if os.path.isdir(self.src):
if self.dst:
raise RuntimeError(
"Destination path not supported for batch decoding jobs."
)
super().__post_init__()
config = TranscribeConfig.init_from_argparse(arg_prefix="")
initialize_globals_from_instance(config)
def main():
from coqui_stt_training.util.helpers import check_ctcdecoder_version
assert not tf.test.is_gpu_available()
# Set start method to spawn on all platforms to avoid issues with TensorFlow
multiprocessing.set_start_method("spawn")
try:
import webrtcvad

View File

@ -7,7 +7,13 @@ import tensorflow as tf
from .config import Config, log_error, log_info, log_warn
def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=True):
def _load_checkpoint(
session,
checkpoint_path,
allow_drop_layers,
allow_lr_init=True,
silent: bool = False,
):
# Load the checkpoint and put all variables into loading list
# we will exclude variables we do not wish to load and then
# we will initialize them instead
@ -75,15 +81,16 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
init_vars.add(v)
load_vars -= init_vars
log_info(f"Vars to load: {list(sorted(v.op.name for v in load_vars))}")
def maybe_log_info(*args, **kwargs):
if not silent:
log_info(*args, **kwargs)
for v in sorted(load_vars, key=lambda v: v.op.name):
log_info(f"Getting tensor from variable: {v.op.name}")
tensor = ckpt.get_tensor(v.op.name)
log_info(f"Loading tensor from checkpoint: {v.op.name}")
v.load(tensor, session=session)
maybe_log_info(f"Loading variable from checkpoint: {v.op.name}")
v.load(ckpt.get_tensor(v.op.name), session=session)
for v in sorted(init_vars, key=lambda v: v.op.name):
log_info("Initializing variable: %s" % (v.op.name))
maybe_log_info("Initializing variable: %s" % (v.op.name))
session.run(v.initializer)
@ -102,31 +109,49 @@ def _initialize_all_variables(session):
session.run(v.initializer)
def _load_or_init_impl(session, method_order, allow_drop_layers, allow_lr_init=True):
def _load_or_init_impl(
session, method_order, allow_drop_layers, allow_lr_init=True, silent: bool = False
):
def maybe_log_info(*args, **kwargs):
if not silent:
log_info(*args, **kwargs)
for method in method_order:
# Load best validating checkpoint, saved in checkpoint file 'best_dev_checkpoint'
if method == "best":
ckpt_path = _checkpoint_path_or_none("best_dev_checkpoint")
if ckpt_path:
log_info("Loading best validating checkpoint from {}".format(ckpt_path))
return _load_checkpoint(
session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init
maybe_log_info(
"Loading best validating checkpoint from {}".format(ckpt_path)
)
log_info("Could not find best validating checkpoint.")
return _load_checkpoint(
session,
ckpt_path,
allow_drop_layers,
allow_lr_init=allow_lr_init,
silent=silent,
)
maybe_log_info("Could not find best validating checkpoint.")
# Load most recent checkpoint, saved in checkpoint file 'checkpoint'
elif method == "last":
ckpt_path = _checkpoint_path_or_none("checkpoint")
if ckpt_path:
log_info("Loading most recent checkpoint from {}".format(ckpt_path))
return _load_checkpoint(
session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init
maybe_log_info(
"Loading most recent checkpoint from {}".format(ckpt_path)
)
log_info("Could not find most recent checkpoint.")
return _load_checkpoint(
session,
ckpt_path,
allow_drop_layers,
allow_lr_init=allow_lr_init,
silent=silent,
)
maybe_log_info("Could not find most recent checkpoint.")
# Initialize all variables
elif method == "init":
log_info("Initializing all variables.")
maybe_log_info("Initializing all variables.")
return _initialize_all_variables(session)
else:
@ -141,7 +166,7 @@ def reload_best_checkpoint(session):
_load_or_init_impl(session, ["best"], allow_drop_layers=False, allow_lr_init=False)
def load_or_init_graph_for_training(session):
def load_or_init_graph_for_training(session, silent: bool = False):
"""
Load variables from checkpoint or initialize variables. By default this will
try to load the best validating checkpoint, then try the last checkpoint,
@ -152,10 +177,10 @@ def load_or_init_graph_for_training(session):
methods = ["best", "last", "init"]
else:
methods = [Config.load_train]
_load_or_init_impl(session, methods, allow_drop_layers=True)
_load_or_init_impl(session, methods, allow_drop_layers=True, silent=silent)
def load_graph_for_evaluation(session):
def load_graph_for_evaluation(session, silent: bool = False):
"""
Load variables from checkpoint. Initialization is not allowed. By default
this will try to load the best validating checkpoint, then try the last
@ -166,4 +191,4 @@ def load_graph_for_evaluation(session):
methods = ["best", "last"]
else:
methods = [Config.load_evaluate]
_load_or_init_impl(session, methods, allow_drop_layers=False)
_load_or_init_impl(session, methods, allow_drop_layers=False, silent=silent)

View File

@ -217,12 +217,14 @@ class BaseSttConfig(Coqpit):
if not is_remote_path(self.save_checkpoint_dir):
os.makedirs(self.save_checkpoint_dir, exist_ok=True)
flags_file = os.path.join(self.save_checkpoint_dir, "flags.txt")
with open_remote(flags_file, "w") as fout:
json.dump(self.serialize(), fout, indent=2)
if not os.path.exists(flags_file):
with open_remote(flags_file, "w") as fout:
json.dump(self.serialize(), fout, indent=2)
# Serialize alphabet alongside checkpoint
with open_remote(saved_checkpoint_alphabet_file, "wb") as fout:
fout.write(self.alphabet.SerializeText())
if not os.path.exists(saved_checkpoint_alphabet_file):
with open_remote(saved_checkpoint_alphabet_file, "wb") as fout:
fout.write(self.alphabet.SerializeText())
# Geometric Constants
# ===================