Support TF logs in TerminateOnNaN callback
This commit is contained in:
parent
18445b0e39
commit
c7eba425b9
@ -884,10 +884,15 @@ class TerminateOnNaN(Callback):
|
|||||||
"""Callback that terminates training when a NaN loss is encountered.
|
"""Callback that terminates training when a NaN loss is encountered.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(TerminateOnNaN, self).__init__()
|
||||||
|
self._supports_tf_logs = True
|
||||||
|
|
||||||
def on_batch_end(self, batch, logs=None):
|
def on_batch_end(self, batch, logs=None):
|
||||||
logs = logs or {}
|
logs = logs or {}
|
||||||
loss = logs.get('loss')
|
loss = logs.get('loss')
|
||||||
if loss is not None:
|
if loss is not None:
|
||||||
|
loss = tf_utils.to_numpy_or_python_type(loss)
|
||||||
if np.isnan(loss) or np.isinf(loss):
|
if np.isnan(loss) or np.isinf(loss):
|
||||||
print('Batch %d: Invalid loss, terminating training' % (batch))
|
print('Batch %d: Invalid loss, terminating training' % (batch))
|
||||||
self.model.stop_training = True
|
self.model.stop_training = True
|
||||||
|
Loading…
Reference in New Issue
Block a user