diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index dbc9461e146..91e00c42b6e 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -1000,6 +1000,8 @@ class FunctionSpec(object): if self._is_method: # Remove `self`: default arguments shouldn't be matched to it. + # TODO(b/127938157): Should this error out if there is no arg to + # be removed? args = fullargspec.args[1:] else: args = fullargspec.args diff --git a/tensorflow/python/saved_model/function_deserialization.py b/tensorflow/python/saved_model/function_deserialization.py index 5924b2e0e16..04136048a31 100644 --- a/tensorflow/python/saved_model/function_deserialization.py +++ b/tensorflow/python/saved_model/function_deserialization.py @@ -89,6 +89,13 @@ def _concrete_function_callable_with(function, inputs, allow_conversion): flatten_inputs = nest.flatten_up_to(expected_structure, inputs) except (TypeError, ValueError): return False + try: + # Verify that no input elements were dropped during flattening. + repacked = nest.pack_sequence_as(expected_structure, flatten_inputs) + nest.assert_same_structure(inputs, repacked) + except (TypeError, ValueError): + return False + for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)): if isinstance(expected, tensor_spec.TensorSpec): if allow_conversion: @@ -105,23 +112,36 @@ def _concrete_function_callable_with(function, inputs, allow_conversion): return True -def _deserialize_function_spec(function_spec_proto, coder): +def _deserialize_function_spec_as_nonmethod(function_spec_proto, coder): """Deserialize a FunctionSpec object from its proto representation.""" typeless_fullargspec = coder.decode_proto(function_spec_proto.fullargspec) + + # Convert a method function into a non method. + if function_spec_proto.is_method: + if not typeless_fullargspec.args: + raise NotImplementedError( + "Missing support to deserialize a method function without a named " + "'self' argument.") + args = typeless_fullargspec.args[1:] + else: + args = typeless_fullargspec.args + fullargspec = tf_inspect.FullArgSpec( - args=typeless_fullargspec.args, + args=args, varargs=typeless_fullargspec.varargs, varkw=typeless_fullargspec.varkw, defaults=typeless_fullargspec.defaults, kwonlyargs=typeless_fullargspec.kwonlyargs, kwonlydefaults=typeless_fullargspec.kwonlydefaults, annotations=typeless_fullargspec.annotations) - is_method = function_spec_proto.is_method args_to_prepend = coder.decode_proto(function_spec_proto.args_to_prepend) kwargs_to_include = coder.decode_proto(function_spec_proto.kwargs_to_include) input_signature = coder.decode_proto(function_spec_proto.input_signature) - return function_lib.FunctionSpec(fullargspec, is_method, args_to_prepend, - kwargs_to_include, input_signature) + return function_lib.FunctionSpec(fullargspec=fullargspec, + is_method=False, + args_to_prepend=args_to_prepend, + kwargs_to_include=kwargs_to_include, + input_signature=input_signature) # TODO(allenl): The fact that we can't derive ConcreteFunction calling @@ -178,8 +198,19 @@ def recreate_function(saved_function, concrete_functions): # serialization cycle. coder = nested_structure_coder.StructureCoder() - function_spec = _deserialize_function_spec(saved_function.function_spec, - coder) + + # Note: handling method functions is tricky since make_decorator does not + # allows control of "ismethod". Additionally since restored functions do + # not behave as methods i.e. they always use the same captured tensors + # independent of the object they are bound to, there is little value on + # propagating that correctly. + # + # Ideally this conversion should happen at serialization time. But since + # there are SavedModels which have "ismethod" populated and have an extra + # argument that they expect to be ignored, we do it at deserialization. + function_spec = _deserialize_function_spec_as_nonmethod( + saved_function.function_spec, + coder) def restored_function_body(*args, **kwargs): """Calls a restored function.""" diff --git a/tensorflow/python/saved_model/function_serialization.py b/tensorflow/python/saved_model/function_serialization.py index e876eef8b34..9915a7e4842 100644 --- a/tensorflow/python/saved_model/function_serialization.py +++ b/tensorflow/python/saved_model/function_serialization.py @@ -25,6 +25,10 @@ from tensorflow.python.saved_model import nested_structure_coder def _serialize_function_spec(function_spec, coder): """Serialize a FunctionSpec object into its proto representation.""" + if function_spec.is_method and not function_spec.fullargspec.args: + raise NotImplementedError( + "Missing support to serialize a method function without a named " + "'self' argument.") proto = saved_object_graph_pb2.FunctionSpec() proto.fullargspec.CopyFrom(coder.encode_structure(function_spec.fullargspec)) proto.is_method = function_spec.is_method diff --git a/tensorflow/python/saved_model/load_test.py b/tensorflow/python/saved_model/load_test.py index bd7bceceb84..c772cb066f2 100644 --- a/tensorflow/python/saved_model/load_test.py +++ b/tensorflow/python/saved_model/load_test.py @@ -1218,6 +1218,24 @@ class LoadTest(test.TestCase, parameterized.TestCase): self.assertEqual([2], root.f([2]).numpy()) + def test_extra_args(self, cycles): + + @def_function.function + def f(x): + return math_ops.add(x["a"], 1.) + # Trigger a trace. + f({"a": constant_op.constant(2.0)}) + + obj = tracking.AutoTrackable() + obj.__call__ = f + imported = self.cycle(obj) + + self.assertEqual(4.0, imported({"a": 3.0}).numpy()) + + with self.assertRaisesRegexp(ValueError, + "Could not find matching function to call"): + imported({"a": 2.0, "b": 3.0}) + def test_shapes_available(self, cycles): @def_function.function(input_signature=[