Merge pull request #2826 from TeHikuMedia/add_trial_pruning
Add trial pruning to lm_optimizer.py
This commit is contained in:
commit
0cc815f1f0
|
@ -27,11 +27,23 @@ def objective(trial):
|
||||||
FLAGS.lm_alpha = trial.suggest_uniform('lm_alpha', 0, FLAGS.lm_alpha_max)
|
FLAGS.lm_alpha = trial.suggest_uniform('lm_alpha', 0, FLAGS.lm_alpha_max)
|
||||||
FLAGS.lm_beta = trial.suggest_uniform('lm_beta', 0, FLAGS.lm_beta_max)
|
FLAGS.lm_beta = trial.suggest_uniform('lm_beta', 0, FLAGS.lm_beta_max)
|
||||||
|
|
||||||
tfv1.reset_default_graph()
|
|
||||||
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
|
||||||
|
|
||||||
is_character_based = trial.study.user_attrs['is_character_based']
|
is_character_based = trial.study.user_attrs['is_character_based']
|
||||||
|
|
||||||
|
samples = []
|
||||||
|
for step, test_file in enumerate(FLAGS.test_files.split(',')):
|
||||||
|
tfv1.reset_default_graph()
|
||||||
|
|
||||||
|
current_samples = evaluate([test_file], create_model)
|
||||||
|
samples += current_samples
|
||||||
|
|
||||||
|
# Report intermediate objective value.
|
||||||
|
wer, cer = wer_cer_batch(current_samples)
|
||||||
|
trial.report(cer if is_character_based else wer, step)
|
||||||
|
|
||||||
|
# Handle pruning based on the intermediate value.
|
||||||
|
if trial.should_prune():
|
||||||
|
raise optuna.exceptions.TrialPruned()
|
||||||
|
|
||||||
wer, cer = wer_cer_batch(samples)
|
wer, cer = wer_cer_batch(samples)
|
||||||
return cer if is_character_based else wer
|
return cer if is_character_based else wer
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue