Make restored tf.functions throw an error if there are extra arguments passed to it.
Before extra arguments, e.g. extra fields in a nested structured were silently ignored. When fixing this, errors in wiring the "self" showed up. "self" was being passed (due to "tf_inspect.ismethod" not being wired by tf_decorator.make_decorator). The code was fixed by always restoring tf.function as if they are not methods. PiperOrigin-RevId: 240597612
This commit is contained in:
parent
ac88b10f3f
commit
5770e124af
@ -1000,6 +1000,8 @@ class FunctionSpec(object):
|
|||||||
|
|
||||||
if self._is_method:
|
if self._is_method:
|
||||||
# Remove `self`: default arguments shouldn't be matched to it.
|
# Remove `self`: default arguments shouldn't be matched to it.
|
||||||
|
# TODO(b/127938157): Should this error out if there is no arg to
|
||||||
|
# be removed?
|
||||||
args = fullargspec.args[1:]
|
args = fullargspec.args[1:]
|
||||||
else:
|
else:
|
||||||
args = fullargspec.args
|
args = fullargspec.args
|
||||||
|
@ -89,6 +89,13 @@ def _concrete_function_callable_with(function, inputs, allow_conversion):
|
|||||||
flatten_inputs = nest.flatten_up_to(expected_structure, inputs)
|
flatten_inputs = nest.flatten_up_to(expected_structure, inputs)
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
return False
|
return False
|
||||||
|
try:
|
||||||
|
# Verify that no input elements were dropped during flattening.
|
||||||
|
repacked = nest.pack_sequence_as(expected_structure, flatten_inputs)
|
||||||
|
nest.assert_same_structure(inputs, repacked)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return False
|
||||||
|
|
||||||
for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)):
|
for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)):
|
||||||
if isinstance(expected, tensor_spec.TensorSpec):
|
if isinstance(expected, tensor_spec.TensorSpec):
|
||||||
if allow_conversion:
|
if allow_conversion:
|
||||||
@ -105,23 +112,36 @@ def _concrete_function_callable_with(function, inputs, allow_conversion):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _deserialize_function_spec(function_spec_proto, coder):
|
def _deserialize_function_spec_as_nonmethod(function_spec_proto, coder):
|
||||||
"""Deserialize a FunctionSpec object from its proto representation."""
|
"""Deserialize a FunctionSpec object from its proto representation."""
|
||||||
typeless_fullargspec = coder.decode_proto(function_spec_proto.fullargspec)
|
typeless_fullargspec = coder.decode_proto(function_spec_proto.fullargspec)
|
||||||
|
|
||||||
|
# Convert a method function into a non method.
|
||||||
|
if function_spec_proto.is_method:
|
||||||
|
if not typeless_fullargspec.args:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Missing support to deserialize a method function without a named "
|
||||||
|
"'self' argument.")
|
||||||
|
args = typeless_fullargspec.args[1:]
|
||||||
|
else:
|
||||||
|
args = typeless_fullargspec.args
|
||||||
|
|
||||||
fullargspec = tf_inspect.FullArgSpec(
|
fullargspec = tf_inspect.FullArgSpec(
|
||||||
args=typeless_fullargspec.args,
|
args=args,
|
||||||
varargs=typeless_fullargspec.varargs,
|
varargs=typeless_fullargspec.varargs,
|
||||||
varkw=typeless_fullargspec.varkw,
|
varkw=typeless_fullargspec.varkw,
|
||||||
defaults=typeless_fullargspec.defaults,
|
defaults=typeless_fullargspec.defaults,
|
||||||
kwonlyargs=typeless_fullargspec.kwonlyargs,
|
kwonlyargs=typeless_fullargspec.kwonlyargs,
|
||||||
kwonlydefaults=typeless_fullargspec.kwonlydefaults,
|
kwonlydefaults=typeless_fullargspec.kwonlydefaults,
|
||||||
annotations=typeless_fullargspec.annotations)
|
annotations=typeless_fullargspec.annotations)
|
||||||
is_method = function_spec_proto.is_method
|
|
||||||
args_to_prepend = coder.decode_proto(function_spec_proto.args_to_prepend)
|
args_to_prepend = coder.decode_proto(function_spec_proto.args_to_prepend)
|
||||||
kwargs_to_include = coder.decode_proto(function_spec_proto.kwargs_to_include)
|
kwargs_to_include = coder.decode_proto(function_spec_proto.kwargs_to_include)
|
||||||
input_signature = coder.decode_proto(function_spec_proto.input_signature)
|
input_signature = coder.decode_proto(function_spec_proto.input_signature)
|
||||||
return function_lib.FunctionSpec(fullargspec, is_method, args_to_prepend,
|
return function_lib.FunctionSpec(fullargspec=fullargspec,
|
||||||
kwargs_to_include, input_signature)
|
is_method=False,
|
||||||
|
args_to_prepend=args_to_prepend,
|
||||||
|
kwargs_to_include=kwargs_to_include,
|
||||||
|
input_signature=input_signature)
|
||||||
|
|
||||||
|
|
||||||
# TODO(allenl): The fact that we can't derive ConcreteFunction calling
|
# TODO(allenl): The fact that we can't derive ConcreteFunction calling
|
||||||
@ -178,7 +198,18 @@ def recreate_function(saved_function, concrete_functions):
|
|||||||
# serialization cycle.
|
# serialization cycle.
|
||||||
|
|
||||||
coder = nested_structure_coder.StructureCoder()
|
coder = nested_structure_coder.StructureCoder()
|
||||||
function_spec = _deserialize_function_spec(saved_function.function_spec,
|
|
||||||
|
# Note: handling method functions is tricky since make_decorator does not
|
||||||
|
# allows control of "ismethod". Additionally since restored functions do
|
||||||
|
# not behave as methods i.e. they always use the same captured tensors
|
||||||
|
# independent of the object they are bound to, there is little value on
|
||||||
|
# propagating that correctly.
|
||||||
|
#
|
||||||
|
# Ideally this conversion should happen at serialization time. But since
|
||||||
|
# there are SavedModels which have "ismethod" populated and have an extra
|
||||||
|
# argument that they expect to be ignored, we do it at deserialization.
|
||||||
|
function_spec = _deserialize_function_spec_as_nonmethod(
|
||||||
|
saved_function.function_spec,
|
||||||
coder)
|
coder)
|
||||||
|
|
||||||
def restored_function_body(*args, **kwargs):
|
def restored_function_body(*args, **kwargs):
|
||||||
|
@ -25,6 +25,10 @@ from tensorflow.python.saved_model import nested_structure_coder
|
|||||||
|
|
||||||
def _serialize_function_spec(function_spec, coder):
|
def _serialize_function_spec(function_spec, coder):
|
||||||
"""Serialize a FunctionSpec object into its proto representation."""
|
"""Serialize a FunctionSpec object into its proto representation."""
|
||||||
|
if function_spec.is_method and not function_spec.fullargspec.args:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Missing support to serialize a method function without a named "
|
||||||
|
"'self' argument.")
|
||||||
proto = saved_object_graph_pb2.FunctionSpec()
|
proto = saved_object_graph_pb2.FunctionSpec()
|
||||||
proto.fullargspec.CopyFrom(coder.encode_structure(function_spec.fullargspec))
|
proto.fullargspec.CopyFrom(coder.encode_structure(function_spec.fullargspec))
|
||||||
proto.is_method = function_spec.is_method
|
proto.is_method = function_spec.is_method
|
||||||
|
@ -1218,6 +1218,24 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
self.assertEqual([2], root.f([2]).numpy())
|
self.assertEqual([2], root.f([2]).numpy())
|
||||||
|
|
||||||
|
def test_extra_args(self, cycles):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def f(x):
|
||||||
|
return math_ops.add(x["a"], 1.)
|
||||||
|
# Trigger a trace.
|
||||||
|
f({"a": constant_op.constant(2.0)})
|
||||||
|
|
||||||
|
obj = tracking.AutoTrackable()
|
||||||
|
obj.__call__ = f
|
||||||
|
imported = self.cycle(obj)
|
||||||
|
|
||||||
|
self.assertEqual(4.0, imported({"a": 3.0}).numpy())
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(ValueError,
|
||||||
|
"Could not find matching function to call"):
|
||||||
|
imported({"a": 2.0, "b": 3.0})
|
||||||
|
|
||||||
def test_shapes_available(self, cycles):
|
def test_shapes_available(self, cycles):
|
||||||
|
|
||||||
@def_function.function(input_signature=[
|
@def_function.function(input_signature=[
|
||||||
|
Loading…
Reference in New Issue
Block a user