Fix issue with callback timing check

PiperOrigin-RevId: 326483013
Change-Id: I9b9a20d8881ecfb44be5641690a753a8cd18c821
This commit is contained in:
Francois Chollet 2020-08-13 11:06:16 -07:00 committed by TensorFlower Gardener
parent 3d87c9d297
commit c979f5a424
2 changed files with 11 additions and 10 deletions

View File

@ -242,7 +242,8 @@ class CallbackList(object):
# Performance check: Check batch hooks for slowness compared to batch time. # Performance check: Check batch hooks for slowness compared to batch time.
# Only run check for custom callbacks (i.e. not present in this file). # Only run check for custom callbacks (i.e. not present in this file).
self._check_timing = self.__class__ not in globals() self._check_timing = any([cbk.__class__.__name__ not in globals()
for cbk in self.callbacks])
self._num_batches_for_timing_check = 5 self._num_batches_for_timing_check = 5
self._hook_times = {} self._hook_times = {}
self._batch_start_time = None self._batch_start_time = None
@ -321,7 +322,7 @@ class CallbackList(object):
avg_begin_hook_time = sum(self._hook_times[begin_hook_name]) / len( avg_begin_hook_time = sum(self._hook_times[begin_hook_name]) / len(
self._hook_times[begin_hook_name]) self._hook_times[begin_hook_name])
threshold_time = 1.5 * avg_batch_time threshold_time = 1.0 * avg_batch_time
warning_msg = ('Callback method `{hook}` is slow compared to ' warning_msg = ('Callback method `{hook}` is slow compared to '
'the batch time (batch time: {batch_time:.4f}s vs ' 'the batch time (batch time: {batch_time:.4f}s vs '
'`{hook}` time: {hook_time:.4f}s). Check your callbacks.') '`{hook}` time: {hook_time:.4f}s). Check your callbacks.')

View File

@ -282,7 +282,7 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
class SleepCallback(keras.callbacks.Callback): class SleepCallback(keras.callbacks.Callback):
def on_train_batch_end(self, batch, logs=None): def on_train_batch_end(self, batch, logs=None):
time.sleep(1) time.sleep(0.1)
model = sequential.Sequential() model = sequential.Sequential()
model.add(keras.layers.Dense(1)) model.add(keras.layers.Dense(1))
@ -298,17 +298,17 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
with test.mock.patch.object(logging, 'warning', warning): with test.mock.patch.object(logging, 'warning', warning):
model.fit( model.fit(
np.ones((20, 1), 'float32'), np.ones((16, 1), 'float32'),
np.ones((20, 1), 'float32'), np.ones((16, 1), 'float32'),
batch_size=3, batch_size=3,
epochs=10, epochs=1,
callbacks=[SleepCallback()]) callbacks=[SleepCallback()])
warning_msg = ('Callback method `on_train_batch_end` is slow compared ' warning_msg = ('Callback method `on_train_batch_end` is slow compared '
'to the batch time') 'to the batch time')
self.assertIn(warning_msg, '\n'.join(warning_messages)) self.assertIn(warning_msg, '\n'.join(warning_messages))
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
def test__default_callbacks_no_warning(self): def test_default_callbacks_no_warning(self):
# Test that without the callback no warning is raised # Test that without the callback no warning is raised
model = sequential.Sequential() model = sequential.Sequential()
model.add(keras.layers.Dense(1)) model.add(keras.layers.Dense(1))
@ -324,10 +324,10 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
with test.mock.patch.object(logging, 'warning', warning): with test.mock.patch.object(logging, 'warning', warning):
model.fit( model.fit(
np.ones((20, 1), 'float32'), np.ones((16, 1), 'float32'),
np.ones((20, 1), 'float32'), np.ones((16, 1), 'float32'),
batch_size=3, batch_size=3,
epochs=10) epochs=1)
self.assertListEqual(warning_messages, []) self.assertListEqual(warning_messages, [])
@keras_parameterized.run_with_all_model_types(exclude_models='functional') @keras_parameterized.run_with_all_model_types(exclude_models='functional')