Fix lm_optimizer.py to use new Config/flags/logging setup

This commit is contained in:
Reuben Morais 2021-08-19 18:42:07 +02:00
parent f9556d2236
commit 4c3537952a
1 changed files with 25 additions and 18 deletions

View File

@ -4,36 +4,36 @@ from __future__ import absolute_import, print_function
import sys import sys
import absl.app
import optuna import optuna
import tensorflow.compat.v1 as tfv1 import tensorflow.compat.v1 as tfv1
from coqui_stt_ctcdecoder import Scorer from coqui_stt_ctcdecoder import Scorer
from coqui_stt_training.evaluate import evaluate from coqui_stt_training.evaluate import evaluate
from coqui_stt_training.train import create_model from coqui_stt_training.train import create_model, early_training_checks
from coqui_stt_training.util.config import Config, initialize_globals_from_cli from coqui_stt_training.util.config import (
Config,
initialize_globals_from_cli,
log_error,
)
from coqui_stt_training.util.evaluate_tools import wer_cer_batch from coqui_stt_training.util.evaluate_tools import wer_cer_batch
from coqui_stt_training.util.flags import FLAGS, create_flags
from coqui_stt_training.util.logging import log_error
def character_based(): def character_based():
is_character_based = False is_character_based = False
if FLAGS.scorer_path: scorer = Scorer(
scorer = Scorer( Config.lm_alpha, Config.lm_beta, Config.scorer_path, Config.alphabet
FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet )
) is_character_based = scorer.is_utf8_mode()
is_character_based = scorer.is_utf8_mode()
return is_character_based return is_character_based
def objective(trial): def objective(trial):
FLAGS.lm_alpha = trial.suggest_uniform("lm_alpha", 0, FLAGS.lm_alpha_max) Config.lm_alpha = trial.suggest_uniform("lm_alpha", 0, Config.lm_alpha_max)
FLAGS.lm_beta = trial.suggest_uniform("lm_beta", 0, FLAGS.lm_beta_max) Config.lm_beta = trial.suggest_uniform("lm_beta", 0, Config.lm_beta_max)
is_character_based = trial.study.user_attrs["is_character_based"] is_character_based = trial.study.user_attrs["is_character_based"]
samples = [] samples = []
for step, test_file in enumerate(FLAGS.test_files.split(",")): for step, test_file in enumerate(Config.test_files):
tfv1.reset_default_graph() tfv1.reset_default_graph()
current_samples = evaluate([test_file], create_model) current_samples = evaluate([test_file], create_model)
@ -51,10 +51,18 @@ def objective(trial):
return cer if is_character_based else wer return cer if is_character_based else wer
def main(_): def main():
initialize_globals_from_cli() initialize_globals_from_cli()
early_training_checks()
if not FLAGS.test_files: if not Config.scorer_path:
log_error(
"Missing --scorer_path: can't optimize scorer alpha and beta "
"parameters without a scorer!"
)
sys.exit(1)
if not Config.test_files:
log_error( log_error(
"You need to specify what files to use for evaluation via " "You need to specify what files to use for evaluation via "
"the --test_files flag." "the --test_files flag."
@ -65,7 +73,7 @@ def main(_):
study = optuna.create_study() study = optuna.create_study()
study.set_user_attr("is_character_based", is_character_based) study.set_user_attr("is_character_based", is_character_based)
study.optimize(objective, n_jobs=1, n_trials=FLAGS.n_trials) study.optimize(objective, n_jobs=1, n_trials=Config.n_trials)
print( print(
"Best params: lm_alpha={} and lm_beta={} with WER={}".format( "Best params: lm_alpha={} and lm_beta={} with WER={}".format(
study.best_params["lm_alpha"], study.best_params["lm_alpha"],
@ -76,5 +84,4 @@ def main(_):
if __name__ == "__main__": if __name__ == "__main__":
create_flags() main()
absl.app.run(main)