diff --git a/training/deepspeech_training/evaluate.py b/training/deepspeech_training/evaluate.py index 3a03aa7e..8848b7a1 100755 --- a/training/deepspeech_training/evaluate.py +++ b/training/deepspeech_training/evaluate.py @@ -16,7 +16,7 @@ from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer from six.moves import zip from .util.config import Config, initialize_globals -from .util.checkpoints import load_or_init_graph +from .util.checkpoints import load_graph from .util.evaluate_tools import calculate_and_print_report from .util.feeding import create_dataset from .util.flags import create_flags, FLAGS @@ -86,7 +86,7 @@ def evaluate(test_csvs, create_model): method_order = ['best', 'last'] else: method_order = [FLAGS.load] - load_or_init_graph(session, method_order) + load_graph(session, method_order) def run_test(init_op, dataset): wav_filenames = [] diff --git a/training/deepspeech_training/train.py b/training/deepspeech_training/train.py index bbf6b7b3..45671e1f 100644 --- a/training/deepspeech_training/train.py +++ b/training/deepspeech_training/train.py @@ -30,7 +30,7 @@ from ds_ctcdecoder import ctc_beam_search_decoder, Scorer from .evaluate import evaluate from six.moves import zip, range from .util.config import Config, initialize_globals -from .util.checkpoints import load_or_init_graph +from .util.checkpoints import load_or_init_graph_for_training, load_graph from .util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features from .util.flags import create_flags, FLAGS from .util.helpers import check_ctcdecoder_version, ExceptionBox @@ -512,7 +512,7 @@ def train(): method_order = ['best', 'last', 'init'] else: method_order = [FLAGS.load] - load_or_init_graph(session, method_order) + load_or_init_graph_for_training(session, method_order) def run_set(set_name, epoch, init_op, dataset=None): is_train = set_name == 'train' @@ -777,7 +777,7 @@ def export(): method_order = ['best', 'last'] else: method_order = [FLAGS.load] - load_or_init_graph(session, method_order) + load_graph(session, method_order) output_filename = FLAGS.export_file_name + '.pb' if FLAGS.remove_export: @@ -861,7 +861,7 @@ def do_single_file_inference(input_file_path): method_order = ['best', 'last'] else: method_order = [FLAGS.load] - load_or_init_graph(session, method_order) + load_graph(session, method_order) features, features_len = audiofile_to_features(input_file_path) previous_state_c = np.zeros([1, Config.n_cell_dim]) diff --git a/training/deepspeech_training/util/checkpoints.py b/training/deepspeech_training/util/checkpoints.py index ac77dc29..80ffe48b 100644 --- a/training/deepspeech_training/util/checkpoints.py +++ b/training/deepspeech_training/util/checkpoints.py @@ -6,7 +6,7 @@ from .flags import FLAGS from .logging import log_info, log_error, log_warn -def _load_checkpoint(session, checkpoint_path): +def _load_checkpoint(session, checkpoint_path, allow_drop_layers): # Load the checkpoint and put all variables into loading list # we will exclude variables we do not wish to load and then # we will initialize them instead @@ -45,7 +45,7 @@ def _load_checkpoint(session, checkpoint_path): 'tensors. Missing variables: {}'.format(missing_var_names)) sys.exit(1) - if FLAGS.drop_source_layers > 0: + if allow_drop_layers and FLAGS.drop_source_layers > 0: # This transfer learning approach requires supplying # the layers which we exclude from the source model. # Say we want to exclude all layers except for the first one, @@ -87,7 +87,7 @@ def _initialize_all_variables(session): session.run(v.initializer) -def load_or_init_graph(session, method_order): +def load_or_init_graph_for_training(session, method_order, allow_drop_layers=True): ''' Load variables from checkpoint or initialize variables following the method order specified in the method_order parameter. @@ -100,7 +100,7 @@ def load_or_init_graph(session, method_order): ckpt_path = _checkpoint_path_or_none('best_dev_checkpoint') if ckpt_path: log_info('Loading best validating checkpoint from {}'.format(ckpt_path)) - return _load_checkpoint(session, ckpt_path) + return _load_checkpoint(session, ckpt_path, allow_drop_layers) log_info('Could not find best validating checkpoint.') # Load most recent checkpoint, saved in checkpoint file 'checkpoint' @@ -108,7 +108,7 @@ def load_or_init_graph(session, method_order): ckpt_path = _checkpoint_path_or_none('checkpoint') if ckpt_path: log_info('Loading most recent checkpoint from {}'.format(ckpt_path)) - return _load_checkpoint(session, ckpt_path) + return _load_checkpoint(session, ckpt_path, allow_drop_layers) log_info('Could not find most recent checkpoint.') # Initialize all variables @@ -122,3 +122,14 @@ def load_or_init_graph(session, method_order): log_error('All initialization methods failed ({}).'.format(method_order)) sys.exit(1) + + +def load_graph(session, method_order): + ''' + Load variables from checkpoint. Initialization is not allowed. Follows the + method order specified in the method_order parameter. + + Valid methods are 'best' and 'last'. + ''' + assert('init' not in method_order) + load_or_init_graph_for_training(session, method_order, allow_drop_layers=False) diff --git a/transcribe.py b/transcribe.py index d9bc2f13..aad4fd73 100755 --- a/transcribe.py +++ b/transcribe.py @@ -29,7 +29,7 @@ def fail(message, code=1): def transcribe_file(audio_path, tlog_path): from deepspeech_training.train import create_model # pylint: disable=cyclic-import,import-outside-toplevel - from deepspeech_training.util.checkpoints import load_or_init_graph + from deepspeech_training.util.checkpoints import load_graph initialize_globals() scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet) try: @@ -54,7 +54,7 @@ def transcribe_file(audio_path, tlog_path): method_order = ['best', 'last'] else: method_order = [FLAGS.load] - load_or_init_graph(session, method_order) + load_graph(session, method_order) session.run(iterator.make_initializer(data_set)) transcripts = [] while True: