Merge pull request #41586 from lgeiger:nan-callback-tf-logs
PiperOrigin-RevId: 328676757 Change-Id: Id2d251ec5f864b9926807721dfeefcd515772797
This commit is contained in:
commit
e2db60c9d9
@ -896,10 +896,15 @@ class TerminateOnNaN(Callback):
|
||||
"""Callback that terminates training when a NaN loss is encountered.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(TerminateOnNaN, self).__init__()
|
||||
self._supports_tf_logs = True
|
||||
|
||||
def on_batch_end(self, batch, logs=None):
|
||||
logs = logs or {}
|
||||
loss = logs.get('loss')
|
||||
if loss is not None:
|
||||
loss = tf_utils.to_numpy_or_python_type(loss)
|
||||
if np.isnan(loss) or np.isinf(loss):
|
||||
print('Batch %d: Invalid loss, terminating training' % (batch))
|
||||
self.model.stop_training = True
|
||||
|
Loading…
Reference in New Issue
Block a user