Added optimizer for lm_alpha + lm_beta
This commit is contained in:
parent
84ac39769c
commit
561131a05c
59
optimizer.py
Normal file
59
optimizer.py
Normal file
@ -0,0 +1,59 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, print_function
|
||||
|
||||
import sys
|
||||
|
||||
import optuna
|
||||
import absl.app
|
||||
from ds_ctcdecoder import Scorer
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
|
||||
from DeepSpeech import create_model
|
||||
from evaluate import evaluate
|
||||
from util.config import Config, initialize_globals
|
||||
from util.flags import create_flags, FLAGS
|
||||
from util.logging import log_error
|
||||
from util.evaluate_tools import wer_cer_batch
|
||||
|
||||
|
||||
def character_based():
|
||||
is_character_based = False
|
||||
if FLAGS.scorer_path:
|
||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet)
|
||||
is_character_based = scorer.is_utf8_mode()
|
||||
return is_character_based
|
||||
|
||||
def objective(trial):
|
||||
FLAGS.lm_alpha = trial.suggest_uniform('lm_alpha', 0, FLAGS.lm_alpha_max)
|
||||
FLAGS.lm_beta = trial.suggest_uniform('lm_beta', 0, FLAGS.lm_beta_max)
|
||||
|
||||
tfv1.reset_default_graph()
|
||||
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
||||
|
||||
is_character_based = trial.study.user_attrs['is_character_based']
|
||||
|
||||
wer, cer = wer_cer_batch(samples)
|
||||
return cer if is_character_based else wer
|
||||
|
||||
def main(_):
|
||||
initialize_globals()
|
||||
|
||||
if not FLAGS.test_files:
|
||||
log_error('You need to specify what files to use for evaluation via '
|
||||
'the --test_files flag.')
|
||||
sys.exit(1)
|
||||
|
||||
is_character_based = character_based()
|
||||
|
||||
study = optuna.create_study()
|
||||
study.set_user_attr("is_character_based", is_character_based)
|
||||
study.optimize(objective, n_jobs=1, n_trials=FLAGS.n_trials)
|
||||
print('Best params: lm_alpha={} and lm_beta={} with WER={}'.format(study.best_params['lm_alpha'],
|
||||
study.best_params['lm_beta'],
|
||||
study.best_value))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
create_flags()
|
||||
absl.app.run(main)
|
@ -18,3 +18,6 @@ bs4
|
||||
requests
|
||||
librosa
|
||||
soundfile
|
||||
|
||||
# Requirements for optimizer
|
||||
optuna
|
||||
|
@ -166,6 +166,12 @@ def create_flags():
|
||||
|
||||
f.DEFINE_string('one_shot_infer', '', 'one-shot inference mode: specify a wav file and the script will load the checkpoint and perform inference on it.')
|
||||
|
||||
# Optimizer mode
|
||||
|
||||
f.DEFINE_float('lm_alpha_max', 5, 'the maximum of the alpha hyperparameter of the CTC decoder explored during hyperparameter optimization. Language Model weight.')
|
||||
f.DEFINE_float('lm_beta_max', 5, 'the maximum beta hyperparameter of the CTC decoder explored during hyperparameter optimization. Word insertion weight.')
|
||||
f.DEFINE_integer('n_trials', 2400, 'the number of trials to run during hyperparameter optimization.')
|
||||
|
||||
# Register validators for paths which require a file to be specified
|
||||
|
||||
f.register_validator('alphabet_config_path',
|
||||
|
Loading…
Reference in New Issue
Block a user