Merge pull request #2780 from rhamnett/patch-1
Fix transcribe.py - use new checkpoint load method
This commit is contained in:
commit
aff310d73a
|
@ -27,7 +27,8 @@ def fail(message, code=1):
|
||||||
|
|
||||||
|
|
||||||
def transcribe_file(audio_path, tlog_path):
|
def transcribe_file(audio_path, tlog_path):
|
||||||
from DeepSpeech import create_model, try_loading # pylint: disable=cyclic-import,import-outside-toplevel
|
from DeepSpeech import create_model # pylint: disable=cyclic-import,import-outside-toplevel
|
||||||
|
from util.checkpoints import load_or_init_graph
|
||||||
initialize_globals()
|
initialize_globals()
|
||||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet)
|
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet)
|
||||||
try:
|
try:
|
||||||
|
@ -47,14 +48,12 @@ def transcribe_file(audio_path, tlog_path):
|
||||||
logits, _ = create_model(batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout)
|
logits, _ = create_model(batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout)
|
||||||
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
|
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
|
||||||
tf.train.get_or_create_global_step()
|
tf.train.get_or_create_global_step()
|
||||||
saver = tf.train.Saver()
|
|
||||||
with tf.Session(config=Config.session_config) as session:
|
with tf.Session(config=Config.session_config) as session:
|
||||||
loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation', log_success=False)
|
if FLAGS.load == 'auto':
|
||||||
if not loaded:
|
method_order = ['best', 'last']
|
||||||
loaded = try_loading(session, saver, 'checkpoint', 'most recent', log_success=False)
|
else:
|
||||||
if not loaded:
|
method_order = [FLAGS.load]
|
||||||
fail('Checkpoint directory ({}) does not contain a valid checkpoint state.'
|
load_or_init_graph(session, method_order)
|
||||||
.format(FLAGS.checkpoint_dir))
|
|
||||||
session.run(iterator.make_initializer(data_set))
|
session.run(iterator.make_initializer(data_set))
|
||||||
transcripts = []
|
transcripts = []
|
||||||
while True:
|
while True:
|
||||||
|
|
Loading…
Reference in New Issue