For EarlyStopping, restore the best epoch even if no epoch beats the baseline.
When both baseline and restore_best_weights are specified, we should restore the best epoch according to the monitoring value even if no epoch improves on the baseline. Before this change, we would error out in this case when attempting to restore the best weights. PiperOrigin-RevId: 351401209 Change-Id: I2e143b1cedefedc6e2728fa822ec0654650359d7
This commit is contained in:
parent
c213508ff9
commit
c5f8604fed
@ -1697,7 +1697,10 @@ class EarlyStopping(Callback):
|
||||
restore_best_weights: Whether to restore model weights from
|
||||
the epoch with the best value of the monitored quantity.
|
||||
If False, the model weights obtained at the last step of
|
||||
training are used.
|
||||
training are used. An epoch will be restored regardless
|
||||
of the performance relative to the `baseline`. If no epoch
|
||||
improves on `baseline`, training will run for `patience`
|
||||
epochs and restore weights from the best epoch in that set.
|
||||
|
||||
Example:
|
||||
|
||||
@ -1757,30 +1760,30 @@ class EarlyStopping(Callback):
|
||||
# Allow instances to be re-used
|
||||
self.wait = 0
|
||||
self.stopped_epoch = 0
|
||||
if self.baseline is not None:
|
||||
self.best = self.baseline
|
||||
else:
|
||||
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
|
||||
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
|
||||
self.best_weights = None
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
current = self.get_monitor_value(logs)
|
||||
if current is None:
|
||||
return
|
||||
if self.monitor_op(current - self.min_delta, self.best):
|
||||
|
||||
self.wait += 1
|
||||
if self._is_improvement(current, self.best):
|
||||
self.best = current
|
||||
self.wait = 0
|
||||
if self.restore_best_weights:
|
||||
self.best_weights = self.model.get_weights()
|
||||
else:
|
||||
self.wait += 1
|
||||
if self.wait >= self.patience:
|
||||
self.stopped_epoch = epoch
|
||||
self.model.stop_training = True
|
||||
if self.restore_best_weights:
|
||||
if self.verbose > 0:
|
||||
print('Restoring model weights from the end of the best epoch.')
|
||||
self.model.set_weights(self.best_weights)
|
||||
# Only restart wait if we beat both the baseline and our previous best.
|
||||
if self.baseline is None or self._is_improvement(current, self.baseline):
|
||||
self.wait = 0
|
||||
|
||||
if self.wait >= self.patience:
|
||||
self.stopped_epoch = epoch
|
||||
self.model.stop_training = True
|
||||
if self.restore_best_weights and self.best_weights is not None:
|
||||
if self.verbose > 0:
|
||||
print('Restoring model weights from the end of the best epoch.')
|
||||
self.model.set_weights(self.best_weights)
|
||||
|
||||
def on_train_end(self, logs=None):
|
||||
if self.stopped_epoch > 0 and self.verbose > 0:
|
||||
@ -1795,6 +1798,9 @@ class EarlyStopping(Callback):
|
||||
self.monitor, ','.join(list(logs.keys())))
|
||||
return monitor_value
|
||||
|
||||
def _is_improvement(self, monitor_value, reference_value):
|
||||
return self.monitor_op(monitor_value - self.min_delta, reference_value)
|
||||
|
||||
|
||||
@keras_export('keras.callbacks.RemoteMonitor')
|
||||
class RemoteMonitor(Callback):
|
||||
|
@ -1123,6 +1123,25 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
|
||||
# so we end up at the epoch with the best weights, i.e. epoch 2
|
||||
self.assertEqual(early_stop.model.get_weights(), 2)
|
||||
|
||||
# Check early stopping when no model beats the baseline.
|
||||
early_stop = keras.callbacks.EarlyStopping(
|
||||
monitor='val_loss', patience=5, baseline=0.5, restore_best_weights=True)
|
||||
early_stop.model = DummyModel()
|
||||
losses = [0.9, 0.8, 0.7, 0.71, 0.72, 0.73]
|
||||
# The best configuration is in the epoch 2 (loss = 0.7000).
|
||||
epochs_trained = 0
|
||||
early_stop.on_train_begin()
|
||||
for epoch in range(len(losses)):
|
||||
epochs_trained += 1
|
||||
early_stop.model.set_weight_to_epoch(epoch=epoch)
|
||||
early_stop.on_epoch_end(epoch, logs={'val_loss': losses[epoch]})
|
||||
if early_stop.model.stop_training:
|
||||
break
|
||||
# No epoch improves on the baseline, so we should train for only 5 epochs,
|
||||
# and restore the second model.
|
||||
self.assertEqual(epochs_trained, 5)
|
||||
self.assertEqual(early_stop.model.get_weights(), 2)
|
||||
|
||||
def test_RemoteMonitor(self):
|
||||
if requests is None:
|
||||
self.skipTest('`requests` required to run this test')
|
||||
|
Loading…
Reference in New Issue
Block a user