Add checks to ensure all arguments have valid values during input canonicalization.

Before this update, the behavior of tf.function is not consistent with a "naked" function call:
```
>>> def func(position_arg, named_arg=None):
      return position_arg, named_arg

>>> func(named_arg="Hello")
TypeError: func() missing 1 required positional argument: 'position_arg'

>>> tf_func = tf.function(func)

>>> tf_func(named_arg="Hello")
(<tf.Tensor: shape=(), dtype=string, numpy=b'Hello'>, None)
After the update:

>>> def func(position_arg, named_arg=None):
      return position_arg, named_arg

>>> tf_func = tf.function(func)

>>> tf_func(named_arg="Hello")
TypeError: func(position_arg, named_arg) missing 1 required argument: position_arg
```
PiperOrigin-RevId: 353041646
Change-Id: Ia08d1177f152054281455fc04511ca54482fd22d
This commit is contained in:
A. Unique TensorFlower 2021-01-21 09:55:12 -08:00 committed by TensorFlower Gardener
parent 8cb8c460a3
commit f48ae3edac
3 changed files with 190 additions and 30 deletions

View File

@ -2412,7 +2412,25 @@ class FunctionSpec(object):
kwonlyargs=[],
kwonlydefaults={},
annotations=fullargspec.annotations)
is_method = tf_inspect.ismethod(python_function)
# inspect.ismethod() and inspect.isfunction() both return False on a
# functools.partial-wrapped function. We set it to False to
# maintain consistency with prior versions.
is_method = False
else:
# Instead of using tf_inspect.ismethod() which only checks the
# final unwrapped target, we check if any decorated target along the chain
# is a method.
is_method = tf_inspect.isanytargetmethod(python_function)
# In the following scenario, 'python_function' is a callable object.
# python_function(...) is equal to python_function.__call__(self, ...)
if not is_method and not tf_inspect.isfunction(
python_function) and hasattr(
python_function, "__class__") and hasattr(
python_function.__class__, "__call__"):
is_method = True
# Get the function's name. Remove functools.partial wrappers if necessary.
while isinstance(python_function, functools.partial):
@ -2477,6 +2495,8 @@ class FunctionSpec(object):
offset + index: default
for index, default in enumerate(default_values or [])
}
self._arg_indices_no_default_values = set(range(len(args))) - set(
self._arg_indices_to_default_values)
if input_signature is None:
self._input_signature = None
else:
@ -2633,12 +2653,14 @@ class FunctionSpec(object):
args, kwargs = self._convert_variables_to_tensors(args, kwargs)
if self._experimental_follow_type_hints:
args, kwargs = self._convert_annotated_args_to_tensors(args, kwargs)
# Pre-calculate to reduce overhead
arglen = len(args)
if self._input_signature is not None:
if len(args) > len(self._input_signature):
if arglen > len(self._input_signature):
raise TypeError("{} takes {} positional arguments (as specified by the "
"input_signature) but {} were given".format(
self.signature_summary(),
len(self._input_signature), len(args)))
len(self._input_signature), arglen))
for arg in six.iterkeys(kwargs):
index = self._args_to_indices.get(arg, None)
if index is None:
@ -2653,13 +2675,12 @@ class FunctionSpec(object):
inputs = args
if self._arg_indices_to_default_values:
try:
inputs += tuple(
self._arg_indices_to_default_values[i]
for i in range(len(args), len(self._arg_names)))
inputs += tuple(self._arg_indices_to_default_values[i]
for i in range(arglen, len(self._arg_names)))
except KeyError:
missing_args = [
self._arg_names[i]
for i in range(len(args), len(self._arg_names))
for i in range(arglen, len(self._arg_names))
if i not in self._arg_indices_to_default_values
]
raise TypeError("{} missing required arguments: {}".format(
@ -2673,22 +2694,36 @@ class FunctionSpec(object):
# aren't in `args`.
arg_indices_to_values = {
index: default for index, default in six.iteritems(
self._arg_indices_to_default_values) if index >= len(args)
self._arg_indices_to_default_values) if index >= arglen
}
consumed_args = []
missing_arg_indices = self._arg_indices_no_default_values - set(
range(arglen))
for arg, value in six.iteritems(kwargs):
index = self._args_to_indices.get(arg, None)
if index is not None:
if index < len(args):
if index < arglen:
raise TypeError("{} got two values for argument '{}'".format(
self.signature_summary(), arg))
arg_indices_to_values[index] = value
# These arguments in 'kwargs' might also belong to
# positional arguments
missing_arg_indices.discard(index)
consumed_args.append(arg)
for arg in consumed_args:
# After this loop, `kwargs` will only contain keyword_only arguments,
# and all positional_or_keyword arguments have been moved to `inputs`.
kwargs.pop(arg)
inputs = args + _deterministic_dict_values(arg_indices_to_values)
# Exclude positional args with values
if missing_arg_indices:
missing_args = [self._arg_names[i] for i in sorted(missing_arg_indices)]
if len(missing_args) == 1:
raise TypeError("{} missing 1 required argument: {}".format(
self.signature_summary(), missing_args[0]))
else:
raise TypeError("{} missing required arguments: {}".format(
self.signature_summary(), ", ".join(missing_args)))
if kwargs and self._input_signature is not None:
raise TypeError(
@ -3911,9 +3946,9 @@ def class_method_to_instance_method(original_function, instance):
jit_compile=original_function._jit_compile)
# pylint: enable=protected-access
# And we wrap the function with tf_decorator so inspection works correctly
wrapped_instance_func = tf_decorator.make_decorator(
original_function.python_function, instance_func)
# We wrap the the bound method with tf_decorator so inspection works correctly
wrapped_instance_func = tf_decorator.make_decorator(bound_method,
instance_func)
return wrapped_instance_func

View File

@ -83,6 +83,7 @@ from tensorflow.python.platform import test
from tensorflow.python.training import training_ops
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
try:
@ -119,6 +120,15 @@ def _spec_for_value(value):
return value
# This dummy decorator imitates ordinary decorators utilizing tf_decorator.
def dummy_tf_decorator(method):
def wrapper(*args, **kwargs):
return method(*args, **kwargs)
return tf_decorator.make_decorator(method, wrapper)
class FunctionTest(test.TestCase, parameterized.TestCase):
def setUp(self):
@ -4194,8 +4204,8 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
trace_count[0] = 0
disabled = def_function.function(func, experimental_follow_type_hints=False)
disabled(x=1, y=2)
disabled(x=2, y=2,) # Retrace
disabled(0, 0, x=1, y=2)
disabled(0, 0, x=2, y=2,) # Retrace
self.assertEqual(trace_count[0], 2)
def testFollowTypeHintsTraceWithArgsEqualsTypedKwargs(self):
@ -4319,8 +4329,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
enabled(1, 2, 3, 4, kwonly=5, kwarg1=600, kwarg2=700) # No retrace
self.assertEqual(trace_count[0], 4)
def testWithModuleNameScope(self):
self.skipTest('b/166158748:function does not handle this case correctly.')
def testWithExtraWrapper(self):
class Foo(module.Module):
@ -4329,16 +4338,18 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
self.var = None
@def_function.function
@module.Module.with_name_scope
@dummy_tf_decorator
def add(self, x, y, z=1):
if self.var is None:
return x + y + z
foo = Foo()
self.assertEqual(foo.add(2, 3), 6)
self.assertEqual(foo.add(2, 3).numpy(), 6)
def testWithModuleNameScopeRedundantArgs(self):
self.skipTest('b/166158748:function does not handle this case correctly.')
@parameterized.parameters([(def_function.function, dummy_tf_decorator),
(dummy_tf_decorator, def_function.function),
(def_function.function, def_function.function)])
def testWithExtraWrapperRedundantArgs(self, decorator1, decorator2):
class Foo(module.Module):
@ -4346,18 +4357,17 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
super().__init__()
self.var = None
@def_function.function
@module.Module.with_name_scope
def add(self, x, y):
@decorator1
@decorator2
def add1(self, x, y):
if self.var is None:
return x + y
foo = Foo()
with self.assertRaisesRegex(TypeError, 'got two values for argument'):
foo.add(2, x=3) # pylint: disable=redundant-keyword-arg,no-value-for-parameter
foo.add1(2, x=3) # pylint: disable=redundant-keyword-arg,no-value-for-parameter
def testWithModuleNameScopeMissingArgs(self):
self.skipTest('b/166158748:function does not handle this case correctly.')
def testWithExtraWrapperMissingArgs(self):
class Foo(module.Module):
@ -4366,14 +4376,115 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
self.var = None
@def_function.function
@module.Module.with_name_scope
def add(self, x, y):
@dummy_tf_decorator
def add1(self, x, y):
if self.var is None:
return x + y
@def_function.function
@dummy_tf_decorator
def add2(self, x, y):
if self.var is None:
return x + y
@def_function.function
@def_function.function
def add3(self, x, y):
if self.var is None:
return x + y
foo = Foo()
with self.assertRaisesRegex(TypeError, 'missing required arguments: y'):
foo.add(2) # pylint: disable=no-value-for-parameter
with self.assertRaisesRegex(
TypeError, 'missing 1 required positional argument: \'y\''):
foo.add1(2) # pylint: disable=no-value-for-parameter
with self.assertRaisesRegex(TypeError, 'missing 1 required argument: x'):
foo.add1(y=2) # pylint: disable=no-value-for-parameter
with self.assertRaisesRegex(
TypeError, 'missing 1 required positional argument: \'y\''):
foo.add2(2) # pylint: disable=no-value-for-parameter
with self.assertRaisesRegex(TypeError, 'missing 1 required argument: x'):
foo.add2(y=2) # pylint: disable=no-value-for-parameter
with self.assertRaisesRegex(
TypeError, 'missing 1 required positional argument: \'y\''):
foo.add3(2) # pylint: disable=no-value-for-parameter
with self.assertRaisesRegex(TypeError, 'missing 1 required argument: x'):
foo.add3(y=2) # pylint: disable=no-value-for-parameter
def testMissingArgsTfFunctionedMethod(self):
class A(object):
def func(self, position_arg1, position_arg2):
return position_arg1, position_arg2
@def_function.function
def decorated_method(self, position_arg1, position_arg2):
return position_arg1, position_arg2
a_instance = A()
tf_method_pos = def_function.function(a_instance.func)
with self.assertRaisesRegex(
TypeError, '.* missing 1 required argument: position_arg1'):
tf_method_pos(position_arg2='foo')
# tf.function-decorated instance methods need to be tested because of
# the __get__ method implementation.
tf_func_decorated_method = def_function.function(
a_instance.decorated_method)
tf_func_decorated_method(position_arg1='foo', position_arg2='bar')
with self.assertRaisesRegex(
TypeError, '.* missing 1 required argument: position_arg1'):
tf_func_decorated_method(position_arg2='bar')
def testMissingArgsTfFunctionedObject(self):
class A(object):
def __call__(self, position_arg1, position_arg2):
return position_arg1, position_arg2
a_instance = A()
# A tf.function-decorated callable object needs to be tested because of
# the special inspect results.
tf_func_obj = def_function.function(a_instance)
tf_func_obj(position_arg1=1, position_arg2=2)
with self.assertRaisesRegex(
TypeError, '.* missing 1 required argument: position_arg1'):
tf_func_obj(position_arg2='bar')
def testMissingArgsTfFunctionedFunctions(self):
def func_pos(position_arg1, position_arg2):
return position_arg1, position_arg2
def func_with_default(position_arg, named_arg=None):
return position_arg, named_arg
def func_pos_3args(position_arg1, position_arg2, position_arg3):
return position_arg1, position_arg2, position_arg3
tf_func_pos = def_function.function(func_pos)
with self.assertRaisesRegex(
TypeError, '.* missing 1 required argument: position_arg1'):
tf_func_pos(position_arg2='foo')
tf_func_with_default = def_function.function(func_with_default)
tf_func_with_default(position_arg='bar')
with self.assertRaisesRegex(TypeError,
'.* missing 1 required argument: position_arg'):
tf_func_with_default(named_arg='foo')
tf_func_pos_3args = def_function.function(func_pos_3args)
with self.assertRaisesRegex(
TypeError,
'.* missing required arguments: position_arg1, position_arg3'):
tf_func_pos_3args(position_arg2='foo')
def testShapeInferencePropagateConstNestedStack(self):

View File

@ -406,6 +406,20 @@ def ismethod(object): # pylint: disable=redefined-builtin
return _inspect.ismethod(tf_decorator.unwrap(object)[1])
def isanytargetmethod(object): # pylint: disable=redefined-builtin
# pylint: disable=g-doc-args,g-doc-return-or-yield
"""Checks all the decorated targets along the chain of decorators.
Returns True if any of the decorated targets in the chain is a method.
"""
decorators, _ = tf_decorator.unwrap(object)
for decorator in decorators:
if _inspect.ismethod(decorator.decorated_target):
return True
return False
def ismodule(object): # pylint: disable=redefined-builtin
"""TFDecorator-aware replacement for inspect.ismodule."""
return _inspect.ismodule(tf_decorator.unwrap(object)[1])