From 318ed830c7af977457f1633cddd08fef8caa7a9e Mon Sep 17 00:00:00 2001
From: Allen Lavoie <allenl@google.com>
Date: Thu, 23 May 2019 13:49:16 -0700
Subject: [PATCH] Auto-uniquify ConcreteFunction kwargs, and allow passing
 anything as a positional argument

Same format as ConcreteFunction.inputs without the captures, nest.flatten(args, expand_composites=True) + nest.flatten(kwargs, expand_composities=True) filtered for Tensors only.

ConcreteFunction won't complain about anything now. Actually calling the resulting function may be confusing, and we can in the future make a more restricted API which only accepts pure Tensor structures where the ConcreteFunction can be called with those structures. But taking mixed Python/Tensor structures is probably not a good idea.

PiperOrigin-RevId: 249708123
---
 tensorflow/python/eager/def_function_test.py  | 25 +++++++++++
 tensorflow/python/eager/function.py           | 41 ++++++-------------
 .../eager/function_argument_naming_test.py    |  7 ----
 tensorflow/python/eager/function_test.py      | 14 +++----
 tensorflow/python/saved_model/save_test.py    |  5 ---
 5 files changed, 45 insertions(+), 47 deletions(-)

diff --git a/tensorflow/python/eager/def_function_test.py b/tensorflow/python/eager/def_function_test.py
index fb248459527..2fcfde86bad 100644
--- a/tensorflow/python/eager/def_function_test.py
+++ b/tensorflow/python/eager/def_function_test.py
@@ -448,6 +448,31 @@ class DefFunctionTest(test.TestCase):
         re.compile('An op outside of the function.*passed.*Const', re.DOTALL)):
       failing_function()
 
+  def testNonUniqueNamesGetConcreteFunction(self):
+    @def_function.function
+    def non_unique_arg_names(x, **kwargs):
+      a, b, c = x
+      d = kwargs['d']
+      return a + b + c + d
+
+    concrete = non_unique_arg_names.get_concrete_function(
+        (tensor_spec.TensorSpec(None, dtypes.float32),
+         tensor_spec.TensorSpec(None, dtypes.float32),
+         tensor_spec.TensorSpec(None, dtypes.float32)),
+        d=tensor_spec.TensorSpec(None, dtypes.float32))
+    self.assertAllClose(
+        10.,
+        concrete(x=constant_op.constant(1.),
+                 x_1=constant_op.constant(2.),
+                 x_2=constant_op.constant(3.),
+                 d=constant_op.constant(4.)))
+    self.assertAllClose(
+        10.,
+        concrete(constant_op.constant(1.),
+                 constant_op.constant(2.),
+                 constant_op.constant(3.),
+                 constant_op.constant(4.)))
+
   def testVariableCreatorScope(self):
     created_variables = []
     captured_variables = []
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 52f6ad908ab..24377614031 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -1406,40 +1406,25 @@ class Function(object):
       kwargs = {}
     seen_names = set()
     captured = frozenset(graph_function.graph.internal_captures)
-    allowed_positional = 0
-    if args:
-      for outer_arg in args:
-        # TODO(allenl): Consider allowing arguments with defaults in the Python
-        # function's signature to be passed as positional arguments to the
-        # concrete function.
-        if not isinstance(
-            outer_arg,
-            (ops.Tensor, resource_variable_ops.ResourceVariable,
-             tensor_spec.TensorSpec)):
-          break
-        allowed_positional += 1
     # pylint: disable=protected-access
-    graph_function._num_positional_args = allowed_positional
     graph_function._arg_keywords = []
+    prefix_counts = {}
     # pylint: enable=protected-access
+    num_positional = 0
     for arg in graph_function.graph.inputs:
       if arg in captured:
         break
