Merge pull request #2826 from TeHikuMedia/add_trial_pruning
Add trial pruning to lm_optimizer.py
This commit is contained in:
commit
0cc815f1f0
@ -27,11 +27,23 @@ def objective(trial):
|
||||
FLAGS.lm_alpha = trial.suggest_uniform('lm_alpha', 0, FLAGS.lm_alpha_max)
|
||||
FLAGS.lm_beta = trial.suggest_uniform('lm_beta', 0, FLAGS.lm_beta_max)
|
||||
|
||||
tfv1.reset_default_graph()
|
||||
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
||||
|
||||
is_character_based = trial.study.user_attrs['is_character_based']
|
||||
|
||||
samples = []
|
||||
for step, test_file in enumerate(FLAGS.test_files.split(',')):
|
||||
tfv1.reset_default_graph()
|
||||
|
||||
current_samples = evaluate([test_file], create_model)
|
||||
samples += current_samples
|
||||
|
||||
# Report intermediate objective value.
|
||||
wer, cer = wer_cer_batch(current_samples)
|
||||
trial.report(cer if is_character_based else wer, step)
|
||||
|
||||
# Handle pruning based on the intermediate value.
|
||||
if trial.should_prune():
|
||||
raise optuna.exceptions.TrialPruned()
|
||||
|
||||
wer, cer = wer_cer_batch(samples)
|
||||
return cer if is_character_based else wer
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user