From b55f61086ad3456a87ad7b12df920db33904ad51 Mon Sep 17 00:00:00 2001
From: Alexandre Passos <apassos@google.com>
Date: Wed, 13 Mar 2019 10:32:17 -0700
Subject: [PATCH] Make Graph().create_op() dtype argument optional.

This makes the signature of create_op closer to the signature of execute, which will allow us to merge the graph and eager execution codepaths.

PiperOrigin-RevId: 238253922
---
 tensorflow/python/framework/func_graph.py        |  6 +++---
 tensorflow/python/framework/function.py          |  5 +++--
 tensorflow/python/framework/op_def_library.py    | 16 ++--------------
 tensorflow/python/framework/ops.py               |  6 +++---
 .../tools/api/golden/v1/tensorflow.-graph.pbtxt  |  2 +-
 .../tools/api/golden/v2/tensorflow.-graph.pbtxt  |  2 +-
 6 files changed, 13 insertions(+), 24 deletions(-)

diff --git a/tensorflow/python/framework/func_graph.py b/tensorflow/python/framework/func_graph.py
index 0b9593ccea1..aafab297ca1 100644
--- a/tensorflow/python/framework/func_graph.py
+++ b/tensorflow/python/framework/func_graph.py
@@ -384,7 +384,7 @@ class FuncGraph(ops.Graph):
       self,
       op_type,
       inputs,
-      dtypes,  # pylint: disable=redefined-outer-name
+      dtypes=None,  # pylint: disable=redefined-outer-name
       input_types=None,
       name=None,
       attrs=None,
@@ -401,8 +401,8 @@ class FuncGraph(ops.Graph):
       op_type: The `Operation` type to create. This corresponds to the
         `OpDef.name` field for the proto that defines the operation.
       inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
-      dtypes: A list of `DType` objects that will be the types of the tensors
-        that the operation produces.
+      dtypes: (Optional) A list of `DType` objects that will be the types of the
+        tensors that the operation produces.
       input_types: (Optional.) A list of `DType`s that will be the types of
         the tensors that the operation consumes. By default, uses the base
         `DType` of each input in `inputs`. Operations that expect
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index d7d069872ba..5589c93cc46 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -755,11 +755,12 @@ class _FuncGraph(ops.Graph):
         return var.value()
       return var
 
-  def create_op(self, op_type, inputs, data_types, **kwargs):
+  def create_op(self, op_type, inputs, data_types=None, **kwargs):
     for i, x in enumerate(inputs):
       if isinstance(x, ops.EagerTensor) or x.graph is not self:
         inputs[i] = self.capture(x)
-    return super(_FuncGraph, self).create_op(op_type, inputs, data_types,
+    return super(_FuncGraph, self).create_op(op_type, inputs,
+                                             dtypes=data_types,
                                              **kwargs)
 
   def capture(self, tensor, name=None):
diff --git a/tensorflow/python/framework/op_def_library.py b/tensorflow/python/framework/op_def_library.py
index 372763a862b..2e3757b9316 100644
--- a/tensorflow/python/framework/op_def_library.py
+++ b/tensorflow/python/framework/op_def_library.py
@@ -759,31 +759,19 @@ class OpDefLibrary(object):
       del attrs  # attrs is no longer authoritative, use attr_protos instead
 
       # Determine output types (possibly using attrs)
-      output_types = []
       output_structure = []
       for arg in op_def.output_arg:
-        types = []
         if arg.number_attr:
           n = _AttrValue(attr_protos, arg.number_attr).i
-          if arg.type_attr:
-            types = [_AttrValue(attr_protos, arg.type_attr).type] * n
-          else:
-            types = [arg.type] * n
           output_structure.append(n)
         elif arg.type_attr:
           t = _AttrValue(attr_protos, arg.type_attr)
-          types = [t.type]
           output_structure.append(None)
         elif arg.type_list_attr:
           t = _AttrValue(attr_protos, arg.type_list_attr)
-          types = t.list.type
-          output_structure.append(len(types))
+          output_structure.append(len(t.list.type))
         else:
-          types = [arg.type]
           output_structure.append(None)
-        if arg.is_ref:
-          types = [dtypes.as_dtype(x)._as_ref for x in types]  # pylint: disable=protected-access
-        output_types.extend(types)
 
       if keywords:
         raise TypeError("apply_op() got unexpected keyword arguments: " +
@@ -795,7 +783,7 @@ class OpDefLibrary(object):
                               if arg.is_ref]
       with _MaybeColocateWith(must_colocate_inputs):
         # Add Op to graph
-        op = g.create_op(op_type_name, inputs, output_types, name=scope,
+        op = g.create_op(op_type_name, inputs, name=scope,
                          input_types=input_types, attrs=attr_protos,
                          op_def=op_def)
       return output_structure, op_def.is_stateful, op
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 267b3873303..8dfcf381626 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -3402,7 +3402,7 @@ class Graph(object):
       self,
       op_type,
       inputs,
-      dtypes,  # pylint: disable=redefined-outer-name
+      dtypes=None,  # pylint: disable=redefined-outer-name
       input_types=None,
       name=None,
       attrs=None,
@@ -3420,8 +3420,8 @@ class Graph(object):
       op_type: The `Operation` type to create. This corresponds to the
         `OpDef.name` field for the proto that defines the operation.
       inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
-      dtypes: A list of `DType` objects that will be the types of the tensors
-        that the operation produces.
+      dtypes: (Optional) A list of `DType` objects that will be the types of the
+        tensors that the operation produces.
       input_types: (Optional.) A list of `DType`s that will be the types of
         the tensors that the operation consumes. By default, uses the base
         `DType` of each input in `inputs`. Operations that expect
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-graph.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-graph.pbtxt
index cdaeb55e308..9193168c207 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-graph.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-graph.pbtxt
@@ -68,7 +68,7 @@ tf_class {
   }
   member_method {
     name: "create_op"
-    argspec: "args=[\'self\', \'op_type\', \'inputs\', \'dtypes\', \'input_types\', \'name\', \'attrs\', \'op_def\', \'compute_shapes\', \'compute_device\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'True\', \'True\'], "
+    argspec: "args=[\'self\', \'op_type\', \'inputs\', \'dtypes\', \'input_types\', \'name\', \'attrs\', \'op_def\', \'compute_shapes\', \'compute_device\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'True\'], "
   }
   member_method {
     name: "device"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-graph.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-graph.pbtxt
index cdaeb55e308..9193168c207 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-graph.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-graph.pbtxt
@@ -68,7 +68,7 @@ tf_class {
   }
   member_method {
     name: "create_op"
-    argspec: "args=[\'self\', \'op_type\', \'inputs\', \'dtypes\', \'input_types\', \'name\', \'attrs\', \'op_def\', \'compute_shapes\', \'compute_device\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'True\', \'True\'], "
+    argspec: "args=[\'self\', \'op_type\', \'inputs\', \'dtypes\', \'input_types\', \'name\', \'attrs\', \'op_def\', \'compute_shapes\', \'compute_device\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'True\'], "
   }
   member_method {
     name: "device"