Fix issue with callback timing check
PiperOrigin-RevId: 326483013 Change-Id: I9b9a20d8881ecfb44be5641690a753a8cd18c821
This commit is contained in:
parent
3d87c9d297
commit
c979f5a424
@ -242,7 +242,8 @@ class CallbackList(object):
|
|||||||
|
|
||||||
# Performance check: Check batch hooks for slowness compared to batch time.
|
# Performance check: Check batch hooks for slowness compared to batch time.
|
||||||
# Only run check for custom callbacks (i.e. not present in this file).
|
# Only run check for custom callbacks (i.e. not present in this file).
|
||||||
self._check_timing = self.__class__ not in globals()
|
self._check_timing = any([cbk.__class__.__name__ not in globals()
|
||||||
|
for cbk in self.callbacks])
|
||||||
self._num_batches_for_timing_check = 5
|
self._num_batches_for_timing_check = 5
|
||||||
self._hook_times = {}
|
self._hook_times = {}
|
||||||
self._batch_start_time = None
|
self._batch_start_time = None
|
||||||
@ -321,7 +322,7 @@ class CallbackList(object):
|
|||||||
avg_begin_hook_time = sum(self._hook_times[begin_hook_name]) / len(
|
avg_begin_hook_time = sum(self._hook_times[begin_hook_name]) / len(
|
||||||
self._hook_times[begin_hook_name])
|
self._hook_times[begin_hook_name])
|
||||||
|
|
||||||
threshold_time = 1.5 * avg_batch_time
|
threshold_time = 1.0 * avg_batch_time
|
||||||
warning_msg = ('Callback method `{hook}` is slow compared to '
|
warning_msg = ('Callback method `{hook}` is slow compared to '
|
||||||
'the batch time (batch time: {batch_time:.4f}s vs '
|
'the batch time (batch time: {batch_time:.4f}s vs '
|
||||||
'`{hook}` time: {hook_time:.4f}s). Check your callbacks.')
|
'`{hook}` time: {hook_time:.4f}s). Check your callbacks.')
|
||||||
|
@ -282,7 +282,7 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
|
|||||||
class SleepCallback(keras.callbacks.Callback):
|
class SleepCallback(keras.callbacks.Callback):
|
||||||
|
|
||||||
def on_train_batch_end(self, batch, logs=None):
|
def on_train_batch_end(self, batch, logs=None):
|
||||||
time.sleep(1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
model = sequential.Sequential()
|
model = sequential.Sequential()
|
||||||
model.add(keras.layers.Dense(1))
|
model.add(keras.layers.Dense(1))
|
||||||
@ -298,17 +298,17 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
|
|||||||
|
|
||||||
with test.mock.patch.object(logging, 'warning', warning):
|
with test.mock.patch.object(logging, 'warning', warning):
|
||||||
model.fit(
|
model.fit(
|
||||||
np.ones((20, 1), 'float32'),
|
np.ones((16, 1), 'float32'),
|
||||||
np.ones((20, 1), 'float32'),
|
np.ones((16, 1), 'float32'),
|
||||||
batch_size=3,
|
batch_size=3,
|
||||||
epochs=10,
|
epochs=1,
|
||||||
callbacks=[SleepCallback()])
|
callbacks=[SleepCallback()])
|
||||||
warning_msg = ('Callback method `on_train_batch_end` is slow compared '
|
warning_msg = ('Callback method `on_train_batch_end` is slow compared '
|
||||||
'to the batch time')
|
'to the batch time')
|
||||||
self.assertIn(warning_msg, '\n'.join(warning_messages))
|
self.assertIn(warning_msg, '\n'.join(warning_messages))
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes
|
@keras_parameterized.run_all_keras_modes
|
||||||
def test__default_callbacks_no_warning(self):
|
def test_default_callbacks_no_warning(self):
|
||||||
# Test that without the callback no warning is raised
|
# Test that without the callback no warning is raised
|
||||||
model = sequential.Sequential()
|
model = sequential.Sequential()
|
||||||
model.add(keras.layers.Dense(1))
|
model.add(keras.layers.Dense(1))
|
||||||
@ -324,10 +324,10 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
|
|||||||
|
|
||||||
with test.mock.patch.object(logging, 'warning', warning):
|
with test.mock.patch.object(logging, 'warning', warning):
|
||||||
model.fit(
|
model.fit(
|
||||||
np.ones((20, 1), 'float32'),
|
np.ones((16, 1), 'float32'),
|
||||||
np.ones((20, 1), 'float32'),
|
np.ones((16, 1), 'float32'),
|
||||||
batch_size=3,
|
batch_size=3,
|
||||||
epochs=10)
|
epochs=1)
|
||||||
self.assertListEqual(warning_messages, [])
|
self.assertListEqual(warning_messages, [])
|
||||||
|
|
||||||
@keras_parameterized.run_with_all_model_types(exclude_models='functional')
|
@keras_parameterized.run_with_all_model_types(exclude_models='functional')
|
||||||
|
Loading…
Reference in New Issue
Block a user