From 4d35b20572f3eaad09e1681371d2d09cec7bb8ef Mon Sep 17 00:00:00 2001
From: Edward Loper <edloper@google.com>
Date: Mon, 30 Sep 2019 09:15:54 -0700
Subject: [PATCH] For ops with an input for the form "a: N * T", if N=0 and
 there's no other information about what type T should resolve to, then set T
 to its default value (if one is present).

PiperOrigin-RevId: 271996738
---
 tensorflow/python/eager/execute.py            |  2 ++
 tensorflow/python/eager/ops_test.py           | 14 ++++----
 tensorflow/python/framework/op_def_library.py | 25 +++++++++------
 .../python/framework/op_def_library_test.py   | 32 ++++++++++++++++---
 tensorflow/python/framework/test_ops.cc       |  2 +-
 5 files changed, 54 insertions(+), 21 deletions(-)

diff --git a/tensorflow/python/eager/execute.py b/tensorflow/python/eager/execute.py
index 30c9ca6217d..92e22beefae 100644
--- a/tensorflow/python/eager/execute.py
+++ b/tensorflow/python/eager/execute.py
@@ -235,6 +235,8 @@ def make_tensor(v, arg_name):
 
 def args_to_matching_eager(l, ctx, default_dtype=None):
   """Convert sequence `l` to eager same-type Tensors."""
+  if (not l) and (default_dtype is not None):
+    return default_dtype, []  # List is empty; assume default dtype.
   EagerTensor = ops.EagerTensor  # pylint: disable=invalid-name
   for x in l:
     if not isinstance(x, EagerTensor):
diff --git a/tensorflow/python/eager/ops_test.py b/tensorflow/python/eager/ops_test.py
index cebf786d6a8..c7748fd12e1 100644
--- a/tensorflow/python/eager/ops_test.py
+++ b/tensorflow/python/eager/ops_test.py
@@ -326,16 +326,18 @@ class OpsTest(test_util.TensorFlowTestCase):
     # Uses default
     ctx = context.context()
     t, r = execute.args_to_matching_eager([[3, 4]], ctx, dtypes.int32)
-    self.assertEquals(t, dtypes.int32)
-    self.assertEquals(r[0].dtype, dtypes.int32)
+    self.assertEqual(t, dtypes.int32)
+    self.assertEqual(r[0].dtype, dtypes.int32)
     t, r = execute.args_to_matching_eager([[3, 4]], ctx, dtypes.int64)
-    self.assertEquals(t, dtypes.int64)
-    self.assertEquals(r[0].dtype, dtypes.int64)
+    self.assertEqual(t, dtypes.int64)
+    self.assertEqual(r[0].dtype, dtypes.int64)
+    t, r = execute.args_to_matching_eager([], ctx, dtypes.int64)
+    self.assertEqual(t, dtypes.int64)
     # Doesn't use default
     t, r = execute.args_to_matching_eager(
         [['string', 'arg']], ctx, dtypes.int32)
-    self.assertEquals(t, dtypes.string)
-    self.assertEquals(r[0].dtype, dtypes.string)
+    self.assertEqual(t, dtypes.string)
+    self.assertEqual(r[0].dtype, dtypes.string)
 
   def testFlattenLayer(self):
     flatten_layer = core.Flatten()
diff --git a/tensorflow/python/framework/op_def_library.py b/tensorflow/python/framework/op_def_library.py
index ec780f26d6b..e8479b10bed 100644
--- a/tensorflow/python/framework/op_def_library.py
+++ b/tensorflow/python/framework/op_def_library.py
@@ -545,15 +545,19 @@ def _apply_op_helper(op_type_name, name=None, **keywords):  # pylint: disable=in
           # <number-attr> * <type-attr> case, where we are now setting
           # the <type-attr> based on this input
           if not base_types:
