For EarlyStopping, restore the best epoch even if no epoch beats the baseline.

When both baseline and restore_best_weights are specified, we should restore the best epoch
according to the monitoring value even if no epoch improves on the baseline. Before this change,
we would error out in this case when attempting to restore the best weights.

PiperOrigin-RevId: 351401209
Change-Id: I2e143b1cedefedc6e2728fa822ec0654650359d7
This commit is contained in:
Matt Watson 2021-01-12 10:39:44 -08:00 committed by TensorFlower Gardener
parent c213508ff9
commit c5f8604fed
2 changed files with 41 additions and 16 deletions

View File

@ -1697,7 +1697,10 @@ class EarlyStopping(Callback):
restore_best_weights: Whether to restore model weights from
the epoch with the best value of the monitored quantity.
If False, the model weights obtained at the last step of
training are used.
training are used. An epoch will be restored regardless
of the performance relative to the `baseline`. If no epoch
improves on `baseline`, training will run for `patience`
epochs and restore weights from the best epoch in that set.
Example:
@ -1757,30 +1760,30 @@ class EarlyStopping(Callback):
# Allow instances to be re-used
self.wait = 0
self.stopped_epoch = 0
if self.baseline is not None:
self.best = self.baseline
else:
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
self.best_weights = None
def on_epoch_end(self, epoch, logs=None):
current = self.get_monitor_value(logs)
if current is None:
return
if self.monitor_op(current - self.min_delta, self.best):
self.wait += 1
if self._is_improvement(current, self.best):
self.best = current
self.wait = 0
if self.restore_best_weights:
self.best_weights = self.model.get_weights()
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
if self.restore_best_weights:
if self.verbose > 0:
print('Restoring model weights from the end of the best epoch.')
self.model.set_weights(self.best_weights)
# Only restart wait if we beat both the baseline and our previous best.
if self.baseline is None or self._is_improvement(current, self.baseline):
self.wait = 0
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
if self.restore_best_weights and self.best_weights is not None:
if self.verbose > 0:
print('Restoring model weights from the end of the best epoch.')
self.model.set_weights(self.best_weights)
def on_train_end(self, logs=None):
if self.stopped_epoch > 0 and self.verbose > 0:
@ -1795,6 +1798,9 @@ class EarlyStopping(Callback):
self.monitor, ','.join(list(logs.keys())))
return monitor_value
def _is_improvement(self, monitor_value, reference_value):
return self.monitor_op(monitor_value - self.min_delta, reference_value)
@keras_export('keras.callbacks.RemoteMonitor')
class RemoteMonitor(Callback):

View File

@ -1123,6 +1123,25 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
# so we end up at the epoch with the best weights, i.e. epoch 2
self.assertEqual(early_stop.model.get_weights(), 2)
# Check early stopping when no model beats the baseline.
early_stop = keras.callbacks.EarlyStopping(
monitor='val_loss', patience=5, baseline=0.5, restore_best_weights=True)
early_stop.model = DummyModel()
losses = [0.9, 0.8, 0.7, 0.71, 0.72, 0.73]
# The best configuration is in the epoch 2 (loss = 0.7000).
epochs_trained = 0
early_stop.on_train_begin()
for epoch in range(len(losses)):
epochs_trained += 1
early_stop.model.set_weight_to_epoch(epoch=epoch)
early_stop.on_epoch_end(epoch, logs={'val_loss': losses[epoch]})
if early_stop.model.stop_training:
break
# No epoch improves on the baseline, so we should train for only 5 epochs,
# and restore the second model.
self.assertEqual(epochs_trained, 5)
self.assertEqual(early_stop.model.get_weights(), 2)
def test_RemoteMonitor(self):
if requests is None:
self.skipTest('`requests` required to run this test')