Fix #2020 - Testing best-dev checkpoint
This commit is contained in:
parent
4b7c00fc36
commit
42e5d78e9a
@ -360,8 +360,6 @@ def try_loading(session, saver, checkpoint_filename, caption):
|
||||
' between train runs using the same checkpoint dir? Try moving'
|
||||
' or removing the contents of {0}.'.format(checkpoint_path))
|
||||
sys.exit(1)
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
def train():
|
||||
@ -549,7 +547,7 @@ def train():
|
||||
|
||||
|
||||
def test():
|
||||
evaluate.evaluate(FLAGS.test_files.split(','), create_model)
|
||||
evaluate.evaluate(FLAGS.test_files.split(','), create_model, try_loading)
|
||||
|
||||
|
||||
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
|
18
evaluate.py
18
evaluate.py
@ -38,7 +38,7 @@ def sparse_tuple_to_texts(tuple, alphabet):
|
||||
return results
|
||||
|
||||
|
||||
def evaluate(test_csvs, create_model):
|
||||
def evaluate(test_csvs, create_model, try_loading):
|
||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
||||
FLAGS.lm_binary_path, FLAGS.lm_trie_path,
|
||||
Config.alphabet)
|
||||
@ -63,19 +63,20 @@ def evaluate(test_csvs, create_model):
|
||||
inputs=logits,
|
||||
sequence_length=batch_x_len)
|
||||
|
||||
global_step = tf.train.create_global_step()
|
||||
|
||||
with tf.Session(config=Config.session_config) as session:
|
||||
# Create a saver using variables from the above newly created graph
|
||||
saver = tf.train.Saver()
|
||||
|
||||
# Restore variables from training checkpoint
|
||||
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
|
||||
if not checkpoint:
|
||||
loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation')
|
||||
if not loaded:
|
||||
loaded = try_loading(session, saver, 'checkpoint', 'most recent')
|
||||
if not loaded:
|
||||
log_error('Checkpoint directory ({}) does not contain a valid checkpoint state.'.format(FLAGS.checkpoint_dir))
|
||||
exit(1)
|
||||
|
||||
checkpoint_path = checkpoint.model_checkpoint_path
|
||||
saver.restore(session, checkpoint_path)
|
||||
|
||||
logitses = []
|
||||
losses = []
|
||||
seq_lengths = []
|
||||
@ -150,8 +151,8 @@ def main(_):
|
||||
'the --test_files flag.')
|
||||
exit(1)
|
||||
|
||||
from DeepSpeech import create_model
|
||||
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
||||
from DeepSpeech import create_model, try_loading
|
||||
samples = evaluate(FLAGS.test_files.split(','), create_model, try_loading)
|
||||
|
||||
if FLAGS.test_output_file:
|
||||
# Save decoded tuples as JSON, converting NumPy floats to Python floats
|
||||
@ -160,6 +161,5 @@ def main(_):
|
||||
|
||||
if __name__ == '__main__':
|
||||
create_flags()
|
||||
tf.app.flags.DEFINE_string('hdf5_test_set', '', 'path to hdf5 file to cache test set features')
|
||||
tf.app.flags.DEFINE_string('test_output_file', '', 'path to a file to save all src/decoded/distance/loss tuples')
|
||||
tf.app.run(main)
|
||||
|
Loading…
x
Reference in New Issue
Block a user