-            raise TypeError(
-                "Don't know how to infer type variable from empty input "
-                "list passed to input '%s' of '%s' Op." %
-                (input_name, op_type_name))
-          attrs[input_arg.type_attr] = base_types[0]
-          inferred_from[input_arg.type_attr] = input_name
-          type_attr = _Attr(op_def, input_arg.type_attr)
-          _SatisfiesTypeConstraint(base_types[0], type_attr,
-                                   param_name=input_name)
+            # If it's in default_type_attr_map, then wait to set it
+            # (in "process remaining attrs", below).
+            if input_arg.type_attr not in default_type_attr_map:
+              raise TypeError(
+                  "Don't know how to infer type variable from empty input "
+                  "list passed to input '%s' of '%s' Op." %
+                  (input_name, op_type_name))
+          else:
+            attrs[input_arg.type_attr] = base_types[0]
+            inferred_from[input_arg.type_attr] = input_name
+            type_attr = _Attr(op_def, input_arg.type_attr)
+            _SatisfiesTypeConstraint(base_types[0], type_attr,
+                                     param_name=input_name)
       elif input_arg.type_attr:
         # <type-attr>
         attr_value = base_types[0]
@@ -620,6 +624,9 @@ def _apply_op_helper(op_type_name, name=None, **keywords):  # pylint: disable=in
         # Attrs whose names match Python keywords have an extra '_'
         # appended, so we must check for that as well.
         attrs[attr.name] = keywords.pop(attr.name + "_")
+      elif attr.name in default_type_attr_map:
+        attrs[attr.name] = default_type_attr_map[attr.name]
+        inferred_from.setdefault(attr.name, "Default in OpDef")
       else:
         raise TypeError("No argument for attr " + attr.name)
 
diff --git a/tensorflow/python/framework/op_def_library_test.py b/tensorflow/python/framework/op_def_library_test.py
index 83990d7648b..dda42f246e0 100644
--- a/tensorflow/python/framework/op_def_library_test.py
+++ b/tensorflow/python/framework/op_def_library_test.py
@@ -1048,11 +1048,33 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
         attr { key: 'M' value { i: 0 } }
         """, op.node_def)
 
-      with self.assertRaises(TypeError) as cm:
-        op_def_library.apply_op("InPolymorphicTwice", a=[], b=[3, 4, 5])
-      self.assertEqual(str(cm.exception),
-                       "Don't know how to infer type variable from empty input "
-                       "list passed to input 'a' of 'InPolymorphicTwice' Op.")
+      op = op_def_library.apply_op(
+          "InPolymorphicTwice", a=[], b=[3, 4], name="p")
+      self.assertProtoEquals("""
+        name: 'p' op: 'InPolymorphicTwice' input: 'p/b_0' input: 'p/b_1'
+        attr { key: 'T' value { type: DT_INT32 } }
+        attr { key: 'N' value { i: 0 } }
+        attr { key: 'M' value { i: 2 } }
+        """, op.node_def)
+
+      op = op_def_library.apply_op(
+          "InPolymorphicTwice", a=[], b=[3.0, 4.0], name="q")
+      self.assertProtoEquals("""
+        name: 'q' op: 'InPolymorphicTwice' input: 'q/b_0' input: 'q/b_1'
+        attr { key: 'T' value { type: DT_FLOAT } }
+        attr { key: 'N' value { i: 0 } }
+        attr { key: 'M' value { i: 2 } }
+        """, op.node_def)
+
+      # Empty input lists: assume defaut type for T.
+      op = op_def_library.apply_op(
+          "InPolymorphicTwice", a=[], b=[], name="r")
+      self.assertProtoEquals("""
+        name: 'r' op: 'InPolymorphicTwice'
+        attr { key: 'T' value { type: DT_INT32 } }
+        attr { key: 'N' value { i: 0 } }
+        attr { key: 'M' value { i: 0 } }
+        """, op.node_def)
 
       with self.assertRaises(TypeError) as cm:
         op_def_library.apply_op(
diff --git a/tensorflow/python/framework/test_ops.cc b/tensorflow/python/framework/test_ops.cc
index 623081c2dc5..fc864692b7b 100644
--- a/tensorflow/python/framework/test_ops.cc
+++ b/tensorflow/python/framework/test_ops.cc
@@ -598,7 +598,7 @@ REGISTER_OP("NInTwoTypeVariables")
 REGISTER_OP("InPolymorphicTwice")
     .Input("a: N * T")
     .Input("b: M * T")
-    .Attr("T: type")
+    .Attr("T: type = DT_INT32")
     .Attr("N: int >= 0")
     .Attr("M: int >= 0")
     .SetShapeFn(shape_inference::UnknownShape);