Respect FLAGS.load in evaluate.py
This commit is contained in:
parent
e0d8ef75e8
commit
739841d731
12
evaluate.py
12
evaluate.py
@ -85,12 +85,14 @@ def evaluate(test_csvs, create_model, try_loading):
|
||||
|
||||
with tfv1.Session(config=Config.session_config) as session:
|
||||
# Restore variables from training checkpoint
|
||||
loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation')
|
||||
if not loaded:
|
||||
loaded = False
|
||||
if not loaded and FLAGS.load in ['auto', 'best']:
|
||||
loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation')
|
||||
if not loaded and FLAGS.load in ['auto', 'last']:
|
||||
loaded = try_loading(session, saver, 'checkpoint', 'most recent')
|
||||
if not loaded:
|
||||
log_error('Checkpoint directory ({}) does not contain a valid checkpoint state.'.format(FLAGS.checkpoint_dir))
|
||||
exit(1)
|
||||
print('Could not load checkpoint from {}'.format(FLAGS.checkpoint_dir))
|
||||
sys.exit(1)
|
||||
|
||||
def run_test(init_op, dataset):
|
||||
wav_filenames = []
|
||||
@ -159,7 +161,7 @@ def main(_):
|
||||
if not FLAGS.test_files:
|
||||
log_error('You need to specify what files to use for evaluation via '
|
||||
'the --test_files flag.')
|
||||
exit(1)
|
||||
sys.exit(1)
|
||||
|
||||
from DeepSpeech import create_model, try_loading # pylint: disable=cyclic-import
|
||||
samples = evaluate(FLAGS.test_files.split(','), create_model, try_loading)
|
||||
|
Loading…
x
Reference in New Issue
Block a user