Add checks to ensure all arguments have valid values during input canonicalization.
Before this update, the behavior of tf.function is not consistent with a "naked" function call:
```
>>> def func(position_arg, named_arg=None):
return position_arg, named_arg
>>> func(named_arg="Hello")
TypeError: func() missing 1 required positional argument: 'position_arg'
>>> tf_func = tf.function(func)
>>> tf_func(named_arg="Hello")
(<tf.Tensor: shape=(), dtype=string, numpy=b'Hello'>, None)
```
After the update:
```
>>> def func(position_arg, named_arg=None):
return position_arg, named_arg
>>> tf_func = tf.function(func)
>>> tf_func(named_arg="Hello")
TypeError: func(position_arg, named_arg) missing 1 required argument: position_arg
```
PiperOrigin-RevId: 350703016
Change-Id: I371cfd041ab466c486fc1dc8a2e746e9a1721280
This commit is contained in:
parent
fe004863b7
commit
d40e992f9f
@ -1000,6 +1000,46 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertAllEqual(obj2.testDouble.experimental_get_tracing_count(), 3)
|
self.assertAllEqual(obj2.testDouble.experimental_get_tracing_count(), 3)
|
||||||
self.assertAllEqual(obj1.testDouble.experimental_get_tracing_count(), 2)
|
self.assertAllEqual(obj1.testDouble.experimental_get_tracing_count(), 2)
|
||||||
|
|
||||||
|
def testFunctionArgumentChecking(self):
|
||||||
|
|
||||||
|
class A(object):
|
||||||
|
|
||||||
|
def func(self, position_arg1, position_arg2):
|
||||||
|
return position_arg1, position_arg2
|
||||||
|
|
||||||
|
def func_pos(position_arg1, position_arg2):
|
||||||
|
return position_arg1, position_arg2
|
||||||
|
|
||||||
|
def func_named(position_arg, named_arg=None):
|
||||||
|
return position_arg, named_arg
|
||||||
|
|
||||||
|
def func_pos_3args(position_arg1, position_arg2, position_arg3):
|
||||||
|
return position_arg1, position_arg2, position_arg3
|
||||||
|
|
||||||
|
a_instance = A()
|
||||||
|
tf_method_pos = def_function.function(a_instance.func)
|
||||||
|
tf_func_pos = def_function.function(func_pos)
|
||||||
|
tf_func_named = def_function.function(func_named)
|
||||||
|
tf_func_pos_3args = def_function.function(func_pos_3args)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
TypeError, '.* missing 1 required argument: position_arg1'):
|
||||||
|
tf_method_pos(position_arg2='foo')
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
TypeError, '.* missing 1 required argument: position_arg1'):
|
||||||
|
tf_func_pos(position_arg2='foo')
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(TypeError,
|
||||||
|
'.* missing 1 required argument: position_arg'):
|
||||||
|
tf_func_named(named_arg='foo')
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
TypeError,
|
||||||
|
'.* missing required arguments: position_arg1, position_arg3'):
|
||||||
|
tf_func_pos_3args(position_arg2='foo')
|
||||||
|
|
||||||
|
tf_func_named(position_arg='bar')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
ops.enable_eager_execution()
|
ops.enable_eager_execution()
|
||||||
|
|||||||
@ -2494,6 +2494,8 @@ class FunctionSpec(object):
|
|||||||
offset + index: default
|
offset + index: default
|
||||||
for index, default in enumerate(default_values or [])
|
for index, default in enumerate(default_values or [])
|
||||||
}
|
}
|
||||||
|
self._arg_indices_no_default_values = set(range(len(args))) - set(
|
||||||
|
self._arg_indices_to_default_values)
|
||||||
if input_signature is None:
|
if input_signature is None:
|
||||||
self._input_signature = None
|
self._input_signature = None
|
||||||
else:
|
else:
|
||||||
@ -2654,12 +2656,14 @@ class FunctionSpec(object):
|
|||||||
args, kwargs = self._convert_variables_to_tensors(args, kwargs)
|
args, kwargs = self._convert_variables_to_tensors(args, kwargs)
|
||||||
if self._experimental_follow_type_hints:
|
if self._experimental_follow_type_hints:
|
||||||
args, kwargs = self._convert_annotated_args_to_tensors(args, kwargs)
|
args, kwargs = self._convert_annotated_args_to_tensors(args, kwargs)
|
||||||
|
# Pre-calculate to reduce overhead
|
||||||
|
arglen = len(args)
|
||||||
if self._input_signature is not None:
|
if self._input_signature is not None:
|
||||||
if len(args) > len(self._input_signature):
|
if arglen > len(self._input_signature):
|
||||||
raise TypeError("{} takes {} positional arguments (as specified by the "
|
raise TypeError("{} takes {} positional arguments (as specified by the "
|
||||||
"input_signature) but {} were given".format(
|
"input_signature) but {} were given".format(
|
||||||
self.signature_summary(),
|
self.signature_summary(),
|
||||||
len(self._input_signature), len(args)))
|
len(self._input_signature), arglen))
|
||||||
for arg in six.iterkeys(kwargs):
|
for arg in six.iterkeys(kwargs):
|
||||||
index = self._args_to_indices.get(arg, None)
|
index = self._args_to_indices.get(arg, None)
|
||||||
if index is None:
|
if index is None:
|
||||||
@ -2674,13 +2678,12 @@ class FunctionSpec(object):
|
|||||||
inputs = args
|
inputs = args
|
||||||
if self._arg_indices_to_default_values:
|
if self._arg_indices_to_default_values:
|
||||||
try:
|
try:
|
||||||
inputs += tuple(
|
inputs += tuple(self._arg_indices_to_default_values[i]
|
||||||
self._arg_indices_to_default_values[i]
|
for i in range(arglen, len(self._arg_names)))
|
||||||
for i in range(len(args), len(self._arg_names)))
|
|
||||||
except KeyError:
|
except KeyError:
|
||||||
missing_args = [
|
missing_args = [
|
||||||
self._arg_names[i]
|
self._arg_names[i]
|
||||||
for i in range(len(args), len(self._arg_names))
|
for i in range(arglen, len(self._arg_names))
|
||||||
if i not in self._arg_indices_to_default_values
|
if i not in self._arg_indices_to_default_values
|
||||||
]
|
]
|
||||||
raise TypeError("{} missing required arguments: {}".format(
|
raise TypeError("{} missing required arguments: {}".format(
|
||||||
@ -2694,22 +2697,36 @@ class FunctionSpec(object):
|
|||||||
# aren't in `args`.
|
# aren't in `args`.
|
||||||
arg_indices_to_values = {
|
arg_indices_to_values = {
|
||||||
index: default for index, default in six.iteritems(
|
index: default for index, default in six.iteritems(
|
||||||
self._arg_indices_to_default_values) if index >= len(args)
|
self._arg_indices_to_default_values) if index >= arglen
|
||||||
}
|
}
|
||||||
consumed_args = []
|
consumed_args = []
|
||||||
|
missing_arg_indices = self._arg_indices_no_default_values - set(
|
||||||
|
range(arglen))
|
||||||
for arg, value in six.iteritems(kwargs):
|
for arg, value in six.iteritems(kwargs):
|
||||||
index = self._args_to_indices.get(arg, None)
|
index = self._args_to_indices.get(arg, None)
|
||||||
if index is not None:
|
if index is not None:
|
||||||
if index < len(args):
|
if index < arglen:
|
||||||
raise TypeError("{} got two values for argument '{}'".format(
|
raise TypeError("{} got two values for argument '{}'".format(
|
||||||
self.signature_summary(), arg))
|
self.signature_summary(), arg))
|
||||||
arg_indices_to_values[index] = value
|
arg_indices_to_values[index] = value
|
||||||
|
# These arguments in 'kwargs' might also belong to
|
||||||
|
# positional arguments
|
||||||
|
missing_arg_indices.discard(index)
|
||||||
consumed_args.append(arg)
|
consumed_args.append(arg)
|
||||||
for arg in consumed_args:
|
for arg in consumed_args:
|
||||||
# After this loop, `kwargs` will only contain keyword_only arguments,
|
# After this loop, `kwargs` will only contain keyword_only arguments,
|
||||||
# and all positional_or_keyword arguments have been moved to `inputs`.
|
# and all positional_or_keyword arguments have been moved to `inputs`.
|
||||||
kwargs.pop(arg)
|
kwargs.pop(arg)
|
||||||
inputs = args + _deterministic_dict_values(arg_indices_to_values)
|
inputs = args + _deterministic_dict_values(arg_indices_to_values)
|
||||||
|
# Exclude positional args with values
|
||||||
|
if missing_arg_indices:
|
||||||
|
missing_args = [self._arg_names[i] for i in sorted(missing_arg_indices)]
|
||||||
|
if len(missing_args) == 1:
|
||||||
|
raise TypeError("{} missing 1 required argument: {}".format(
|
||||||
|
self.signature_summary(), missing_args[0]))
|
||||||
|
else:
|
||||||
|
raise TypeError("{} missing required arguments: {}".format(
|
||||||
|
self.signature_summary(), ", ".join(missing_args)))
|
||||||
|
|
||||||
if kwargs and self._input_signature is not None:
|
if kwargs and self._input_signature is not None:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user