195 lines
7.1 KiB
Python
195 lines
7.1 KiB
Python
import sys
|
|
|
|
import tensorflow.compat.v1 as tfv1
|
|
|
|
import tensorflow as tf
|
|
|
|
from .config import Config, log_error, log_info, log_warn
|
|
|
|
|
|
def _load_checkpoint(
|
|
session,
|
|
checkpoint_path,
|
|
allow_drop_layers,
|
|
allow_lr_init=True,
|
|
silent: bool = False,
|
|
):
|
|
# Load the checkpoint and put all variables into loading list
|
|
# we will exclude variables we do not wish to load and then
|
|
# we will initialize them instead
|
|
ckpt = tfv1.train.load_checkpoint(checkpoint_path)
|
|
vars_in_ckpt = frozenset(ckpt.get_variable_to_shape_map().keys())
|
|
load_vars = set(tfv1.global_variables())
|
|
init_vars = set()
|
|
|
|
# We explicitly allow the learning rate variable to be missing for backwards
|
|
# compatibility with older checkpoints.
|
|
lr_var = set(v for v in load_vars if v.op.name == "learning_rate")
|
|
if lr_var and (
|
|
"learning_rate" not in vars_in_ckpt
|
|
or (Config.force_initialize_learning_rate and allow_lr_init)
|
|
):
|
|
assert len(lr_var) <= 1
|
|
load_vars -= lr_var
|
|
init_vars |= lr_var
|
|
|
|
if Config.load_cudnn:
|
|
# Initialize training from a CuDNN RNN checkpoint
|
|
# Identify the variables which we cannot load, and set them
|
|
# for initialization
|
|
missing_vars = set()
|
|
for v in load_vars:
|
|
if v.op.name not in vars_in_ckpt:
|
|
log_warn("CUDNN variable not found: %s" % (v.op.name))
|
|
missing_vars.add(v)
|
|
init_vars.add(v)
|
|
|
|
load_vars -= init_vars
|
|
|
|
# Check that the only missing variables (i.e. those to be initialised)
|
|
# are the Adam moment tensors, if they aren't then we have an issue
|
|
missing_var_names = [v.op.name for v in missing_vars]
|
|
if any("Adam" not in v for v in missing_var_names):
|
|
log_error(
|
|
"Tried to load a CuDNN RNN checkpoint but there were "
|
|
"more missing variables than just the Adam moment "
|
|
"tensors. Missing variables: {}".format(missing_var_names)
|
|
)
|
|
sys.exit(1)
|
|
|
|
if allow_drop_layers and Config.drop_source_layers > 0:
|
|
# This transfer learning approach requires supplying
|
|
# the layers which we exclude from the source model.
|
|
# Say we want to exclude all layers except for the first one,
|
|
# then we are dropping five layers total, so: drop_source_layers=5
|
|
# If we want to use all layers from the source model except
|
|
# the last one, we use this: drop_source_layers=1
|
|
if Config.drop_source_layers >= 6:
|
|
log_warn(
|
|
"The checkpoint only has 6 layers, but you are trying to drop "
|
|
"all of them or more than all of them. Continuing and "
|
|
"dropping only 5 layers."
|
|
)
|
|
Config.drop_source_layers = 5
|
|
|
|
dropped_layers = ["2", "3", "lstm", "5", "6"][
|
|
-1 * int(Config.drop_source_layers) :
|
|
]
|
|
# Initialize all variables needed for DS, but not loaded from ckpt
|
|
for v in load_vars:
|
|
if any(layer in v.op.name for layer in dropped_layers):
|
|
init_vars.add(v)
|
|
load_vars -= init_vars
|
|
|
|
def maybe_log_info(*args, **kwargs):
|
|
if not silent:
|
|
log_info(*args, **kwargs)
|
|
|
|
for v in sorted(load_vars, key=lambda v: v.op.name):
|
|
maybe_log_info(f"Loading variable from checkpoint: {v.op.name}")
|
|
v.load(ckpt.get_tensor(v.op.name), session=session)
|
|
|
|
for v in sorted(init_vars, key=lambda v: v.op.name):
|
|
maybe_log_info("Initializing variable: %s" % (v.op.name))
|
|
session.run(v.initializer)
|
|
|
|
|
|
def _checkpoint_path_or_none(checkpoint_filename):
|
|
checkpoint = tfv1.train.get_checkpoint_state(
|
|
Config.load_checkpoint_dir, checkpoint_filename
|
|
)
|
|
if not checkpoint:
|
|
return None
|
|
return checkpoint.model_checkpoint_path
|
|
|
|
|
|
def _initialize_all_variables(session):
|
|
init_vars = tfv1.global_variables()
|
|
for v in init_vars:
|
|
session.run(v.initializer)
|
|
|
|
|
|
def _load_or_init_impl(
|
|
session, method_order, allow_drop_layers, allow_lr_init=True, silent: bool = False
|
|
):
|
|
def maybe_log_info(*args, **kwargs):
|
|
if not silent:
|
|
log_info(*args, **kwargs)
|
|
|
|
for method in method_order:
|
|
# Load best validating checkpoint, saved in checkpoint file 'best_dev_checkpoint'
|
|
if method == "best":
|
|
ckpt_path = _checkpoint_path_or_none("best_dev_checkpoint")
|
|
if ckpt_path:
|
|
maybe_log_info(
|
|
"Loading best validating checkpoint from {}".format(ckpt_path)
|
|
)
|
|
return _load_checkpoint(
|
|
session,
|
|
ckpt_path,
|
|
allow_drop_layers,
|
|
allow_lr_init=allow_lr_init,
|
|
silent=silent,
|
|
)
|
|
maybe_log_info("Could not find best validating checkpoint.")
|
|
|
|
# Load most recent checkpoint, saved in checkpoint file 'checkpoint'
|
|
elif method == "last":
|
|
ckpt_path = _checkpoint_path_or_none("checkpoint")
|
|
if ckpt_path:
|
|
maybe_log_info(
|
|
"Loading most recent checkpoint from {}".format(ckpt_path)
|
|
)
|
|
return _load_checkpoint(
|
|
session,
|
|
ckpt_path,
|
|
allow_drop_layers,
|
|
allow_lr_init=allow_lr_init,
|
|
silent=silent,
|
|
)
|
|
maybe_log_info("Could not find most recent checkpoint.")
|
|
|
|
# Initialize all variables
|
|
elif method == "init":
|
|
maybe_log_info("Initializing all variables.")
|
|
return _initialize_all_variables(session)
|
|
|
|
else:
|
|
log_error("Unknown initialization method: {}".format(method))
|
|
sys.exit(1)
|
|
|
|
log_error("All initialization methods failed ({}).".format(method_order))
|
|
sys.exit(1)
|
|
|
|
|
|
def reload_best_checkpoint(session):
|
|
_load_or_init_impl(session, ["best"], allow_drop_layers=False, allow_lr_init=False)
|
|
|
|
|
|
def load_or_init_graph_for_training(session, silent: bool = False):
|
|
"""
|
|
Load variables from checkpoint or initialize variables. By default this will
|
|
try to load the best validating checkpoint, then try the last checkpoint,
|
|
and finally initialize the weights from scratch. This can be overriden with
|
|
the `--load_train` flag. See its documentation for more info.
|
|
"""
|
|
if Config.load_train == "auto":
|
|
methods = ["best", "last", "init"]
|
|
else:
|
|
methods = [Config.load_train]
|
|
_load_or_init_impl(session, methods, allow_drop_layers=True, silent=silent)
|
|
|
|
|
|
def load_graph_for_evaluation(session, silent: bool = False):
|
|
"""
|
|
Load variables from checkpoint. Initialization is not allowed. By default
|
|
this will try to load the best validating checkpoint, then try the last
|
|
checkpoint. This can be overriden with the `--load_evaluate` flag. See its
|
|
documentation for more info.
|
|
"""
|
|
if Config.load_evaluate == "auto":
|
|
methods = ["best", "last"]
|
|
else:
|
|
methods = [Config.load_evaluate]
|
|
_load_or_init_impl(session, methods, allow_drop_layers=False, silent=silent)
|