Merge pull request #41586 from lgeiger:nan-callback-tf-logs

PiperOrigin-RevId: 328676757
Change-Id: Id2d251ec5f864b9926807721dfeefcd515772797
This commit is contained in:
TensorFlower Gardener 2020-08-26 22:38:23 -07:00
commit e2db60c9d9

View File

@ -896,10 +896,15 @@ class TerminateOnNaN(Callback):
"""Callback that terminates training when a NaN loss is encountered.
"""
def __init__(self):
super(TerminateOnNaN, self).__init__()
self._supports_tf_logs = True
def on_batch_end(self, batch, logs=None):
logs = logs or {}
loss = logs.get('loss')
if loss is not None:
loss = tf_utils.to_numpy_or_python_type(loss)
if np.isnan(loss) or np.isinf(loss):
print('Batch %d: Invalid loss, terminating training' % (batch))
self.model.stop_training = True