diff --git a/optimizer.py b/optimizer.py new file mode 100644 index 00000000..9e01ab96 --- /dev/null +++ b/optimizer.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +from __future__ import absolute_import, print_function + +import sys + +import optuna +import absl.app +from ds_ctcdecoder import Scorer +import tensorflow.compat.v1 as tfv1 + +from DeepSpeech import create_model +from evaluate import evaluate +from util.config import Config, initialize_globals +from util.flags import create_flags, FLAGS +from util.logging import log_error +from util.evaluate_tools import wer_cer_batch + + +def character_based(): + is_character_based = False + if FLAGS.scorer_path: + scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet) + is_character_based = scorer.is_utf8_mode() + return is_character_based + +def objective(trial): + FLAGS.lm_alpha = trial.suggest_uniform('lm_alpha', 0, FLAGS.lm_alpha_max) + FLAGS.lm_beta = trial.suggest_uniform('lm_beta', 0, FLAGS.lm_beta_max) + + tfv1.reset_default_graph() + samples = evaluate(FLAGS.test_files.split(','), create_model) + + is_character_based = trial.study.user_attrs['is_character_based'] + + wer, cer = wer_cer_batch(samples) + return cer if is_character_based else wer + +def main(_): + initialize_globals() + + if not FLAGS.test_files: + log_error('You need to specify what files to use for evaluation via ' + 'the --test_files flag.') + sys.exit(1) + + is_character_based = character_based() + + study = optuna.create_study() + study.set_user_attr("is_character_based", is_character_based) + study.optimize(objective, n_jobs=1, n_trials=FLAGS.n_trials) + print('Best params: lm_alpha={} and lm_beta={} with WER={}'.format(study.best_params['lm_alpha'], + study.best_params['lm_beta'], + study.best_value)) + + +if __name__ == '__main__': + create_flags() + absl.app.run(main) diff --git a/requirements.txt b/requirements.txt index 742b8244..54159a37 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,3 +18,6 @@ bs4 requests librosa soundfile + +# Requirements for optimizer +optuna diff --git a/util/flags.py b/util/flags.py index 5057d76c..195e62f7 100644 --- a/util/flags.py +++ b/util/flags.py @@ -166,6 +166,12 @@ def create_flags(): f.DEFINE_string('one_shot_infer', '', 'one-shot inference mode: specify a wav file and the script will load the checkpoint and perform inference on it.') + # Optimizer mode + + f.DEFINE_float('lm_alpha_max', 5, 'the maximum of the alpha hyperparameter of the CTC decoder explored during hyperparameter optimization. Language Model weight.') + f.DEFINE_float('lm_beta_max', 5, 'the maximum beta hyperparameter of the CTC decoder explored during hyperparameter optimization. Word insertion weight.') + f.DEFINE_integer('n_trials', 2400, 'the number of trials to run during hyperparameter optimization.') + # Register validators for paths which require a file to be specified f.register_validator('alphabet_config_path',