Deprecate ModelCheckpoint.__init__'s load_weights_on_restart argument and provide a warning message if used.

PiperOrigin-RevId: 251711208
This commit is contained in:
Rick Chao 2019-06-05 13:51:38 -07:00 committed by TensorFlower Gardener
parent 725725250f
commit 6811026756
3 changed files with 13 additions and 14 deletions

View File

@ -819,14 +819,6 @@ class ModelCheckpoint(Callback):
monitored metric may potentially be less reliable (it could reflect as
little as 1 batch, since the metrics get reset every epoch). Defaults to
`'epoch'`
load_weights_on_restart: Whether the training should restore the model. If
True, the model will attempt to load the checkpoint file from `filepath`
at the start of `model.fit()`. This saves the need of manually calling
`model.load_weights()` before `model.fit(). In multi-worker distributed
training, this provides fault-tolerance and loads the model
automatically upon recovery of workers. The callback gives up loading if
the filepath does not exist, and raises ValueError if format does not
match. Defaults to False.
**kwargs: Additional arguments for backwards compatibility. Possible key
is `period`.
"""
@ -839,7 +831,6 @@ class ModelCheckpoint(Callback):
save_weights_only=False,
mode='auto',
save_freq='epoch',
load_weights_on_restart=False,
**kwargs):
super(ModelCheckpoint, self).__init__()
self.monitor = monitor
@ -848,10 +839,20 @@ class ModelCheckpoint(Callback):
self.save_best_only = save_best_only
self.save_weights_only = save_weights_only
self.save_freq = save_freq
self.load_weights_on_restart = load_weights_on_restart
self.epochs_since_last_save = 0
self._samples_seen_since_last_saving = 0
# Deprecated field `load_weights_on_restart` is for loading the checkpoint
# file from `filepath` at the start of `model.fit()`
# TODO(rchao): Remove the arg during next breaking release.
if 'load_weights_on_restart' in kwargs:
self.load_weights_on_restart = kwargs['load_weights_on_restart']
logging.warning('`load_weights_on_restart` argument is deprecated. '
'Please use `model.load_weights()` for loading weights '
'before the start of `model.fit()`.')
else:
self.load_weights_on_restart = False
# Deprecated field `period` is for the number of epochs between which
# the model is saved.
if 'period' in kwargs:
@ -912,8 +913,6 @@ class ModelCheckpoint(Callback):
# If this is not multi worker training, restoring is not needed, or
# restoring failed, check if it should load weights on restart.
# TODO(rchao): Also restore the epoch in single-worker training when
# `self.load_weights_on_restart=True`.
if self.load_weights_on_restart:
# In multi worker training, it only should if `experimental_should_init`
# is True.

View File

@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'save_freq\', \'load_weights_on_restart\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'epoch\', \'False\'], "
argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'save_freq\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'epoch\'], "
}
member_method {
name: "on_batch_begin"

View File

@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'save_freq\', \'load_weights_on_restart\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'epoch\', \'False\'], "
argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'save_freq\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'epoch\'], "
}
member_method {
name: "on_batch_begin"