Only check Callback batch timings and throw warning once per method.

PiperOrigin-RevId: 307911020
Change-Id: I5f63312261f0dfe425b2d561fd3740e24b073667
This commit is contained in:
Thomas O'Malley 2020-04-22 15:14:02 -07:00 committed by TensorFlower Gardener
parent 93d545e8b1
commit f21a440cde
2 changed files with 87 additions and 30 deletions

View File

@ -224,12 +224,7 @@ class CallbackList(object):
if params:
self.set_params(params)
self._queue_length = 10
self._reset_batch_timing()
# Determines if batch-level hooks need to be called.
# This is important for performance, because processing batch-level logs
# will cause async eager to block on each batch.
# Performance optimization: determines if batch hooks need to be called.
# pylint: disable=protected-access
self._should_call_train_batch_hooks = any(
cb._implements_train_batch_hooks() for cb in self.callbacks)
@ -239,6 +234,11 @@ class CallbackList(object):
cb._implements_predict_batch_hooks() for cb in self.callbacks)
# pylint: enable=protected-access
# Performance check: Check batch hooks for slowness compared to batch time.
self._timing = {}
self._check_timing = False
self._batch_start_time = None
def _add_default_callbacks(self, add_history, add_progbar):
"""Adds `Callback`s that are always present."""
self._progbar = None
@ -258,11 +258,6 @@ class CallbackList(object):
self._history = History()
self.callbacks.append(self._history)
def _reset_batch_timing(self):
self._delta_t_batch = 0.
self._delta_ts = collections.defaultdict(
lambda: collections.deque([], maxlen=self._queue_length))
def append(self, callback):
self.callbacks.append(callback)
@ -282,33 +277,65 @@ class CallbackList(object):
"""Helper function for all batch_{begin | end} methods."""
if not self.callbacks:
return
hook_name = 'on_{mode}_batch_{hook}'.format(mode=mode, hook=hook)
if hook == 'begin':
self._t_enter_batch = time.time()
if hook == 'end':
# Batch is ending, calculate batch time.
self._delta_t_batch = time.time() - self._t_enter_batch
if hook == 'begin':
self._call_batch_begin_hook(mode, batch, logs)
elif hook == 'end':
self._call_batch_end_hook(mode, batch, logs)
else:
raise ValueError('Unrecognized hook: {}'.format(hook))
def _call_batch_begin_hook(self, mode, batch, logs):
"""Helper function for `on_*_batch_begin` methods."""
hook_name = 'on_{mode}_batch_begin'.format(mode=mode)
self._check_timing = batch == 1 and hook_name not in self._timing
self._call_batch_hook_helper(hook_name, batch, logs)
if self._check_timing:
self._batch_start_time = time.time()
def _call_batch_end_hook(self, mode, batch, logs):
"""Helper function for `on_*_batch_end` methods."""
hook_name = 'on_{mode}_batch_end'.format(mode=mode)
if self._check_timing:
batch_time = time.time() - self._batch_start_time
self._call_batch_hook_helper(hook_name, batch, logs)
if self._check_timing:
end_hook_name = hook_name
begin_hook_name = 'on_{mode}_batch_begin'.format(mode=mode)
threshold_time = 0.5 * batch_time
warning_msg = ('Callbacks method `{hook}` is slow compared to '
'the batch time. Check your callbacks.')
if self._timing[begin_hook_name] > threshold_time:
logging.warning(warning_msg.format(hook=begin_hook_name))
if self._timing[end_hook_name] > threshold_time:
logging.warning(warning_msg.format(hook=end_hook_name))
self._check_timing = False
self._batch_start_time = None
def _call_batch_hook_helper(self, hook_name, batch, logs):
"""Helper function for `on_*_batch_*` methods."""
logs = logs or {}
t_before_callbacks = time.time()
numpy_logs = None
if self._check_timing:
start_time = time.time()
for callback in self.callbacks:
batch_hook = getattr(callback, hook_name)
hook = getattr(callback, hook_name)
if getattr(callback, '_supports_tf_logs', False):
batch_hook(batch, logs)
hook(batch, logs)
else:
if numpy_logs is None: # Only convert once.
numpy_logs = tf_utils.to_numpy_or_python_type(logs)
batch_hook(batch, numpy_logs)
self._delta_ts[hook_name].append(time.time() - t_before_callbacks)
hook(batch, numpy_logs)
delta_t_median = np.median(self._delta_ts[hook_name])
if (self._delta_t_batch > 0. and
delta_t_median > 0.95 * self._delta_t_batch and delta_t_median > 0.1):
logging.warning(
'Method (%s) is slow compared '
'to the batch update (%f). Check your callbacks.', hook_name,
delta_t_median)
if self._check_timing:
self._timing[hook_name] = time.time() - start_time
def _call_begin_hook(self, mode):
"""Helper function for on_{train|test|predict}_begin methods."""
@ -355,7 +382,6 @@ class CallbackList(object):
if numpy_logs is None: # Only convert once.
numpy_logs = tf_utils.to_numpy_or_python_type(logs)
callback.on_epoch_begin(epoch, numpy_logs)
self._reset_batch_timing()
def on_epoch_end(self, epoch, logs=None):
"""Calls the `on_epoch_end` methods of its callbacks.

View File

@ -274,6 +274,37 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
model.fit(dataset, epochs=2, steps_per_epoch=10)
self.assertRegexpMatches(printed.contents(), expected_log)
@keras_parameterized.run_all_keras_modes
def test_callback_warning(self):
class SleepCallback(keras.callbacks.Callback):
def on_train_batch_end(self, batch, logs=None):
time.sleep(1)
model = sequential.Sequential()
model.add(keras.layers.Dense(1, activation='sigmoid'))
model.compile(
'sgd',
loss='binary_crossentropy',
run_eagerly=testing_utils.should_run_eagerly())
warning_messages = []
def warning(msg):
warning_messages.append(msg)
with test.mock.patch.object(logging, 'warning', warning):
model.fit(
np.ones((10, 10), 'float32'),
np.ones((10, 1), 'float32'),
batch_size=5,
epochs=10,
callbacks=[SleepCallback()])
warning_msg = ('Callbacks method `on_train_batch_end` is slow compared '
'to the batch time. Check your callbacks.')
self.assertIn(warning_msg, warning_messages)
@keras_parameterized.run_with_all_model_types(exclude_models='functional')
@keras_parameterized.run_all_keras_modes
def test_progbar_logging_deferred_model_build(self):