Fix transcribe.py - use new checkpoint load method
Replaced non existing try_loading() method with the saver method and respect load flag Removed tf.train.Saver()
This commit is contained in:
parent
4291db7309
commit
e101cb8cc5
|
@ -27,7 +27,8 @@ def fail(message, code=1):
|
|||
|
||||
|
||||
def transcribe_file(audio_path, tlog_path):
|
||||
from DeepSpeech import create_model, try_loading # pylint: disable=cyclic-import,import-outside-toplevel
|
||||
from DeepSpeech import create_model # pylint: disable=cyclic-import,import-outside-toplevel
|
||||
from util.checkpoints import load_or_init_graph
|
||||
initialize_globals()
|
||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet)
|
||||
try:
|
||||
|
@ -47,14 +48,12 @@ def transcribe_file(audio_path, tlog_path):
|
|||
logits, _ = create_model(batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout)
|
||||
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
|
||||
tf.train.get_or_create_global_step()
|
||||
saver = tf.train.Saver()
|
||||
with tf.Session(config=Config.session_config) as session:
|
||||
loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation', log_success=False)
|
||||
if not loaded:
|
||||
loaded = try_loading(session, saver, 'checkpoint', 'most recent', log_success=False)
|
||||
if not loaded:
|
||||
fail('Checkpoint directory ({}) does not contain a valid checkpoint state.'
|
||||
.format(FLAGS.checkpoint_dir))
|
||||
if FLAGS.load == 'auto':
|
||||
method_order = ['best', 'last']
|
||||
else:
|
||||
method_order = [FLAGS.load]
|
||||
load_or_init_graph(session, method_order)
|
||||
session.run(iterator.make_initializer(data_set))
|
||||
transcripts = []
|
||||
while True:
|
||||
|
|
Loading…
Reference in New Issue