Fix transcribe.py - use new checkpoint load method
Replaced non existing try_loading() method with the saver method and respect load flag Removed tf.train.Saver()
This commit is contained in:
		
							parent
							
								
									4291db7309
								
							
						
					
					
						commit
						e101cb8cc5
					
				@ -27,7 +27,8 @@ def fail(message, code=1):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def transcribe_file(audio_path, tlog_path):
 | 
			
		||||
    from DeepSpeech import create_model, try_loading  # pylint: disable=cyclic-import,import-outside-toplevel
 | 
			
		||||
    from DeepSpeech import create_model  # pylint: disable=cyclic-import,import-outside-toplevel
 | 
			
		||||
    from util.checkpoints import load_or_init_graph
 | 
			
		||||
    initialize_globals()
 | 
			
		||||
    scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet)
 | 
			
		||||
    try:
 | 
			
		||||
@ -47,14 +48,12 @@ def transcribe_file(audio_path, tlog_path):
 | 
			
		||||
        logits, _ = create_model(batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout)
 | 
			
		||||
        transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
 | 
			
		||||
        tf.train.get_or_create_global_step()
 | 
			
		||||
        saver = tf.train.Saver()
 | 
			
		||||
        with tf.Session(config=Config.session_config) as session:
 | 
			
		||||
            loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation', log_success=False)
 | 
			
		||||
            if not loaded:
 | 
			
		||||
                loaded = try_loading(session, saver, 'checkpoint', 'most recent', log_success=False)
 | 
			
		||||
            if not loaded:
 | 
			
		||||
                fail('Checkpoint directory ({}) does not contain a valid checkpoint state.'
 | 
			
		||||
                     .format(FLAGS.checkpoint_dir))
 | 
			
		||||
            if FLAGS.load == 'auto':
 | 
			
		||||
                method_order = ['best', 'last']
 | 
			
		||||
            else:
 | 
			
		||||
                method_order = [FLAGS.load]
 | 
			
		||||
            load_or_init_graph(session, method_order)
 | 
			
		||||
            session.run(iterator.make_initializer(data_set))
 | 
			
		||||
            transcripts = []
 | 
			
		||||
            while True:
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user