From 37f5570f8028fee1f017348e74e38c9083dea6f1 Mon Sep 17 00:00:00 2001 From: Monica Song <monicadsong@google.com> Date: Thu, 22 Oct 2020 11:59:01 -0700 Subject: [PATCH] Decrease frequency of logged warning for untraced functions while saving. Remove warning if attempting to restore an untraced function while loading. PiperOrigin-RevId: 338519892 Change-Id: If328725d39560a2a2c37a87bb6d7e2c9ed06e979 --- .../saved_model/function_deserialization.py | 4 ---- tensorflow/python/saved_model/load_test.py | 10 +++------- tensorflow/python/saved_model/save.py | 16 ++++++++++++---- tensorflow/python/saved_model/save_test.py | 5 +++-- 4 files changed, 18 insertions(+), 17 deletions(-) diff --git a/tensorflow/python/saved_model/function_deserialization.py b/tensorflow/python/saved_model/function_deserialization.py index ed7eabeb519..4d24f7dd009 100644 --- a/tensorflow/python/saved_model/function_deserialization.py +++ b/tensorflow/python/saved_model/function_deserialization.py @@ -216,10 +216,6 @@ def recreate_function(saved_function, concrete_functions): Returns: A `Function`. """ - if not saved_function.concrete_functions: - logging.warning("Could not find any concrete functions to restore for this " - "SavedFunction object while loading. The function will not " - "be callable.") # TODO(andresp): Construct a `Function` with the cache populated # instead of creating a new `Function` backed by a Python layer to # glue things together. Current approach is nesting functions deeper for each diff --git a/tensorflow/python/saved_model/load_test.py b/tensorflow/python/saved_model/load_test.py index 5733b554ea9..5d3315b0d68 100644 --- a/tensorflow/python/saved_model/load_test.py +++ b/tensorflow/python/saved_model/load_test.py @@ -2070,14 +2070,10 @@ class SingleCycleTests(test.TestCase, parameterized.TestCase): loaded = cycle(root, 1) expected_save_message = ( - "WARNING:absl:No concrete functions found for untraced function `foo` " - "while saving. This function will not be callable after loading.") - expected_load_message = ( - "WARNING:absl:Could not find any concrete functions to restore for " - "this SavedFunction object while loading. The function will not be " - "callable.") + "WARNING:absl:Found untraced functions such as foo while saving " + "(showing 1 of 1). These functions will not be directly callable after " + "loading.") self.assertIn(expected_save_message, logs.output) - self.assertIn(expected_load_message, logs.output) with self.assertRaisesRegex( ValueError, "Found zero restored functions for caller function."): diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py index 27a2867dcd4..87a65724ab9 100644 --- a/tensorflow/python/saved_model/save.py +++ b/tensorflow/python/saved_model/save.py @@ -74,6 +74,9 @@ _UNCOPIABLE_DTYPES = frozenset((dtypes.resource, dtypes.variant)) _CapturedConstant = collections.namedtuple("_CapturedConstant", ["eager_tensor", "graph_tensor"]) +# Number of untraced functions to display to user in warning message. +_NUM_DISPLAY_UNTRACED_FUNCTIONS = 5 + class _AugmentedGraphView(graph_view.ObjectGraphView): """An extendable graph which also tracks functions attached to objects. @@ -187,6 +190,7 @@ class _SaveableView(object): self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary() self.slot_variables = slot_variables self.concrete_functions = [] + self.untraced_functions = [] self.saveable_objects_for_node, all_saveable_functions = ( self._add_saveable_objects()) @@ -223,15 +227,19 @@ class _SaveableView(object): else: concrete_functions = [function] if not concrete_functions: - logging.warning( - "No concrete functions found for untraced function `%s` while " - "saving. This function will not be callable after loading.", - function._name) + self.untraced_functions.append(function._name) for concrete_function in concrete_functions: if concrete_function.name not in seen_function_names: seen_function_names.add(concrete_function.name) self.concrete_functions.append(concrete_function) + if self.untraced_functions: + logging.warning( + "Found untraced functions such as %s while saving (showing %d of %d)." + " These functions will not be directly callable after loading.", + ", ".join(self.untraced_functions[:_NUM_DISPLAY_UNTRACED_FUNCTIONS]), + min(_NUM_DISPLAY_UNTRACED_FUNCTIONS, len(self.untraced_functions)), + len(self.untraced_functions)) def _add_saveable_objects(self): """Retrieves SaveablesObjects and traces their save/restore functions.""" diff --git a/tensorflow/python/saved_model/save_test.py b/tensorflow/python/saved_model/save_test.py index 3de9d70e9dc..77e19b8a8ec 100644 --- a/tensorflow/python/saved_model/save_test.py +++ b/tensorflow/python/saved_model/save_test.py @@ -308,8 +308,9 @@ class SaveTest(test.TestCase, parameterized.TestCase): save.save(root, save_dir) expected_message = ( - "WARNING:absl:No concrete functions found for untraced function `foo` " - "while saving. This function will not be callable after loading.") + "WARNING:absl:Found untraced functions such as foo while saving " + "(showing 1 of 1). These functions will not be directly callable after " + "loading.") self.assertIn(expected_message, logs.output) def test_find_default_save_function(self):