Add checks to ensure all arguments have valid values during input canonicalization.
Before this update, the behavior of tf.function is not consistent with a "naked" function call: ``` >>> def func(position_arg, named_arg=None): return position_arg, named_arg >>> func(named_arg="Hello") TypeError: func() missing 1 required positional argument: 'position_arg' >>> tf_func = tf.function(func) >>> tf_func(named_arg="Hello") (<tf.Tensor: shape=(), dtype=string, numpy=b'Hello'>, None) ``` After the update: ``` >>> def func(position_arg, named_arg=None): return position_arg, named_arg >>> tf_func = tf.function(func) >>> tf_func(named_arg="Hello") TypeError: func(position_arg, named_arg) missing 1 required argument: position_arg ``` PiperOrigin-RevId: 350703016 Change-Id: I371cfd041ab466c486fc1dc8a2e746e9a1721280
This commit is contained in:
parent
fe004863b7
commit
d40e992f9f
@ -1000,6 +1000,46 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllEqual(obj2.testDouble.experimental_get_tracing_count(), 3)
|
||||
self.assertAllEqual(obj1.testDouble.experimental_get_tracing_count(), 2)
|
||||
|
||||
def testFunctionArgumentChecking(self):
|
||||
|
||||
class A(object):
|
||||
|
||||
def func(self, position_arg1, position_arg2):
|
||||
return position_arg1, position_arg2
|
||||
|
||||
def func_pos(position_arg1, position_arg2):
|
||||
return position_arg1, position_arg2
|
||||
|
||||
def func_named(position_arg, named_arg=None):
|
||||
return position_arg, named_arg
|
||||
|
||||
def func_pos_3args(position_arg1, position_arg2, position_arg3):
|
||||
return position_arg1, position_arg2, position_arg3
|
||||
|
||||
a_instance = A()
|
||||
tf_method_pos = def_function.function(a_instance.func)
|
||||
tf_func_pos = def_function.function(func_pos)
|
||||
tf_func_named = def_function.function(func_named)
|
||||
tf_func_pos_3args = def_function.function(func_pos_3args)
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, '.* missing 1 required argument: position_arg1'):
|
||||
tf_method_pos(position_arg2='foo')
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, '.* missing 1 required argument: position_arg1'):
|
||||
tf_func_pos(position_arg2='foo')
|
||||
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'.* missing 1 required argument: position_arg'):
|
||||
tf_func_named(named_arg='foo')
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
'.* missing required arguments: position_arg1, position_arg3'):
|
||||
tf_func_pos_3args(position_arg2='foo')
|
||||
|
||||
tf_func_named(position_arg='bar')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
|
@ -2494,6 +2494,8 @@ class FunctionSpec(object):
|
||||
offset + index: default
|
||||
for index, default in enumerate(default_values or [])
|
||||
}
|
||||
self._arg_indices_no_default_values = set(range(len(args))) - set(
|
||||
self._arg_indices_to_default_values)
|
||||
if input_signature is None:
|
||||
self._input_signature = None
|
||||
else:
|
||||
@ -2654,12 +2656,14 @@ class FunctionSpec(object):
|
||||
args, kwargs = self._convert_variables_to_tensors(args, kwargs)
|
||||
if self._experimental_follow_type_hints:
|
||||
args, kwargs = self._convert_annotated_args_to_tensors(args, kwargs)
|
||||
# Pre-calculate to reduce overhead
|
||||
arglen = len(args)
|
||||
if self._input_signature is not None:
|
||||
if len(args) > len(self._input_signature):
|
||||
if arglen > len(self._input_signature):
|
||||
raise TypeError("{} takes {} positional arguments (as specified by the "
|
||||
"input_signature) but {} were given".format(
|
||||
self.signature_summary(),
|
||||
len(self._input_signature), len(args)))
|
||||
len(self._input_signature), arglen))
|
||||
for arg in six.iterkeys(kwargs):
|
||||
index = self._args_to_indices.get(arg, None)
|
||||
if index is None:
|
||||
@ -2674,13 +2678,12 @@ class FunctionSpec(object):
|
||||
inputs = args
|
||||
if self._arg_indices_to_default_values:
|
||||
try:
|
||||
inputs += tuple(
|
||||
self._arg_indices_to_default_values[i]
|
||||
for i in range(len(args), len(self._arg_names)))
|
||||
inputs += tuple(self._arg_indices_to_default_values[i]
|
||||
for i in range(arglen, len(self._arg_names)))
|
||||
except KeyError:
|
||||
missing_args = [
|
||||
self._arg_names[i]
|
||||
for i in range(len(args), len(self._arg_names))
|
||||
for i in range(arglen, len(self._arg_names))
|
||||
if i not in self._arg_indices_to_default_values
|
||||
]
|
||||
raise TypeError("{} missing required arguments: {}".format(
|
||||
@ -2694,22 +2697,36 @@ class FunctionSpec(object):
|
||||
# aren't in `args`.
|
||||
arg_indices_to_values = {
|
||||
index: default for index, default in six.iteritems(
|
||||
self._arg_indices_to_default_values) if index >= len(args)
|
||||
self._arg_indices_to_default_values) if index >= arglen
|
||||
}
|
||||
consumed_args = []
|
||||
missing_arg_indices = self._arg_indices_no_default_values - set(
|
||||
range(arglen))
|
||||
for arg, value in six.iteritems(kwargs):
|
||||
index = self._args_to_indices.get(arg, None)
|
||||
if index is not None:
|
||||
if index < len(args):
|
||||
if index < arglen:
|
||||
raise TypeError("{} got two values for argument '{}'".format(
|
||||
self.signature_summary(), arg))
|
||||
arg_indices_to_values[index] = value
|
||||
# These arguments in 'kwargs' might also belong to
|
||||
# positional arguments
|
||||
missing_arg_indices.discard(index)
|
||||
consumed_args.append(arg)
|
||||
for arg in consumed_args:
|
||||
# After this loop, `kwargs` will only contain keyword_only arguments,
|
||||
# and all positional_or_keyword arguments have been moved to `inputs`.
|
||||
kwargs.pop(arg)
|
||||
inputs = args + _deterministic_dict_values(arg_indices_to_values)
|
||||
# Exclude positional args with values
|
||||
if missing_arg_indices:
|
||||
missing_args = [self._arg_names[i] for i in sorted(missing_arg_indices)]
|
||||
if len(missing_args) == 1:
|
||||
raise TypeError("{} missing 1 required argument: {}".format(
|
||||
self.signature_summary(), missing_args[0]))
|
||||
else:
|
||||
raise TypeError("{} missing required arguments: {}".format(
|
||||
self.signature_summary(), ", ".join(missing_args)))
|
||||
|
||||
if kwargs and self._input_signature is not None:
|
||||
raise TypeError(
|
||||
|
Loading…
Reference in New Issue
Block a user