From 42e5d78e9af7db5772682041f0e8cde82b102ac2 Mon Sep 17 00:00:00 2001 From: Tilman Kamp <5991088+tilmankamp@users.noreply.github.com> Date: Mon, 8 Apr 2019 16:35:36 +0200 Subject: [PATCH] Fix #2020 - Testing best-dev checkpoint --- DeepSpeech.py | 4 +--- evaluate.py | 18 +++++++++--------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/DeepSpeech.py b/DeepSpeech.py index 924902fa..b412fd1d 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -360,8 +360,6 @@ def try_loading(session, saver, checkpoint_filename, caption): ' between train runs using the same checkpoint dir? Try moving' ' or removing the contents of {0}.'.format(checkpoint_path)) sys.exit(1) - except: - return False def train(): @@ -549,7 +547,7 @@ def train(): def test(): - evaluate.evaluate(FLAGS.test_files.split(','), create_model) + evaluate.evaluate(FLAGS.test_files.split(','), create_model, try_loading) def create_inference_graph(batch_size=1, n_steps=16, tflite=False): diff --git a/evaluate.py b/evaluate.py index edb1f593..940d985b 100755 --- a/evaluate.py +++ b/evaluate.py @@ -38,7 +38,7 @@ def sparse_tuple_to_texts(tuple, alphabet): return results -def evaluate(test_csvs, create_model): +def evaluate(test_csvs, create_model, try_loading): scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.lm_binary_path, FLAGS.lm_trie_path, Config.alphabet) @@ -63,19 +63,20 @@ def evaluate(test_csvs, create_model): inputs=logits, sequence_length=batch_x_len) + global_step = tf.train.create_global_step() + with tf.Session(config=Config.session_config) as session: # Create a saver using variables from the above newly created graph saver = tf.train.Saver() # Restore variables from training checkpoint - checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) - if not checkpoint: + loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation') + if not loaded: + loaded = try_loading(session, saver, 'checkpoint', 'most recent') + if not loaded: log_error('Checkpoint directory ({}) does not contain a valid checkpoint state.'.format(FLAGS.checkpoint_dir)) exit(1) - checkpoint_path = checkpoint.model_checkpoint_path - saver.restore(session, checkpoint_path) - logitses = [] losses = [] seq_lengths = [] @@ -150,8 +151,8 @@ def main(_): 'the --test_files flag.') exit(1) - from DeepSpeech import create_model - samples = evaluate(FLAGS.test_files.split(','), create_model) + from DeepSpeech import create_model, try_loading + samples = evaluate(FLAGS.test_files.split(','), create_model, try_loading) if FLAGS.test_output_file: # Save decoded tuples as JSON, converting NumPy floats to Python floats @@ -160,6 +161,5 @@ def main(_): if __name__ == '__main__': create_flags() - tf.app.flags.DEFINE_string('hdf5_test_set', '', 'path to hdf5 file to cache test set features') tf.app.flags.DEFINE_string('test_output_file', '', 'path to a file to save all src/decoded/distance/loss tuples') tf.app.run(main)