Merge pull request #44667 from monicadsong/cherrypicks_NGZKF
2.4.0rc1 cherry-pick request: Reduce frequency of logging for untraced functions while loading and saving from SavedModel
This commit is contained in:
commit
ef82f4c66c
@ -216,10 +216,6 @@ def recreate_function(saved_function, concrete_functions):
|
||||
Returns:
|
||||
A `Function`.
|
||||
"""
|
||||
if not saved_function.concrete_functions:
|
||||
logging.warning("Could not find any concrete functions to restore for this "
|
||||
"SavedFunction object while loading. The function will not "
|
||||
"be callable.")
|
||||
# TODO(andresp): Construct a `Function` with the cache populated
|
||||
# instead of creating a new `Function` backed by a Python layer to
|
||||
# glue things together. Current approach is nesting functions deeper for each
|
||||
|
@ -2070,14 +2070,10 @@ class SingleCycleTests(test.TestCase, parameterized.TestCase):
|
||||
loaded = cycle(root, 1)
|
||||
|
||||
expected_save_message = (
|
||||
"WARNING:absl:No concrete functions found for untraced function `foo` "
|
||||
"while saving. This function will not be callable after loading.")
|
||||
expected_load_message = (
|
||||
"WARNING:absl:Could not find any concrete functions to restore for "
|
||||
"this SavedFunction object while loading. The function will not be "
|
||||
"callable.")
|
||||
"WARNING:absl:Found untraced functions such as foo while saving "
|
||||
"(showing 1 of 1). These functions will not be directly callable after "
|
||||
"loading.")
|
||||
self.assertIn(expected_save_message, logs.output)
|
||||
self.assertIn(expected_load_message, logs.output)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Found zero restored functions for caller function."):
|
||||
|
@ -74,6 +74,9 @@ _UNCOPIABLE_DTYPES = frozenset((dtypes.resource, dtypes.variant))
|
||||
_CapturedConstant = collections.namedtuple("_CapturedConstant",
|
||||
["eager_tensor", "graph_tensor"])
|
||||
|
||||
# Number of untraced functions to display to user in warning message.
|
||||
_NUM_DISPLAY_UNTRACED_FUNCTIONS = 5
|
||||
|
||||
|
||||
class _AugmentedGraphView(graph_view.ObjectGraphView):
|
||||
"""An extendable graph which also tracks functions attached to objects.
|
||||
@ -187,6 +190,7 @@ class _SaveableView(object):
|
||||
self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary()
|
||||
self.slot_variables = slot_variables
|
||||
self.concrete_functions = []
|
||||
self.untraced_functions = []
|
||||
|
||||
self.saveable_objects_for_node, all_saveable_functions = (
|
||||
self._add_saveable_objects())
|
||||
@ -223,15 +227,19 @@ class _SaveableView(object):
|
||||
else:
|
||||
concrete_functions = [function]
|
||||
if not concrete_functions:
|
||||
logging.warning(
|
||||
"No concrete functions found for untraced function `%s` while "
|
||||
"saving. This function will not be callable after loading.",
|
||||
function._name)
|
||||
self.untraced_functions.append(function._name)
|
||||
|
||||
for concrete_function in concrete_functions:
|
||||
if concrete_function.name not in seen_function_names:
|
||||
seen_function_names.add(concrete_function.name)
|
||||
self.concrete_functions.append(concrete_function)
|
||||
if self.untraced_functions:
|
||||
logging.warning(
|
||||
"Found untraced functions such as %s while saving (showing %d of %d)."
|
||||
" These functions will not be directly callable after loading.",
|
||||
", ".join(self.untraced_functions[:_NUM_DISPLAY_UNTRACED_FUNCTIONS]),
|
||||
min(_NUM_DISPLAY_UNTRACED_FUNCTIONS, len(self.untraced_functions)),
|
||||
len(self.untraced_functions))
|
||||
|
||||
def _add_saveable_objects(self):
|
||||
"""Retrieves SaveablesObjects and traces their save/restore functions."""
|
||||
|
@ -308,8 +308,9 @@ class SaveTest(test.TestCase, parameterized.TestCase):
|
||||
save.save(root, save_dir)
|
||||
|
||||
expected_message = (
|
||||
"WARNING:absl:No concrete functions found for untraced function `foo` "
|
||||
"while saving. This function will not be callable after loading.")
|
||||
"WARNING:absl:Found untraced functions such as foo while saving "
|
||||
"(showing 1 of 1). These functions will not be directly callable after "
|
||||
"loading.")
|
||||
self.assertIn(expected_message, logs.output)
|
||||
|
||||
def test_find_default_save_function(self):
|
||||
|
Loading…
Reference in New Issue
Block a user