diff --git a/tensorflow/python/eager/def_function_test.py b/tensorflow/python/eager/def_function_test.py index 6f235b6475b..bc41a070e0b 100644 --- a/tensorflow/python/eager/def_function_test.py +++ b/tensorflow/python/eager/def_function_test.py @@ -1000,46 +1000,6 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): self.assertAllEqual(obj2.testDouble.experimental_get_tracing_count(), 3) self.assertAllEqual(obj1.testDouble.experimental_get_tracing_count(), 2) - def testFunctionArgumentChecking(self): - - class A(object): - - def func(self, position_arg1, position_arg2): - return position_arg1, position_arg2 - - def func_pos(position_arg1, position_arg2): - return position_arg1, position_arg2 - - def func_named(position_arg, named_arg=None): - return position_arg, named_arg - - def func_pos_3args(position_arg1, position_arg2, position_arg3): - return position_arg1, position_arg2, position_arg3 - - a_instance = A() - tf_method_pos = def_function.function(a_instance.func) - tf_func_pos = def_function.function(func_pos) - tf_func_named = def_function.function(func_named) - tf_func_pos_3args = def_function.function(func_pos_3args) - with self.assertRaisesRegex( - TypeError, '.* missing 1 required argument: position_arg1'): - tf_method_pos(position_arg2='foo') - - with self.assertRaisesRegex( - TypeError, '.* missing 1 required argument: position_arg1'): - tf_func_pos(position_arg2='foo') - - with self.assertRaisesRegex(TypeError, - '.* missing 1 required argument: position_arg'): - tf_func_named(named_arg='foo') - - with self.assertRaisesRegex( - TypeError, - '.* missing required arguments: position_arg1, position_arg3'): - tf_func_pos_3args(position_arg2='foo') - - tf_func_named(position_arg='bar') - if __name__ == '__main__': ops.enable_eager_execution() diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 3a8fad827ed..828af8f52e8 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -2494,8 +2494,6 @@ class FunctionSpec(object): offset + index: default for index, default in enumerate(default_values or []) } - self._arg_indices_no_default_values = set(range(len(args))) - set( - self._arg_indices_to_default_values) if input_signature is None: self._input_signature = None else: @@ -2656,14 +2654,12 @@ class FunctionSpec(object): args, kwargs = self._convert_variables_to_tensors(args, kwargs) if self._experimental_follow_type_hints: args, kwargs = self._convert_annotated_args_to_tensors(args, kwargs) - # Pre-calculate to reduce overhead - arglen = len(args) if self._input_signature is not None: - if arglen > len(self._input_signature): + if len(args) > len(self._input_signature): raise TypeError("{} takes {} positional arguments (as specified by the " "input_signature) but {} were given".format( self.signature_summary(), - len(self._input_signature), arglen)) + len(self._input_signature), len(args))) for arg in six.iterkeys(kwargs): index = self._args_to_indices.get(arg, None) if index is None: @@ -2678,12 +2674,13 @@ class FunctionSpec(object): inputs = args if self._arg_indices_to_default_values: try: - inputs += tuple(self._arg_indices_to_default_values[i] - for i in range(arglen, len(self._arg_names))) + inputs += tuple( + self._arg_indices_to_default_values[i] + for i in range(len(args), len(self._arg_names))) except KeyError: missing_args = [ self._arg_names[i] - for i in range(arglen, len(self._arg_names)) + for i in range(len(args), len(self._arg_names)) if i not in self._arg_indices_to_default_values ] raise TypeError("{} missing required arguments: {}".format( @@ -2697,36 +2694,22 @@ class FunctionSpec(object): # aren't in `args`. arg_indices_to_values = { index: default for index, default in six.iteritems( - self._arg_indices_to_default_values) if index >= arglen + self._arg_indices_to_default_values) if index >= len(args) } consumed_args = [] - missing_arg_indices = self._arg_indices_no_default_values - set( - range(arglen)) for arg, value in six.iteritems(kwargs): index = self._args_to_indices.get(arg, None) if index is not None: - if index < arglen: + if index < len(args): raise TypeError("{} got two values for argument '{}'".format( self.signature_summary(), arg)) arg_indices_to_values[index] = value - # These arguments in 'kwargs' might also belong to - # positional arguments - missing_arg_indices.discard(index) consumed_args.append(arg) for arg in consumed_args: # After this loop, `kwargs` will only contain keyword_only arguments, # and all positional_or_keyword arguments have been moved to `inputs`. kwargs.pop(arg) inputs = args + _deterministic_dict_values(arg_indices_to_values) - # Exclude positional args with values - if missing_arg_indices: - missing_args = [self._arg_names[i] for i in sorted(missing_arg_indices)] - if len(missing_args) == 1: - raise TypeError("{} missing 1 required argument: {}".format( - self.signature_summary(), missing_args[0])) - else: - raise TypeError("{} missing required arguments: {}".format( - self.signature_summary(), ", ".join(missing_args))) if kwargs and self._input_signature is not None: raise TypeError(