Support batch-level Model.stop_training check in Model.fit

PiperOrigin-RevId: 333371130
Change-Id: I88382860c565277e8b61ac5bd389502f1122cb9e
This commit is contained in:
Thomas O'Malley 2020-09-23 14:17:25 -07:00 committed by TensorFlower Gardener
parent 938cc7bf9c
commit 181aba715d
2 changed files with 25 additions and 1 deletions

View File

@ -1468,7 +1468,7 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
epochs=20)
loss = history.history['loss']
self.assertEqual(len(loss), 1)
self.assertTrue(np.isnan(loss[0]))
self.assertTrue(np.isnan(loss[0]) or np.isinf(loss[0]))
@unittest.skipIf(
os.name == 'nt',
@ -1769,6 +1769,28 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
with self.assertRaisesRegexp(ValueError, 'New function '):
model.predict(x, batch_size=2)
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
def test_stop_training_batch_level(self):
class MyCallback(keras.callbacks.Callback):
def __init__(self):
super(MyCallback, self).__init__()
self.batch_counter = 0
def on_train_batch_end(self, batch, logs=None):
self.batch_counter += 1
if batch == 2:
self.model.stop_training = True
model = keras.Sequential([keras.layers.Dense(1)])
model.compile('sgd', 'mse')
x, y = np.ones((10, 10)), np.ones((10, 1))
my_cb = MyCallback()
# Will run 5 batches if `stop_training` doesn't work.
model.fit(x, y, batch_size=2, callbacks=[my_cb])
self.assertEqual(my_cb.batch_counter, 3)
# A summary that was emitted during a test. Fields:
# logdir: str. The logdir of the FileWriter to which the summary was

View File

@ -1078,6 +1078,8 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
logs = tmp_logs # No error, now safe to assign to logs.
end_step = step + data_handler.step_increment
callbacks.on_train_batch_end(end_step, logs)
if self.stop_training:
break
if logs is None:
raise ValueError('Expect x to be a non-empty array or dataset.')