From c979f5a4244fa390deb545b58156f0e60eb22975 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Thu, 13 Aug 2020 11:06:16 -0700 Subject: [PATCH] Fix issue with callback timing check PiperOrigin-RevId: 326483013 Change-Id: I9b9a20d8881ecfb44be5641690a753a8cd18c821 --- tensorflow/python/keras/callbacks.py | 5 +++-- tensorflow/python/keras/callbacks_test.py | 16 ++++++++-------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index ff3eef8b6e9..a7e3a404f4d 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -242,7 +242,8 @@ class CallbackList(object): # Performance check: Check batch hooks for slowness compared to batch time. # Only run check for custom callbacks (i.e. not present in this file). - self._check_timing = self.__class__ not in globals() + self._check_timing = any([cbk.__class__.__name__ not in globals() + for cbk in self.callbacks]) self._num_batches_for_timing_check = 5 self._hook_times = {} self._batch_start_time = None @@ -321,7 +322,7 @@ class CallbackList(object): avg_begin_hook_time = sum(self._hook_times[begin_hook_name]) / len( self._hook_times[begin_hook_name]) - threshold_time = 1.5 * avg_batch_time + threshold_time = 1.0 * avg_batch_time warning_msg = ('Callback method `{hook}` is slow compared to ' 'the batch time (batch time: {batch_time:.4f}s vs ' '`{hook}` time: {hook_time:.4f}s). Check your callbacks.') diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py index 828c78ebf15..9fd8bf86609 100644 --- a/tensorflow/python/keras/callbacks_test.py +++ b/tensorflow/python/keras/callbacks_test.py @@ -282,7 +282,7 @@ class KerasCallbacksTest(keras_parameterized.TestCase): class SleepCallback(keras.callbacks.Callback): def on_train_batch_end(self, batch, logs=None): - time.sleep(1) + time.sleep(0.1) model = sequential.Sequential() model.add(keras.layers.Dense(1)) @@ -298,17 +298,17 @@ class KerasCallbacksTest(keras_parameterized.TestCase): with test.mock.patch.object(logging, 'warning', warning): model.fit( - np.ones((20, 1), 'float32'), - np.ones((20, 1), 'float32'), + np.ones((16, 1), 'float32'), + np.ones((16, 1), 'float32'), batch_size=3, - epochs=10, + epochs=1, callbacks=[SleepCallback()]) warning_msg = ('Callback method `on_train_batch_end` is slow compared ' 'to the batch time') self.assertIn(warning_msg, '\n'.join(warning_messages)) @keras_parameterized.run_all_keras_modes - def test__default_callbacks_no_warning(self): + def test_default_callbacks_no_warning(self): # Test that without the callback no warning is raised model = sequential.Sequential() model.add(keras.layers.Dense(1)) @@ -324,10 +324,10 @@ class KerasCallbacksTest(keras_parameterized.TestCase): with test.mock.patch.object(logging, 'warning', warning): model.fit( - np.ones((20, 1), 'float32'), - np.ones((20, 1), 'float32'), + np.ones((16, 1), 'float32'), + np.ones((16, 1), 'float32'), batch_size=3, - epochs=10) + epochs=1) self.assertListEqual(warning_messages, []) @keras_parameterized.run_with_all_model_types(exclude_models='functional')