Respect FLAGS.load in evaluate.py

This commit is contained in:
Reuben Morais 2019-10-14 11:38:09 +02:00
parent e0d8ef75e8
commit 739841d731

View File

@ -85,12 +85,14 @@ def evaluate(test_csvs, create_model, try_loading):
with tfv1.Session(config=Config.session_config) as session:
# Restore variables from training checkpoint
loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation')
if not loaded:
loaded = False
if not loaded and FLAGS.load in ['auto', 'best']:
loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation')
if not loaded and FLAGS.load in ['auto', 'last']:
loaded = try_loading(session, saver, 'checkpoint', 'most recent')
if not loaded:
log_error('Checkpoint directory ({}) does not contain a valid checkpoint state.'.format(FLAGS.checkpoint_dir))
exit(1)
print('Could not load checkpoint from {}'.format(FLAGS.checkpoint_dir))
sys.exit(1)
def run_test(init_op, dataset):
wav_filenames = []
@ -159,7 +161,7 @@ def main(_):
if not FLAGS.test_files:
log_error('You need to specify what files to use for evaluation via '
'the --test_files flag.')
exit(1)
sys.exit(1)
from DeepSpeech import create_model, try_loading # pylint: disable=cyclic-import
samples = evaluate(FLAGS.test_files.split(','), create_model, try_loading)