Decrease frequency of logged warning for untraced functions while saving. Remove warning if attempting to restore an untraced function while loading.
PiperOrigin-RevId: 338519892 Change-Id: If328725d39560a2a2c37a87bb6d7e2c9ed06e979
This commit is contained in:
parent
8bfaaccad9
commit
4e281e9a21
@ -216,10 +216,6 @@ def recreate_function(saved_function, concrete_functions):
|
|||||||
Returns:
|
Returns:
|
||||||
A `Function`.
|
A `Function`.
|
||||||
"""
|
"""
|
||||||
if not saved_function.concrete_functions:
|
|
||||||
logging.warning("Could not find any concrete functions to restore for this "
|
|
||||||
"SavedFunction object while loading. The function will not "
|
|
||||||
"be callable.")
|
|
||||||
# TODO(andresp): Construct a `Function` with the cache populated
|
# TODO(andresp): Construct a `Function` with the cache populated
|
||||||
# instead of creating a new `Function` backed by a Python layer to
|
# instead of creating a new `Function` backed by a Python layer to
|
||||||
# glue things together. Current approach is nesting functions deeper for each
|
# glue things together. Current approach is nesting functions deeper for each
|
||||||
|
@ -2070,14 +2070,10 @@ class SingleCycleTests(test.TestCase, parameterized.TestCase):
|
|||||||
loaded = cycle(root, 1)
|
loaded = cycle(root, 1)
|
||||||
|
|
||||||
expected_save_message = (
|
expected_save_message = (
|
||||||
"WARNING:absl:No concrete functions found for untraced function `foo` "
|
"WARNING:absl:Found untraced functions such as foo while saving "
|
||||||
"while saving. This function will not be callable after loading.")
|
"(showing 1 of 1). These functions will not be directly callable after "
|
||||||
expected_load_message = (
|
"loading.")
|
||||||
"WARNING:absl:Could not find any concrete functions to restore for "
|
|
||||||
"this SavedFunction object while loading. The function will not be "
|
|
||||||
"callable.")
|
|
||||||
self.assertIn(expected_save_message, logs.output)
|
self.assertIn(expected_save_message, logs.output)
|
||||||
self.assertIn(expected_load_message, logs.output)
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError, "Found zero restored functions for caller function."):
|
ValueError, "Found zero restored functions for caller function."):
|
||||||
|
@ -74,6 +74,9 @@ _UNCOPIABLE_DTYPES = frozenset((dtypes.resource, dtypes.variant))
|
|||||||
_CapturedConstant = collections.namedtuple("_CapturedConstant",
|
_CapturedConstant = collections.namedtuple("_CapturedConstant",
|
||||||
["eager_tensor", "graph_tensor"])
|
["eager_tensor", "graph_tensor"])
|
||||||
|
|
||||||
|
# Number of untraced functions to display to user in warning message.
|
||||||
|
_NUM_DISPLAY_UNTRACED_FUNCTIONS = 5
|
||||||
|
|
||||||
|
|
||||||
class _AugmentedGraphView(graph_view.ObjectGraphView):
|
class _AugmentedGraphView(graph_view.ObjectGraphView):
|
||||||
"""An extendable graph which also tracks functions attached to objects.
|
"""An extendable graph which also tracks functions attached to objects.
|
||||||
@ -187,6 +190,7 @@ class _SaveableView(object):
|
|||||||
self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary()
|
self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary()
|
||||||
self.slot_variables = slot_variables
|
self.slot_variables = slot_variables
|
||||||
self.concrete_functions = []
|
self.concrete_functions = []
|
||||||
|
self.untraced_functions = []
|
||||||
|
|
||||||
self.saveable_objects_for_node, all_saveable_functions = (
|
self.saveable_objects_for_node, all_saveable_functions = (
|
||||||
self._add_saveable_objects())
|
self._add_saveable_objects())
|
||||||
@ -223,15 +227,19 @@ class _SaveableView(object):
|
|||||||
else:
|
else:
|
||||||
concrete_functions = [function]
|
concrete_functions = [function]
|
||||||
if not concrete_functions:
|
if not concrete_functions:
|
||||||
logging.warning(
|
self.untraced_functions.append(function._name)
|
||||||
"No concrete functions found for untraced function `%s` while "
|
|
||||||
"saving. This function will not be callable after loading.",
|
|
||||||
function._name)
|
|
||||||
|
|
||||||
for concrete_function in concrete_functions:
|
for concrete_function in concrete_functions:
|
||||||
if concrete_function.name not in seen_function_names:
|
if concrete_function.name not in seen_function_names:
|
||||||
seen_function_names.add(concrete_function.name)
|
seen_function_names.add(concrete_function.name)
|
||||||
self.concrete_functions.append(concrete_function)
|
self.concrete_functions.append(concrete_function)
|
||||||
|
if self.untraced_functions:
|
||||||
|
logging.warning(
|
||||||
|
"Found untraced functions such as %s while saving (showing %d of %d)."
|
||||||
|
" These functions will not be directly callable after loading.",
|
||||||
|
", ".join(self.untraced_functions[:_NUM_DISPLAY_UNTRACED_FUNCTIONS]),
|
||||||
|
min(_NUM_DISPLAY_UNTRACED_FUNCTIONS, len(self.untraced_functions)),
|
||||||
|
len(self.untraced_functions))
|
||||||
|
|
||||||
def _add_saveable_objects(self):
|
def _add_saveable_objects(self):
|
||||||
"""Retrieves SaveablesObjects and traces their save/restore functions."""
|
"""Retrieves SaveablesObjects and traces their save/restore functions."""
|
||||||
|
@ -308,8 +308,9 @@ class SaveTest(test.TestCase, parameterized.TestCase):
|
|||||||
save.save(root, save_dir)
|
save.save(root, save_dir)
|
||||||
|
|
||||||
expected_message = (
|
expected_message = (
|
||||||
"WARNING:absl:No concrete functions found for untraced function `foo` "
|
"WARNING:absl:Found untraced functions such as foo while saving "
|
||||||
"while saving. This function will not be callable after loading.")
|
"(showing 1 of 1). These functions will not be directly callable after "
|
||||||
|
"loading.")
|
||||||
self.assertIn(expected_message, logs.output)
|
self.assertIn(expected_message, logs.output)
|
||||||
|
|
||||||
def test_find_default_save_function(self):
|
def test_find_default_save_function(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user