Add checks to ensure all arguments have valid values during input canonicalization.
Before this update, the behavior of tf.function is not consistent with a "naked" function call:
```
>>> def func(position_arg, named_arg=None):
      return position_arg, named_arg
>>> func(named_arg="Hello")
TypeError: func() missing 1 required positional argument: 'position_arg'
>>> tf_func = tf.function(func)
>>> tf_func(named_arg="Hello")
(<tf.Tensor: shape=(), dtype=string, numpy=b'Hello'>, None)
```
After the update:
```
>>> def func(position_arg, named_arg=None):
      return position_arg, named_arg
>>> tf_func = tf.function(func)
>>> tf_func(named_arg="Hello")
TypeError: func(position_arg, named_arg) missing 1 required argument: position_arg
```
PiperOrigin-RevId: 350703016
Change-Id: I371cfd041ab466c486fc1dc8a2e746e9a1721280
			
			
This commit is contained in:
		
							parent
							
								
									fe004863b7
								
							
						
					
					
						commit
						d40e992f9f
					
				@ -1000,6 +1000,46 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
 | 
			
		||||
    self.assertAllEqual(obj2.testDouble.experimental_get_tracing_count(), 3)
 | 
			
		||||
    self.assertAllEqual(obj1.testDouble.experimental_get_tracing_count(), 2)
 | 
			
		||||
 | 
			
		||||
  def testFunctionArgumentChecking(self):
 | 
			
		||||
 | 
			
		||||
    class A(object):
 | 
			
		||||
 | 
			
		||||
      def func(self, position_arg1, position_arg2):
 | 
			
		||||
        return position_arg1, position_arg2
 | 
			
		||||
 | 
			
		||||
    def func_pos(position_arg1, position_arg2):
 | 
			
		||||
      return position_arg1, position_arg2
 | 
			
		||||
 | 
			
		||||
    def func_named(position_arg, named_arg=None):
 | 
			
		||||
      return position_arg, named_arg
 | 
			
		||||
 | 
			
		||||
    def func_pos_3args(position_arg1, position_arg2, position_arg3):
 | 
			
		||||
      return position_arg1, position_arg2, position_arg3
 | 
			
		||||
 | 
			
		||||
    a_instance = A()
 | 
			
		||||
    tf_method_pos = def_function.function(a_instance.func)
 | 
			
		||||
    tf_func_pos = def_function.function(func_pos)
 | 
			
		||||
    tf_func_named = def_function.function(func_named)
 | 
			
		||||
    tf_func_pos_3args = def_function.function(func_pos_3args)
 | 
			
		||||
    with self.assertRaisesRegex(
 | 
			
		||||
        TypeError, '.* missing 1 required argument: position_arg1'):
 | 
			
		||||
      tf_method_pos(position_arg2='foo')
 | 
			
		||||
 | 
			
		||||
    with self.assertRaisesRegex(
 | 
			
		||||
        TypeError, '.* missing 1 required argument: position_arg1'):
 | 
			
		||||
      tf_func_pos(position_arg2='foo')
 | 
			
		||||
 | 
			
		||||
    with self.assertRaisesRegex(TypeError,
 | 
			
		||||
                                '.* missing 1 required argument: position_arg'):
 | 
			
		||||
      tf_func_named(named_arg='foo')
 | 
			
		||||
 | 
			
		||||
    with self.assertRaisesRegex(
 | 
			
		||||
        TypeError,
 | 
			
		||||
        '.* missing required arguments: position_arg1, position_arg3'):
 | 
			
		||||
      tf_func_pos_3args(position_arg2='foo')
 | 
			
		||||
 | 
			
		||||
    tf_func_named(position_arg='bar')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
  ops.enable_eager_execution()
 | 
			
		||||
 | 
			
		||||
@ -2494,6 +2494,8 @@ class FunctionSpec(object):
 | 
			
		||||
        offset + index: default
 | 
			
		||||
        for index, default in enumerate(default_values or [])
 | 
			
		||||
    }
 | 
			
		||||
    self._arg_indices_no_default_values = set(range(len(args))) - set(
 | 
			
		||||
        self._arg_indices_to_default_values)
 | 
			
		||||
    if input_signature is None:
 | 
			
		||||
      self._input_signature = None
 | 
			
		||||
    else:
 | 
			
		||||
