Decrease frequency of logged warning for untraced functions while saving. Remove warning if attempting to restore an untraced function while loading.

PiperOrigin-RevId: 338519892
Change-Id: If328725d39560a2a2c37a87bb6d7e2c9ed06e979
This commit is contained in:
Monica Song 2020-10-22 11:59:01 -07:00 committed by TensorFlower Gardener
parent 8bfaaccad9
commit 4e281e9a21
4 changed files with 18 additions and 17 deletions

View File

@ -216,10 +216,6 @@ def recreate_function(saved_function, concrete_functions):
Returns: Returns:
A `Function`. A `Function`.
""" """
if not saved_function.concrete_functions:
logging.warning("Could not find any concrete functions to restore for this "
"SavedFunction object while loading. The function will not "
"be callable.")
# TODO(andresp): Construct a `Function` with the cache populated # TODO(andresp): Construct a `Function` with the cache populated
# instead of creating a new `Function` backed by a Python layer to # instead of creating a new `Function` backed by a Python layer to
# glue things together. Current approach is nesting functions deeper for each # glue things together. Current approach is nesting functions deeper for each

View File

@ -2070,14 +2070,10 @@ class SingleCycleTests(test.TestCase, parameterized.TestCase):
loaded = cycle(root, 1) loaded = cycle(root, 1)
expected_save_message = ( expected_save_message = (
"WARNING:absl:No concrete functions found for untraced function `foo` " "WARNING:absl:Found untraced functions such as foo while saving "
"while saving. This function will not be callable after loading.") "(showing 1 of 1). These functions will not be directly callable after "
expected_load_message = ( "loading.")
"WARNING:absl:Could not find any concrete functions to restore for "
"this SavedFunction object while loading. The function will not be "
"callable.")
self.assertIn(expected_save_message, logs.output) self.assertIn(expected_save_message, logs.output)
self.assertIn(expected_load_message, logs.output)
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, "Found zero restored functions for caller function."): ValueError, "Found zero restored functions for caller function."):

View File

@ -74,6 +74,9 @@ _UNCOPIABLE_DTYPES = frozenset((dtypes.resource, dtypes.variant))
_CapturedConstant = collections.namedtuple("_CapturedConstant", _CapturedConstant = collections.namedtuple("_CapturedConstant",
["eager_tensor", "graph_tensor"]) ["eager_tensor", "graph_tensor"])
# Number of untraced functions to display to user in warning message.
_NUM_DISPLAY_UNTRACED_FUNCTIONS = 5
class _AugmentedGraphView(graph_view.ObjectGraphView): class _AugmentedGraphView(graph_view.ObjectGraphView):
"""An extendable graph which also tracks functions attached to objects. """An extendable graph which also tracks functions attached to objects.
@ -187,6 +190,7 @@ class _SaveableView(object):
self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary() self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary()
self.slot_variables = slot_variables self.slot_variables = slot_variables
self.concrete_functions = [] self.concrete_functions = []
self.untraced_functions = []
self.saveable_objects_for_node, all_saveable_functions = ( self.saveable_objects_for_node, all_saveable_functions = (
self._add_saveable_objects()) self._add_saveable_objects())
@ -223,15 +227,19 @@ class _SaveableView(object):
else: else:
concrete_functions = [function] concrete_functions = [function]
if not concrete_functions: if not concrete_functions:
logging.warning( self.untraced_functions.append(function._name)
"No concrete functions found for untraced function `%s` while "
"saving. This function will not be callable after loading.",
function._name)
for concrete_function in concrete_functions: for concrete_function in concrete_functions:
if concrete_function.name not in seen_function_names: if concrete_function.name not in seen_function_names:
seen_function_names.add(concrete_function.name) seen_function_names.add(concrete_function.name)
self.concrete_functions.append(concrete_function) self.concrete_functions.append(concrete_function)
if self.untraced_functions:
logging.warning(
"Found untraced functions such as %s while saving (showing %d of %d)."
" These functions will not be directly callable after loading.",
", ".join(self.untraced_functions[:_NUM_DISPLAY_UNTRACED_FUNCTIONS]),
min(_NUM_DISPLAY_UNTRACED_FUNCTIONS, len(self.untraced_functions)),
len(self.untraced_functions))
def _add_saveable_objects(self): def _add_saveable_objects(self):
"""Retrieves SaveablesObjects and traces their save/restore functions.""" """Retrieves SaveablesObjects and traces their save/restore functions."""

View File

@ -308,8 +308,9 @@ class SaveTest(test.TestCase, parameterized.TestCase):
save.save(root, save_dir) save.save(root, save_dir)
expected_message = ( expected_message = (
"WARNING:absl:No concrete functions found for untraced function `foo` " "WARNING:absl:Found untraced functions such as foo while saving "
"while saving. This function will not be callable after loading.") "(showing 1 of 1). These functions will not be directly callable after "
"loading.")
self.assertIn(expected_message, logs.output) self.assertIn(expected_message, logs.output)
def test_find_default_save_function(self): def test_find_default_save_function(self):