Don't run restored functions eagerly.
We can't *really* do it anyway, and it's breaking calls with optional arguments when the user does not specify the argument explicitly. PiperOrigin-RevId: 355651854 Change-Id: Ia8bad470f9ef6f0290855350fef10b3fb879007d
This commit is contained in:
parent
ad1df4bc4f
commit
a02ad9725c
tensorflow/python
@ -807,9 +807,13 @@ class Function(object):
|
||||
result += self._stateful_fn.tracing_count if self._stateful_fn else 0
|
||||
return result
|
||||
|
||||
@property
|
||||
def _run_functions_eagerly(self):
|
||||
return RUN_FUNCTIONS_EAGERLY
|
||||
|
||||
def __call__(self, *args, **kwds):
|
||||
"""Calls the graph function and warn too frequent tracings."""
|
||||
if RUN_FUNCTIONS_EAGERLY:
|
||||
if self._run_functions_eagerly:
|
||||
with trace.Trace(self._name, tf_function_call="eager"):
|
||||
return self._python_function(*args, **kwds)
|
||||
|
||||
|
@ -199,6 +199,20 @@ class RestoredFunction(def_function.Function):
|
||||
# warnings.
|
||||
self._omit_frequent_tracing_warning = True
|
||||
|
||||
@property
|
||||
def _run_functions_eagerly(self):
|
||||
# We do not have access to the original python function, and thus, we
|
||||
# cannot meaningfully do anything but call our concrete function graphs
|
||||
# under the hood.
|
||||
#
|
||||
# Attempting to call our bespoke python function (i.e.
|
||||
# `restored_function_body`) will work so long as the user passes in all
|
||||
# required and optional arguments. If an optional argument is missing,
|
||||
# however, the call will break. For this reason, we instead skip the
|
||||
# eager call path altogether if a user has enabled eager function execution
|
||||
# via `tf.config.run_functions_eagerly`.
|
||||
return False
|
||||
|
||||
def _list_all_concrete_functions_for_serialization(self):
|
||||
return self.concrete_functions
|
||||
|
||||
|
@ -2178,5 +2178,26 @@ class SingleCycleTests(test.TestCase, parameterized.TestCase):
|
||||
ValueError, "Found zero restored functions for caller function."):
|
||||
loaded.foo(1)
|
||||
|
||||
def test_restored_function_execute_eagerly(self):
|
||||
try:
|
||||
def_function.run_functions_eagerly(True)
|
||||
|
||||
class MyModel(module.Module):
|
||||
|
||||
@def_function.function
|
||||
def __call__(self, inputs, training=False):
|
||||
return math_ops.multiply(0.5, inputs)
|
||||
|
||||
model = MyModel()
|
||||
model.__call__.get_concrete_function(
|
||||
tensor_spec.TensorSpec([None], dtypes.float32))
|
||||
loaded = cycle(model, 1)
|
||||
|
||||
# Calling the function should not throw an exception.
|
||||
loaded(constant_op.constant([1.0]))
|
||||
|
||||
finally:
|
||||
def_function.run_functions_eagerly(False)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user