Fix transcribe.py - use new checkpoint load method

Replaced non existing try_loading() method with the saver method and respect load flag

Removed tf.train.Saver()
This commit is contained in:
Richard Hamnett 2020-02-21 15:56:58 +00:00 committed by GitHub
parent 4291db7309
commit e101cb8cc5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 8 deletions

View File

@ -27,7 +27,8 @@ def fail(message, code=1):
def transcribe_file(audio_path, tlog_path): def transcribe_file(audio_path, tlog_path):
from DeepSpeech import create_model, try_loading # pylint: disable=cyclic-import,import-outside-toplevel from DeepSpeech import create_model # pylint: disable=cyclic-import,import-outside-toplevel
from util.checkpoints import load_or_init_graph
initialize_globals() initialize_globals()
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet) scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet)
try: try:
@ -47,14 +48,12 @@ def transcribe_file(audio_path, tlog_path):
logits, _ = create_model(batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout) logits, _ = create_model(batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout)
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2])) transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
tf.train.get_or_create_global_step() tf.train.get_or_create_global_step()
saver = tf.train.Saver()
with tf.Session(config=Config.session_config) as session: with tf.Session(config=Config.session_config) as session:
loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation', log_success=False) if FLAGS.load == 'auto':
if not loaded: method_order = ['best', 'last']
loaded = try_loading(session, saver, 'checkpoint', 'most recent', log_success=False) else:
if not loaded: method_order = [FLAGS.load]
fail('Checkpoint directory ({}) does not contain a valid checkpoint state.' load_or_init_graph(session, method_order)
.format(FLAGS.checkpoint_dir))
session.run(iterator.make_initializer(data_set)) session.run(iterator.make_initializer(data_set))
transcripts = [] transcripts = []
while True: while True: