From 2be895ce466fa3e3956e23452e191b909700d201 Mon Sep 17 00:00:00 2001 From: Francois Chollet <fchollet@google.com> Date: Wed, 12 Aug 2020 16:51:20 -0700 Subject: [PATCH] Improve callback timing check by restricting it to custom callbacks and using the mean of 5 batches to get more accurate results. PiperOrigin-RevId: 326342300 Change-Id: I192001a64da11c5fa66de46a9a4e01b2b090a184 --- tensorflow/python/keras/BUILD | 2 +- tensorflow/python/keras/callbacks.py | 42 +++++++++++++++-------- tensorflow/python/keras/callbacks_test.py | 35 +++++++++++++++---- 3 files changed, 57 insertions(+), 22 deletions(-) diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index 24c5b9de8ca..d8eff0f2260 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -522,7 +522,7 @@ tf_py_test( size = "medium", srcs = ["callbacks_test.py"], python_version = "PY3", - shard_count = 4, + shard_count = 6, tags = [ "no_oss", "notsan", diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 5a191263241..ff3eef8b6e9 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -241,9 +241,12 @@ class CallbackList(object): # pylint: enable=protected-access # Performance check: Check batch hooks for slowness compared to batch time. - self._timing = {} - self._check_timing = False + # Only run check for custom callbacks (i.e. not present in this file). + self._check_timing = self.__class__ not in globals() + self._num_batches_for_timing_check = 5 + self._hook_times = {} self._batch_start_time = None + self._batch_times = [] def _add_default_callbacks(self, add_history, add_progbar): """Adds `Callback`s that are always present.""" @@ -294,7 +297,6 @@ class CallbackList(object): def _call_batch_begin_hook(self, mode, batch, logs): """Helper function for `on_*_batch_begin` methods.""" hook_name = 'on_{mode}_batch_begin'.format(mode=mode) - self._check_timing = batch == 1 and hook_name not in self._timing self._call_batch_hook_helper(hook_name, batch, logs) if self._check_timing: @@ -304,31 +306,39 @@ class CallbackList(object): """Helper function for `on_*_batch_end` methods.""" hook_name = 'on_{mode}_batch_end'.format(mode=mode) - if self._check_timing: + if self._check_timing and batch >= 1: batch_time = time.time() - self._batch_start_time + self._batch_times.append(batch_time) self._call_batch_hook_helper(hook_name, batch, logs) - if self._check_timing: + if len(self._batch_times) >= self._num_batches_for_timing_check: end_hook_name = hook_name begin_hook_name = 'on_{mode}_batch_begin'.format(mode=mode) + avg_batch_time = sum(self._batch_times) / len(self._batch_times) + avg_end_hook_time = sum(self._hook_times[end_hook_name]) / len( + self._hook_times[end_hook_name]) + avg_begin_hook_time = sum(self._hook_times[begin_hook_name]) / len( + self._hook_times[begin_hook_name]) - threshold_time = 1.5 * batch_time - warning_msg = ('Callbacks method `{hook}` is slow compared to ' + threshold_time = 1.5 * avg_batch_time + warning_msg = ('Callback method `{hook}` is slow compared to ' 'the batch time (batch time: {batch_time:.4f}s vs ' - '`{hook}` time: {cbk_time:.4f}s). Check your callbacks.') - if self._timing[begin_hook_name] > threshold_time: + '`{hook}` time: {hook_time:.4f}s). Check your callbacks.') + if avg_begin_hook_time > threshold_time: logging.warning(warning_msg.format( hook=begin_hook_name, - batch_time=batch_time, - cbk_time=self._timing[begin_hook_name])) - if self._timing[end_hook_name] > threshold_time: + batch_time=avg_batch_time, + hook_time=avg_begin_hook_time)) + if avg_end_hook_time > threshold_time: logging.warning(warning_msg.format( hook=end_hook_name, - batch_time=batch_time, - cbk_time=self._timing[end_hook_name])) + batch_time=avg_batch_time, + hook_time=avg_end_hook_time)) self._check_timing = False self._batch_start_time = None + self._batch_times = [] + self._hook_times = {} def _call_batch_hook_helper(self, hook_name, batch, logs): """Helper function for `on_*_batch_*` methods.""" @@ -347,7 +357,9 @@ class CallbackList(object): hook(batch, numpy_logs) if self._check_timing: - self._timing[hook_name] = time.time() - start_time + if hook_name not in self._hook_times: + self._hook_times[hook_name] = [] + self._hook_times[hook_name].append(time.time() - start_time) def _call_begin_hook(self, mode): """Helper function for on_{train|test|predict}_begin methods.""" diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py index 1ac933135b9..828c78ebf15 100644 --- a/tensorflow/python/keras/callbacks_test.py +++ b/tensorflow/python/keras/callbacks_test.py @@ -285,10 +285,10 @@ class KerasCallbacksTest(keras_parameterized.TestCase): time.sleep(1) model = sequential.Sequential() - model.add(keras.layers.Dense(1, activation='sigmoid')) + model.add(keras.layers.Dense(1)) model.compile( 'sgd', - loss='binary_crossentropy', + loss='mse', run_eagerly=testing_utils.should_run_eagerly()) warning_messages = [] @@ -298,15 +298,38 @@ class KerasCallbacksTest(keras_parameterized.TestCase): with test.mock.patch.object(logging, 'warning', warning): model.fit( - np.ones((10, 10), 'float32'), - np.ones((10, 1), 'float32'), - batch_size=5, + np.ones((20, 1), 'float32'), + np.ones((20, 1), 'float32'), + batch_size=3, epochs=10, callbacks=[SleepCallback()]) - warning_msg = ('Callbacks method `on_train_batch_end` is slow compared ' + warning_msg = ('Callback method `on_train_batch_end` is slow compared ' 'to the batch time') self.assertIn(warning_msg, '\n'.join(warning_messages)) + @keras_parameterized.run_all_keras_modes + def test__default_callbacks_no_warning(self): + # Test that without the callback no warning is raised + model = sequential.Sequential() + model.add(keras.layers.Dense(1)) + model.compile( + 'sgd', + loss='mse', + run_eagerly=testing_utils.should_run_eagerly()) + + warning_messages = [] + + def warning(msg): + warning_messages.append(msg) + + with test.mock.patch.object(logging, 'warning', warning): + model.fit( + np.ones((20, 1), 'float32'), + np.ones((20, 1), 'float32'), + batch_size=3, + epochs=10) + self.assertListEqual(warning_messages, []) + @keras_parameterized.run_with_all_model_types(exclude_models='functional') @keras_parameterized.run_all_keras_modes def test_progbar_logging_deferred_model_build(self):