Fix issue with callback timing check
PiperOrigin-RevId: 326483013 Change-Id: I9b9a20d8881ecfb44be5641690a753a8cd18c821
This commit is contained in:
parent
3d87c9d297
commit
c979f5a424
@ -242,7 +242,8 @@ class CallbackList(object):
|
||||
|
||||
# Performance check: Check batch hooks for slowness compared to batch time.
|
||||
# Only run check for custom callbacks (i.e. not present in this file).
|
||||
self._check_timing = self.__class__ not in globals()
|
||||
self._check_timing = any([cbk.__class__.__name__ not in globals()
|
||||
for cbk in self.callbacks])
|
||||
self._num_batches_for_timing_check = 5
|
||||
self._hook_times = {}
|
||||
self._batch_start_time = None
|
||||
@ -321,7 +322,7 @@ class CallbackList(object):
|
||||
avg_begin_hook_time = sum(self._hook_times[begin_hook_name]) / len(
|
||||
self._hook_times[begin_hook_name])
|
||||
|
||||
threshold_time = 1.5 * avg_batch_time
|
||||
threshold_time = 1.0 * avg_batch_time
|
||||
warning_msg = ('Callback method `{hook}` is slow compared to '
|
||||
'the batch time (batch time: {batch_time:.4f}s vs '
|
||||
'`{hook}` time: {hook_time:.4f}s). Check your callbacks.')
|
||||
|
@ -282,7 +282,7 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
|
||||
class SleepCallback(keras.callbacks.Callback):
|
||||
|
||||
def on_train_batch_end(self, batch, logs=None):
|
||||
time.sleep(1)
|
||||
time.sleep(0.1)
|
||||
|
||||
model = sequential.Sequential()
|
||||
model.add(keras.layers.Dense(1))
|
||||
@ -298,17 +298,17 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
|
||||
|
||||
with test.mock.patch.object(logging, 'warning', warning):
|
||||
model.fit(
|
||||
np.ones((20, 1), 'float32'),
|
||||
np.ones((20, 1), 'float32'),
|
||||
np.ones((16, 1), 'float32'),
|
||||
np.ones((16, 1), 'float32'),
|
||||
batch_size=3,
|
||||
epochs=10,
|
||||
epochs=1,
|
||||
callbacks=[SleepCallback()])
|
||||
warning_msg = ('Callback method `on_train_batch_end` is slow compared '
|
||||
'to the batch time')
|
||||
self.assertIn(warning_msg, '\n'.join(warning_messages))
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test__default_callbacks_no_warning(self):
|
||||
def test_default_callbacks_no_warning(self):
|
||||
# Test that without the callback no warning is raised
|
||||
model = sequential.Sequential()
|
||||
model.add(keras.layers.Dense(1))
|
||||
@ -324,10 +324,10 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
|
||||
|
||||
with test.mock.patch.object(logging, 'warning', warning):
|
||||
model.fit(
|
||||
np.ones((20, 1), 'float32'),
|
||||
np.ones((20, 1), 'float32'),
|
||||
np.ones((16, 1), 'float32'),
|
||||
np.ones((16, 1), 'float32'),
|
||||
batch_size=3,
|
||||
epochs=10)
|
||||
epochs=1)
|
||||
self.assertListEqual(warning_messages, [])
|
||||
|
||||
@keras_parameterized.run_with_all_model_types(exclude_models='functional')
|
||||
|
Loading…
Reference in New Issue
Block a user