Add checks to ensure all arguments have valid values during input canonicalization.
Before this update, the behavior of tf.function is not consistent with a "naked" function call: ``` >>> def func(position_arg, named_arg=None): return position_arg, named_arg >>> func(named_arg="Hello") TypeError: func() missing 1 required positional argument: 'position_arg' >>> tf_func = tf.function(func) >>> tf_func(named_arg="Hello") (<tf.Tensor: shape=(), dtype=string, numpy=b'Hello'>, None) After the update: >>> def func(position_arg, named_arg=None): return position_arg, named_arg >>> tf_func = tf.function(func) >>> tf_func(named_arg="Hello") TypeError: func(position_arg, named_arg) missing 1 required argument: position_arg ``` PiperOrigin-RevId: 353041646 Change-Id: Ia08d1177f152054281455fc04511ca54482fd22d
This commit is contained in:
parent
8cb8c460a3
commit
f48ae3edac
@ -2412,7 +2412,25 @@ class FunctionSpec(object):
|
|||||||
kwonlyargs=[],
|
kwonlyargs=[],
|
||||||
kwonlydefaults={},
|
kwonlydefaults={},
|
||||||
annotations=fullargspec.annotations)
|
annotations=fullargspec.annotations)
|
||||||
is_method = tf_inspect.ismethod(python_function)
|
|
||||||
|
# inspect.ismethod() and inspect.isfunction() both return False on a
|
||||||
|
# functools.partial-wrapped function. We set it to False to
|
||||||
|
# maintain consistency with prior versions.
|
||||||
|
is_method = False
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Instead of using tf_inspect.ismethod() which only checks the
|
||||||
|
# final unwrapped target, we check if any decorated target along the chain
|
||||||
|
# is a method.
|
||||||
|
is_method = tf_inspect.isanytargetmethod(python_function)
|
||||||
|
|
||||||
|
# In the following scenario, 'python_function' is a callable object.
|
||||||
|
# python_function(...) is equal to python_function.__call__(self, ...)
|
||||||
|
if not is_method and not tf_inspect.isfunction(
|
||||||
|
python_function) and hasattr(
|
||||||
|
python_function, "__class__") and hasattr(
|
||||||
|
python_function.__class__, "__call__"):
|
||||||
|
is_method = True
|
||||||
|
|
||||||
# Get the function's name. Remove functools.partial wrappers if necessary.
|
# Get the function's name. Remove functools.partial wrappers if necessary.
|
||||||
while isinstance(python_function, functools.partial):
|
while isinstance(python_function, functools.partial):
|
||||||
@ -2477,6 +2495,8 @@ class FunctionSpec(object):
|
|||||||
offset + index: default
|
offset + index: default
|
||||||
for index, default in enumerate(default_values or [])
|
for index, default in enumerate(default_values or [])
|
||||||
}
|
}
|
||||||
|
self._arg_indices_no_default_values = set(range(len(args))) - set(
|
||||||
|
self._arg_indices_to_default_values)
|
||||||
if input_signature is None:
|
if input_signature is None:
|
||||||
self._input_signature = None
|
self._input_signature = None
|
||||||
else:
|
else:
|
||||||
@ -2633,12 +2653,14 @@ class FunctionSpec(object):
|
|||||||
args, kwargs = self._convert_variables_to_tensors(args, kwargs)
|
args, kwargs = self._convert_variables_to_tensors(args, kwargs)
|
||||||
if self._experimental_follow_type_hints:
|
if self._experimental_follow_type_hints:
|
||||||
args, kwargs = self._convert_annotated_args_to_tensors(args, kwargs)
|
args, kwargs = self._convert_annotated_args_to_tensors(args, kwargs)
|
||||||
|
# Pre-calculate to reduce overhead
|
||||||
|
arglen = len(args)
|
||||||
if self._input_signature is not None:
|
if self._input_signature is not None:
|
||||||
if len(args) > len(self._input_signature):
|
if arglen > len(self._input_signature):
|
||||||
raise TypeError("{} takes {} positional arguments (as specified by the "
|
raise TypeError("{} takes {} positional arguments (as specified by the "
|
||||||
"input_signature) but {} were given".format(
|
"input_signature) but {} were given".format(
|
||||||
self.signature_summary(),
|
self.signature_summary(),
|
||||||
len(self._input_signature), len(args)))
|
len(self._input_signature), arglen))
|
||||||
for arg in six.iterkeys(kwargs):
|
for arg in six.iterkeys(kwargs):
|
||||||
index = self._args_to_indices.get(arg, None)
|
index = self._args_to_indices.get(arg, None)
|
||||||
if index is None:
|
if index is None:
|
||||||
@ -2653,13 +2675,12 @@ class FunctionSpec(object):
|
|||||||
inputs = args
|
inputs = args
|
||||||
if self._arg_indices_to_default_values:
|
if self._arg_indices_to_default_values:
|
||||||
try:
|
try:
|
||||||
inputs += tuple(
|
inputs += tuple(self._arg_indices_to_default_values[i]
|
||||||
self._arg_indices_to_default_values[i]
|
for i in range(arglen, len(self._arg_names)))
|
||||||
for i in range(len(args), len(self._arg_names)))
|
|
||||||
except KeyError:
|
except KeyError:
|
||||||
missing_args = [
|
missing_args = [
|
||||||
self._arg_names[i]
|
self._arg_names[i]
|
||||||
for i in range(len(args), len(self._arg_names))
|
for i in range(arglen, len(self._arg_names))
|
||||||
if i not in self._arg_indices_to_default_values
|
if i not in self._arg_indices_to_default_values
|
||||||
]
|
]
|
||||||
raise TypeError("{} missing required arguments: {}".format(
|
raise TypeError("{} missing required arguments: {}".format(
|
||||||
@ -2673,22 +2694,36 @@ class FunctionSpec(object):
|
|||||||
# aren't in `args`.
|
# aren't in `args`.
|
||||||
arg_indices_to_values = {
|
arg_indices_to_values = {
|
||||||
index: default for index, default in six.iteritems(
|
index: default for index, default in six.iteritems(
|
||||||
self._arg_indices_to_default_values) if index >= len(args)
|
self._arg_indices_to_default_values) if index >= arglen
|
||||||
}
|
}
|
||||||
consumed_args = []
|
consumed_args = []
|
||||||
|
missing_arg_indices = self._arg_indices_no_default_values - set(
|
||||||
|
range(arglen))
|
||||||
for arg, value in six.iteritems(kwargs):
|
for arg, value in six.iteritems(kwargs):
|
||||||
index = self._args_to_indices.get(arg, None)
|
index = self._args_to_indices.get(arg, None)
|
||||||
if index is not None:
|
if index is not None:
|
||||||
if index < len(args):
|
if index < arglen:
|
||||||
raise TypeError("{} got two values for argument '{}'".format(
|
raise TypeError("{} got two values for argument '{}'".format(
|
||||||
self.signature_summary(), arg))
|
self.signature_summary(), arg))
|
||||||
arg_indices_to_values[index] = value
|
arg_indices_to_values[index] = value
|
||||||
|
# These arguments in 'kwargs' might also belong to
|
||||||
|
# positional arguments
|
||||||
|
missing_arg_indices.discard(index)
|
||||||
consumed_args.append(arg)
|
consumed_args.append(arg)
|
||||||
for arg in consumed_args:
|
for arg in consumed_args:
|
||||||
# After this loop, `kwargs` will only contain keyword_only arguments,
|
# After this loop, `kwargs` will only contain keyword_only arguments,
|
||||||
# and all positional_or_keyword arguments have been moved to `inputs`.
|
# and all positional_or_keyword arguments have been moved to `inputs`.
|
||||||
kwargs.pop(arg)
|
kwargs.pop(arg)
|
||||||
inputs = args + _deterministic_dict_values(arg_indices_to_values)
|
inputs = args + _deterministic_dict_values(arg_indices_to_values)
|
||||||
|
# Exclude positional args with values
|
||||||
|
if missing_arg_indices:
|
||||||
|
missing_args = [self._arg_names[i] for i in sorted(missing_arg_indices)]
|
||||||
|
if len(missing_args) == 1:
|
||||||
|
raise TypeError("{} missing 1 required argument: {}".format(
|
||||||
|
self.signature_summary(), missing_args[0]))
|
||||||
|
else:
|
||||||
|
raise TypeError("{} missing required arguments: {}".format(
|
||||||
|
self.signature_summary(), ", ".join(missing_args)))
|
||||||
|
|
||||||
if kwargs and self._input_signature is not None:
|
if kwargs and self._input_signature is not None:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
@ -3911,9 +3946,9 @@ def class_method_to_instance_method(original_function, instance):
|
|||||||
jit_compile=original_function._jit_compile)
|
jit_compile=original_function._jit_compile)
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
# And we wrap the function with tf_decorator so inspection works correctly
|
# We wrap the the bound method with tf_decorator so inspection works correctly
|
||||||
wrapped_instance_func = tf_decorator.make_decorator(
|
wrapped_instance_func = tf_decorator.make_decorator(bound_method,
|
||||||
original_function.python_function, instance_func)
|
instance_func)
|
||||||
return wrapped_instance_func
|
return wrapped_instance_func
|
||||||
|
|
||||||
|
|
||||||
|
@ -83,6 +83,7 @@ from tensorflow.python.platform import test
|
|||||||
from tensorflow.python.training import training_ops
|
from tensorflow.python.training import training_ops
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
from tensorflow.python.util import tf_decorator
|
||||||
from tensorflow.python.util import tf_inspect
|
from tensorflow.python.util import tf_inspect
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -119,6 +120,15 @@ def _spec_for_value(value):
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
# This dummy decorator imitates ordinary decorators utilizing tf_decorator.
|
||||||
|
def dummy_tf_decorator(method):
|
||||||
|
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
return method(*args, **kwargs)
|
||||||
|
|
||||||
|
return tf_decorator.make_decorator(method, wrapper)
|
||||||
|
|
||||||
|
|
||||||
class FunctionTest(test.TestCase, parameterized.TestCase):
|
class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -4194,8 +4204,8 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
trace_count[0] = 0
|
trace_count[0] = 0
|
||||||
disabled = def_function.function(func, experimental_follow_type_hints=False)
|
disabled = def_function.function(func, experimental_follow_type_hints=False)
|
||||||
disabled(x=1, y=2)
|
disabled(0, 0, x=1, y=2)
|
||||||
disabled(x=2, y=2,) # Retrace
|
disabled(0, 0, x=2, y=2,) # Retrace
|
||||||
self.assertEqual(trace_count[0], 2)
|
self.assertEqual(trace_count[0], 2)
|
||||||
|
|
||||||
def testFollowTypeHintsTraceWithArgsEqualsTypedKwargs(self):
|
def testFollowTypeHintsTraceWithArgsEqualsTypedKwargs(self):
|
||||||
@ -4319,8 +4329,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
enabled(1, 2, 3, 4, kwonly=5, kwarg1=600, kwarg2=700) # No retrace
|
enabled(1, 2, 3, 4, kwonly=5, kwarg1=600, kwarg2=700) # No retrace
|
||||||
self.assertEqual(trace_count[0], 4)
|
self.assertEqual(trace_count[0], 4)
|
||||||
|
|
||||||
def testWithModuleNameScope(self):
|
def testWithExtraWrapper(self):
|
||||||
self.skipTest('b/166158748:function does not handle this case correctly.')
|
|
||||||
|
|
||||||
class Foo(module.Module):
|
class Foo(module.Module):
|
||||||
|
|
||||||
@ -4329,16 +4338,18 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.var = None
|
self.var = None
|
||||||
|
|
||||||
@def_function.function
|
@def_function.function
|
||||||
@module.Module.with_name_scope
|
@dummy_tf_decorator
|
||||||
def add(self, x, y, z=1):
|
def add(self, x, y, z=1):
|
||||||
if self.var is None:
|
if self.var is None:
|
||||||
return x + y + z
|
return x + y + z
|
||||||
|
|
||||||
foo = Foo()
|
foo = Foo()
|
||||||
self.assertEqual(foo.add(2, 3), 6)
|
self.assertEqual(foo.add(2, 3).numpy(), 6)
|
||||||
|
|
||||||
def testWithModuleNameScopeRedundantArgs(self):
|
@parameterized.parameters([(def_function.function, dummy_tf_decorator),
|
||||||
self.skipTest('b/166158748:function does not handle this case correctly.')
|
(dummy_tf_decorator, def_function.function),
|
||||||
|
(def_function.function, def_function.function)])
|
||||||
|
def testWithExtraWrapperRedundantArgs(self, decorator1, decorator2):
|
||||||
|
|
||||||
class Foo(module.Module):
|
class Foo(module.Module):
|
||||||
|
|
||||||
@ -4346,18 +4357,17 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.var = None
|
self.var = None
|
||||||
|
|
||||||
@def_function.function
|
@decorator1
|
||||||
@module.Module.with_name_scope
|
@decorator2
|
||||||
def add(self, x, y):
|
def add1(self, x, y):
|
||||||
if self.var is None:
|
if self.var is None:
|
||||||
return x + y
|
return x + y
|
||||||
|
|
||||||
foo = Foo()
|
foo = Foo()
|
||||||
with self.assertRaisesRegex(TypeError, 'got two values for argument'):
|
with self.assertRaisesRegex(TypeError, 'got two values for argument'):
|
||||||
foo.add(2, x=3) # pylint: disable=redundant-keyword-arg,no-value-for-parameter
|
foo.add1(2, x=3) # pylint: disable=redundant-keyword-arg,no-value-for-parameter
|
||||||
|
|
||||||
def testWithModuleNameScopeMissingArgs(self):
|
def testWithExtraWrapperMissingArgs(self):
|
||||||
self.skipTest('b/166158748:function does not handle this case correctly.')
|
|
||||||
|
|
||||||
class Foo(module.Module):
|
class Foo(module.Module):
|
||||||
|
|
||||||
@ -4366,14 +4376,115 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.var = None
|
self.var = None
|
||||||
|
|
||||||
@def_function.function
|
@def_function.function
|
||||||
@module.Module.with_name_scope
|
@dummy_tf_decorator
|
||||||
def add(self, x, y):
|
def add1(self, x, y):
|
||||||
|
if self.var is None:
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
@dummy_tf_decorator
|
||||||
|
def add2(self, x, y):
|
||||||
|
if self.var is None:
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
@def_function.function
|
||||||
|
def add3(self, x, y):
|
||||||
if self.var is None:
|
if self.var is None:
|
||||||
return x + y
|
return x + y
|
||||||
|
|
||||||
foo = Foo()
|
foo = Foo()
|
||||||
with self.assertRaisesRegex(TypeError, 'missing required arguments: y'):
|
with self.assertRaisesRegex(
|
||||||
foo.add(2) # pylint: disable=no-value-for-parameter
|
TypeError, 'missing 1 required positional argument: \'y\''):
|
||||||
|
foo.add1(2) # pylint: disable=no-value-for-parameter
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(TypeError, 'missing 1 required argument: x'):
|
||||||
|
foo.add1(y=2) # pylint: disable=no-value-for-parameter
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
TypeError, 'missing 1 required positional argument: \'y\''):
|
||||||
|
foo.add2(2) # pylint: disable=no-value-for-parameter
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(TypeError, 'missing 1 required argument: x'):
|
||||||
|
foo.add2(y=2) # pylint: disable=no-value-for-parameter
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
TypeError, 'missing 1 required positional argument: \'y\''):
|
||||||
|
foo.add3(2) # pylint: disable=no-value-for-parameter
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(TypeError, 'missing 1 required argument: x'):
|
||||||
|
foo.add3(y=2) # pylint: disable=no-value-for-parameter
|
||||||
|
|
||||||
|
def testMissingArgsTfFunctionedMethod(self):
|
||||||
|
|
||||||
|
class A(object):
|
||||||
|
|
||||||
|
def func(self, position_arg1, position_arg2):
|
||||||
|
return position_arg1, position_arg2
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def decorated_method(self, position_arg1, position_arg2):
|
||||||
|
return position_arg1, position_arg2
|
||||||
|
|
||||||
|
a_instance = A()
|
||||||
|
tf_method_pos = def_function.function(a_instance.func)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
TypeError, '.* missing 1 required argument: position_arg1'):
|
||||||
|
tf_method_pos(position_arg2='foo')
|
||||||
|
|
||||||
|
# tf.function-decorated instance methods need to be tested because of
|
||||||
|
# the __get__ method implementation.
|
||||||
|
tf_func_decorated_method = def_function.function(
|
||||||
|
a_instance.decorated_method)
|
||||||
|
tf_func_decorated_method(position_arg1='foo', position_arg2='bar')
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
TypeError, '.* missing 1 required argument: position_arg1'):
|
||||||
|
tf_func_decorated_method(position_arg2='bar')
|
||||||
|
|
||||||
|
def testMissingArgsTfFunctionedObject(self):
|
||||||
|
|
||||||
|
class A(object):
|
||||||
|
|
||||||
|
def __call__(self, position_arg1, position_arg2):
|
||||||
|
return position_arg1, position_arg2
|
||||||
|
|
||||||
|
a_instance = A()
|
||||||
|
|
||||||
|
# A tf.function-decorated callable object needs to be tested because of
|
||||||
|
# the special inspect results.
|
||||||
|
tf_func_obj = def_function.function(a_instance)
|
||||||
|
tf_func_obj(position_arg1=1, position_arg2=2)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
TypeError, '.* missing 1 required argument: position_arg1'):
|
||||||
|
tf_func_obj(position_arg2='bar')
|
||||||
|
|
||||||
|
def testMissingArgsTfFunctionedFunctions(self):
|
||||||
|
|
||||||
|
def func_pos(position_arg1, position_arg2):
|
||||||
|
return position_arg1, position_arg2
|
||||||
|
|
||||||
|
def func_with_default(position_arg, named_arg=None):
|
||||||
|
return position_arg, named_arg
|
||||||
|
|
||||||
|
def func_pos_3args(position_arg1, position_arg2, position_arg3):
|
||||||
|
return position_arg1, position_arg2, position_arg3
|
||||||
|
|
||||||
|
tf_func_pos = def_function.function(func_pos)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
TypeError, '.* missing 1 required argument: position_arg1'):
|
||||||
|
tf_func_pos(position_arg2='foo')
|
||||||
|
|
||||||
|
tf_func_with_default = def_function.function(func_with_default)
|
||||||
|
tf_func_with_default(position_arg='bar')
|
||||||
|
with self.assertRaisesRegex(TypeError,
|
||||||
|
'.* missing 1 required argument: position_arg'):
|
||||||
|
tf_func_with_default(named_arg='foo')
|
||||||
|
|
||||||
|
tf_func_pos_3args = def_function.function(func_pos_3args)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
TypeError,
|
||||||
|
'.* missing required arguments: position_arg1, position_arg3'):
|
||||||
|
tf_func_pos_3args(position_arg2='foo')
|
||||||
|
|
||||||
def testShapeInferencePropagateConstNestedStack(self):
|
def testShapeInferencePropagateConstNestedStack(self):
|
||||||
|
|
||||||
|
@ -406,6 +406,20 @@ def ismethod(object): # pylint: disable=redefined-builtin
|
|||||||
return _inspect.ismethod(tf_decorator.unwrap(object)[1])
|
return _inspect.ismethod(tf_decorator.unwrap(object)[1])
|
||||||
|
|
||||||
|
|
||||||
|
def isanytargetmethod(object): # pylint: disable=redefined-builtin
|
||||||
|
# pylint: disable=g-doc-args,g-doc-return-or-yield
|
||||||
|
"""Checks all the decorated targets along the chain of decorators.
|
||||||
|
|
||||||
|
Returns True if any of the decorated targets in the chain is a method.
|
||||||
|
"""
|
||||||
|
decorators, _ = tf_decorator.unwrap(object)
|
||||||
|
for decorator in decorators:
|
||||||
|
if _inspect.ismethod(decorator.decorated_target):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def ismodule(object): # pylint: disable=redefined-builtin
|
def ismodule(object): # pylint: disable=redefined-builtin
|
||||||
"""TFDecorator-aware replacement for inspect.ismodule."""
|
"""TFDecorator-aware replacement for inspect.ismodule."""
|
||||||
return _inspect.ismodule(tf_decorator.unwrap(object)[1])
|
return _inspect.ismodule(tf_decorator.unwrap(object)[1])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user