diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 839243a5fb7..a78dfa78cee 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -896,10 +896,15 @@ class TerminateOnNaN(Callback): """Callback that terminates training when a NaN loss is encountered. """ + def __init__(self): + super(TerminateOnNaN, self).__init__() + self._supports_tf_logs = True + def on_batch_end(self, batch, logs=None): logs = logs or {} loss = logs.get('loss') if loss is not None: + loss = tf_utils.to_numpy_or_python_type(loss) if np.isnan(loss) or np.isinf(loss): print('Batch %d: Invalid loss, terminating training' % (batch)) self.model.stop_training = True