Reload graph with extra function.
This commit is contained in:
parent
4cf7a012a3
commit
420ba808c8
@ -29,7 +29,7 @@ from mvs_ctcdecoder import ctc_beam_search_decoder, Scorer
|
|||||||
from .evaluate import evaluate
|
from .evaluate import evaluate
|
||||||
from six.moves import zip, range
|
from six.moves import zip, range
|
||||||
from .util.config import Config, initialize_globals
|
from .util.config import Config, initialize_globals
|
||||||
from .util.checkpoints import load_or_init_graph_for_training, load_graph_for_evaluation
|
from .util.checkpoints import load_or_init_graph_for_training, load_graph_for_evaluation, reload_best_checkpoint
|
||||||
from .util.evaluate_tools import save_samples_json
|
from .util.evaluate_tools import save_samples_json
|
||||||
from .util.feeding import create_dataset, audio_to_features, audiofile_to_features
|
from .util.feeding import create_dataset, audio_to_features, audiofile_to_features
|
||||||
from .util.flags import create_flags, FLAGS
|
from .util.flags import create_flags, FLAGS
|
||||||
@ -646,7 +646,7 @@ def train():
|
|||||||
current_learning_rate))
|
current_learning_rate))
|
||||||
|
|
||||||
# Reload checkpoint that we use the best_dev weights again
|
# Reload checkpoint that we use the best_dev weights again
|
||||||
load_or_init_graph_for_training(session, allow_drop_layers=False)
|
reload_best_checkpoint(session)
|
||||||
|
|
||||||
if FLAGS.metrics_files:
|
if FLAGS.metrics_files:
|
||||||
# Read only metrics, not affecting best validation loss tracking
|
# Read only metrics, not affecting best validation loss tracking
|
||||||
|
|||||||
@ -118,7 +118,11 @@ def _load_or_init_impl(session, method_order, allow_drop_layers):
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def load_or_init_graph_for_training(session, allow_drop_layers=True):
|
def reload_best_checkpoint(session):
|
||||||
|
_load_or_init_impl(session, ['best'], allow_drop_layers=False)
|
||||||
|
|
||||||
|
|
||||||
|
def load_or_init_graph_for_training(session):
|
||||||
'''
|
'''
|
||||||
Load variables from checkpoint or initialize variables. By default this will
|
Load variables from checkpoint or initialize variables. By default this will
|
||||||
try to load the best validating checkpoint, then try the last checkpoint,
|
try to load the best validating checkpoint, then try the last checkpoint,
|
||||||
@ -129,7 +133,7 @@ def load_or_init_graph_for_training(session, allow_drop_layers=True):
|
|||||||
methods = ['best', 'last', 'init']
|
methods = ['best', 'last', 'init']
|
||||||
else:
|
else:
|
||||||
methods = [FLAGS.load_train]
|
methods = [FLAGS.load_train]
|
||||||
_load_or_init_impl(session, methods, allow_drop_layers=allow_drop_layers)
|
_load_or_init_impl(session, methods, allow_drop_layers=True)
|
||||||
|
|
||||||
|
|
||||||
def load_graph_for_evaluation(session):
|
def load_graph_for_evaluation(session):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user