From 4c3537952adc643232c83b976865e78315743041 Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Thu, 19 Aug 2021 18:42:07 +0200 Subject: [PATCH] Fix lm_optimizer.py to use new Config/flags/logging setup --- lm_optimizer.py | 43 +++++++++++++++++++++++++------------------ 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/lm_optimizer.py b/lm_optimizer.py index 85ca1fd5..ae919640 100644 --- a/lm_optimizer.py +++ b/lm_optimizer.py @@ -4,36 +4,36 @@ from __future__ import absolute_import, print_function import sys -import absl.app import optuna import tensorflow.compat.v1 as tfv1 from coqui_stt_ctcdecoder import Scorer from coqui_stt_training.evaluate import evaluate -from coqui_stt_training.train import create_model -from coqui_stt_training.util.config import Config, initialize_globals_from_cli +from coqui_stt_training.train import create_model, early_training_checks +from coqui_stt_training.util.config import ( + Config, + initialize_globals_from_cli, + log_error, +) from coqui_stt_training.util.evaluate_tools import wer_cer_batch -from coqui_stt_training.util.flags import FLAGS, create_flags -from coqui_stt_training.util.logging import log_error def character_based(): is_character_based = False - if FLAGS.scorer_path: - scorer = Scorer( - FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet - ) - is_character_based = scorer.is_utf8_mode() + scorer = Scorer( + Config.lm_alpha, Config.lm_beta, Config.scorer_path, Config.alphabet + ) + is_character_based = scorer.is_utf8_mode() return is_character_based def objective(trial): - FLAGS.lm_alpha = trial.suggest_uniform("lm_alpha", 0, FLAGS.lm_alpha_max) - FLAGS.lm_beta = trial.suggest_uniform("lm_beta", 0, FLAGS.lm_beta_max) + Config.lm_alpha = trial.suggest_uniform("lm_alpha", 0, Config.lm_alpha_max) + Config.lm_beta = trial.suggest_uniform("lm_beta", 0, Config.lm_beta_max) is_character_based = trial.study.user_attrs["is_character_based"] samples = [] - for step, test_file in enumerate(FLAGS.test_files.split(",")): + for step, test_file in enumerate(Config.test_files): tfv1.reset_default_graph() current_samples = evaluate([test_file], create_model) @@ -51,10 +51,18 @@ def objective(trial): return cer if is_character_based else wer -def main(_): +def main(): initialize_globals_from_cli() + early_training_checks() - if not FLAGS.test_files: + if not Config.scorer_path: + log_error( + "Missing --scorer_path: can't optimize scorer alpha and beta " + "parameters without a scorer!" + ) + sys.exit(1) + + if not Config.test_files: log_error( "You need to specify what files to use for evaluation via " "the --test_files flag." @@ -65,7 +73,7 @@ def main(_): study = optuna.create_study() study.set_user_attr("is_character_based", is_character_based) - study.optimize(objective, n_jobs=1, n_trials=FLAGS.n_trials) + study.optimize(objective, n_jobs=1, n_trials=Config.n_trials) print( "Best params: lm_alpha={} and lm_beta={} with WER={}".format( study.best_params["lm_alpha"], @@ -76,5 +84,4 @@ def main(_): if __name__ == "__main__": - create_flags() - absl.app.run(main) + main()