Reload graph with extra function.

This commit is contained in:
Daniel 2020-08-19 16:22:02 +02:00 committed by Reuben Morais
parent 4cf7a012a3
commit 420ba808c8
2 changed files with 8 additions and 4 deletions

View File

@ -29,7 +29,7 @@ from mvs_ctcdecoder import ctc_beam_search_decoder, Scorer
from .evaluate import evaluate from .evaluate import evaluate
from six.moves import zip, range from six.moves import zip, range
from .util.config import Config, initialize_globals from .util.config import Config, initialize_globals
from .util.checkpoints import load_or_init_graph_for_training, load_graph_for_evaluation from .util.checkpoints import load_or_init_graph_for_training, load_graph_for_evaluation, reload_best_checkpoint
from .util.evaluate_tools import save_samples_json from .util.evaluate_tools import save_samples_json
from .util.feeding import create_dataset, audio_to_features, audiofile_to_features from .util.feeding import create_dataset, audio_to_features, audiofile_to_features
from .util.flags import create_flags, FLAGS from .util.flags import create_flags, FLAGS
@ -646,7 +646,7 @@ def train():
current_learning_rate)) current_learning_rate))
# Reload checkpoint that we use the best_dev weights again # Reload checkpoint that we use the best_dev weights again
load_or_init_graph_for_training(session, allow_drop_layers=False) reload_best_checkpoint(session)
if FLAGS.metrics_files: if FLAGS.metrics_files:
# Read only metrics, not affecting best validation loss tracking # Read only metrics, not affecting best validation loss tracking

View File

@ -118,7 +118,11 @@ def _load_or_init_impl(session, method_order, allow_drop_layers):
sys.exit(1) sys.exit(1)
def load_or_init_graph_for_training(session, allow_drop_layers=True): def reload_best_checkpoint(session):
_load_or_init_impl(session, ['best'], allow_drop_layers=False)
def load_or_init_graph_for_training(session):
''' '''
Load variables from checkpoint or initialize variables. By default this will Load variables from checkpoint or initialize variables. By default this will
try to load the best validating checkpoint, then try the last checkpoint, try to load the best validating checkpoint, then try the last checkpoint,
@ -129,7 +133,7 @@ def load_or_init_graph_for_training(session, allow_drop_layers=True):
methods = ['best', 'last', 'init'] methods = ['best', 'last', 'init']
else: else:
methods = [FLAGS.load_train] methods = [FLAGS.load_train]
_load_or_init_impl(session, methods, allow_drop_layers=allow_drop_layers) _load_or_init_impl(session, methods, allow_drop_layers=True)
def load_graph_for_evaluation(session): def load_graph_for_evaluation(session):