Add trial pruning to lm_optimizer.py
This commit is contained in:
parent
29a2ac37f0
commit
c9e6cbc958
|
@ -75,7 +75,7 @@ def create_overlapping_windows(batch_x):
|
||||||
|
|
||||||
|
|
||||||
def dense(name, x, units, dropout_rate=None, relu=True):
|
def dense(name, x, units, dropout_rate=None, relu=True):
|
||||||
with tfv1.variable_scope(name):
|
with tfv1.variable_scope(name, reuse=tf.AUTO_REUSE):
|
||||||
bias = variable_on_cpu('bias', [units], tfv1.zeros_initializer())
|
bias = variable_on_cpu('bias', [units], tfv1.zeros_initializer())
|
||||||
weights = variable_on_cpu('weights', [x.shape[-1], units], tfv1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform"))
|
weights = variable_on_cpu('weights', [x.shape[-1], units], tfv1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform"))
|
||||||
|
|
||||||
|
@ -91,7 +91,7 @@ def dense(name, x, units, dropout_rate=None, relu=True):
|
||||||
|
|
||||||
|
|
||||||
def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse):
|
def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse):
|
||||||
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell/cell_0'):
|
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell/cell_0', reuse=tf.AUTO_REUSE):
|
||||||
fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim,
|
fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim,
|
||||||
forget_bias=0,
|
forget_bias=0,
|
||||||
reuse=reuse,
|
reuse=reuse,
|
||||||
|
@ -133,7 +133,7 @@ rnn_impl_cudnn_rnn.cell = None
|
||||||
|
|
||||||
|
|
||||||
def rnn_impl_static_rnn(x, seq_length, previous_state, reuse):
|
def rnn_impl_static_rnn(x, seq_length, previous_state, reuse):
|
||||||
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell'):
|
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell', reuse=tf.AUTO_REUSE):
|
||||||
# Forward direction cell:
|
# Forward direction cell:
|
||||||
fw_cell = tfv1.nn.rnn_cell.LSTMCell(Config.n_cell_dim,
|
fw_cell = tfv1.nn.rnn_cell.LSTMCell(Config.n_cell_dim,
|
||||||
forget_bias=0,
|
forget_bias=0,
|
||||||
|
|
|
@ -29,10 +29,22 @@ def objective(trial):
|
||||||
FLAGS.lm_beta = trial.suggest_uniform('lm_beta', 0, FLAGS.lm_beta_max)
|
FLAGS.lm_beta = trial.suggest_uniform('lm_beta', 0, FLAGS.lm_beta_max)
|
||||||
|
|
||||||
tfv1.reset_default_graph()
|
tfv1.reset_default_graph()
|
||||||
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
|
||||||
|
|
||||||
is_character_based = trial.study.user_attrs['is_character_based']
|
is_character_based = trial.study.user_attrs['is_character_based']
|
||||||
|
|
||||||
|
samples = []
|
||||||
|
for step, test_file in enumerate(FLAGS.test_files.split(',')):
|
||||||
|
current_samples = evaluate([test_file], create_model, try_loading)
|
||||||
|
samples += current_samples
|
||||||
|
|
||||||
|
# Report intermediate objective value.
|
||||||
|
wer, cer = wer_cer_batch(current_samples)
|
||||||
|
trial.report(cer if is_character_based else wer, step)
|
||||||
|
|
||||||
|
# Handle pruning based on the intermediate value.
|
||||||
|
if trial.should_prune():
|
||||||
|
raise optuna.exceptions.TrialPruned()
|
||||||
|
|
||||||
wer, cer = wer_cer_batch(samples)
|
wer, cer = wer_cer_batch(samples)
|
||||||
return cer if is_character_based else wer
|
return cer if is_character_based else wer
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue