Only check Callback batch timings and throw warning once per method.
PiperOrigin-RevId: 307911020 Change-Id: I5f63312261f0dfe425b2d561fd3740e24b073667
This commit is contained in:
parent
93d545e8b1
commit
f21a440cde
@ -224,12 +224,7 @@ class CallbackList(object):
|
||||
if params:
|
||||
self.set_params(params)
|
||||
|
||||
self._queue_length = 10
|
||||
self._reset_batch_timing()
|
||||
|
||||
# Determines if batch-level hooks need to be called.
|
||||
# This is important for performance, because processing batch-level logs
|
||||
# will cause async eager to block on each batch.
|
||||
# Performance optimization: determines if batch hooks need to be called.
|
||||
# pylint: disable=protected-access
|
||||
self._should_call_train_batch_hooks = any(
|
||||
cb._implements_train_batch_hooks() for cb in self.callbacks)
|
||||
@ -239,6 +234,11 @@ class CallbackList(object):
|
||||
cb._implements_predict_batch_hooks() for cb in self.callbacks)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
# Performance check: Check batch hooks for slowness compared to batch time.
|
||||
self._timing = {}
|
||||
self._check_timing = False
|
||||
self._batch_start_time = None
|
||||
|
||||
def _add_default_callbacks(self, add_history, add_progbar):
|
||||
"""Adds `Callback`s that are always present."""
|
||||
self._progbar = None
|
||||
@ -258,11 +258,6 @@ class CallbackList(object):
|
||||
self._history = History()
|
||||
self.callbacks.append(self._history)
|
||||
|
||||
def _reset_batch_timing(self):
|
||||
self._delta_t_batch = 0.
|
||||
self._delta_ts = collections.defaultdict(
|
||||
lambda: collections.deque([], maxlen=self._queue_length))
|
||||
|
||||
def append(self, callback):
|
||||
self.callbacks.append(callback)
|
||||
|
||||
@ -282,33 +277,65 @@ class CallbackList(object):
|
||||
"""Helper function for all batch_{begin | end} methods."""
|
||||
if not self.callbacks:
|
||||
return
|
||||
hook_name = 'on_{mode}_batch_{hook}'.format(mode=mode, hook=hook)
|
||||
if hook == 'begin':
|
||||
self._t_enter_batch = time.time()
|
||||
if hook == 'end':
|
||||
# Batch is ending, calculate batch time.
|
||||
self._delta_t_batch = time.time() - self._t_enter_batch
|
||||
|
||||
if hook == 'begin':
|
||||
self._call_batch_begin_hook(mode, batch, logs)
|
||||
elif hook == 'end':
|
||||
self._call_batch_end_hook(mode, batch, logs)
|
||||
else:
|
||||
raise ValueError('Unrecognized hook: {}'.format(hook))
|
||||
|
||||
def _call_batch_begin_hook(self, mode, batch, logs):
|
||||
"""Helper function for `on_*_batch_begin` methods."""
|
||||
hook_name = 'on_{mode}_batch_begin'.format(mode=mode)
|
||||
self._check_timing = batch == 1 and hook_name not in self._timing
|
||||
self._call_batch_hook_helper(hook_name, batch, logs)
|
||||
|
||||
if self._check_timing:
|
||||
self._batch_start_time = time.time()
|
||||
|
||||
def _call_batch_end_hook(self, mode, batch, logs):
|
||||
"""Helper function for `on_*_batch_end` methods."""
|
||||
hook_name = 'on_{mode}_batch_end'.format(mode=mode)
|
||||
|
||||
if self._check_timing:
|
||||
batch_time = time.time() - self._batch_start_time
|
||||
|
||||
self._call_batch_hook_helper(hook_name, batch, logs)
|
||||
|
||||
if self._check_timing:
|
||||
end_hook_name = hook_name
|
||||
begin_hook_name = 'on_{mode}_batch_begin'.format(mode=mode)
|
||||
|
||||
threshold_time = 0.5 * batch_time
|
||||
warning_msg = ('Callbacks method `{hook}` is slow compared to '
|
||||
'the batch time. Check your callbacks.')
|
||||
if self._timing[begin_hook_name] > threshold_time:
|
||||
logging.warning(warning_msg.format(hook=begin_hook_name))
|
||||
if self._timing[end_hook_name] > threshold_time:
|
||||
logging.warning(warning_msg.format(hook=end_hook_name))
|
||||
|
||||
self._check_timing = False
|
||||
self._batch_start_time = None
|
||||
|
||||
def _call_batch_hook_helper(self, hook_name, batch, logs):
|
||||
"""Helper function for `on_*_batch_*` methods."""
|
||||
logs = logs or {}
|
||||
t_before_callbacks = time.time()
|
||||
numpy_logs = None
|
||||
if self._check_timing:
|
||||
start_time = time.time()
|
||||
|
||||
for callback in self.callbacks:
|
||||
batch_hook = getattr(callback, hook_name)
|
||||
hook = getattr(callback, hook_name)
|
||||
if getattr(callback, '_supports_tf_logs', False):
|
||||
batch_hook(batch, logs)
|
||||
hook(batch, logs)
|
||||
else:
|
||||
if numpy_logs is None: # Only convert once.
|
||||
numpy_logs = tf_utils.to_numpy_or_python_type(logs)
|
||||
batch_hook(batch, numpy_logs)
|
||||
self._delta_ts[hook_name].append(time.time() - t_before_callbacks)
|
||||
hook(batch, numpy_logs)
|
||||
|
||||
delta_t_median = np.median(self._delta_ts[hook_name])
|
||||
if (self._delta_t_batch > 0. and
|
||||
delta_t_median > 0.95 * self._delta_t_batch and delta_t_median > 0.1):
|
||||
logging.warning(
|
||||
'Method (%s) is slow compared '
|
||||
'to the batch update (%f). Check your callbacks.', hook_name,
|
||||
delta_t_median)
|
||||
if self._check_timing:
|
||||
self._timing[hook_name] = time.time() - start_time
|
||||
|
||||
def _call_begin_hook(self, mode):
|
||||
"""Helper function for on_{train|test|predict}_begin methods."""
|
||||
@ -355,7 +382,6 @@ class CallbackList(object):
|
||||
if numpy_logs is None: # Only convert once.
|
||||
numpy_logs = tf_utils.to_numpy_or_python_type(logs)
|
||||
callback.on_epoch_begin(epoch, numpy_logs)
|
||||
self._reset_batch_timing()
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
"""Calls the `on_epoch_end` methods of its callbacks.
|
||||
|
@ -274,6 +274,37 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
|
||||
model.fit(dataset, epochs=2, steps_per_epoch=10)
|
||||
self.assertRegexpMatches(printed.contents(), expected_log)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_callback_warning(self):
|
||||
|
||||
class SleepCallback(keras.callbacks.Callback):
|
||||
|
||||
def on_train_batch_end(self, batch, logs=None):
|
||||
time.sleep(1)
|
||||
|
||||
model = sequential.Sequential()
|
||||
model.add(keras.layers.Dense(1, activation='sigmoid'))
|
||||
model.compile(
|
||||
'sgd',
|
||||
loss='binary_crossentropy',
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
|
||||
warning_messages = []
|
||||
|
||||
def warning(msg):
|
||||
warning_messages.append(msg)
|
||||
|
||||
with test.mock.patch.object(logging, 'warning', warning):
|
||||
model.fit(
|
||||
np.ones((10, 10), 'float32'),
|
||||
np.ones((10, 1), 'float32'),
|
||||
batch_size=5,
|
||||
epochs=10,
|
||||
callbacks=[SleepCallback()])
|
||||
warning_msg = ('Callbacks method `on_train_batch_end` is slow compared '
|
||||
'to the batch time. Check your callbacks.')
|
||||
self.assertIn(warning_msg, warning_messages)
|
||||
|
||||
@keras_parameterized.run_with_all_model_types(exclude_models='functional')
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_progbar_logging_deferred_model_build(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user