diff --git a/DeepSpeech.py b/DeepSpeech.py index 92404e07..f3015aff 100644 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -75,7 +75,7 @@ def create_overlapping_windows(batch_x): def dense(name, x, units, dropout_rate=None, relu=True): - with tfv1.variable_scope(name): + with tfv1.variable_scope(name, reuse=tf.AUTO_REUSE): bias = variable_on_cpu('bias', [units], tfv1.zeros_initializer()) weights = variable_on_cpu('weights', [x.shape[-1], units], tfv1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform")) @@ -91,7 +91,7 @@ def dense(name, x, units, dropout_rate=None, relu=True): def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse): - with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell/cell_0'): + with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell/cell_0', reuse=tf.AUTO_REUSE): fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim, forget_bias=0, reuse=reuse, @@ -133,7 +133,7 @@ rnn_impl_cudnn_rnn.cell = None def rnn_impl_static_rnn(x, seq_length, previous_state, reuse): - with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell'): + with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell', reuse=tf.AUTO_REUSE): # Forward direction cell: fw_cell = tfv1.nn.rnn_cell.LSTMCell(Config.n_cell_dim, forget_bias=0, diff --git a/lm_optimizer.py b/lm_optimizer.py index 9e01ab96..70d2c9db 100644 --- a/lm_optimizer.py +++ b/lm_optimizer.py @@ -29,10 +29,22 @@ def objective(trial): FLAGS.lm_beta = trial.suggest_uniform('lm_beta', 0, FLAGS.lm_beta_max) tfv1.reset_default_graph() - samples = evaluate(FLAGS.test_files.split(','), create_model) is_character_based = trial.study.user_attrs['is_character_based'] + samples = [] + for step, test_file in enumerate(FLAGS.test_files.split(',')): + current_samples = evaluate([test_file], create_model, try_loading) + samples += current_samples + + # Report intermediate objective value. + wer, cer = wer_cer_batch(current_samples) + trial.report(cer if is_character_based else wer, step) + + # Handle pruning based on the intermediate value. + if trial.should_prune(): + raise optuna.exceptions.TrialPruned() + wer, cer = wer_cer_batch(samples) return cer if is_character_based else wer