From c7eba425b96cf5f71b3f6511c171cebc63a36ab9 Mon Sep 17 00:00:00 2001 From: Lukas Geiger Date: Tue, 21 Jul 2020 10:21:26 +0200 Subject: [PATCH] Support TF logs in TerminateOnNaN callback --- tensorflow/python/keras/callbacks.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 1b8c9b085ab..3569cbd37a5 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -884,10 +884,15 @@ class TerminateOnNaN(Callback): """Callback that terminates training when a NaN loss is encountered. """ + def __init__(self): + super(TerminateOnNaN, self).__init__() + self._supports_tf_logs = True + def on_batch_end(self, batch, logs=None): logs = logs or {} loss = logs.get('loss') if loss is not None: + loss = tf_utils.to_numpy_or_python_type(loss) if np.isnan(loss) or np.isinf(loss): print('Batch %d: Invalid loss, terminating training' % (batch)) self.model.stop_training = True