Support batch-level Model.stop_training check in Model.fit
PiperOrigin-RevId: 333371130 Change-Id: I88382860c565277e8b61ac5bd389502f1122cb9e
This commit is contained in:
parent
938cc7bf9c
commit
181aba715d
@ -1468,7 +1468,7 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
|
|||||||
epochs=20)
|
epochs=20)
|
||||||
loss = history.history['loss']
|
loss = history.history['loss']
|
||||||
self.assertEqual(len(loss), 1)
|
self.assertEqual(len(loss), 1)
|
||||||
self.assertTrue(np.isnan(loss[0]))
|
self.assertTrue(np.isnan(loss[0]) or np.isinf(loss[0]))
|
||||||
|
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
os.name == 'nt',
|
os.name == 'nt',
|
||||||
@ -1769,6 +1769,28 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
|
|||||||
with self.assertRaisesRegexp(ValueError, 'New function '):
|
with self.assertRaisesRegexp(ValueError, 'New function '):
|
||||||
model.predict(x, batch_size=2)
|
model.predict(x, batch_size=2)
|
||||||
|
|
||||||
|
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||||
|
def test_stop_training_batch_level(self):
|
||||||
|
|
||||||
|
class MyCallback(keras.callbacks.Callback):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(MyCallback, self).__init__()
|
||||||
|
self.batch_counter = 0
|
||||||
|
|
||||||
|
def on_train_batch_end(self, batch, logs=None):
|
||||||
|
self.batch_counter += 1
|
||||||
|
if batch == 2:
|
||||||
|
self.model.stop_training = True
|
||||||
|
|
||||||
|
model = keras.Sequential([keras.layers.Dense(1)])
|
||||||
|
model.compile('sgd', 'mse')
|
||||||
|
x, y = np.ones((10, 10)), np.ones((10, 1))
|
||||||
|
my_cb = MyCallback()
|
||||||
|
# Will run 5 batches if `stop_training` doesn't work.
|
||||||
|
model.fit(x, y, batch_size=2, callbacks=[my_cb])
|
||||||
|
self.assertEqual(my_cb.batch_counter, 3)
|
||||||
|
|
||||||
|
|
||||||
# A summary that was emitted during a test. Fields:
|
# A summary that was emitted during a test. Fields:
|
||||||
# logdir: str. The logdir of the FileWriter to which the summary was
|
# logdir: str. The logdir of the FileWriter to which the summary was
|
||||||
|
@ -1078,6 +1078,8 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
|||||||
logs = tmp_logs # No error, now safe to assign to logs.
|
logs = tmp_logs # No error, now safe to assign to logs.
|
||||||
end_step = step + data_handler.step_increment
|
end_step = step + data_handler.step_increment
|
||||||
callbacks.on_train_batch_end(end_step, logs)
|
callbacks.on_train_batch_end(end_step, logs)
|
||||||
|
if self.stop_training:
|
||||||
|
break
|
||||||
|
|
||||||
if logs is None:
|
if logs is None:
|
||||||
raise ValueError('Expect x to be a non-empty array or dataset.')
|
raise ValueError('Expect x to be a non-empty array or dataset.')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user