Merge pull request #44667 from monicadsong/cherrypicks_NGZKF

2.4.0rc1 cherry-pick request: Reduce frequency of logging for untraced functions while loading and saving from SavedModel
This commit is contained in:
Goldie Gadde 2020-11-06 17:46:19 -08:00 committed by GitHub
commit ef82f4c66c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 18 additions and 17 deletions

View File

@ -216,10 +216,6 @@ def recreate_function(saved_function, concrete_functions):
Returns:
A `Function`.
"""
if not saved_function.concrete_functions:
logging.warning("Could not find any concrete functions to restore for this "
"SavedFunction object while loading. The function will not "
"be callable.")
# TODO(andresp): Construct a `Function` with the cache populated
# instead of creating a new `Function` backed by a Python layer to
# glue things together. Current approach is nesting functions deeper for each

View File

@ -2070,14 +2070,10 @@ class SingleCycleTests(test.TestCase, parameterized.TestCase):
loaded = cycle(root, 1)
expected_save_message = (
"WARNING:absl:No concrete functions found for untraced function `foo` "
"while saving. This function will not be callable after loading.")
expected_load_message = (
"WARNING:absl:Could not find any concrete functions to restore for "
"this SavedFunction object while loading. The function will not be "
"callable.")
"WARNING:absl:Found untraced functions such as foo while saving "
"(showing 1 of 1). These functions will not be directly callable after "
"loading.")
self.assertIn(expected_save_message, logs.output)
self.assertIn(expected_load_message, logs.output)
with self.assertRaisesRegex(
ValueError, "Found zero restored functions for caller function."):

View File

@ -74,6 +74,9 @@ _UNCOPIABLE_DTYPES = frozenset((dtypes.resource, dtypes.variant))
_CapturedConstant = collections.namedtuple("_CapturedConstant",
["eager_tensor", "graph_tensor"])
# Number of untraced functions to display to user in warning message.
_NUM_DISPLAY_UNTRACED_FUNCTIONS = 5
class _AugmentedGraphView(graph_view.ObjectGraphView):
"""An extendable graph which also tracks functions attached to objects.
@ -187,6 +190,7 @@ class _SaveableView(object):
self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary()
self.slot_variables = slot_variables
self.concrete_functions = []
self.untraced_functions = []
self.saveable_objects_for_node, all_saveable_functions = (
self._add_saveable_objects())
@ -223,15 +227,19 @@ class _SaveableView(object):
else:
concrete_functions = [function]
if not concrete_functions:
logging.warning(
"No concrete functions found for untraced function `%s` while "
"saving. This function will not be callable after loading.",
function._name)
self.untraced_functions.append(function._name)
for concrete_function in concrete_functions:
if concrete_function.name not in seen_function_names:
seen_function_names.add(concrete_function.name)
self.concrete_functions.append(concrete_function)
if self.untraced_functions:
logging.warning(
"Found untraced functions such as %s while saving (showing %d of %d)."
" These functions will not be directly callable after loading.",
", ".join(self.untraced_functions[:_NUM_DISPLAY_UNTRACED_FUNCTIONS]),
min(_NUM_DISPLAY_UNTRACED_FUNCTIONS, len(self.untraced_functions)),
len(self.untraced_functions))
def _add_saveable_objects(self):
"""Retrieves SaveablesObjects and traces their save/restore functions."""

View File

@ -308,8 +308,9 @@ class SaveTest(test.TestCase, parameterized.TestCase):
save.save(root, save_dir)
expected_message = (
"WARNING:absl:No concrete functions found for untraced function `foo` "
"while saving. This function will not be callable after loading.")
"WARNING:absl:Found untraced functions such as foo while saving "
"(showing 1 of 1). These functions will not be directly callable after "
"loading.")
self.assertIn(expected_message, logs.output)
def test_find_default_save_function(self):