Merge pull request #2826 from TeHikuMedia/add_trial_pruning

Add trial pruning to lm_optimizer.py
This commit is contained in:
Kelly Davis 2020-04-01 14:56:31 +02:00 committed by GitHub
commit 0cc815f1f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -27,11 +27,23 @@ def objective(trial):
FLAGS.lm_alpha = trial.suggest_uniform('lm_alpha', 0, FLAGS.lm_alpha_max)
FLAGS.lm_beta = trial.suggest_uniform('lm_beta', 0, FLAGS.lm_beta_max)
tfv1.reset_default_graph()
samples = evaluate(FLAGS.test_files.split(','), create_model)
is_character_based = trial.study.user_attrs['is_character_based']
samples = []
for step, test_file in enumerate(FLAGS.test_files.split(',')):
tfv1.reset_default_graph()
current_samples = evaluate([test_file], create_model)
samples += current_samples
# Report intermediate objective value.
wer, cer = wer_cer_batch(current_samples)
trial.report(cer if is_character_based else wer, step)
# Handle pruning based on the intermediate value.
if trial.should_prune():
raise optuna.exceptions.TrialPruned()
wer, cer = wer_cer_batch(samples)
return cer if is_character_based else wer