From 2be895ce466fa3e3956e23452e191b909700d201 Mon Sep 17 00:00:00 2001
From: Francois Chollet <fchollet@google.com>
Date: Wed, 12 Aug 2020 16:51:20 -0700
Subject: [PATCH] Improve callback timing check by restricting it to custom
 callbacks and using the mean of 5 batches to get more accurate results.

PiperOrigin-RevId: 326342300
Change-Id: I192001a64da11c5fa66de46a9a4e01b2b090a184
---
 tensorflow/python/keras/BUILD             |  2 +-
 tensorflow/python/keras/callbacks.py      | 42 +++++++++++++++--------
 tensorflow/python/keras/callbacks_test.py | 35 +++++++++++++++----
 3 files changed, 57 insertions(+), 22 deletions(-)

diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index 24c5b9de8ca..d8eff0f2260 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -522,7 +522,7 @@ tf_py_test(
     size = "medium",
     srcs = ["callbacks_test.py"],
     python_version = "PY3",
-    shard_count = 4,
+    shard_count = 6,
     tags = [
         "no_oss",
         "notsan",
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index 5a191263241..ff3eef8b6e9 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -241,9 +241,12 @@ class CallbackList(object):
     # pylint: enable=protected-access
 
     # Performance check: Check batch hooks for slowness compared to batch time.
-    self._timing = {}
-    self._check_timing = False
+    # Only run check for custom callbacks (i.e. not present in this file).
+    self._check_timing = self.__class__ not in globals()
+    self._num_batches_for_timing_check = 5
+    self._hook_times = {}
     self._batch_start_time = None
+    self._batch_times = []
 
   def _add_default_callbacks(self, add_history, add_progbar):
     """Adds `Callback`s that are always present."""
@@ -294,7 +297,6 @@ class CallbackList(object):
   def _call_batch_begin_hook(self, mode, batch, logs):
     """Helper function for `on_*_batch_begin` methods."""
     hook_name = 'on_{mode}_batch_begin'.format(mode=mode)
-    self._check_timing = batch == 1 and hook_name not in self._timing
     self._call_batch_hook_helper(hook_name, batch, logs)
 
     if self._check_timing:
@@ -304,31 +306,39 @@ class CallbackList(object):
     """Helper function for `on_*_batch_end` methods."""
     hook_name = 'on_{mode}_batch_end'.format(mode=mode)
 
-    if self._check_timing:
+    if self._check_timing and batch >= 1:
       batch_time = time.time() - self._batch_start_time
+      self._batch_times.append(batch_time)
 
     self._call_batch_hook_helper(hook_name, batch, logs)
 
-    if self._check_timing:
+    if len(self._batch_times) >= self._num_batches_for_timing_check:
       end_hook_name = hook_name
       begin_hook_name = 'on_{mode}_batch_begin'.format(mode=mode)
+      avg_batch_time = sum(self._batch_times) / len(self._batch_times)
+      avg_end_hook_time = sum(self._hook_times[end_hook_name]) / len(
+          self._hook_times[end_hook_name])
+      avg_begin_hook_time = sum(self._hook_times[begin_hook_name]) / len(
+          self._hook_times[begin_hook_name])
 
-      threshold_time = 1.5 * batch_time
-      warning_msg = ('Callbacks method `{hook}` is slow compared to '
+      threshold_time = 1.5 * avg_batch_time
+      warning_msg = ('Callback method `{hook}` is slow compared to '
                      'the batch time (batch time: {batch_time:.4f}s vs '
-                     '`{hook}` time: {cbk_time:.4f}s). Check your callbacks.')
-      if self._timing[begin_hook_name] > threshold_time:
+                     '`{hook}` time: {hook_time:.4f}s). Check your callbacks.')
+      if avg_begin_hook_time > threshold_time:
         logging.warning(warning_msg.format(
             hook=begin_hook_name,
-            batch_time=batch_time,
-            cbk_time=self._timing[begin_hook_name]))
-      if self._timing[end_hook_name] > threshold_time:
+            batch_time=avg_batch_time,
+            hook_time=avg_begin_hook_time))
+      if avg_end_hook_time > threshold_time:
         logging.warning(warning_msg.format(
             hook=end_hook_name,
-            batch_time=batch_time,
-            cbk_time=self._timing[end_hook_name]))
+            batch_time=avg_batch_time,
+            hook_time=avg_end_hook_time))
       self._check_timing = False
       self._batch_start_time = None
+      self._batch_times = []
+      self._hook_times = {}
 
   def _call_batch_hook_helper(self, hook_name, batch, logs):
     """Helper function for `on_*_batch_*` methods."""
@@ -347,7 +357,9 @@ class CallbackList(object):
         hook(batch, numpy_logs)
 
     if self._check_timing:
-      self._timing[hook_name] = time.time() - start_time
+      if hook_name not in self._hook_times:
+        self._hook_times[hook_name] = []
+      self._hook_times[hook_name].append(time.time() - start_time)
 
   def _call_begin_hook(self, mode):
     """Helper function for on_{train|test|predict}_begin methods."""
diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py
index 1ac933135b9..828c78ebf15 100644
--- a/tensorflow/python/keras/callbacks_test.py
+++ b/tensorflow/python/keras/callbacks_test.py
@@ -285,10 +285,10 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
         time.sleep(1)
 
     model = sequential.Sequential()
-    model.add(keras.layers.Dense(1, activation='sigmoid'))
+    model.add(keras.layers.Dense(1))
     model.compile(
         'sgd',
-        loss='binary_crossentropy',
+        loss='mse',
         run_eagerly=testing_utils.should_run_eagerly())
 
     warning_messages = []
@@ -298,15 +298,38 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
 
     with test.mock.patch.object(logging, 'warning', warning):
       model.fit(
-          np.ones((10, 10), 'float32'),
-          np.ones((10, 1), 'float32'),
-          batch_size=5,
+          np.ones((20, 1), 'float32'),
+          np.ones((20, 1), 'float32'),
+          batch_size=3,
           epochs=10,
           callbacks=[SleepCallback()])
-    warning_msg = ('Callbacks method `on_train_batch_end` is slow compared '
+    warning_msg = ('Callback method `on_train_batch_end` is slow compared '
                    'to the batch time')
     self.assertIn(warning_msg, '\n'.join(warning_messages))
 
+  @keras_parameterized.run_all_keras_modes
+  def test__default_callbacks_no_warning(self):
+    # Test that without the callback no warning is raised
+    model = sequential.Sequential()
+    model.add(keras.layers.Dense(1))
+    model.compile(
+        'sgd',
+        loss='mse',
+        run_eagerly=testing_utils.should_run_eagerly())
+
+    warning_messages = []
+
+    def warning(msg):
+      warning_messages.append(msg)
+
+    with test.mock.patch.object(logging, 'warning', warning):
+      model.fit(
+          np.ones((20, 1), 'float32'),
+          np.ones((20, 1), 'float32'),
+          batch_size=3,
+          epochs=10)
+    self.assertListEqual(warning_messages, [])
+
   @keras_parameterized.run_with_all_model_types(exclude_models='functional')
   @keras_parameterized.run_all_keras_modes
   def test_progbar_logging_deferred_model_build(self):