From b55f61086ad3456a87ad7b12df920db33904ad51 Mon Sep 17 00:00:00 2001 From: Alexandre Passos <apassos@google.com> Date: Wed, 13 Mar 2019 10:32:17 -0700 Subject: [PATCH] Make Graph().create_op() dtype argument optional. This makes the signature of create_op closer to the signature of execute, which will allow us to merge the graph and eager execution codepaths. PiperOrigin-RevId: 238253922 --- tensorflow/python/framework/func_graph.py | 6 +++--- tensorflow/python/framework/function.py | 5 +++-- tensorflow/python/framework/op_def_library.py | 16 ++-------------- tensorflow/python/framework/ops.py | 6 +++--- .../tools/api/golden/v1/tensorflow.-graph.pbtxt | 2 +- .../tools/api/golden/v2/tensorflow.-graph.pbtxt | 2 +- 6 files changed, 13 insertions(+), 24 deletions(-) diff --git a/tensorflow/python/framework/func_graph.py b/tensorflow/python/framework/func_graph.py index 0b9593ccea1..aafab297ca1 100644 --- a/tensorflow/python/framework/func_graph.py +++ b/tensorflow/python/framework/func_graph.py @@ -384,7 +384,7 @@ class FuncGraph(ops.Graph): self, op_type, inputs, - dtypes, # pylint: disable=redefined-outer-name + dtypes=None, # pylint: disable=redefined-outer-name input_types=None, name=None, attrs=None, @@ -401,8 +401,8 @@ class FuncGraph(ops.Graph): op_type: The `Operation` type to create. This corresponds to the `OpDef.name` field for the proto that defines the operation. inputs: A list of `Tensor` objects that will be inputs to the `Operation`. - dtypes: A list of `DType` objects that will be the types of the tensors - that the operation produces. + dtypes: (Optional) A list of `DType` objects that will be the types of the + tensors that the operation produces. input_types: (Optional.) A list of `DType`s that will be the types of the tensors that the operation consumes. By default, uses the base `DType` of each input in `inputs`. Operations that expect diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index d7d069872ba..5589c93cc46 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -755,11 +755,12 @@ class _FuncGraph(ops.Graph): return var.value() return var - def create_op(self, op_type, inputs, data_types, **kwargs): + def create_op(self, op_type, inputs, data_types=None, **kwargs): for i, x in enumerate(inputs): if isinstance(x, ops.EagerTensor) or x.graph is not self: inputs[i] = self.capture(x) - return super(_FuncGraph, self).create_op(op_type, inputs, data_types, + return super(_FuncGraph, self).create_op(op_type, inputs, + dtypes=data_types, **kwargs) def capture(self, tensor, name=None): diff --git a/tensorflow/python/framework/op_def_library.py b/tensorflow/python/framework/op_def_library.py index 372763a862b..2e3757b9316 100644 --- a/tensorflow/python/framework/op_def_library.py +++ b/tensorflow/python/framework/op_def_library.py @@ -759,31 +759,19 @@ class OpDefLibrary(object): del attrs # attrs is no longer authoritative, use attr_protos instead # Determine output types (possibly using attrs) - output_types = [] output_structure = [] for arg in op_def.output_arg: - types = [] if arg.number_attr: n = _AttrValue(attr_protos, arg.number_attr).i - if arg.type_attr: - types = [_AttrValue(attr_protos, arg.type_attr).type] * n - else: - types = [arg.type] * n output_structure.append(n) elif arg.type_attr: t = _AttrValue(attr_protos, arg.type_attr) - types = [t.type] output_structure.append(None) elif arg.type_list_attr: t = _AttrValue(attr_protos, arg.type_list_attr) - types = t.list.type - output_structure.append(len(types)) + output_structure.append(len(t.list.type)) else: - types = [arg.type] output_structure.append(None) - if arg.is_ref: - types = [dtypes.as_dtype(x)._as_ref for x in types] # pylint: disable=protected-access - output_types.extend(types) if keywords: raise TypeError("apply_op() got unexpected keyword arguments: " + @@ -795,7 +783,7 @@ class OpDefLibrary(object): if arg.is_ref] with _MaybeColocateWith(must_colocate_inputs): # Add Op to graph - op = g.create_op(op_type_name, inputs, output_types, name=scope, + op = g.create_op(op_type_name, inputs, name=scope, input_types=input_types, attrs=attr_protos, op_def=op_def) return output_structure, op_def.is_stateful, op diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 267b3873303..8dfcf381626 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -3402,7 +3402,7 @@ class Graph(object): self, op_type, inputs, - dtypes, # pylint: disable=redefined-outer-name + dtypes=None, # pylint: disable=redefined-outer-name input_types=None, name=None, attrs=None, @@ -3420,8 +3420,8 @@ class Graph(object): op_type: The `Operation` type to create. This corresponds to the `OpDef.name` field for the proto that defines the operation. inputs: A list of `Tensor` objects that will be inputs to the `Operation`. - dtypes: A list of `DType` objects that will be the types of the tensors - that the operation produces. + dtypes: (Optional) A list of `DType` objects that will be the types of the + tensors that the operation produces. input_types: (Optional.) A list of `DType`s that will be the types of the tensors that the operation consumes. By default, uses the base `DType` of each input in `inputs`. Operations that expect diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-graph.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-graph.pbtxt index cdaeb55e308..9193168c207 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-graph.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-graph.pbtxt @@ -68,7 +68,7 @@ tf_class { } member_method { name: "create_op" - argspec: "args=[\'self\', \'op_type\', \'inputs\', \'dtypes\', \'input_types\', \'name\', \'attrs\', \'op_def\', \'compute_shapes\', \'compute_device\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'True\', \'True\'], " + argspec: "args=[\'self\', \'op_type\', \'inputs\', \'dtypes\', \'input_types\', \'name\', \'attrs\', \'op_def\', \'compute_shapes\', \'compute_device\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'True\'], " } member_method { name: "device" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-graph.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-graph.pbtxt index cdaeb55e308..9193168c207 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.-graph.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.-graph.pbtxt @@ -68,7 +68,7 @@ tf_class { } member_method { name: "create_op" - argspec: "args=[\'self\', \'op_type\', \'inputs\', \'dtypes\', \'input_types\', \'name\', \'attrs\', \'op_def\', \'compute_shapes\', \'compute_device\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'True\', \'True\'], " + argspec: "args=[\'self\', \'op_type\', \'inputs\', \'dtypes\', \'input_types\', \'name\', \'attrs\', \'op_def\', \'compute_shapes\', \'compute_device\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'True\'], " } member_method { name: "device"