Add checks to ensure all arguments have valid values during input canonicalization.
Before this update, the behavior of tf.function is not consistent with a "naked" function call: ``` >>> def func(position_arg, named_arg=None): return position_arg, named_arg >>> func(named_arg="Hello") TypeError: func() missing 1 required positional argument: 'position_arg' >>> tf_func = tf.function(func) >>> tf_func(named_arg="Hello") (<tf.Tensor: shape=(), dtype=string, numpy=b'Hello'>, None) After the update: >>> def func(position_arg, named_arg=None): return position_arg, named_arg >>> tf_func = tf.function(func) >>> tf_func(named_arg="Hello") TypeError: func(position_arg, named_arg) missing 1 required argument: position_arg ``` PiperOrigin-RevId: 353041646 Change-Id: Ia08d1177f152054281455fc04511ca54482fd22d
This commit is contained in:
parent
8cb8c460a3
commit
f48ae3edac
@ -2412,7 +2412,25 @@ class FunctionSpec(object):
|
||||
kwonlyargs=[],
|
||||
kwonlydefaults={},
|
||||
annotations=fullargspec.annotations)
|
||||
is_method = tf_inspect.ismethod(python_function)
|
||||
|
||||
# inspect.ismethod() and inspect.isfunction() both return False on a
|
||||
# functools.partial-wrapped function. We set it to False to
|
||||
# maintain consistency with prior versions.
|
||||
is_method = False
|
||||
|
||||
else:
|
||||
# Instead of using tf_inspect.ismethod() which only checks the
|
||||
# final unwrapped target, we check if any decorated target along the chain
|
||||
# is a method.
|
||||
is_method = tf_inspect.isanytargetmethod(python_function)
|
||||
|
||||
# In the following scenario, 'python_function' is a callable object.
|
||||
# python_function(...) is equal to python_function.__call__(self, ...)
|
||||
if not is_method and not tf_inspect.isfunction(
|
||||
python_function) and hasattr(
|
||||
python_function, "__class__") and hasattr(
|
||||
python_function.__class__, "__call__"):
|
||||
is_method = True
|
||||
|
||||
# Get the function's name. Remove functools.partial wrappers if necessary.
|
||||
while isinstance(python_function, functools.partial):
|
||||
@ -2477,6 +2495,8 @@ class FunctionSpec(object):
|
||||
offset + index: default
|
||||
for index, default in enumerate(default_values or [])
|
||||
}
|
||||
self._arg_indices_no_default_values = set(range(len(args))) - set(
|
||||
self._arg_indices_to_default_values)
|
||||
if input_signature is None:
|
||||
self._input_signature = None
|
||||
else:
|
||||
@ -2633,12 +2653,14 @@ class FunctionSpec(object):
|
||||
args, kwargs = self._convert_variables_to_tensors(args, kwargs)
|
||||
if self._experimental_follow_type_hints:
|
||||
args, kwargs = self._convert_annotated_args_to_tensors(args, kwargs)
|
||||
# Pre-calculate to reduce overhead
|
||||
arglen = len(args)
|
||||
if self._input_signature is not None:
|
||||
if len(args) > len(self._input_signature):
|
||||
if arglen > len(self._input_signature):
|
||||
raise TypeError("{} takes {} positional arguments (as specified by the "
|
||||
"input_signature) but {} were given".format(
|
||||
self.signature_summary(),
|
||||
len(self._input_signature), len(args)))
|
||||
len(self._input_signature), arglen))
|
||||
for arg in six.iterkeys(kwargs):
|
||||
index = self._args_to_indices.get(arg, None)
|
||||
if index is None:
|
||||
@ -2653,13 +2675,12 @@ class FunctionSpec(object):
|
||||
inputs = args
|
||||
if self._arg_indices_to_default_values:
|
||||
try:
|
||||
inputs += tuple(
|
||||
self._arg_indices_to_default_values[i]
|
||||
for i in range(len(args), len(self._arg_names)))
|
||||
inputs += tuple(self._arg_indices_to_default_values[i]
|
||||
for i in range(arglen, len(self._arg_names)))
|
||||
except KeyError:
|
||||
missing_args = [
|
||||
self._arg_names[i]
|
||||
for i in range(len(args), len(self._arg_names))
|
||||
for i in range(arglen, len(self._arg_names))
|
||||
if i not in self._arg_indices_to_default_values
|
||||
]
|
||||
raise TypeError("{} missing required arguments: {}".format(
|
||||
@ -2673,22 +2694,36 @@ class FunctionSpec(object):
|
||||
# aren't in `args`.
|
||||
arg_indices_to_values = {
|
||||
index: default for index, default in six.iteritems(
|
||||
self._arg_indices_to_default_values) if index >= len(args)
|
||||
self._arg_indices_to_default_values) if index >= arglen
|
||||
}
|
||||
consumed_args = []
|
||||
missing_arg_indices = self._arg_indices_no_default_values - set(
|
||||
range(arglen))
|
||||
for arg, value in six.iteritems(kwargs):
|
||||
index = self._args_to_indices.get(arg, None)
|
||||
if index is not None:
|
||||
if index < len(args):
|
||||
if index < arglen:
|
||||
raise TypeError("{} got two values for argument '{}'".format(
|
||||
self.signature_summary(), arg))
|
||||
arg_indices_to_values[index] = value
|
||||
# These arguments in 'kwargs' might also belong to
|
||||
# positional arguments
|
||||
missing_arg_indices.discard(index)
|
||||
consumed_args.append(arg)
|
||||
for arg in consumed_args:
|
||||
# After this loop, `kwargs` will only contain keyword_only arguments,
|
||||
# and all positional_or_keyword arguments have been moved to `inputs`.
|
||||
kwargs.pop(arg)
|
||||
inputs = args + _deterministic_dict_values(arg_indices_to_values)
|
||||
# Exclude positional args with values
|
||||
if missing_arg_indices:
|
||||
missing_args = [self._arg_names[i] for i in sorted(missing_arg_indices)]
|
||||
if len(missing_args) == 1:
|
||||
raise TypeError("{} missing 1 required argument: {}".format(
|
||||
self.signature_summary(), missing_args[0]))
|
||||
else:
|
||||
raise TypeError("{} missing required arguments: {}".format(
|
||||
self.signature_summary(), ", ".join(missing_args)))
|
||||
|
||||
if kwargs and self._input_signature is not None:
|
||||
raise TypeError(
|
||||
@ -3911,9 +3946,9 @@ def class_method_to_instance_method(original_function, instance):
|
||||
jit_compile=original_function._jit_compile)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
# And we wrap the function with tf_decorator so inspection works correctly
|
||||
wrapped_instance_func = tf_decorator.make_decorator(
|
||||
original_function.python_function, instance_func)
|
||||
# We wrap the the bound method with tf_decorator so inspection works correctly
|
||||
wrapped_instance_func = tf_decorator.make_decorator(bound_method,
|
||||
instance_func)
|
||||
return wrapped_instance_func
|
||||
|
||||
|
||||
|
@ -83,6 +83,7 @@ from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import training_ops
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
try:
|
||||
@ -119,6 +120,15 @@ def _spec_for_value(value):
|
||||
return value
|
||||
|
||||
|
||||
# This dummy decorator imitates ordinary decorators utilizing tf_decorator.
|
||||
def dummy_tf_decorator(method):
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
return method(*args, **kwargs)
|
||||
|
||||
return tf_decorator.make_decorator(method, wrapper)
|
||||
|
||||
|
||||
class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -4194,8 +4204,8 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
trace_count[0] = 0
|
||||
disabled = def_function.function(func, experimental_follow_type_hints=False)
|
||||
disabled(x=1, y=2)
|
||||
disabled(x=2, y=2,) # Retrace
|
||||
disabled(0, 0, x=1, y=2)
|
||||
disabled(0, 0, x=2, y=2,) # Retrace
|
||||
self.assertEqual(trace_count[0], 2)
|
||||
|
||||
def testFollowTypeHintsTraceWithArgsEqualsTypedKwargs(self):
|
||||
@ -4319,8 +4329,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
enabled(1, 2, 3, 4, kwonly=5, kwarg1=600, kwarg2=700) # No retrace
|
||||
self.assertEqual(trace_count[0], 4)
|
||||
|
||||
def testWithModuleNameScope(self):
|
||||
self.skipTest('b/166158748:function does not handle this case correctly.')
|
||||
def testWithExtraWrapper(self):
|
||||
|
||||
class Foo(module.Module):
|
||||
|
||||
@ -4329,16 +4338,18 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
self.var = None
|
||||
|
||||
@def_function.function
|
||||
@module.Module.with_name_scope
|
||||
@dummy_tf_decorator
|
||||
def add(self, x, y, z=1):
|
||||
if self.var is None:
|
||||
return x + y + z
|
||||
|
||||
foo = Foo()
|
||||
self.assertEqual(foo.add(2, 3), 6)
|
||||
self.assertEqual(foo.add(2, 3).numpy(), 6)
|
||||
|
||||
def testWithModuleNameScopeRedundantArgs(self):
|
||||
self.skipTest('b/166158748:function does not handle this case correctly.')
|
||||
@parameterized.parameters([(def_function.function, dummy_tf_decorator),
|
||||
(dummy_tf_decorator, def_function.function),
|
||||
(def_function.function, def_function.function)])
|
||||
def testWithExtraWrapperRedundantArgs(self, decorator1, decorator2):
|
||||
|
||||
class Foo(module.Module):
|
||||
|
||||
@ -4346,18 +4357,17 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
super().__init__()
|
||||
self.var = None
|
||||
|
||||
@def_function.function
|
||||
@module.Module.with_name_scope
|
||||
def add(self, x, y):
|
||||
@decorator1
|
||||
@decorator2
|
||||
def add1(self, x, y):
|
||||
if self.var is None:
|
||||
return x + y
|
||||
|
||||
foo = Foo()
|
||||
with self.assertRaisesRegex(TypeError, 'got two values for argument'):
|
||||
foo.add(2, x=3) # pylint: disable=redundant-keyword-arg,no-value-for-parameter
|
||||
foo.add1(2, x=3) # pylint: disable=redundant-keyword-arg,no-value-for-parameter
|
||||
|
||||
def testWithModuleNameScopeMissingArgs(self):
|
||||
self.skipTest('b/166158748:function does not handle this case correctly.')
|
||||
def testWithExtraWrapperMissingArgs(self):
|
||||
|
||||
class Foo(module.Module):
|
||||
|
||||
@ -4366,14 +4376,115 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
self.var = None
|
||||
|
||||
@def_function.function
|
||||
@module.Module.with_name_scope
|
||||
def add(self, x, y):
|
||||
@dummy_tf_decorator
|
||||
def add1(self, x, y):
|
||||
if self.var is None:
|
||||
return x + y
|
||||
|
||||
@def_function.function
|
||||
@dummy_tf_decorator
|
||||
def add2(self, x, y):
|
||||
if self.var is None:
|
||||
return x + y
|
||||
|
||||
@def_function.function
|
||||
@def_function.function
|
||||
def add3(self, x, y):
|
||||
if self.var is None:
|
||||
return x + y
|
||||
|
||||
foo = Foo()
|
||||
with self.assertRaisesRegex(TypeError, 'missing required arguments: y'):
|
||||
foo.add(2) # pylint: disable=no-value-for-parameter
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, 'missing 1 required positional argument: \'y\''):
|
||||
foo.add1(2) # pylint: disable=no-value-for-parameter
|
||||
|
||||
with self.assertRaisesRegex(TypeError, 'missing 1 required argument: x'):
|
||||
foo.add1(y=2) # pylint: disable=no-value-for-parameter
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, 'missing 1 required positional argument: \'y\''):
|
||||
foo.add2(2) # pylint: disable=no-value-for-parameter
|
||||
|
||||
with self.assertRaisesRegex(TypeError, 'missing 1 required argument: x'):
|
||||
foo.add2(y=2) # pylint: disable=no-value-for-parameter
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, 'missing 1 required positional argument: \'y\''):
|
||||
foo.add3(2) # pylint: disable=no-value-for-parameter
|
||||
|
||||
with self.assertRaisesRegex(TypeError, 'missing 1 required argument: x'):
|
||||
foo.add3(y=2) # pylint: disable=no-value-for-parameter
|
||||
|
||||
def testMissingArgsTfFunctionedMethod(self):
|
||||
|
||||
class A(object):
|
||||
|
||||
def func(self, position_arg1, position_arg2):
|
||||
return position_arg1, position_arg2
|
||||
|
||||
@def_function.function
|
||||
def decorated_method(self, position_arg1, position_arg2):
|
||||
return position_arg1, position_arg2
|
||||
|
||||
a_instance = A()
|
||||
tf_method_pos = def_function.function(a_instance.func)
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, '.* missing 1 required argument: position_arg1'):
|
||||
tf_method_pos(position_arg2='foo')
|
||||
|
||||
# tf.function-decorated instance methods need to be tested because of
|
||||
# the __get__ method implementation.
|
||||
tf_func_decorated_method = def_function.function(
|
||||
a_instance.decorated_method)
|
||||
tf_func_decorated_method(position_arg1='foo', position_arg2='bar')
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, '.* missing 1 required argument: position_arg1'):
|
||||
tf_func_decorated_method(position_arg2='bar')
|
||||
|
||||
def testMissingArgsTfFunctionedObject(self):
|
||||
|
||||
class A(object):
|
||||
|
||||
def __call__(self, position_arg1, position_arg2):
|
||||
return position_arg1, position_arg2
|
||||
|
||||
a_instance = A()
|
||||
|
||||
# A tf.function-decorated callable object needs to be tested because of
|
||||
# the special inspect results.
|
||||
tf_func_obj = def_function.function(a_instance)
|
||||
tf_func_obj(position_arg1=1, position_arg2=2)
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, '.* missing 1 required argument: position_arg1'):
|
||||
tf_func_obj(position_arg2='bar')
|
||||
|
||||
def testMissingArgsTfFunctionedFunctions(self):
|
||||
|
||||
def func_pos(position_arg1, position_arg2):
|
||||
return position_arg1, position_arg2
|
||||
|
||||
def func_with_default(position_arg, named_arg=None):
|
||||
return position_arg, named_arg
|
||||
|
||||
def func_pos_3args(position_arg1, position_arg2, position_arg3):
|
||||
return position_arg1, position_arg2, position_arg3
|
||||
|
||||
tf_func_pos = def_function.function(func_pos)
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, '.* missing 1 required argument: position_arg1'):
|
||||
tf_func_pos(position_arg2='foo')
|
||||
|
||||
tf_func_with_default = def_function.function(func_with_default)
|
||||
tf_func_with_default(position_arg='bar')
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'.* missing 1 required argument: position_arg'):
|
||||
tf_func_with_default(named_arg='foo')
|
||||
|
||||
tf_func_pos_3args = def_function.function(func_pos_3args)
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
'.* missing required arguments: position_arg1, position_arg3'):
|
||||
tf_func_pos_3args(position_arg2='foo')
|
||||
|
||||
def testShapeInferencePropagateConstNestedStack(self):
|
||||
|
||||
|
@ -406,6 +406,20 @@ def ismethod(object): # pylint: disable=redefined-builtin
|
||||
return _inspect.ismethod(tf_decorator.unwrap(object)[1])
|
||||
|
||||
|
||||
def isanytargetmethod(object): # pylint: disable=redefined-builtin
|
||||
# pylint: disable=g-doc-args,g-doc-return-or-yield
|
||||
"""Checks all the decorated targets along the chain of decorators.
|
||||
|
||||
Returns True if any of the decorated targets in the chain is a method.
|
||||
"""
|
||||
decorators, _ = tf_decorator.unwrap(object)
|
||||
for decorator in decorators:
|
||||
if _inspect.ismethod(decorator.decorated_target):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def ismodule(object): # pylint: disable=redefined-builtin
|
||||
"""TFDecorator-aware replacement for inspect.ismodule."""
|
||||
return _inspect.ismodule(tf_decorator.unwrap(object)[1])
|
||||
|
Loading…
x
Reference in New Issue
Block a user