Deprecate ModelCheckpoint.__init__
's load_weights_on_restart
argument and provide a warning message if used.
PiperOrigin-RevId: 251711208
This commit is contained in:
parent
725725250f
commit
6811026756
@ -819,14 +819,6 @@ class ModelCheckpoint(Callback):
|
|||||||
monitored metric may potentially be less reliable (it could reflect as
|
monitored metric may potentially be less reliable (it could reflect as
|
||||||
little as 1 batch, since the metrics get reset every epoch). Defaults to
|
little as 1 batch, since the metrics get reset every epoch). Defaults to
|
||||||
`'epoch'`
|
`'epoch'`
|
||||||
load_weights_on_restart: Whether the training should restore the model. If
|
|
||||||
True, the model will attempt to load the checkpoint file from `filepath`
|
|
||||||
at the start of `model.fit()`. This saves the need of manually calling
|
|
||||||
`model.load_weights()` before `model.fit(). In multi-worker distributed
|
|
||||||
training, this provides fault-tolerance and loads the model
|
|
||||||
automatically upon recovery of workers. The callback gives up loading if
|
|
||||||
the filepath does not exist, and raises ValueError if format does not
|
|
||||||
match. Defaults to False.
|
|
||||||
**kwargs: Additional arguments for backwards compatibility. Possible key
|
**kwargs: Additional arguments for backwards compatibility. Possible key
|
||||||
is `period`.
|
is `period`.
|
||||||
"""
|
"""
|
||||||
@ -839,7 +831,6 @@ class ModelCheckpoint(Callback):
|
|||||||
save_weights_only=False,
|
save_weights_only=False,
|
||||||
mode='auto',
|
mode='auto',
|
||||||
save_freq='epoch',
|
save_freq='epoch',
|
||||||
load_weights_on_restart=False,
|
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super(ModelCheckpoint, self).__init__()
|
super(ModelCheckpoint, self).__init__()
|
||||||
self.monitor = monitor
|
self.monitor = monitor
|
||||||
@ -848,10 +839,20 @@ class ModelCheckpoint(Callback):
|
|||||||
self.save_best_only = save_best_only
|
self.save_best_only = save_best_only
|
||||||
self.save_weights_only = save_weights_only
|
self.save_weights_only = save_weights_only
|
||||||
self.save_freq = save_freq
|
self.save_freq = save_freq
|
||||||
self.load_weights_on_restart = load_weights_on_restart
|
|
||||||
self.epochs_since_last_save = 0
|
self.epochs_since_last_save = 0
|
||||||
self._samples_seen_since_last_saving = 0
|
self._samples_seen_since_last_saving = 0
|
||||||
|
|
||||||
|
# Deprecated field `load_weights_on_restart` is for loading the checkpoint
|
||||||
|
# file from `filepath` at the start of `model.fit()`
|
||||||
|
# TODO(rchao): Remove the arg during next breaking release.
|
||||||
|
if 'load_weights_on_restart' in kwargs:
|
||||||
|
self.load_weights_on_restart = kwargs['load_weights_on_restart']
|
||||||
|
logging.warning('`load_weights_on_restart` argument is deprecated. '
|
||||||
|
'Please use `model.load_weights()` for loading weights '
|
||||||
|
'before the start of `model.fit()`.')
|
||||||
|
else:
|
||||||
|
self.load_weights_on_restart = False
|
||||||
|
|
||||||
# Deprecated field `period` is for the number of epochs between which
|
# Deprecated field `period` is for the number of epochs between which
|
||||||
# the model is saved.
|
# the model is saved.
|
||||||
if 'period' in kwargs:
|
if 'period' in kwargs:
|
||||||
@ -912,8 +913,6 @@ class ModelCheckpoint(Callback):
|
|||||||
|
|
||||||
# If this is not multi worker training, restoring is not needed, or
|
# If this is not multi worker training, restoring is not needed, or
|
||||||
# restoring failed, check if it should load weights on restart.
|
# restoring failed, check if it should load weights on restart.
|
||||||
# TODO(rchao): Also restore the epoch in single-worker training when
|
|
||||||
# `self.load_weights_on_restart=True`.
|
|
||||||
if self.load_weights_on_restart:
|
if self.load_weights_on_restart:
|
||||||
# In multi worker training, it only should if `experimental_should_init`
|
# In multi worker training, it only should if `experimental_should_init`
|
||||||
# is True.
|
# is True.
|
||||||
|
@ -5,7 +5,7 @@ tf_class {
|
|||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'save_freq\', \'load_weights_on_restart\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'epoch\', \'False\'], "
|
argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'save_freq\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'epoch\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "on_batch_begin"
|
name: "on_batch_begin"
|
||||||
|
@ -5,7 +5,7 @@ tf_class {
|
|||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'save_freq\', \'load_weights_on_restart\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'epoch\', \'False\'], "
|
argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'save_freq\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'epoch\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "on_batch_begin"
|
name: "on_batch_begin"
|
||||||
|
Loading…
Reference in New Issue
Block a user