diff --git a/transcribe.py b/transcribe.py index c66bbe61..0b3e4aa7 100755 --- a/transcribe.py +++ b/transcribe.py @@ -27,7 +27,8 @@ def fail(message, code=1): def transcribe_file(audio_path, tlog_path): - from DeepSpeech import create_model, try_loading # pylint: disable=cyclic-import,import-outside-toplevel + from DeepSpeech import create_model # pylint: disable=cyclic-import,import-outside-toplevel + from util.checkpoints import load_or_init_graph initialize_globals() scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet) try: @@ -47,14 +48,12 @@ def transcribe_file(audio_path, tlog_path): logits, _ = create_model(batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout) transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2])) tf.train.get_or_create_global_step() - saver = tf.train.Saver() with tf.Session(config=Config.session_config) as session: - loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation', log_success=False) - if not loaded: - loaded = try_loading(session, saver, 'checkpoint', 'most recent', log_success=False) - if not loaded: - fail('Checkpoint directory ({}) does not contain a valid checkpoint state.' - .format(FLAGS.checkpoint_dir)) + if FLAGS.load == 'auto': + method_order = ['best', 'last'] + else: + method_order = [FLAGS.load] + load_or_init_graph(session, method_order) session.run(iterator.make_initializer(data_set)) transcripts = [] while True: