Don't run restored functions eagerly.

We can't *really* do it anyway, and it's breaking calls with optional arguments when the user does not specify the argument explicitly.

PiperOrigin-RevId: 355651854
Change-Id: Ia8bad470f9ef6f0290855350fef10b3fb879007d
This commit is contained in:
Daniel Ellis 2021-02-04 10:04:54 -08:00 committed by TensorFlower Gardener
parent ad1df4bc4f
commit a02ad9725c
3 changed files with 40 additions and 1 deletions

View File

@ -807,9 +807,13 @@ class Function(object):
result += self._stateful_fn.tracing_count if self._stateful_fn else 0
return result
@property
def _run_functions_eagerly(self):
return RUN_FUNCTIONS_EAGERLY
def __call__(self, *args, **kwds):
"""Calls the graph function and warn too frequent tracings."""
if RUN_FUNCTIONS_EAGERLY:
if self._run_functions_eagerly:
with trace.Trace(self._name, tf_function_call="eager"):
return self._python_function(*args, **kwds)

View File

@ -199,6 +199,20 @@ class RestoredFunction(def_function.Function):
# warnings.
self._omit_frequent_tracing_warning = True
@property
def _run_functions_eagerly(self):
# We do not have access to the original python function, and thus, we
# cannot meaningfully do anything but call our concrete function graphs
# under the hood.
#
# Attempting to call our bespoke python function (i.e.
# `restored_function_body`) will work so long as the user passes in all
# required and optional arguments. If an optional argument is missing,
# however, the call will break. For this reason, we instead skip the
# eager call path altogether if a user has enabled eager function execution
# via `tf.config.run_functions_eagerly`.
return False
def _list_all_concrete_functions_for_serialization(self):
return self.concrete_functions

View File

@ -2178,5 +2178,26 @@ class SingleCycleTests(test.TestCase, parameterized.TestCase):
ValueError, "Found zero restored functions for caller function."):
loaded.foo(1)
def test_restored_function_execute_eagerly(self):
try:
def_function.run_functions_eagerly(True)
class MyModel(module.Module):
@def_function.function
def __call__(self, inputs, training=False):
return math_ops.multiply(0.5, inputs)
model = MyModel()
model.__call__.get_concrete_function(
tensor_spec.TensorSpec([None], dtypes.float32))
loaded = cycle(model, 1)
# Calling the function should not throw an exception.
loaded(constant_op.constant([1.0]))
finally:
def_function.run_functions_eagerly(False)
if __name__ == "__main__":
test.main()