diff --git a/lm_optimizer.py b/lm_optimizer.py index 78215682..25d8a05e 100644 --- a/lm_optimizer.py +++ b/lm_optimizer.py @@ -27,11 +27,23 @@ def objective(trial): FLAGS.lm_alpha = trial.suggest_uniform('lm_alpha', 0, FLAGS.lm_alpha_max) FLAGS.lm_beta = trial.suggest_uniform('lm_beta', 0, FLAGS.lm_beta_max) - tfv1.reset_default_graph() - samples = evaluate(FLAGS.test_files.split(','), create_model) - is_character_based = trial.study.user_attrs['is_character_based'] + samples = [] + for step, test_file in enumerate(FLAGS.test_files.split(',')): + tfv1.reset_default_graph() + + current_samples = evaluate([test_file], create_model) + samples += current_samples + + # Report intermediate objective value. + wer, cer = wer_cer_batch(current_samples) + trial.report(cer if is_character_based else wer, step) + + # Handle pruning based on the intermediate value. + if trial.should_prune(): + raise optuna.exceptions.TrialPruned() + wer, cer = wer_cer_batch(samples) return cer if is_character_based else wer