Fix EarlyStopping callback when no progress is ever made.

The error was:
File "tensorflow/python/keras/callbacks.py", line 1782, in on_epoch_end
    self.model.set_weights(self.best_weights)
File "tensorflow/python/keras/engine/base_layer.py", line 1850, in set_weights
    if expected_num_weights != len(weights):
TypeError: object of type 'NoneType' has no len()

We can't put the initialization inside "on_begin" functions because not all weights are initialized.

PiperOrigin-RevId: 351416566
Change-Id: I800eb91f37afd25ec3272b9405415d42375f3204
This commit is contained in:
A. Unique TensorFlower 2021-01-12 11:49:17 -08:00 committed by TensorFlower Gardener
parent 5483ab817e
commit abcabe1266

View File

@ -1767,6 +1767,9 @@ class EarlyStopping(Callback):
current = self.get_monitor_value(logs)
if current is None:
return
if self.restore_best_weights and self.best_weights is None:
# Restore the weights after first epoch if no progress is ever made.
self.best_weights = self.model.get_weights()
self.wait += 1
if self._is_improvement(current, self.best):