Set model.history after on_epoch_end
in History
class.
PiperOrigin-RevId: 297267664 Change-Id: I0cf97412913bdfe3599d74f162d172f02e7ceac9
This commit is contained in:
parent
001037c469
commit
8c97290ba3
tensorflow/python/keras
@ -894,9 +894,12 @@ class History(Callback):
|
||||
gets returned by the `fit` method of models.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(History, self).__init__()
|
||||
self.history = {}
|
||||
|
||||
def on_train_begin(self, logs=None):
|
||||
self.epoch = []
|
||||
self.history = {}
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
logs = logs or {}
|
||||
@ -904,6 +907,10 @@ class History(Callback):
|
||||
for k, v in logs.items():
|
||||
self.history.setdefault(k, []).append(v)
|
||||
|
||||
# Set the history attribute on the model after the epoch ends. This will
|
||||
# make sure that the state which is set is the latest one.
|
||||
self.model.history = self
|
||||
|
||||
|
||||
@keras_export('keras.callbacks.ModelCheckpoint')
|
||||
class ModelCheckpoint(Callback):
|
||||
|
@ -174,6 +174,7 @@ class Model(network.Network, version_utils.ModelVersionSelector):
|
||||
|
||||
# Fault-tolerance handler. Set in `ModelCheckpoint`.
|
||||
self._training_state = None
|
||||
self.history = None
|
||||
|
||||
def get_weights(self):
|
||||
"""Retrieves the weights of the model.
|
||||
|
Loading…
Reference in New Issue
Block a user