Fix lm_optimizer.py to use new Config/flags/logging setup
This commit is contained in:
parent
f9556d2236
commit
4c3537952a
|
@ -4,36 +4,36 @@ from __future__ import absolute_import, print_function
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import absl.app
|
|
||||||
import optuna
|
import optuna
|
||||||
import tensorflow.compat.v1 as tfv1
|
import tensorflow.compat.v1 as tfv1
|
||||||
from coqui_stt_ctcdecoder import Scorer
|
from coqui_stt_ctcdecoder import Scorer
|
||||||
from coqui_stt_training.evaluate import evaluate
|
from coqui_stt_training.evaluate import evaluate
|
||||||
from coqui_stt_training.train import create_model
|
from coqui_stt_training.train import create_model, early_training_checks
|
||||||
from coqui_stt_training.util.config import Config, initialize_globals_from_cli
|
from coqui_stt_training.util.config import (
|
||||||
|
Config,
|
||||||
|
initialize_globals_from_cli,
|
||||||
|
log_error,
|
||||||
|
)
|
||||||
from coqui_stt_training.util.evaluate_tools import wer_cer_batch
|
from coqui_stt_training.util.evaluate_tools import wer_cer_batch
|
||||||
from coqui_stt_training.util.flags import FLAGS, create_flags
|
|
||||||
from coqui_stt_training.util.logging import log_error
|
|
||||||
|
|
||||||
|
|
||||||
def character_based():
|
def character_based():
|
||||||
is_character_based = False
|
is_character_based = False
|
||||||
if FLAGS.scorer_path:
|
scorer = Scorer(
|
||||||
scorer = Scorer(
|
Config.lm_alpha, Config.lm_beta, Config.scorer_path, Config.alphabet
|
||||||
FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet
|
)
|
||||||
)
|
is_character_based = scorer.is_utf8_mode()
|
||||||
is_character_based = scorer.is_utf8_mode()
|
|
||||||
return is_character_based
|
return is_character_based
|
||||||
|
|
||||||
|
|
||||||
def objective(trial):
|
def objective(trial):
|
||||||
FLAGS.lm_alpha = trial.suggest_uniform("lm_alpha", 0, FLAGS.lm_alpha_max)
|
Config.lm_alpha = trial.suggest_uniform("lm_alpha", 0, Config.lm_alpha_max)
|
||||||
FLAGS.lm_beta = trial.suggest_uniform("lm_beta", 0, FLAGS.lm_beta_max)
|
Config.lm_beta = trial.suggest_uniform("lm_beta", 0, Config.lm_beta_max)
|
||||||
|
|
||||||
is_character_based = trial.study.user_attrs["is_character_based"]
|
is_character_based = trial.study.user_attrs["is_character_based"]
|
||||||
|
|
||||||
samples = []
|
samples = []
|
||||||
for step, test_file in enumerate(FLAGS.test_files.split(",")):
|
for step, test_file in enumerate(Config.test_files):
|
||||||
tfv1.reset_default_graph()
|
tfv1.reset_default_graph()
|
||||||
|
|
||||||
current_samples = evaluate([test_file], create_model)
|
current_samples = evaluate([test_file], create_model)
|
||||||
|
@ -51,10 +51,18 @@ def objective(trial):
|
||||||
return cer if is_character_based else wer
|
return cer if is_character_based else wer
|
||||||
|
|
||||||
|
|
||||||
def main(_):
|
def main():
|
||||||
initialize_globals_from_cli()
|
initialize_globals_from_cli()
|
||||||
|
early_training_checks()
|
||||||
|
|
||||||
if not FLAGS.test_files:
|
if not Config.scorer_path:
|
||||||
|
log_error(
|
||||||
|
"Missing --scorer_path: can't optimize scorer alpha and beta "
|
||||||
|
"parameters without a scorer!"
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if not Config.test_files:
|
||||||
log_error(
|
log_error(
|
||||||
"You need to specify what files to use for evaluation via "
|
"You need to specify what files to use for evaluation via "
|
||||||
"the --test_files flag."
|
"the --test_files flag."
|
||||||
|
@ -65,7 +73,7 @@ def main(_):
|
||||||
|
|
||||||
study = optuna.create_study()
|
study = optuna.create_study()
|
||||||
study.set_user_attr("is_character_based", is_character_based)
|
study.set_user_attr("is_character_based", is_character_based)
|
||||||
study.optimize(objective, n_jobs=1, n_trials=FLAGS.n_trials)
|
study.optimize(objective, n_jobs=1, n_trials=Config.n_trials)
|
||||||
print(
|
print(
|
||||||
"Best params: lm_alpha={} and lm_beta={} with WER={}".format(
|
"Best params: lm_alpha={} and lm_beta={} with WER={}".format(
|
||||||
study.best_params["lm_alpha"],
|
study.best_params["lm_alpha"],
|
||||||
|
@ -76,5 +84,4 @@ def main(_):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
create_flags()
|
main()
|
||||||
absl.app.run(main)
|
|
||||||
|
|
Loading…
Reference in New Issue