-      user_arg_name = arg.op.get_attr("_user_specified_name")
-      if user_arg_name in seen_names:
-        raise ValueError(
-            ("Unable to construct a concrete function for {} since some "
-             "arguments do not have unique names. Got two arguments named "
-             "'{}'. When constructing a concrete TensorFlow function from a "
-             "Python function which takes nested structures or variadic "
-             "positional arguments, pass unique names to tf.TensorSpec objects "
-             "used to identify these Tensor inputs. These names may then be "
-             "used as keyword arguments to the concrete function.")
-            .format(
-                self._python_function,
-                compat.as_str(arg.op.get_attr("_user_specified_name"))))
-      seen_names.add(user_arg_name)
-      graph_function._arg_keywords.append(user_arg_name)  # pylint: disable=protected-access
+      num_positional += 1
+      user_arg_name = compat.as_str(arg.op.get_attr("_user_specified_name"))
+      proposal = user_arg_name
+      while proposal in seen_names:
+        index = prefix_counts.get(user_arg_name, 1)
+        proposal = "{}_{}".format(user_arg_name, index)
+        prefix_counts[user_arg_name] = index + 1
+      seen_names.add(proposal)
+      graph_function._arg_keywords.append(proposal)  # pylint: disable=protected-access
+    # Anything can be a positional argument, in the same order as .inputs
+    graph_function._num_positional_args = num_positional  # pylint: disable=protected-access
     return graph_function
 
   def __get__(self, instance, owner):
diff --git a/tensorflow/python/eager/function_argument_naming_test.py b/tensorflow/python/eager/function_argument_naming_test.py
index 08a50a8f513..4e6a60e0d27 100644
--- a/tensorflow/python/eager/function_argument_naming_test.py
+++ b/tensorflow/python/eager/function_argument_naming_test.py
@@ -100,13 +100,6 @@ class ArgumentNamingTests(test.TestCase, parameterized.TestCase):
     self.assertEqual({'alpha', 'beta'},
                      set(fn_op.graph.structured_outputs.keys()))
 
-    with self.assertRaisesRegexp(ValueError, "two arguments named 'z'"):
-      fn.get_concrete_function(
-          z=(tensor_spec.TensorSpec(shape=(None,), dtype=dtypes.float32),
-             tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32)),
-          y=tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32,
-                                   name='custom'),
-          x=4.)
     fn_op2 = fn.get_concrete_function(
         z=(tensor_spec.TensorSpec(shape=(None,), dtype=dtypes.float32,
                                   name='z_first'),
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 9ef84b86d75..84c6f529069 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -385,10 +385,11 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
       ((a, b),) = mats
       return matmul(a, b)
 
-    with self.assertRaisesRegexp(ValueError, "two arguments named 'mats'"):
-      sq.get_concrete_function(
-          [(tensor_spec.TensorSpec((None, None), dtypes.float32),
-            tensor_spec.TensorSpec((None, None), dtypes.float32))])
+    sq_op_autonamed = sq.get_concrete_function(
+        [(tensor_spec.TensorSpec((None, None), dtypes.float32),
+          tensor_spec.TensorSpec((None, None), dtypes.float32))])
+    self.assertEqual([None, None], sq_op_autonamed.output_shapes.as_list())
+
     sq_op = sq.get_concrete_function(
         [(tensor_spec.TensorSpec((None, None), dtypes.float32,
                                  name='first_mat'),
@@ -398,11 +399,10 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
 
     t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
     t2 = constant_op.constant([[1.4, 2.4], [3.4, 4.4]])
-    with self.assertRaisesRegexp(
-        TypeError, 'bound to Tensors within nested structures'):
-      sq_op(t1, t2)
     out = sq_op(first_mat=t1, second_mat=t2)
     self.assertAllEqual(out, math_ops.matmul(t1, t2).numpy())
+    self.assertAllEqual(sq_op_autonamed(t1, t2),
+                        math_ops.matmul(t1, t2).numpy())
 
   def testExecutingStatelessDefunConcurrently(self):
 
diff --git a/tensorflow/python/saved_model/save_test.py b/tensorflow/python/saved_model/save_test.py
index 14cb5abc075..a2ee86f6a9a 100644
--- a/tensorflow/python/saved_model/save_test.py
+++ b/tensorflow/python/saved_model/save_test.py
@@ -171,11 +171,6 @@ class SaveTest(test.TestCase):
         input_signature=([tensor_spec.TensorSpec(None, dtypes.float32),
                           tensor_spec.TensorSpec(None, dtypes.float32)],))
     root.f([constant_op.constant(1.), constant_op.constant(1.)])
-    # Concrete functions must always have uniquely named Tensor inputs. Save
-    # relies on this.
-    with self.assertRaisesRegexp(
-        ValueError, "two arguments named 'x'"):
-      root.f.get_concrete_function()
 
   def test_nested_outputs(self):
     root = tracking.AutoTrackable()