Add checks to ensure all arguments have valid values during input canonicalization.

Before this update, the behavior of tf.function is not consistent with a "naked" function call:

```
>>> def func(position_arg, named_arg=None):
      return position_arg, named_arg

>>> func(named_arg="Hello")
TypeError: func() missing 1 required positional argument: 'position_arg'

>>> tf_func = tf.function(func)

>>> tf_func(named_arg="Hello")
(<tf.Tensor: shape=(), dtype=string, numpy=b'Hello'>, None)
```

After the update:
```
>>> def func(position_arg, named_arg=None):
      return position_arg, named_arg

>>> tf_func = tf.function(func)

>>> tf_func(named_arg="Hello")
TypeError: func(position_arg, named_arg) missing 1 required argument: position_arg
```
PiperOrigin-RevId: 350703016
Change-Id: I371cfd041ab466c486fc1dc8a2e746e9a1721280
This commit is contained in:
A. Unique TensorFlower 2021-01-07 22:32:30 -08:00 committed by TensorFlower Gardener
parent fe004863b7
commit d40e992f9f
2 changed files with 65 additions and 8 deletions

View File

@ -1000,6 +1000,46 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(obj2.testDouble.experimental_get_tracing_count(), 3)
self.assertAllEqual(obj1.testDouble.experimental_get_tracing_count(), 2)
def testFunctionArgumentChecking(self):
class A(object):
def func(self, position_arg1, position_arg2):
return position_arg1, position_arg2
def func_pos(position_arg1, position_arg2):
return position_arg1, position_arg2
def func_named(position_arg, named_arg=None):
return position_arg, named_arg
def func_pos_3args(position_arg1, position_arg2, position_arg3):
return position_arg1, position_arg2, position_arg3
a_instance = A()
tf_method_pos = def_function.function(a_instance.func)
tf_func_pos = def_function.function(func_pos)
tf_func_named = def_function.function(func_named)
tf_func_pos_3args = def_function.function(func_pos_3args)
with self.assertRaisesRegex(
TypeError, '.* missing 1 required argument: position_arg1'):
tf_method_pos(position_arg2='foo')
with self.assertRaisesRegex(
TypeError, '.* missing 1 required argument: position_arg1'):
tf_func_pos(position_arg2='foo')
with self.assertRaisesRegex(TypeError,
'.* missing 1 required argument: position_arg'):
tf_func_named(named_arg='foo')
with self.assertRaisesRegex(
TypeError,
'.* missing required arguments: position_arg1, position_arg3'):
tf_func_pos_3args(position_arg2='foo')
tf_func_named(position_arg='bar')
if __name__ == '__main__':
ops.enable_eager_execution()

View File

@ -2494,6 +2494,8 @@ class FunctionSpec(object):
offset + index: default
for index, default in enumerate(default_values or [])
}
self._arg_indices_no_default_values = set(range(len(args))) - set(
self._arg_indices_to_default_values)
if input_signature is None:
self._input_signature = None
else:
@ -2654,12 +2656,14 @@ class FunctionSpec(object):
args, kwargs = self._convert_variables_to_tensors(args, kwargs)
if self._experimental_follow_type_hints:
args, kwargs = self._convert_annotated_args_to_tensors(args, kwargs)
# Pre-calculate to reduce overhead
arglen = len(args)
if self._input_signature is not None:
if len(args) > len(self._input_signature):
if arglen > len(self._input_signature):
raise TypeError("{} takes {} positional arguments (as specified by the "
"input_signature) but {} were given".format(
self.signature_summary(),
len(self._input_signature), len(args)))
len(self._input_signature), arglen))
for arg in six.iterkeys(kwargs):
index = self._args_to_indices.get(arg, None)
if index is None:
@ -2674,13 +2678,12 @@ class FunctionSpec(object):
inputs = args
if self._arg_indices_to_default_values:
try:
inputs += tuple(
self._arg_indices_to_default_values[i]
for i in range(len(args), len(self._arg_names)))
inputs += tuple(self._arg_indices_to_default_values[i]
for i in range(arglen, len(self._arg_names)))
except KeyError:
missing_args = [
self._arg_names[i]
for i in range(len(args), len(self._arg_names))
for i in range(arglen, len(self._arg_names))
if i not in self._arg_indices_to_default_values
]
raise TypeError("{} missing required arguments: {}".format(
@ -2694,22 +2697,36 @@ class FunctionSpec(object):
# aren't in `args`.
arg_indices_to_values = {
index: default for index, default in six.iteritems(
self._arg_indices_to_default_values) if index >= len(args)
self._arg_indices_to_default_values) if index >= arglen
}
consumed_args = []
missing_arg_indices = self._arg_indices_no_default_values - set(
range(arglen))
for arg, value in six.iteritems(kwargs):
index = self._args_to_indices.get(arg, None)
if index is not None:
if index < len(args):
if index < arglen:
raise TypeError("{} got two values for argument '{}'".format(
self.signature_summary(), arg))
arg_indices_to_values[index] = value
# These arguments in 'kwargs' might also belong to
# positional arguments
missing_arg_indices.discard(index)
consumed_args.append(arg)
for arg in consumed_args:
# After this loop, `kwargs` will only contain keyword_only arguments,
# and all positional_or_keyword arguments have been moved to `inputs`.
kwargs.pop(arg)
inputs = args + _deterministic_dict_values(arg_indices_to_values)
# Exclude positional args with values
if missing_arg_indices:
missing_args = [self._arg_names[i] for i in sorted(missing_arg_indices)]
if len(missing_args) == 1:
raise TypeError("{} missing 1 required argument: {}".format(
self.signature_summary(), missing_args[0]))
else:
raise TypeError("{} missing required arguments: {}".format(
self.signature_summary(), ", ".join(missing_args)))
if kwargs and self._input_signature is not None:
raise TypeError(