Run reset_default_graph before every evaluate

This commit is contained in:
Caleb Moses 2020-03-18 10:38:51 +13:00
parent c9e6cbc958
commit 8e37a5cfb4
2 changed files with 5 additions and 5 deletions

View File

@ -75,7 +75,7 @@ def create_overlapping_windows(batch_x):
def dense(name, x, units, dropout_rate=None, relu=True): def dense(name, x, units, dropout_rate=None, relu=True):
with tfv1.variable_scope(name, reuse=tf.AUTO_REUSE): with tfv1.variable_scope(name):
bias = variable_on_cpu('bias', [units], tfv1.zeros_initializer()) bias = variable_on_cpu('bias', [units], tfv1.zeros_initializer())
weights = variable_on_cpu('weights', [x.shape[-1], units], tfv1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform")) weights = variable_on_cpu('weights', [x.shape[-1], units], tfv1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform"))
@ -91,7 +91,7 @@ def dense(name, x, units, dropout_rate=None, relu=True):
def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse): def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse):
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell/cell_0', reuse=tf.AUTO_REUSE): with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell/cell_0'):
fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim, fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim,
forget_bias=0, forget_bias=0,
reuse=reuse, reuse=reuse,
@ -133,7 +133,7 @@ rnn_impl_cudnn_rnn.cell = None
def rnn_impl_static_rnn(x, seq_length, previous_state, reuse): def rnn_impl_static_rnn(x, seq_length, previous_state, reuse):
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell', reuse=tf.AUTO_REUSE): with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell'):
# Forward direction cell: # Forward direction cell:
fw_cell = tfv1.nn.rnn_cell.LSTMCell(Config.n_cell_dim, fw_cell = tfv1.nn.rnn_cell.LSTMCell(Config.n_cell_dim,
forget_bias=0, forget_bias=0,

View File

@ -28,12 +28,12 @@ def objective(trial):
FLAGS.lm_alpha = trial.suggest_uniform('lm_alpha', 0, FLAGS.lm_alpha_max) FLAGS.lm_alpha = trial.suggest_uniform('lm_alpha', 0, FLAGS.lm_alpha_max)
FLAGS.lm_beta = trial.suggest_uniform('lm_beta', 0, FLAGS.lm_beta_max) FLAGS.lm_beta = trial.suggest_uniform('lm_beta', 0, FLAGS.lm_beta_max)
tfv1.reset_default_graph()
is_character_based = trial.study.user_attrs['is_character_based'] is_character_based = trial.study.user_attrs['is_character_based']
samples = [] samples = []
for step, test_file in enumerate(FLAGS.test_files.split(',')): for step, test_file in enumerate(FLAGS.test_files.split(',')):
tfv1.reset_default_graph()
current_samples = evaluate([test_file], create_model, try_loading) current_samples = evaluate([test_file], create_model, try_loading)
samples += current_samples samples += current_samples