From 37f5570f8028fee1f017348e74e38c9083dea6f1 Mon Sep 17 00:00:00 2001
From: Monica Song <monicadsong@google.com>
Date: Thu, 22 Oct 2020 11:59:01 -0700
Subject: [PATCH] Decrease frequency of logged warning for untraced functions
 while saving. Remove warning if attempting to restore an untraced function
 while loading.

PiperOrigin-RevId: 338519892
Change-Id: If328725d39560a2a2c37a87bb6d7e2c9ed06e979
---
 .../saved_model/function_deserialization.py      |  4 ----
 tensorflow/python/saved_model/load_test.py       | 10 +++-------
 tensorflow/python/saved_model/save.py            | 16 ++++++++++++----
 tensorflow/python/saved_model/save_test.py       |  5 +++--
 4 files changed, 18 insertions(+), 17 deletions(-)

diff --git a/tensorflow/python/saved_model/function_deserialization.py b/tensorflow/python/saved_model/function_deserialization.py
index ed7eabeb519..4d24f7dd009 100644
--- a/tensorflow/python/saved_model/function_deserialization.py
+++ b/tensorflow/python/saved_model/function_deserialization.py
@@ -216,10 +216,6 @@ def recreate_function(saved_function, concrete_functions):
   Returns:
     A `Function`.
   """
-  if not saved_function.concrete_functions:
-    logging.warning("Could not find any concrete functions to restore for this "
-                    "SavedFunction object while loading. The function will not "
-                    "be callable.")
   # TODO(andresp): Construct a `Function` with the cache populated
   # instead of creating a new `Function` backed by a Python layer to
   # glue things together. Current approach is nesting functions deeper for each
diff --git a/tensorflow/python/saved_model/load_test.py b/tensorflow/python/saved_model/load_test.py
index 5733b554ea9..5d3315b0d68 100644
--- a/tensorflow/python/saved_model/load_test.py
+++ b/tensorflow/python/saved_model/load_test.py
@@ -2070,14 +2070,10 @@ class SingleCycleTests(test.TestCase, parameterized.TestCase):
       loaded = cycle(root, 1)
 
     expected_save_message = (
-        "WARNING:absl:No concrete functions found for untraced function `foo` "
-        "while saving. This function will not be callable after loading.")
-    expected_load_message = (
-        "WARNING:absl:Could not find any concrete functions to restore for "
-        "this SavedFunction object while loading. The function will not be "
-        "callable.")
+        "WARNING:absl:Found untraced functions such as foo while saving "
+        "(showing 1 of 1). These functions will not be directly callable after "
+        "loading.")
     self.assertIn(expected_save_message, logs.output)
-    self.assertIn(expected_load_message, logs.output)
 
     with self.assertRaisesRegex(
         ValueError, "Found zero restored functions for caller function."):
diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py
index 27a2867dcd4..87a65724ab9 100644
--- a/tensorflow/python/saved_model/save.py
+++ b/tensorflow/python/saved_model/save.py
@@ -74,6 +74,9 @@ _UNCOPIABLE_DTYPES = frozenset((dtypes.resource, dtypes.variant))
 _CapturedConstant = collections.namedtuple("_CapturedConstant",
                                            ["eager_tensor", "graph_tensor"])
 
+# Number of untraced functions to display to user in warning message.
+_NUM_DISPLAY_UNTRACED_FUNCTIONS = 5
+
 
 class _AugmentedGraphView(graph_view.ObjectGraphView):
   """An extendable graph which also tracks functions attached to objects.
@@ -187,6 +190,7 @@ class _SaveableView(object):
     self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary()
     self.slot_variables = slot_variables
     self.concrete_functions = []
+    self.untraced_functions = []
 
     self.saveable_objects_for_node, all_saveable_functions = (
         self._add_saveable_objects())
@@ -223,15 +227,19 @@ class _SaveableView(object):
         else:
           concrete_functions = [function]
         if not concrete_functions:
-          logging.warning(
-              "No concrete functions found for untraced function `%s` while "
-              "saving. This function will not be callable after loading.",
-              function._name)
+          self.untraced_functions.append(function._name)
 
         for concrete_function in concrete_functions:
           if concrete_function.name not in seen_function_names:
             seen_function_names.add(concrete_function.name)
             self.concrete_functions.append(concrete_function)
+    if self.untraced_functions:
+      logging.warning(
+          "Found untraced functions such as %s while saving (showing %d of %d)."
+          " These functions will not be directly callable after loading.",
+          ", ".join(self.untraced_functions[:_NUM_DISPLAY_UNTRACED_FUNCTIONS]),
+          min(_NUM_DISPLAY_UNTRACED_FUNCTIONS, len(self.untraced_functions)),
+          len(self.untraced_functions))
 
   def _add_saveable_objects(self):
     """Retrieves SaveablesObjects and traces their save/restore functions."""
diff --git a/tensorflow/python/saved_model/save_test.py b/tensorflow/python/saved_model/save_test.py
index 3de9d70e9dc..77e19b8a8ec 100644
--- a/tensorflow/python/saved_model/save_test.py
+++ b/tensorflow/python/saved_model/save_test.py
@@ -308,8 +308,9 @@ class SaveTest(test.TestCase, parameterized.TestCase):
       save.save(root, save_dir)
 
     expected_message = (
-        "WARNING:absl:No concrete functions found for untraced function `foo` "
-        "while saving. This function will not be callable after loading.")
+        "WARNING:absl:Found untraced functions such as foo while saving "
+        "(showing 1 of 1). These functions will not be directly callable after "
+        "loading.")
     self.assertIn(expected_message, logs.output)
 
   def test_find_default_save_function(self):