Improve callback timing check by restricting it to custom callbacks
and using the mean of 5 batches to get more accurate results. PiperOrigin-RevId: 326342300 Change-Id: I192001a64da11c5fa66de46a9a4e01b2b090a184
This commit is contained in:
parent
f08bbe942b
commit
2be895ce46
tensorflow/python/keras
@ -522,7 +522,7 @@ tf_py_test(
|
||||
size = "medium",
|
||||
srcs = ["callbacks_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 4,
|
||||
shard_count = 6,
|
||||
tags = [
|
||||
"no_oss",
|
||||
"notsan",
|
||||
|
@ -241,9 +241,12 @@ class CallbackList(object):
|
||||
# pylint: enable=protected-access
|
||||
|
||||
# Performance check: Check batch hooks for slowness compared to batch time.
|
||||
self._timing = {}
|
||||
self._check_timing = False
|
||||
# Only run check for custom callbacks (i.e. not present in this file).
|
||||
self._check_timing = self.__class__ not in globals()
|
||||
self._num_batches_for_timing_check = 5
|
||||
self._hook_times = {}
|
||||
self._batch_start_time = None
|
||||
self._batch_times = []
|
||||
|
||||
def _add_default_callbacks(self, add_history, add_progbar):
|
||||
"""Adds `Callback`s that are always present."""
|
||||
@ -294,7 +297,6 @@ class CallbackList(object):
|
||||
def _call_batch_begin_hook(self, mode, batch, logs):
|
||||
"""Helper function for `on_*_batch_begin` methods."""
|
||||
hook_name = 'on_{mode}_batch_begin'.format(mode=mode)
|
||||
self._check_timing = batch == 1 and hook_name not in self._timing
|
||||
self._call_batch_hook_helper(hook_name, batch, logs)
|
||||
|
||||
if self._check_timing:
|
||||
@ -304,31 +306,39 @@ class CallbackList(object):
|
||||
"""Helper function for `on_*_batch_end` methods."""
|
||||
hook_name = 'on_{mode}_batch_end'.format(mode=mode)
|
||||
|
||||
if self._check_timing:
|
||||
if self._check_timing and batch >= 1:
|
||||
batch_time = time.time() - self._batch_start_time
|
||||
self._batch_times.append(batch_time)
|
||||
|
||||
self._call_batch_hook_helper(hook_name, batch, logs)
|
||||
|
||||
if self._check_timing:
|
||||
if len(self._batch_times) >= self._num_batches_for_timing_check:
|
||||
end_hook_name = hook_name
|
||||
begin_hook_name = 'on_{mode}_batch_begin'.format(mode=mode)
|
||||
avg_batch_time = sum(self._batch_times) / len(self._batch_times)
|
||||
avg_end_hook_time = sum(self._hook_times[end_hook_name]) / len(
|
||||
self._hook_times[end_hook_name])
|
||||
avg_begin_hook_time = sum(self._hook_times[begin_hook_name]) / len(
|
||||
self._hook_times[begin_hook_name])
|
||||
|
||||
threshold_time = 1.5 * batch_time
|
||||
warning_msg = ('Callbacks method `{hook}` is slow compared to '
|
||||
threshold_time = 1.5 * avg_batch_time
|
||||
warning_msg = ('Callback method `{hook}` is slow compared to '
|
||||
'the batch time (batch time: {batch_time:.4f}s vs '
|
||||
'`{hook}` time: {cbk_time:.4f}s). Check your callbacks.')
|
||||
if self._timing[begin_hook_name] > threshold_time:
|
||||
'`{hook}` time: {hook_time:.4f}s). Check your callbacks.')
|
||||
if avg_begin_hook_time > threshold_time:
|
||||
logging.warning(warning_msg.format(
|
||||
hook=begin_hook_name,
|
||||
batch_time=batch_time,
|
||||
cbk_time=self._timing[begin_hook_name]))
|
||||
if self._timing[end_hook_name] > threshold_time:
|
||||
batch_time=avg_batch_time,
|
||||
hook_time=avg_begin_hook_time))
|
||||
if avg_end_hook_time > threshold_time:
|
||||
logging.warning(warning_msg.format(
|
||||
hook=end_hook_name,
|
||||
batch_time=batch_time,
|
||||
cbk_time=self._timing[end_hook_name]))
|
||||
batch_time=avg_batch_time,
|
||||
hook_time=avg_end_hook_time))
|
||||
self._check_timing = False
|
||||
self._batch_start_time = None
|
||||
self._batch_times = []
|
||||
self._hook_times = {}
|
||||
|
||||
def _call_batch_hook_helper(self, hook_name, batch, logs):
|
||||
"""Helper function for `on_*_batch_*` methods."""
|
||||
@ -347,7 +357,9 @@ class CallbackList(object):
|
||||
hook(batch, numpy_logs)
|
||||
|
||||
if self._check_timing:
|
||||
self._timing[hook_name] = time.time() - start_time
|
||||
if hook_name not in self._hook_times:
|
||||
self._hook_times[hook_name] = []
|
||||
self._hook_times[hook_name].append(time.time() - start_time)
|
||||
|
||||
def _call_begin_hook(self, mode):
|
||||
"""Helper function for on_{train|test|predict}_begin methods."""
|
||||
|
@ -285,10 +285,10 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
|
||||
time.sleep(1)
|
||||
|
||||
model = sequential.Sequential()
|
||||
model.add(keras.layers.Dense(1, activation='sigmoid'))
|
||||
model.add(keras.layers.Dense(1))
|
||||
model.compile(
|
||||
'sgd',
|
||||
loss='binary_crossentropy',
|
||||
loss='mse',
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
|
||||
warning_messages = []
|
||||
@ -298,15 +298,38 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
|
||||
|
||||
with test.mock.patch.object(logging, 'warning', warning):
|
||||
model.fit(
|
||||
np.ones((10, 10), 'float32'),
|
||||
np.ones((10, 1), 'float32'),
|
||||
batch_size=5,
|
||||
np.ones((20, 1), 'float32'),
|
||||
np.ones((20, 1), 'float32'),
|
||||
batch_size=3,
|
||||
epochs=10,
|
||||
callbacks=[SleepCallback()])
|
||||
warning_msg = ('Callbacks method `on_train_batch_end` is slow compared '
|
||||
warning_msg = ('Callback method `on_train_batch_end` is slow compared '
|
||||
'to the batch time')
|
||||
self.assertIn(warning_msg, '\n'.join(warning_messages))
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test__default_callbacks_no_warning(self):
|
||||
# Test that without the callback no warning is raised
|
||||
model = sequential.Sequential()
|
||||
model.add(keras.layers.Dense(1))
|
||||
model.compile(
|
||||
'sgd',
|
||||
loss='mse',
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
|
||||
warning_messages = []
|
||||
|
||||
def warning(msg):
|
||||
warning_messages.append(msg)
|
||||
|
||||
with test.mock.patch.object(logging, 'warning', warning):
|
||||
model.fit(
|
||||
np.ones((20, 1), 'float32'),
|
||||
np.ones((20, 1), 'float32'),
|
||||
batch_size=3,
|
||||
epochs=10)
|
||||
self.assertListEqual(warning_messages, [])
|
||||
|
||||
@keras_parameterized.run_with_all_model_types(exclude_models='functional')
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_progbar_logging_deferred_model_build(self):
|
||||
|
Loading…
Reference in New Issue
Block a user