Only allow graph/layer initialization at start of training

This commit is contained in:
Reuben Morais 2020-04-04 17:53:03 +02:00
parent b5a805056f
commit cc7a0ada46
4 changed files with 24 additions and 13 deletions

View File

@ -16,7 +16,7 @@ from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
from six.moves import zip
from .util.config import Config, initialize_globals
from .util.checkpoints import load_or_init_graph
from .util.checkpoints import load_graph
from .util.evaluate_tools import calculate_and_print_report
from .util.feeding import create_dataset
from .util.flags import create_flags, FLAGS
@ -86,7 +86,7 @@ def evaluate(test_csvs, create_model):
method_order = ['best', 'last']
else:
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
load_graph(session, method_order)
def run_test(init_op, dataset):
wav_filenames = []

View File

@ -30,7 +30,7 @@ from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
from .evaluate import evaluate
from six.moves import zip, range
from .util.config import Config, initialize_globals
from .util.checkpoints import load_or_init_graph
from .util.checkpoints import load_or_init_graph_for_training, load_graph
from .util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
from .util.flags import create_flags, FLAGS
from .util.helpers import check_ctcdecoder_version, ExceptionBox
@ -512,7 +512,7 @@ def train():
method_order = ['best', 'last', 'init']
else:
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
load_or_init_graph_for_training(session, method_order)
def run_set(set_name, epoch, init_op, dataset=None):
is_train = set_name == 'train'
@ -777,7 +777,7 @@ def export():
method_order = ['best', 'last']
else:
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
load_graph(session, method_order)
output_filename = FLAGS.export_file_name + '.pb'
if FLAGS.remove_export:
@ -861,7 +861,7 @@ def do_single_file_inference(input_file_path):
method_order = ['best', 'last']
else:
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
load_graph(session, method_order)
features, features_len = audiofile_to_features(input_file_path)
previous_state_c = np.zeros([1, Config.n_cell_dim])

View File

@ -6,7 +6,7 @@ from .flags import FLAGS
from .logging import log_info, log_error, log_warn
def _load_checkpoint(session, checkpoint_path):
def _load_checkpoint(session, checkpoint_path, allow_drop_layers):
# Load the checkpoint and put all variables into loading list
# we will exclude variables we do not wish to load and then
# we will initialize them instead
@ -45,7 +45,7 @@ def _load_checkpoint(session, checkpoint_path):
'tensors. Missing variables: {}'.format(missing_var_names))
sys.exit(1)
if FLAGS.drop_source_layers > 0:
if allow_drop_layers and FLAGS.drop_source_layers > 0:
# This transfer learning approach requires supplying
# the layers which we exclude from the source model.
# Say we want to exclude all layers except for the first one,
@ -87,7 +87,7 @@ def _initialize_all_variables(session):
session.run(v.initializer)
def load_or_init_graph(session, method_order):
def load_or_init_graph_for_training(session, method_order, allow_drop_layers=True):
'''
Load variables from checkpoint or initialize variables following the method
order specified in the method_order parameter.
@ -100,7 +100,7 @@ def load_or_init_graph(session, method_order):
ckpt_path = _checkpoint_path_or_none('best_dev_checkpoint')
if ckpt_path:
log_info('Loading best validating checkpoint from {}'.format(ckpt_path))
return _load_checkpoint(session, ckpt_path)
return _load_checkpoint(session, ckpt_path, allow_drop_layers)
log_info('Could not find best validating checkpoint.')
# Load most recent checkpoint, saved in checkpoint file 'checkpoint'
@ -108,7 +108,7 @@ def load_or_init_graph(session, method_order):
ckpt_path = _checkpoint_path_or_none('checkpoint')
if ckpt_path:
log_info('Loading most recent checkpoint from {}'.format(ckpt_path))
return _load_checkpoint(session, ckpt_path)
return _load_checkpoint(session, ckpt_path, allow_drop_layers)
log_info('Could not find most recent checkpoint.')
# Initialize all variables
@ -122,3 +122,14 @@ def load_or_init_graph(session, method_order):
log_error('All initialization methods failed ({}).'.format(method_order))
sys.exit(1)
def load_graph(session, method_order):
'''
Load variables from checkpoint. Initialization is not allowed. Follows the
method order specified in the method_order parameter.
Valid methods are 'best' and 'last'.
'''
assert('init' not in method_order)
load_or_init_graph_for_training(session, method_order, allow_drop_layers=False)

View File

@ -29,7 +29,7 @@ def fail(message, code=1):
def transcribe_file(audio_path, tlog_path):
from deepspeech_training.train import create_model # pylint: disable=cyclic-import,import-outside-toplevel
from deepspeech_training.util.checkpoints import load_or_init_graph
from deepspeech_training.util.checkpoints import load_graph
initialize_globals()
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet)
try:
@ -54,7 +54,7 @@ def transcribe_file(audio_path, tlog_path):
method_order = ['best', 'last']
else:
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
load_graph(session, method_order)
session.run(iterator.make_initializer(data_set))
transcripts = []
while True: