Only allow graph/layer initialization at start of training
This commit is contained in:
parent
b5a805056f
commit
cc7a0ada46
@ -16,7 +16,7 @@ from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
|
||||
from six.moves import zip
|
||||
|
||||
from .util.config import Config, initialize_globals
|
||||
from .util.checkpoints import load_or_init_graph
|
||||
from .util.checkpoints import load_graph
|
||||
from .util.evaluate_tools import calculate_and_print_report
|
||||
from .util.feeding import create_dataset
|
||||
from .util.flags import create_flags, FLAGS
|
||||
@ -86,7 +86,7 @@ def evaluate(test_csvs, create_model):
|
||||
method_order = ['best', 'last']
|
||||
else:
|
||||
method_order = [FLAGS.load]
|
||||
load_or_init_graph(session, method_order)
|
||||
load_graph(session, method_order)
|
||||
|
||||
def run_test(init_op, dataset):
|
||||
wav_filenames = []
|
||||
|
@ -30,7 +30,7 @@ from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
|
||||
from .evaluate import evaluate
|
||||
from six.moves import zip, range
|
||||
from .util.config import Config, initialize_globals
|
||||
from .util.checkpoints import load_or_init_graph
|
||||
from .util.checkpoints import load_or_init_graph_for_training, load_graph
|
||||
from .util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
|
||||
from .util.flags import create_flags, FLAGS
|
||||
from .util.helpers import check_ctcdecoder_version, ExceptionBox
|
||||
@ -512,7 +512,7 @@ def train():
|
||||
method_order = ['best', 'last', 'init']
|
||||
else:
|
||||
method_order = [FLAGS.load]
|
||||
load_or_init_graph(session, method_order)
|
||||
load_or_init_graph_for_training(session, method_order)
|
||||
|
||||
def run_set(set_name, epoch, init_op, dataset=None):
|
||||
is_train = set_name == 'train'
|
||||
@ -777,7 +777,7 @@ def export():
|
||||
method_order = ['best', 'last']
|
||||
else:
|
||||
method_order = [FLAGS.load]
|
||||
load_or_init_graph(session, method_order)
|
||||
load_graph(session, method_order)
|
||||
|
||||
output_filename = FLAGS.export_file_name + '.pb'
|
||||
if FLAGS.remove_export:
|
||||
@ -861,7 +861,7 @@ def do_single_file_inference(input_file_path):
|
||||
method_order = ['best', 'last']
|
||||
else:
|
||||
method_order = [FLAGS.load]
|
||||
load_or_init_graph(session, method_order)
|
||||
load_graph(session, method_order)
|
||||
|
||||
features, features_len = audiofile_to_features(input_file_path)
|
||||
previous_state_c = np.zeros([1, Config.n_cell_dim])
|
||||
|
@ -6,7 +6,7 @@ from .flags import FLAGS
|
||||
from .logging import log_info, log_error, log_warn
|
||||
|
||||
|
||||
def _load_checkpoint(session, checkpoint_path):
|
||||
def _load_checkpoint(session, checkpoint_path, allow_drop_layers):
|
||||
# Load the checkpoint and put all variables into loading list
|
||||
# we will exclude variables we do not wish to load and then
|
||||
# we will initialize them instead
|
||||
@ -45,7 +45,7 @@ def _load_checkpoint(session, checkpoint_path):
|
||||
'tensors. Missing variables: {}'.format(missing_var_names))
|
||||
sys.exit(1)
|
||||
|
||||
if FLAGS.drop_source_layers > 0:
|
||||
if allow_drop_layers and FLAGS.drop_source_layers > 0:
|
||||
# This transfer learning approach requires supplying
|
||||
# the layers which we exclude from the source model.
|
||||
# Say we want to exclude all layers except for the first one,
|
||||
@ -87,7 +87,7 @@ def _initialize_all_variables(session):
|
||||
session.run(v.initializer)
|
||||
|
||||
|
||||
def load_or_init_graph(session, method_order):
|
||||
def load_or_init_graph_for_training(session, method_order, allow_drop_layers=True):
|
||||
'''
|
||||
Load variables from checkpoint or initialize variables following the method
|
||||
order specified in the method_order parameter.
|
||||
@ -100,7 +100,7 @@ def load_or_init_graph(session, method_order):
|
||||
ckpt_path = _checkpoint_path_or_none('best_dev_checkpoint')
|
||||
if ckpt_path:
|
||||
log_info('Loading best validating checkpoint from {}'.format(ckpt_path))
|
||||
return _load_checkpoint(session, ckpt_path)
|
||||
return _load_checkpoint(session, ckpt_path, allow_drop_layers)
|
||||
log_info('Could not find best validating checkpoint.')
|
||||
|
||||
# Load most recent checkpoint, saved in checkpoint file 'checkpoint'
|
||||
@ -108,7 +108,7 @@ def load_or_init_graph(session, method_order):
|
||||
ckpt_path = _checkpoint_path_or_none('checkpoint')
|
||||
if ckpt_path:
|
||||
log_info('Loading most recent checkpoint from {}'.format(ckpt_path))
|
||||
return _load_checkpoint(session, ckpt_path)
|
||||
return _load_checkpoint(session, ckpt_path, allow_drop_layers)
|
||||
log_info('Could not find most recent checkpoint.')
|
||||
|
||||
# Initialize all variables
|
||||
@ -122,3 +122,14 @@ def load_or_init_graph(session, method_order):
|
||||
|
||||
log_error('All initialization methods failed ({}).'.format(method_order))
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def load_graph(session, method_order):
|
||||
'''
|
||||
Load variables from checkpoint. Initialization is not allowed. Follows the
|
||||
method order specified in the method_order parameter.
|
||||
|
||||
Valid methods are 'best' and 'last'.
|
||||
'''
|
||||
assert('init' not in method_order)
|
||||
load_or_init_graph_for_training(session, method_order, allow_drop_layers=False)
|
||||
|
@ -29,7 +29,7 @@ def fail(message, code=1):
|
||||
|
||||
def transcribe_file(audio_path, tlog_path):
|
||||
from deepspeech_training.train import create_model # pylint: disable=cyclic-import,import-outside-toplevel
|
||||
from deepspeech_training.util.checkpoints import load_or_init_graph
|
||||
from deepspeech_training.util.checkpoints import load_graph
|
||||
initialize_globals()
|
||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet)
|
||||
try:
|
||||
@ -54,7 +54,7 @@ def transcribe_file(audio_path, tlog_path):
|
||||
method_order = ['best', 'last']
|
||||
else:
|
||||
method_order = [FLAGS.load]
|
||||
load_or_init_graph(session, method_order)
|
||||
load_graph(session, method_order)
|
||||
session.run(iterator.make_initializer(data_set))
|
||||
transcripts = []
|
||||
while True:
|
||||
|
Loading…
x
Reference in New Issue
Block a user