Support batch-level Model.stop_training check in Model.fit
PiperOrigin-RevId: 333371130 Change-Id: I88382860c565277e8b61ac5bd389502f1122cb9e
This commit is contained in:
parent
938cc7bf9c
commit
181aba715d
@ -1468,7 +1468,7 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
|
||||
epochs=20)
|
||||
loss = history.history['loss']
|
||||
self.assertEqual(len(loss), 1)
|
||||
self.assertTrue(np.isnan(loss[0]))
|
||||
self.assertTrue(np.isnan(loss[0]) or np.isinf(loss[0]))
|
||||
|
||||
@unittest.skipIf(
|
||||
os.name == 'nt',
|
||||
@ -1769,6 +1769,28 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
|
||||
with self.assertRaisesRegexp(ValueError, 'New function '):
|
||||
model.predict(x, batch_size=2)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
def test_stop_training_batch_level(self):
|
||||
|
||||
class MyCallback(keras.callbacks.Callback):
|
||||
|
||||
def __init__(self):
|
||||
super(MyCallback, self).__init__()
|
||||
self.batch_counter = 0
|
||||
|
||||
def on_train_batch_end(self, batch, logs=None):
|
||||
self.batch_counter += 1
|
||||
if batch == 2:
|
||||
self.model.stop_training = True
|
||||
|
||||
model = keras.Sequential([keras.layers.Dense(1)])
|
||||
model.compile('sgd', 'mse')
|
||||
x, y = np.ones((10, 10)), np.ones((10, 1))
|
||||
my_cb = MyCallback()
|
||||
# Will run 5 batches if `stop_training` doesn't work.
|
||||
model.fit(x, y, batch_size=2, callbacks=[my_cb])
|
||||
self.assertEqual(my_cb.batch_counter, 3)
|
||||
|
||||
|
||||
# A summary that was emitted during a test. Fields:
|
||||
# logdir: str. The logdir of the FileWriter to which the summary was
|
||||
|
@ -1078,6 +1078,8 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||
logs = tmp_logs # No error, now safe to assign to logs.
|
||||
end_step = step + data_handler.step_increment
|
||||
callbacks.on_train_batch_end(end_step, logs)
|
||||
if self.stop_training:
|
||||
break
|
||||
|
||||
if logs is None:
|
||||
raise ValueError('Expect x to be a non-empty array or dataset.')
|
||||
|
Loading…
Reference in New Issue
Block a user