Internal change

PiperOrigin-RevId: 350707357
Change-Id: I6f4fb57a6a943ee39ce87e32dbdcd4262004e9c0
This commit is contained in:
A. Unique TensorFlower 2021-01-07 23:20:36 -08:00 committed by TensorFlower Gardener
parent 940ed6b14e
commit ae1b19aaaa
2 changed files with 8 additions and 65 deletions

View File

@ -1000,46 +1000,6 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(obj2.testDouble.experimental_get_tracing_count(), 3)
self.assertAllEqual(obj1.testDouble.experimental_get_tracing_count(), 2)
def testFunctionArgumentChecking(self):
class A(object):
def func(self, position_arg1, position_arg2):
return position_arg1, position_arg2
def func_pos(position_arg1, position_arg2):
return position_arg1, position_arg2
def func_named(position_arg, named_arg=None):
return position_arg, named_arg
def func_pos_3args(position_arg1, position_arg2, position_arg3):
return position_arg1, position_arg2, position_arg3
a_instance = A()
tf_method_pos = def_function.function(a_instance.func)
tf_func_pos = def_function.function(func_pos)
tf_func_named = def_function.function(func_named)
tf_func_pos_3args = def_function.function(func_pos_3args)
with self.assertRaisesRegex(
TypeError, '.* missing 1 required argument: position_arg1'):
tf_method_pos(position_arg2='foo')
with self.assertRaisesRegex(
TypeError, '.* missing 1 required argument: position_arg1'):
tf_func_pos(position_arg2='foo')
with self.assertRaisesRegex(TypeError,
'.* missing 1 required argument: position_arg'):
tf_func_named(named_arg='foo')
with self.assertRaisesRegex(
TypeError,
'.* missing required arguments: position_arg1, position_arg3'):
tf_func_pos_3args(position_arg2='foo')
tf_func_named(position_arg='bar')
if __name__ == '__main__':
ops.enable_eager_execution()

View File

@ -2494,8 +2494,6 @@ class FunctionSpec(object):
offset + index: default
for index, default in enumerate(default_values or [])
}
self._arg_indices_no_default_values = set(range(len(args))) - set(
self._arg_indices_to_default_values)
if input_signature is None:
self._input_signature = None
else:
@ -2656,14 +2654,12 @@ class FunctionSpec(object):
args, kwargs = self._convert_variables_to_tensors(args, kwargs)
if self._experimental_follow_type_hints:
args, kwargs = self._convert_annotated_args_to_tensors(args, kwargs)
# Pre-calculate to reduce overhead
arglen = len(args)
if self._input_signature is not None:
if arglen > len(self._input_signature):
if len(args) > len(self._input_signature):
raise TypeError("{} takes {} positional arguments (as specified by the "
"input_signature) but {} were given".format(
self.signature_summary(),
len(self._input_signature), arglen))
len(self._input_signature), len(args)))
for arg in six.iterkeys(kwargs):
index = self._args_to_indices.get(arg, None)
if index is None:
@ -2678,12 +2674,13 @@ class FunctionSpec(object):
inputs = args
if self._arg_indices_to_default_values:
try:
inputs += tuple(self._arg_indices_to_default_values[i]
for i in range(arglen, len(self._arg_names)))
inputs += tuple(
self._arg_indices_to_default_values[i]
for i in range(len(args), len(self._arg_names)))
except KeyError:
missing_args = [
self._arg_names[i]
for i in range(arglen, len(self._arg_names))
for i in range(len(args), len(self._arg_names))
if i not in self._arg_indices_to_default_values
]
raise TypeError("{} missing required arguments: {}".format(
@ -2697,36 +2694,22 @@ class FunctionSpec(object):
# aren't in `args`.
arg_indices_to_values = {
index: default for index, default in six.iteritems(
self._arg_indices_to_default_values) if index >= arglen
self._arg_indices_to_default_values) if index >= len(args)
}
consumed_args = []
missing_arg_indices = self._arg_indices_no_default_values - set(
range(arglen))
for arg, value in six.iteritems(kwargs):
index = self._args_to_indices.get(arg, None)
if index is not None:
if index < arglen:
if index < len(args):
raise TypeError("{} got two values for argument '{}'".format(
self.signature_summary(), arg))
arg_indices_to_values[index] = value
# These arguments in 'kwargs' might also belong to
# positional arguments
missing_arg_indices.discard(index)
consumed_args.append(arg)
for arg in consumed_args:
# After this loop, `kwargs` will only contain keyword_only arguments,
# and all positional_or_keyword arguments have been moved to `inputs`.
kwargs.pop(arg)
inputs = args + _deterministic_dict_values(arg_indices_to_values)
# Exclude positional args with values
if missing_arg_indices:
missing_args = [self._arg_names[i] for i in sorted(missing_arg_indices)]
if len(missing_args) == 1:
raise TypeError("{} missing 1 required argument: {}".format(
self.signature_summary(), missing_args[0]))
else:
raise TypeError("{} missing required arguments: {}".format(
self.signature_summary(), ", ".join(missing_args)))
if kwargs and self._input_signature is not None:
raise TypeError(