Set model.history after on_epoch_end in History class.

PiperOrigin-RevId: 297267664
Change-Id: I0cf97412913bdfe3599d74f162d172f02e7ceac9
This commit is contained in:
Yash Katariya 2020-02-25 20:14:49 -08:00 committed by TensorFlower Gardener
parent 001037c469
commit 8c97290ba3
2 changed files with 9 additions and 1 deletions
tensorflow/python/keras

View File

@ -894,9 +894,12 @@ class History(Callback):
gets returned by the `fit` method of models.
"""
def __init__(self):
super(History, self).__init__()
self.history = {}
def on_train_begin(self, logs=None):
self.epoch = []
self.history = {}
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
@ -904,6 +907,10 @@ class History(Callback):
for k, v in logs.items():
self.history.setdefault(k, []).append(v)
# Set the history attribute on the model after the epoch ends. This will
# make sure that the state which is set is the latest one.
self.model.history = self
@keras_export('keras.callbacks.ModelCheckpoint')
class ModelCheckpoint(Callback):

View File

@ -174,6 +174,7 @@ class Model(network.Network, version_utils.ModelVersionSelector):
# Fault-tolerance handler. Set in `ModelCheckpoint`.
self._training_state = None
self.history = None
def get_weights(self):
"""Retrieves the weights of the model.