Support TF logs in TerminateOnNaN callback

This commit is contained in:
Lukas Geiger 2020-07-21 10:21:26 +02:00
parent 18445b0e39
commit c7eba425b9

View File

@ -884,10 +884,15 @@ class TerminateOnNaN(Callback):
"""Callback that terminates training when a NaN loss is encountered.
"""
def __init__(self):
super(TerminateOnNaN, self).__init__()
self._supports_tf_logs = True
def on_batch_end(self, batch, logs=None):
logs = logs or {}
loss = logs.get('loss')
if loss is not None:
loss = tf_utils.to_numpy_or_python_type(loss)
if np.isnan(loss) or np.isinf(loss):
print('Batch %d: Invalid loss, terminating training' % (batch))
self.model.stop_training = True