[transcribe] Fix multiprocessing hangs, clean-up target collection
This commit is contained in:
parent
5cefd7069c
commit
d90bb60506
@ -80,12 +80,11 @@ time python -m coqui_stt_training.transcribe \
|
|||||||
--n_hidden 100 \
|
--n_hidden 100 \
|
||||||
--scorer_path "data/smoke_test/pruned_lm.scorer"
|
--scorer_path "data/smoke_test/pruned_lm.scorer"
|
||||||
|
|
||||||
#TODO: investigate why this is hanging in CI
|
mkdir /tmp/transcribe_dir
|
||||||
#mkdir /tmp/transcribe_dir
|
cp data/smoke_test/LDC93S1.wav /tmp/transcribe_dir
|
||||||
#cp data/smoke_test/LDC93S1.wav /tmp/transcribe_dir
|
time python -m coqui_stt_training.transcribe \
|
||||||
#time python -m coqui_stt_training.transcribe \
|
--src "/tmp/transcribe_dir/" \
|
||||||
# --src "/tmp/transcribe_dir/" \
|
--n_hidden 100 \
|
||||||
# --n_hidden 100 \
|
--scorer_path "data/smoke_test/pruned_lm.scorer"
|
||||||
# --scorer_path "data/smoke_test/pruned_lm.scorer"
|
|
||||||
#
|
for i in /tmp/transcribe_dir/*.tlog; do echo $i; cat $i; echo; done
|
||||||
#for i in data/smoke_test/*.tlog; do echo $i; cat $i; echo; done
|
|
||||||
|
@ -7,25 +7,22 @@
|
|||||||
# restructure the code so that TensorFlow is only imported inside the child
|
# restructure the code so that TensorFlow is only imported inside the child
|
||||||
# processes.
|
# processes.
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import glob
|
import glob
|
||||||
import itertools
|
import itertools
|
||||||
import json
|
import json
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
from multiprocessing import Pool, cpu_count
|
import os
|
||||||
|
import sys
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from multiprocessing import Pool, Lock, cpu_count
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch
|
from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
def fail(message, code=1):
|
def transcribe_file(audio_path: Path, tlog_path: Path):
|
||||||
print(f"E {message}")
|
|
||||||
sys.exit(code)
|
|
||||||
|
|
||||||
|
|
||||||
def transcribe_file(audio_path, tlog_path):
|
|
||||||
log_level_index = (
|
log_level_index = (
|
||||||
sys.argv.index("--log_level") + 1 if "--log_level" in sys.argv else 0
|
sys.argv.index("--log_level") + 1 if "--log_level" in sys.argv else 0
|
||||||
)
|
)
|
||||||
@ -56,7 +53,7 @@ def transcribe_file(audio_path, tlog_path):
|
|||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
num_processes = 1
|
num_processes = 1
|
||||||
|
|
||||||
with AudioFile(audio_path, as_path=True) as wav_path:
|
with AudioFile(str(audio_path), as_path=True) as wav_path:
|
||||||
data_set = split_audio_file(
|
data_set = split_audio_file(
|
||||||
wav_path,
|
wav_path,
|
||||||
batch_size=Config.batch_size,
|
batch_size=Config.batch_size,
|
||||||
@ -73,7 +70,9 @@ def transcribe_file(audio_path, tlog_path):
|
|||||||
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
|
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
|
||||||
tf.train.get_or_create_global_step()
|
tf.train.get_or_create_global_step()
|
||||||
with tf.Session(config=Config.session_config) as session:
|
with tf.Session(config=Config.session_config) as session:
|
||||||
load_graph_for_evaluation(session)
|
# Load checkpoint in a mutex way to avoid hangs in TensorFlow code
|
||||||
|
with lock:
|
||||||
|
load_graph_for_evaluation(session, silent=True)
|
||||||
transcripts = []
|
transcripts = []
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@ -101,6 +100,11 @@ def transcribe_file(audio_path, tlog_path):
|
|||||||
json.dump(transcripts, tlog_file, default=float)
|
json.dump(transcripts, tlog_file, default=float)
|
||||||
|
|
||||||
|
|
||||||
|
def init_fn(l):
|
||||||
|
global lock
|
||||||
|
lock = l
|
||||||
|
|
||||||
|
|
||||||
def step_function(job):
|
def step_function(job):
|
||||||
"""Wrap transcribe_file to unpack arguments from a single tuple"""
|
"""Wrap transcribe_file to unpack arguments from a single tuple"""
|
||||||
idx, src, dst = job
|
idx, src, dst = job
|
||||||
@ -111,11 +115,15 @@ def step_function(job):
|
|||||||
def transcribe_many(src_paths, dst_paths):
|
def transcribe_many(src_paths, dst_paths):
|
||||||
from coqui_stt_training.util.config import Config, log_progress
|
from coqui_stt_training.util.config import Config, log_progress
|
||||||
|
|
||||||
pool = Pool(processes=min(cpu_count(), len(src_paths)))
|
|
||||||
|
|
||||||
# Create list of items to be processed: [(i, src_path[i], dst_paths[i])]
|
# Create list of items to be processed: [(i, src_path[i], dst_paths[i])]
|
||||||
jobs = zip(itertools.count(), src_paths, dst_paths)
|
jobs = zip(itertools.count(), src_paths, dst_paths)
|
||||||
|
|
||||||
|
lock = Lock()
|
||||||
|
with Pool(
|
||||||
|
processes=min(cpu_count(), len(src_paths)),
|
||||||
|
initializer=init_fn,
|
||||||
|
initargs=(lock,),
|
||||||
|
) as pool:
|
||||||
process_iterable = tqdm(
|
process_iterable = tqdm(
|
||||||
pool.imap_unordered(step_function, jobs),
|
pool.imap_unordered(step_function, jobs),
|
||||||
desc="Transcribing files",
|
desc="Transcribing files",
|
||||||
@ -123,96 +131,109 @@ def transcribe_many(src_paths, dst_paths):
|
|||||||
disable=not Config.show_progressbar,
|
disable=not Config.show_progressbar,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cwd = Path.cwd()
|
||||||
for result in process_iterable:
|
for result in process_iterable:
|
||||||
idx, src, dst = result
|
idx, src, dst = result
|
||||||
log_progress(
|
# Revert to relative to avoid spamming logs
|
||||||
f'Transcribed file {idx+1} of {len(src_paths)} from "{src}" to "{dst}"'
|
if not src.is_absolute():
|
||||||
)
|
src = src.relative_to(cwd)
|
||||||
|
if not dst.is_absolute():
|
||||||
|
dst = dst.relative_to(cwd)
|
||||||
|
tqdm.write(f'[{idx+1}]: "{src}" -> "{dst}"')
|
||||||
|
|
||||||
|
|
||||||
def transcribe_one(src_path, dst_path):
|
def get_tasks_from_catalog(catalog_file_path: Path) -> Tuple[List[Path], List[Path]]:
|
||||||
transcribe_file(src_path, dst_path)
|
"""Given a `catalog_file_path` pointing to a .catalog file (from DSAlign),
|
||||||
print(f'I Transcribed file "{src_path}" to "{dst_path}"')
|
extract transcription tasks, ie. (src_path, dest_path) pairs corresponding to
|
||||||
|
a path to an audio file to be transcribed, and a path to a JSON file to place
|
||||||
|
transcription results. For .catalog file inputs, these are taken from the
|
||||||
|
"audio" and "tlog" properties of the entries in the catalog, with any relative
|
||||||
|
paths being absolutized relative to the directory containing the .catalog file.
|
||||||
|
"""
|
||||||
|
assert catalog_file_path.suffix == ".catalog"
|
||||||
|
|
||||||
|
catalog_dir = catalog_file_path.parent
|
||||||
|
with open(catalog_file_path, "r") as catalog_file:
|
||||||
|
catalog_entries = json.load(catalog_file)
|
||||||
|
|
||||||
def resolve(base_path, spec_path):
|
def resolve(spec_path: Optional[Path]):
|
||||||
if spec_path is None:
|
if spec_path is None:
|
||||||
return None
|
return None
|
||||||
if not os.path.isabs(spec_path):
|
if not spec_path.is_absolute():
|
||||||
spec_path = os.path.join(base_path, spec_path)
|
spec_path = catalog_dir / spec_path
|
||||||
return spec_path
|
return spec_path
|
||||||
|
|
||||||
|
catalog_entries = [
|
||||||
|
(resolve(Path(e["audio"])), resolve(Path(e["tlog"]))) for e in catalog_entries
|
||||||
|
]
|
||||||
|
|
||||||
|
for src, dst in catalog_entries:
|
||||||
|
if not Config.force and dst.is_file():
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Destination file already exists: {dst}. Use --force for overwriting."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not dst.parent.is_dir():
|
||||||
|
dst.parent.mkdir(parents=True)
|
||||||
|
|
||||||
|
src_paths, dst_paths = zip(*catalog_entries)
|
||||||
|
return src_paths, dst_paths
|
||||||
|
|
||||||
|
|
||||||
|
def get_tasks_from_dir(src_dir: Path, recursive: bool) -> Tuple[List[Path], List[Path]]:
|
||||||
|
"""Given a directory `src_dir` containing audio files, scan it for audio files
|
||||||
|
and return transcription tasks, ie. (src_path, dest_path) pairs corresponding to
|
||||||
|
a path to an audio file to be transcribed, and a path to a JSON file to place
|
||||||
|
transcription results.
|
||||||
|
"""
|
||||||
|
glob_method = src_dir.rglob if recursive else src_dir.glob
|
||||||
|
src_paths = list(itertools.chain(glob_method("*.wav"), glob_method("*.opus")))
|
||||||
|
dst_paths = [path.with_suffix(".tlog") for path in src_paths]
|
||||||
|
return src_paths, dst_paths
|
||||||
|
|
||||||
|
|
||||||
def transcribe():
|
def transcribe():
|
||||||
from coqui_stt_training.util.config import Config
|
from coqui_stt_training.util.config import Config
|
||||||
|
|
||||||
initialize_transcribe_config()
|
initialize_transcribe_config()
|
||||||
|
|
||||||
if not Config.src or not os.path.exists(Config.src):
|
src_path = Path(Config.src).resolve()
|
||||||
|
if not Config.src or not src_path.exists():
|
||||||
# path not given or non-existant
|
# path not given or non-existant
|
||||||
fail(
|
raise RuntimeError(
|
||||||
"You have to specify which file or catalog to transcribe via the --src flag."
|
"You have to specify which audio file, catalog file or directory to "
|
||||||
|
"transcribe with the --src flag."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# path given and exists
|
# path given and exists
|
||||||
src_path = os.path.abspath(Config.src)
|
if src_path.is_file():
|
||||||
if os.path.isfile(src_path):
|
if src_path.suffix != ".catalog":
|
||||||
if src_path.endswith(".catalog"):
|
|
||||||
# Transcribe batch of files via ".catalog" file (from DSAlign)
|
|
||||||
catalog_dir = os.path.dirname(src_path)
|
|
||||||
with open(src_path, "r") as catalog_file:
|
|
||||||
catalog_entries = json.load(catalog_file)
|
|
||||||
catalog_entries = [
|
|
||||||
(resolve(catalog_dir, e["audio"]), resolve(catalog_dir, e["tlog"]))
|
|
||||||
for e in catalog_entries
|
|
||||||
]
|
|
||||||
if any(map(lambda e: not os.path.isfile(e[0]), catalog_entries)):
|
|
||||||
fail("Missing source file(s) in catalog")
|
|
||||||
if not Config.force and any(
|
|
||||||
map(lambda e: os.path.isfile(e[1]), catalog_entries)
|
|
||||||
):
|
|
||||||
fail(
|
|
||||||
"Destination file(s) from catalog already existing, use --force for overwriting"
|
|
||||||
)
|
|
||||||
if any(
|
|
||||||
map(
|
|
||||||
lambda e: not os.path.isdir(os.path.dirname(e[1])),
|
|
||||||
catalog_entries,
|
|
||||||
)
|
|
||||||
):
|
|
||||||
fail("Missing destination directory for at least one catalog entry")
|
|
||||||
src_paths, dst_paths = zip(*paths)
|
|
||||||
transcribe_many(src_paths, dst_paths)
|
|
||||||
else:
|
|
||||||
# Transcribe one file
|
# Transcribe one file
|
||||||
dst_path = (
|
dst_path = (
|
||||||
os.path.abspath(Config.dst)
|
Path(Config.dst).resolve()
|
||||||
if Config.dst
|
if Config.dst
|
||||||
else os.path.splitext(src_path)[0] + ".tlog"
|
else src_path.with_suffix(".tlog")
|
||||||
)
|
)
|
||||||
if os.path.isfile(dst_path):
|
|
||||||
if Config.force:
|
if dst_path.is_file() and not Config.force:
|
||||||
transcribe_one(src_path, dst_path)
|
raise RuntimeError(
|
||||||
else:
|
f'Destination file "{dst_path}" already exists - use '
|
||||||
fail(
|
"--force for overwriting."
|
||||||
'Destination file "{}" already existing - use --force for overwriting'.format(
|
|
||||||
dst_path
|
|
||||||
),
|
|
||||||
code=0,
|
|
||||||
)
|
)
|
||||||
elif os.path.isdir(os.path.dirname(dst_path)):
|
|
||||||
transcribe_one(src_path, dst_path)
|
if not dst_path.parent.is_dir():
|
||||||
|
raise RuntimeError("Missing destination directory")
|
||||||
|
|
||||||
|
transcribe_many([src_path], [dst_path])
|
||||||
else:
|
else:
|
||||||
fail("Missing destination directory")
|
# Transcribe from .catalog input
|
||||||
elif os.path.isdir(src_path):
|
src_paths, dst_paths = get_tasks_from_catalog(src_path)
|
||||||
# Transcribe all files in dir
|
transcribe_many(src_paths, dst_paths)
|
||||||
print("Transcribing all WAV files in --src")
|
elif src_path.is_dir():
|
||||||
if Config.recursive:
|
# Transcribe from dir input
|
||||||
wav_paths = glob.glob(os.path.join(src_path, "**", "*.wav"))
|
print(f"Transcribing all files in --src directory {src_path}")
|
||||||
else:
|
src_paths, dst_paths = get_tasks_from_dir(src_path, Config.recursive)
|
||||||
wav_paths = glob.glob(os.path.join(src_path, "*.wav"))
|
transcribe_many(src_paths, dst_paths)
|
||||||
dst_paths = [path.replace(".wav", ".tlog") for path in wav_paths]
|
|
||||||
transcribe_many(wav_paths, dst_paths)
|
|
||||||
|
|
||||||
|
|
||||||
def initialize_transcribe_config():
|
def initialize_transcribe_config():
|
||||||
@ -230,7 +251,7 @@ def initialize_transcribe_config():
|
|||||||
"Catalog files should be formatted from DSAlign. A directory "
|
"Catalog files should be formatted from DSAlign. A directory "
|
||||||
"will be recursively searched for audio. If --dst not set, "
|
"will be recursively searched for audio. If --dst not set, "
|
||||||
"transcription logs (.tlog) will be written in-place using the "
|
"transcription logs (.tlog) will be written in-place using the "
|
||||||
'source filenames with suffix ".tlog" instead of ".wav".'
|
'source filenames with suffix ".tlog" instead of the original.'
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -299,6 +320,9 @@ def initialize_transcribe_config():
|
|||||||
def main():
|
def main():
|
||||||
from coqui_stt_training.util.helpers import check_ctcdecoder_version
|
from coqui_stt_training.util.helpers import check_ctcdecoder_version
|
||||||
|
|
||||||
|
# Set start method to spawn on all platforms to avoid issues with TensorFlow
|
||||||
|
multiprocessing.set_start_method("spawn")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import webrtcvad
|
import webrtcvad
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -7,7 +7,13 @@ import tensorflow as tf
|
|||||||
from .config import Config, log_error, log_info, log_warn
|
from .config import Config, log_error, log_info, log_warn
|
||||||
|
|
||||||
|
|
||||||
def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=True):
|
def _load_checkpoint(
|
||||||
|
session,
|
||||||
|
checkpoint_path,
|
||||||
|
allow_drop_layers,
|
||||||
|
allow_lr_init=True,
|
||||||
|
silent: bool = False,
|
||||||
|
):
|
||||||
# Load the checkpoint and put all variables into loading list
|
# Load the checkpoint and put all variables into loading list
|
||||||
# we will exclude variables we do not wish to load and then
|
# we will exclude variables we do not wish to load and then
|
||||||
# we will initialize them instead
|
# we will initialize them instead
|
||||||
@ -75,15 +81,16 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
|
|||||||
init_vars.add(v)
|
init_vars.add(v)
|
||||||
load_vars -= init_vars
|
load_vars -= init_vars
|
||||||
|
|
||||||
log_info(f"Vars to load: {list(sorted(v.op.name for v in load_vars))}")
|
def maybe_log_info(*args, **kwargs):
|
||||||
|
if not silent:
|
||||||
|
log_info(*args, **kwargs)
|
||||||
|
|
||||||
for v in sorted(load_vars, key=lambda v: v.op.name):
|
for v in sorted(load_vars, key=lambda v: v.op.name):
|
||||||
log_info(f"Getting tensor from variable: {v.op.name}")
|
maybe_log_info(f"Loading variable from checkpoint: {v.op.name}")
|
||||||
tensor = ckpt.get_tensor(v.op.name)
|
v.load(ckpt.get_tensor(v.op.name), session=session)
|
||||||
log_info(f"Loading tensor from checkpoint: {v.op.name}")
|
|
||||||
v.load(tensor, session=session)
|
|
||||||
|
|
||||||
for v in sorted(init_vars, key=lambda v: v.op.name):
|
for v in sorted(init_vars, key=lambda v: v.op.name):
|
||||||
log_info("Initializing variable: %s" % (v.op.name))
|
maybe_log_info("Initializing variable: %s" % (v.op.name))
|
||||||
session.run(v.initializer)
|
session.run(v.initializer)
|
||||||
|
|
||||||
|
|
||||||
@ -102,31 +109,49 @@ def _initialize_all_variables(session):
|
|||||||
session.run(v.initializer)
|
session.run(v.initializer)
|
||||||
|
|
||||||
|
|
||||||
def _load_or_init_impl(session, method_order, allow_drop_layers, allow_lr_init=True):
|
def _load_or_init_impl(
|
||||||
|
session, method_order, allow_drop_layers, allow_lr_init=True, silent: bool = False
|
||||||
|
):
|
||||||
|
def maybe_log_info(*args, **kwargs):
|
||||||
|
if not silent:
|
||||||
|
log_info(*args, **kwargs)
|
||||||
|
|
||||||
for method in method_order:
|
for method in method_order:
|
||||||
# Load best validating checkpoint, saved in checkpoint file 'best_dev_checkpoint'
|
# Load best validating checkpoint, saved in checkpoint file 'best_dev_checkpoint'
|
||||||
if method == "best":
|
if method == "best":
|
||||||
ckpt_path = _checkpoint_path_or_none("best_dev_checkpoint")
|
ckpt_path = _checkpoint_path_or_none("best_dev_checkpoint")
|
||||||
if ckpt_path:
|
if ckpt_path:
|
||||||
log_info("Loading best validating checkpoint from {}".format(ckpt_path))
|
maybe_log_info(
|
||||||
return _load_checkpoint(
|
"Loading best validating checkpoint from {}".format(ckpt_path)
|
||||||
session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init
|
|
||||||
)
|
)
|
||||||
log_info("Could not find best validating checkpoint.")
|
return _load_checkpoint(
|
||||||
|
session,
|
||||||
|
ckpt_path,
|
||||||
|
allow_drop_layers,
|
||||||
|
allow_lr_init=allow_lr_init,
|
||||||
|
silent=silent,
|
||||||
|
)
|
||||||
|
maybe_log_info("Could not find best validating checkpoint.")
|
||||||
|
|
||||||
# Load most recent checkpoint, saved in checkpoint file 'checkpoint'
|
# Load most recent checkpoint, saved in checkpoint file 'checkpoint'
|
||||||
elif method == "last":
|
elif method == "last":
|
||||||
ckpt_path = _checkpoint_path_or_none("checkpoint")
|
ckpt_path = _checkpoint_path_or_none("checkpoint")
|
||||||
if ckpt_path:
|
if ckpt_path:
|
||||||
log_info("Loading most recent checkpoint from {}".format(ckpt_path))
|
maybe_log_info(
|
||||||
return _load_checkpoint(
|
"Loading most recent checkpoint from {}".format(ckpt_path)
|
||||||
session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init
|
|
||||||
)
|
)
|
||||||
log_info("Could not find most recent checkpoint.")
|
return _load_checkpoint(
|
||||||
|
session,
|
||||||
|
ckpt_path,
|
||||||
|
allow_drop_layers,
|
||||||
|
allow_lr_init=allow_lr_init,
|
||||||
|
silent=silent,
|
||||||
|
)
|
||||||
|
maybe_log_info("Could not find most recent checkpoint.")
|
||||||
|
|
||||||
# Initialize all variables
|
# Initialize all variables
|
||||||
elif method == "init":
|
elif method == "init":
|
||||||
log_info("Initializing all variables.")
|
maybe_log_info("Initializing all variables.")
|
||||||
return _initialize_all_variables(session)
|
return _initialize_all_variables(session)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -141,7 +166,7 @@ def reload_best_checkpoint(session):
|
|||||||
_load_or_init_impl(session, ["best"], allow_drop_layers=False, allow_lr_init=False)
|
_load_or_init_impl(session, ["best"], allow_drop_layers=False, allow_lr_init=False)
|
||||||
|
|
||||||
|
|
||||||
def load_or_init_graph_for_training(session):
|
def load_or_init_graph_for_training(session, silent: bool = False):
|
||||||
"""
|
"""
|
||||||
Load variables from checkpoint or initialize variables. By default this will
|
Load variables from checkpoint or initialize variables. By default this will
|
||||||
try to load the best validating checkpoint, then try the last checkpoint,
|
try to load the best validating checkpoint, then try the last checkpoint,
|
||||||
@ -152,10 +177,10 @@ def load_or_init_graph_for_training(session):
|
|||||||
methods = ["best", "last", "init"]
|
methods = ["best", "last", "init"]
|
||||||
else:
|
else:
|
||||||
methods = [Config.load_train]
|
methods = [Config.load_train]
|
||||||
_load_or_init_impl(session, methods, allow_drop_layers=True)
|
_load_or_init_impl(session, methods, allow_drop_layers=True, silent=silent)
|
||||||
|
|
||||||
|
|
||||||
def load_graph_for_evaluation(session):
|
def load_graph_for_evaluation(session, silent: bool = False):
|
||||||
"""
|
"""
|
||||||
Load variables from checkpoint. Initialization is not allowed. By default
|
Load variables from checkpoint. Initialization is not allowed. By default
|
||||||
this will try to load the best validating checkpoint, then try the last
|
this will try to load the best validating checkpoint, then try the last
|
||||||
@ -166,4 +191,4 @@ def load_graph_for_evaluation(session):
|
|||||||
methods = ["best", "last"]
|
methods = ["best", "last"]
|
||||||
else:
|
else:
|
||||||
methods = [Config.load_evaluate]
|
methods = [Config.load_evaluate]
|
||||||
_load_or_init_impl(session, methods, allow_drop_layers=False)
|
_load_or_init_impl(session, methods, allow_drop_layers=False, silent=silent)
|
||||||
|
@ -217,10 +217,12 @@ class BaseSttConfig(Coqpit):
|
|||||||
if not is_remote_path(self.save_checkpoint_dir):
|
if not is_remote_path(self.save_checkpoint_dir):
|
||||||
os.makedirs(self.save_checkpoint_dir, exist_ok=True)
|
os.makedirs(self.save_checkpoint_dir, exist_ok=True)
|
||||||
flags_file = os.path.join(self.save_checkpoint_dir, "flags.txt")
|
flags_file = os.path.join(self.save_checkpoint_dir, "flags.txt")
|
||||||
|
if not os.path.exists(flags_file):
|
||||||
with open_remote(flags_file, "w") as fout:
|
with open_remote(flags_file, "w") as fout:
|
||||||
json.dump(self.serialize(), fout, indent=2)
|
json.dump(self.serialize(), fout, indent=2)
|
||||||
|
|
||||||
# Serialize alphabet alongside checkpoint
|
# Serialize alphabet alongside checkpoint
|
||||||
|
if not os.path.exists(saved_checkpoint_alphabet_file):
|
||||||
with open_remote(saved_checkpoint_alphabet_file, "wb") as fout:
|
with open_remote(saved_checkpoint_alphabet_file, "wb") as fout:
|
||||||
fout.write(self.alphabet.SerializeText())
|
fout.write(self.alphabet.SerializeText())
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user