Make restored tf.functions throw an error if there are extra arguments passed to it.
Before extra arguments, e.g. extra fields in a nested structured were silently ignored. When fixing this, errors in wiring the "self" showed up. "self" was being passed (due to "tf_inspect.ismethod" not being wired by tf_decorator.make_decorator). The code was fixed by always restoring tf.function as if they are not methods. PiperOrigin-RevId: 240597612
This commit is contained in:
parent
ac88b10f3f
commit
5770e124af
@ -1000,6 +1000,8 @@ class FunctionSpec(object):
|
||||
|
||||
if self._is_method:
|
||||
# Remove `self`: default arguments shouldn't be matched to it.
|
||||
# TODO(b/127938157): Should this error out if there is no arg to
|
||||
# be removed?
|
||||
args = fullargspec.args[1:]
|
||||
else:
|
||||
args = fullargspec.args
|
||||
|
@ -89,6 +89,13 @@ def _concrete_function_callable_with(function, inputs, allow_conversion):
|
||||
flatten_inputs = nest.flatten_up_to(expected_structure, inputs)
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
try:
|
||||
# Verify that no input elements were dropped during flattening.
|
||||
repacked = nest.pack_sequence_as(expected_structure, flatten_inputs)
|
||||
nest.assert_same_structure(inputs, repacked)
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
|
||||
for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)):
|
||||
if isinstance(expected, tensor_spec.TensorSpec):
|
||||
if allow_conversion:
|
||||
@ -105,23 +112,36 @@ def _concrete_function_callable_with(function, inputs, allow_conversion):
|
||||
return True
|
||||
|
||||
|
||||
def _deserialize_function_spec(function_spec_proto, coder):
|
||||
def _deserialize_function_spec_as_nonmethod(function_spec_proto, coder):
|
||||
"""Deserialize a FunctionSpec object from its proto representation."""
|
||||
typeless_fullargspec = coder.decode_proto(function_spec_proto.fullargspec)
|
||||
|
||||
# Convert a method function into a non method.
|
||||
if function_spec_proto.is_method:
|
||||
if not typeless_fullargspec.args:
|
||||
raise NotImplementedError(
|
||||
"Missing support to deserialize a method function without a named "
|
||||
"'self' argument.")
|
||||
args = typeless_fullargspec.args[1:]
|
||||
else:
|
||||
args = typeless_fullargspec.args
|
||||
|
||||
fullargspec = tf_inspect.FullArgSpec(
|
||||
args=typeless_fullargspec.args,
|
||||
args=args,
|
||||
varargs=typeless_fullargspec.varargs,
|
||||
varkw=typeless_fullargspec.varkw,
|
||||
defaults=typeless_fullargspec.defaults,
|
||||
kwonlyargs=typeless_fullargspec.kwonlyargs,
|
||||
kwonlydefaults=typeless_fullargspec.kwonlydefaults,
|
||||
annotations=typeless_fullargspec.annotations)
|
||||
is_method = function_spec_proto.is_method
|
||||
args_to_prepend = coder.decode_proto(function_spec_proto.args_to_prepend)
|
||||
kwargs_to_include = coder.decode_proto(function_spec_proto.kwargs_to_include)
|
||||
input_signature = coder.decode_proto(function_spec_proto.input_signature)
|
||||
return function_lib.FunctionSpec(fullargspec, is_method, args_to_prepend,
|
||||
kwargs_to_include, input_signature)
|
||||
return function_lib.FunctionSpec(fullargspec=fullargspec,
|
||||
is_method=False,
|
||||
args_to_prepend=args_to_prepend,
|
||||
kwargs_to_include=kwargs_to_include,
|
||||
input_signature=input_signature)
|
||||
|
||||
|
||||
# TODO(allenl): The fact that we can't derive ConcreteFunction calling
|
||||
@ -178,8 +198,19 @@ def recreate_function(saved_function, concrete_functions):
|
||||
# serialization cycle.
|
||||
|
||||
coder = nested_structure_coder.StructureCoder()
|
||||
function_spec = _deserialize_function_spec(saved_function.function_spec,
|
||||
coder)
|
||||
|
||||
# Note: handling method functions is tricky since make_decorator does not
|
||||
# allows control of "ismethod". Additionally since restored functions do
|
||||
# not behave as methods i.e. they always use the same captured tensors
|
||||
# independent of the object they are bound to, there is little value on
|
||||
# propagating that correctly.
|
||||
#
|
||||
# Ideally this conversion should happen at serialization time. But since
|
||||
# there are SavedModels which have "ismethod" populated and have an extra
|
||||
# argument that they expect to be ignored, we do it at deserialization.
|
||||
function_spec = _deserialize_function_spec_as_nonmethod(
|
||||
saved_function.function_spec,
|
||||
coder)
|
||||
|
||||
def restored_function_body(*args, **kwargs):
|
||||
"""Calls a restored function."""
|
||||
|
@ -25,6 +25,10 @@ from tensorflow.python.saved_model import nested_structure_coder
|
||||
|
||||
def _serialize_function_spec(function_spec, coder):
|
||||
"""Serialize a FunctionSpec object into its proto representation."""
|
||||
if function_spec.is_method and not function_spec.fullargspec.args:
|
||||
raise NotImplementedError(
|
||||
"Missing support to serialize a method function without a named "
|
||||
"'self' argument.")
|
||||
proto = saved_object_graph_pb2.FunctionSpec()
|
||||
proto.fullargspec.CopyFrom(coder.encode_structure(function_spec.fullargspec))
|
||||
proto.is_method = function_spec.is_method
|
||||
|
@ -1218,6 +1218,24 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
self.assertEqual([2], root.f([2]).numpy())
|
||||
|
||||
def test_extra_args(self, cycles):
|
||||
|
||||
@def_function.function
|
||||
def f(x):
|
||||
return math_ops.add(x["a"], 1.)
|
||||
# Trigger a trace.
|
||||
f({"a": constant_op.constant(2.0)})
|
||||
|
||||
obj = tracking.AutoTrackable()
|
||||
obj.__call__ = f
|
||||
imported = self.cycle(obj)
|
||||
|
||||
self.assertEqual(4.0, imported({"a": 3.0}).numpy())
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"Could not find matching function to call"):
|
||||
imported({"a": 2.0, "b": 3.0})
|
||||
|
||||
def test_shapes_available(self, cycles):
|
||||
|
||||
@def_function.function(input_signature=[
|
||||
|
Loading…
Reference in New Issue
Block a user