Fix EarlyStopping callback when no progress is ever made.
The error was: File "tensorflow/python/keras/callbacks.py", line 1782, in on_epoch_end self.model.set_weights(self.best_weights) File "tensorflow/python/keras/engine/base_layer.py", line 1850, in set_weights if expected_num_weights != len(weights): TypeError: object of type 'NoneType' has no len() We can't put the initialization inside "on_begin" functions because not all weights are initialized. PiperOrigin-RevId: 351416566 Change-Id: I800eb91f37afd25ec3272b9405415d42375f3204
This commit is contained in:
parent
5483ab817e
commit
abcabe1266
@ -1767,6 +1767,9 @@ class EarlyStopping(Callback):
|
||||
current = self.get_monitor_value(logs)
|
||||
if current is None:
|
||||
return
|
||||
if self.restore_best_weights and self.best_weights is None:
|
||||
# Restore the weights after first epoch if no progress is ever made.
|
||||
self.best_weights = self.model.get_weights()
|
||||
|
||||
self.wait += 1
|
||||
if self._is_improvement(current, self.best):
|
||||
|
Loading…
Reference in New Issue
Block a user