@ -2654,12 +2656,14 @@ class FunctionSpec(object):
 | 
			
		||||
      args, kwargs = self._convert_variables_to_tensors(args, kwargs)
 | 
			
		||||
    if self._experimental_follow_type_hints:
 | 
			
		||||
      args, kwargs = self._convert_annotated_args_to_tensors(args, kwargs)
 | 
			
		||||
    # Pre-calculate to reduce overhead
 | 
			
		||||
    arglen = len(args)
 | 
			
		||||
    if self._input_signature is not None:
 | 
			
		||||
      if len(args) > len(self._input_signature):
 | 
			
		||||
      if arglen > len(self._input_signature):
 | 
			
		||||
        raise TypeError("{} takes {} positional arguments (as specified by the "
 | 
			
		||||
                        "input_signature) but {} were given".format(
 | 
			
		||||
                            self.signature_summary(),
 | 
			
		||||
                            len(self._input_signature), len(args)))
 | 
			
		||||
                            len(self._input_signature), arglen))
 | 
			
		||||
      for arg in six.iterkeys(kwargs):
 | 
			
		||||
        index = self._args_to_indices.get(arg, None)
 | 
			
		||||
        if index is None:
 | 
			
		||||
@ -2674,13 +2678,12 @@ class FunctionSpec(object):
 | 
			
		||||
      inputs = args
 | 
			
		||||
      if self._arg_indices_to_default_values:
 | 
			
		||||
        try:
 | 
			
		||||
          inputs += tuple(
 | 
			
		||||
              self._arg_indices_to_default_values[i]
 | 
			
		||||
              for i in range(len(args), len(self._arg_names)))
 | 
			
		||||
          inputs += tuple(self._arg_indices_to_default_values[i]
 | 
			
		||||
                          for i in range(arglen, len(self._arg_names)))
 | 
			
		||||
        except KeyError:
 | 
			
		||||
          missing_args = [
 | 
			
		||||
              self._arg_names[i]
 | 
			
		||||
              for i in range(len(args), len(self._arg_names))
 | 
			
		||||
              for i in range(arglen, len(self._arg_names))
 | 
			
		||||
              if i not in self._arg_indices_to_default_values
 | 
			
		||||
          ]
 | 
			
		||||
          raise TypeError("{} missing required arguments: {}".format(
 | 
			
		||||
@ -2694,22 +2697,36 @@ class FunctionSpec(object):
 | 
			
		||||
      # aren't in `args`.
 | 
			
		||||
      arg_indices_to_values = {
 | 
			
		||||
          index: default for index, default in six.iteritems(
 | 
			
		||||
              self._arg_indices_to_default_values) if index >= len(args)
 | 
			
		||||
              self._arg_indices_to_default_values) if index >= arglen
 | 
			
		||||
      }
 | 
			
		||||
      consumed_args = []
 | 
			
		||||
      missing_arg_indices = self._arg_indices_no_default_values - set(
 | 
			
		||||
          range(arglen))
 | 
			
		||||
      for arg, value in six.iteritems(kwargs):
 | 
			
		||||
        index = self._args_to_indices.get(arg, None)
 | 
			
		||||
        if index is not None:
 | 
			
		||||
          if index < len(args):
 | 
			
		||||
          if index < arglen:
 | 
			
		||||
            raise TypeError("{} got two values for argument '{}'".format(
 | 
			
		||||
                self.signature_summary(), arg))
 | 
			
		||||
          arg_indices_to_values[index] = value
 | 
			
		||||
          # These arguments in 'kwargs' might also belong to
 | 
			
		||||
          # positional arguments
 | 
			
		||||
          missing_arg_indices.discard(index)
 | 
			
		||||
          consumed_args.append(arg)
 | 
			
		||||
      for arg in consumed_args:
 | 
			
		||||
        # After this loop, `kwargs` will only contain keyword_only arguments,
 | 
			
		||||
        # and all positional_or_keyword arguments have been moved to `inputs`.
 | 
			
		||||
        kwargs.pop(arg)
 | 
			
		||||
      inputs = args + _deterministic_dict_values(arg_indices_to_values)
 | 
			
		||||
      # Exclude positional args with values
 | 
			
		||||
      if missing_arg_indices:
 | 
			
		||||
        missing_args = [self._arg_names[i] for i in sorted(missing_arg_indices)]
 | 
			
		||||
        if len(missing_args) == 1:
 | 
			
		||||
          raise TypeError("{} missing 1 required argument: {}".format(
 | 
			
		||||
              self.signature_summary(), missing_args[0]))
 | 
			
		||||
        else:
 | 
			
		||||
          raise TypeError("{} missing required arguments: {}".format(
 | 
			
		||||
              self.signature_summary(), ", ".join(missing_args)))
 | 
			
		||||
 | 
			
		||||
      if kwargs and self._input_signature is not None:
 | 
			
		||||
        raise TypeError(
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user