diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 09acb5f7957..38ec1ffe7fd 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -819,14 +819,6 @@ class ModelCheckpoint(Callback): monitored metric may potentially be less reliable (it could reflect as little as 1 batch, since the metrics get reset every epoch). Defaults to `'epoch'` - load_weights_on_restart: Whether the training should restore the model. If - True, the model will attempt to load the checkpoint file from `filepath` - at the start of `model.fit()`. This saves the need of manually calling - `model.load_weights()` before `model.fit(). In multi-worker distributed - training, this provides fault-tolerance and loads the model - automatically upon recovery of workers. The callback gives up loading if - the filepath does not exist, and raises ValueError if format does not - match. Defaults to False. **kwargs: Additional arguments for backwards compatibility. Possible key is `period`. """ @@ -839,7 +831,6 @@ class ModelCheckpoint(Callback): save_weights_only=False, mode='auto', save_freq='epoch', - load_weights_on_restart=False, **kwargs): super(ModelCheckpoint, self).__init__() self.monitor = monitor @@ -848,10 +839,20 @@ class ModelCheckpoint(Callback): self.save_best_only = save_best_only self.save_weights_only = save_weights_only self.save_freq = save_freq - self.load_weights_on_restart = load_weights_on_restart self.epochs_since_last_save = 0 self._samples_seen_since_last_saving = 0 + # Deprecated field `load_weights_on_restart` is for loading the checkpoint + # file from `filepath` at the start of `model.fit()` + # TODO(rchao): Remove the arg during next breaking release. + if 'load_weights_on_restart' in kwargs: + self.load_weights_on_restart = kwargs['load_weights_on_restart'] + logging.warning('`load_weights_on_restart` argument is deprecated. ' + 'Please use `model.load_weights()` for loading weights ' + 'before the start of `model.fit()`.') + else: + self.load_weights_on_restart = False + # Deprecated field `period` is for the number of epochs between which # the model is saved. if 'period' in kwargs: @@ -912,8 +913,6 @@ class ModelCheckpoint(Callback): # If this is not multi worker training, restoring is not needed, or # restoring failed, check if it should load weights on restart. - # TODO(rchao): Also restore the epoch in single-worker training when - # `self.load_weights_on_restart=True`. if self.load_weights_on_restart: # In multi worker training, it only should if `experimental_should_init` # is True. diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-model-checkpoint.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-model-checkpoint.pbtxt index 57ce0cb3e85..5fb646e1c63 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-model-checkpoint.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-model-checkpoint.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'save_freq\', \'load_weights_on_restart\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'epoch\', \'False\'], " + argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'save_freq\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'epoch\'], " } member_method { name: "on_batch_begin" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-model-checkpoint.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-model-checkpoint.pbtxt index 57ce0cb3e85..5fb646e1c63 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-model-checkpoint.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-model-checkpoint.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'save_freq\', \'load_weights_on_restart\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'epoch\', \'False\'], " + argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'save_freq\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'epoch\'], " } member_method { name: "on_batch_begin"