Improve callback timing check by restricting it to custom callbacks

and using the mean of 5 batches to get more accurate results.

PiperOrigin-RevId: 326342300
Change-Id: I192001a64da11c5fa66de46a9a4e01b2b090a184
This commit is contained in:
Francois Chollet 2020-08-12 16:51:20 -07:00 committed by TensorFlower Gardener
parent f08bbe942b
commit 2be895ce46
3 changed files with 57 additions and 22 deletions
tensorflow/python/keras

View File

@ -522,7 +522,7 @@ tf_py_test(
size = "medium",
srcs = ["callbacks_test.py"],
python_version = "PY3",
shard_count = 4,
shard_count = 6,
tags = [
"no_oss",
"notsan",

View File

@ -241,9 +241,12 @@ class CallbackList(object):
# pylint: enable=protected-access
# Performance check: Check batch hooks for slowness compared to batch time.
self._timing = {}
self._check_timing = False
# Only run check for custom callbacks (i.e. not present in this file).
self._check_timing = self.__class__ not in globals()
self._num_batches_for_timing_check = 5
self._hook_times = {}
self._batch_start_time = None
self._batch_times = []
def _add_default_callbacks(self, add_history, add_progbar):
"""Adds `Callback`s that are always present."""
@ -294,7 +297,6 @@ class CallbackList(object):
def _call_batch_begin_hook(self, mode, batch, logs):
"""Helper function for `on_*_batch_begin` methods."""
hook_name = 'on_{mode}_batch_begin'.format(mode=mode)
self._check_timing = batch == 1 and hook_name not in self._timing
self._call_batch_hook_helper(hook_name, batch, logs)
if self._check_timing:
@ -304,31 +306,39 @@ class CallbackList(object):
"""Helper function for `on_*_batch_end` methods."""
hook_name = 'on_{mode}_batch_end'.format(mode=mode)
if self._check_timing:
if self._check_timing and batch >= 1:
batch_time = time.time() - self._batch_start_time
self._batch_times.append(batch_time)
self._call_batch_hook_helper(hook_name, batch, logs)
if self._check_timing:
if len(self._batch_times) >= self._num_batches_for_timing_check:
end_hook_name = hook_name
begin_hook_name = 'on_{mode}_batch_begin'.format(mode=mode)
avg_batch_time = sum(self._batch_times) / len(self._batch_times)
avg_end_hook_time = sum(self._hook_times[end_hook_name]) / len(
self._hook_times[end_hook_name])
avg_begin_hook_time = sum(self._hook_times[begin_hook_name]) / len(
self._hook_times[begin_hook_name])
threshold_time = 1.5 * batch_time
warning_msg = ('Callbacks method `{hook}` is slow compared to '
threshold_time = 1.5 * avg_batch_time
warning_msg = ('Callback method `{hook}` is slow compared to '
'the batch time (batch time: {batch_time:.4f}s vs '
'`{hook}` time: {cbk_time:.4f}s). Check your callbacks.')
if self._timing[begin_hook_name] > threshold_time:
'`{hook}` time: {hook_time:.4f}s). Check your callbacks.')
if avg_begin_hook_time > threshold_time:
logging.warning(warning_msg.format(
hook=begin_hook_name,
batch_time=batch_time,
cbk_time=self._timing[begin_hook_name]))
if self._timing[end_hook_name] > threshold_time:
batch_time=avg_batch_time,
hook_time=avg_begin_hook_time))
if avg_end_hook_time > threshold_time:
logging.warning(warning_msg.format(
hook=end_hook_name,
batch_time=batch_time,
cbk_time=self._timing[end_hook_name]))
batch_time=avg_batch_time,
hook_time=avg_end_hook_time))
self._check_timing = False
self._batch_start_time = None
self._batch_times = []
self._hook_times = {}
def _call_batch_hook_helper(self, hook_name, batch, logs):
"""Helper function for `on_*_batch_*` methods."""
@ -347,7 +357,9 @@ class CallbackList(object):
hook(batch, numpy_logs)
if self._check_timing:
self._timing[hook_name] = time.time() - start_time
if hook_name not in self._hook_times:
self._hook_times[hook_name] = []
self._hook_times[hook_name].append(time.time() - start_time)
def _call_begin_hook(self, mode):
"""Helper function for on_{train|test|predict}_begin methods."""

View File

@ -285,10 +285,10 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
time.sleep(1)
model = sequential.Sequential()
model.add(keras.layers.Dense(1, activation='sigmoid'))
model.add(keras.layers.Dense(1))
model.compile(
'sgd',
loss='binary_crossentropy',
loss='mse',
run_eagerly=testing_utils.should_run_eagerly())
warning_messages = []
@ -298,15 +298,38 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
with test.mock.patch.object(logging, 'warning', warning):
model.fit(
np.ones((10, 10), 'float32'),
np.ones((10, 1), 'float32'),
batch_size=5,
np.ones((20, 1), 'float32'),
np.ones((20, 1), 'float32'),
batch_size=3,
epochs=10,
callbacks=[SleepCallback()])
warning_msg = ('Callbacks method `on_train_batch_end` is slow compared '
warning_msg = ('Callback method `on_train_batch_end` is slow compared '
'to the batch time')
self.assertIn(warning_msg, '\n'.join(warning_messages))
@keras_parameterized.run_all_keras_modes
def test__default_callbacks_no_warning(self):
# Test that without the callback no warning is raised
model = sequential.Sequential()
model.add(keras.layers.Dense(1))
model.compile(
'sgd',
loss='mse',
run_eagerly=testing_utils.should_run_eagerly())
warning_messages = []
def warning(msg):
warning_messages.append(msg)
with test.mock.patch.object(logging, 'warning', warning):
model.fit(
np.ones((20, 1), 'float32'),
np.ones((20, 1), 'float32'),
batch_size=3,
epochs=10)
self.assertListEqual(warning_messages, [])
@keras_parameterized.run_with_all_model_types(exclude_models='functional')
@keras_parameterized.run_all_keras_modes
def test_progbar_logging_deferred_model_build(self):