Merge pull request #3261 from mozilla/reload-weights-plateau-tests

Tests #3245 Reload weights after plateau
This commit is contained in:
Reuben Morais 2020-08-20 09:48:02 +02:00 committed by GitHub
commit d14c2b2e2d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 9 additions and 1 deletions

1
.gitignore vendored
View File

@ -37,3 +37,4 @@ Dockerfile.train
doc/xml-c
doc/xml-java
doc/xml-dotnet
convert_graphdef_memmapped_format

View File

@ -29,7 +29,7 @@ from mvs_ctcdecoder import ctc_beam_search_decoder, Scorer
from .evaluate import evaluate
from six.moves import zip, range
from .util.config import Config, initialize_globals
from .util.checkpoints import load_or_init_graph_for_training, load_graph_for_evaluation
from .util.checkpoints import load_or_init_graph_for_training, load_graph_for_evaluation, reload_best_checkpoint
from .util.evaluate_tools import save_samples_json
from .util.feeding import create_dataset, audio_to_features, audiofile_to_features
from .util.flags import create_flags, FLAGS
@ -645,6 +645,9 @@ def train():
log_info('Encountered a plateau, reducing learning rate to {}'.format(
current_learning_rate))
# Reload checkpoint that we use the best_dev weights again
reload_best_checkpoint(session)
if FLAGS.metrics_files:
# Read only metrics, not affecting best validation loss tracking
for source, init_op in zip(metrics_sources, metrics_init_ops):

View File

@ -118,6 +118,10 @@ def _load_or_init_impl(session, method_order, allow_drop_layers):
sys.exit(1)
def reload_best_checkpoint(session):
_load_or_init_impl(session, ['best'], allow_drop_layers=False)
def load_or_init_graph_for_training(session):
'''
Load variables from checkpoint or initialize variables. By default this will