Make restored tf.functions throw an error if there are extra arguments passed to it.

Before extra arguments, e.g. extra fields in a nested structured were silently ignored.

When fixing this, errors in wiring the "self" showed up. "self" was being passed
(due to "tf_inspect.ismethod" not being wired by tf_decorator.make_decorator).

The code was fixed by always restoring tf.function as if they are not methods.

PiperOrigin-RevId: 240597612
This commit is contained in:
Andr? Susano Pinto 2019-03-27 10:57:03 -07:00 committed by TensorFlower Gardener
parent ac88b10f3f
commit 5770e124af
4 changed files with 62 additions and 7 deletions

View File

@ -1000,6 +1000,8 @@ class FunctionSpec(object):
if self._is_method:
# Remove `self`: default arguments shouldn't be matched to it.
# TODO(b/127938157): Should this error out if there is no arg to
# be removed?
args = fullargspec.args[1:]
else:
args = fullargspec.args

View File

@ -89,6 +89,13 @@ def _concrete_function_callable_with(function, inputs, allow_conversion):
flatten_inputs = nest.flatten_up_to(expected_structure, inputs)
except (TypeError, ValueError):
return False
try:
# Verify that no input elements were dropped during flattening.
repacked = nest.pack_sequence_as(expected_structure, flatten_inputs)
nest.assert_same_structure(inputs, repacked)
except (TypeError, ValueError):
return False
for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)):
if isinstance(expected, tensor_spec.TensorSpec):
if allow_conversion:
@ -105,23 +112,36 @@ def _concrete_function_callable_with(function, inputs, allow_conversion):
return True
def _deserialize_function_spec(function_spec_proto, coder):
def _deserialize_function_spec_as_nonmethod(function_spec_proto, coder):
"""Deserialize a FunctionSpec object from its proto representation."""
typeless_fullargspec = coder.decode_proto(function_spec_proto.fullargspec)
# Convert a method function into a non method.
if function_spec_proto.is_method:
if not typeless_fullargspec.args:
raise NotImplementedError(
"Missing support to deserialize a method function without a named "
"'self' argument.")
args = typeless_fullargspec.args[1:]
else:
args = typeless_fullargspec.args
fullargspec = tf_inspect.FullArgSpec(
args=typeless_fullargspec.args,
args=args,
varargs=typeless_fullargspec.varargs,
varkw=typeless_fullargspec.varkw,
defaults=typeless_fullargspec.defaults,
kwonlyargs=typeless_fullargspec.kwonlyargs,
kwonlydefaults=typeless_fullargspec.kwonlydefaults,
annotations=typeless_fullargspec.annotations)
is_method = function_spec_proto.is_method
args_to_prepend = coder.decode_proto(function_spec_proto.args_to_prepend)
kwargs_to_include = coder.decode_proto(function_spec_proto.kwargs_to_include)
input_signature = coder.decode_proto(function_spec_proto.input_signature)
return function_lib.FunctionSpec(fullargspec, is_method, args_to_prepend,
kwargs_to_include, input_signature)
return function_lib.FunctionSpec(fullargspec=fullargspec,
is_method=False,
args_to_prepend=args_to_prepend,
kwargs_to_include=kwargs_to_include,
input_signature=input_signature)
# TODO(allenl): The fact that we can't derive ConcreteFunction calling
@ -178,7 +198,18 @@ def recreate_function(saved_function, concrete_functions):
# serialization cycle.
coder = nested_structure_coder.StructureCoder()
function_spec = _deserialize_function_spec(saved_function.function_spec,
# Note: handling method functions is tricky since make_decorator does not
# allows control of "ismethod". Additionally since restored functions do
# not behave as methods i.e. they always use the same captured tensors
# independent of the object they are bound to, there is little value on
# propagating that correctly.
#
# Ideally this conversion should happen at serialization time. But since
# there are SavedModels which have "ismethod" populated and have an extra
# argument that they expect to be ignored, we do it at deserialization.
function_spec = _deserialize_function_spec_as_nonmethod(
saved_function.function_spec,
coder)
def restored_function_body(*args, **kwargs):

View File

@ -25,6 +25,10 @@ from tensorflow.python.saved_model import nested_structure_coder
def _serialize_function_spec(function_spec, coder):
"""Serialize a FunctionSpec object into its proto representation."""
if function_spec.is_method and not function_spec.fullargspec.args:
raise NotImplementedError(
"Missing support to serialize a method function without a named "
"'self' argument.")
proto = saved_object_graph_pb2.FunctionSpec()
proto.fullargspec.CopyFrom(coder.encode_structure(function_spec.fullargspec))
proto.is_method = function_spec.is_method

View File

@ -1218,6 +1218,24 @@ class LoadTest(test.TestCase, parameterized.TestCase):
self.assertEqual([2], root.f([2]).numpy())
def test_extra_args(self, cycles):
@def_function.function
def f(x):
return math_ops.add(x["a"], 1.)
# Trigger a trace.
f({"a": constant_op.constant(2.0)})
obj = tracking.AutoTrackable()
obj.__call__ = f
imported = self.cycle(obj)
self.assertEqual(4.0, imported({"a": 3.0}).numpy())
with self.assertRaisesRegexp(ValueError,
"Could not find matching function to call"):
imported({"a": 2.0, "b": 3.0})
def test_shapes_available(self, cycles):
@def_function.function(input_signature=[