Make Graph().create_op() dtype argument optional.

This makes the signature of create_op closer to the signature of execute, which will allow us to merge the graph and eager execution codepaths.

PiperOrigin-RevId: 238253922
This commit is contained in:
Alexandre Passos 2019-03-13 10:32:17 -07:00 committed by TensorFlower Gardener
parent 959cb35b46
commit b55f61086a
6 changed files with 13 additions and 24 deletions

View File

@ -384,7 +384,7 @@ class FuncGraph(ops.Graph):
self,
op_type,
inputs,
dtypes, # pylint: disable=redefined-outer-name
dtypes=None, # pylint: disable=redefined-outer-name
input_types=None,
name=None,
attrs=None,
@ -401,8 +401,8 @@ class FuncGraph(ops.Graph):
op_type: The `Operation` type to create. This corresponds to the
`OpDef.name` field for the proto that defines the operation.
inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
dtypes: A list of `DType` objects that will be the types of the tensors
that the operation produces.
dtypes: (Optional) A list of `DType` objects that will be the types of the
tensors that the operation produces.
input_types: (Optional.) A list of `DType`s that will be the types of
the tensors that the operation consumes. By default, uses the base
`DType` of each input in `inputs`. Operations that expect

View File

@ -755,11 +755,12 @@ class _FuncGraph(ops.Graph):
return var.value()
return var
def create_op(self, op_type, inputs, data_types, **kwargs):
def create_op(self, op_type, inputs, data_types=None, **kwargs):
for i, x in enumerate(inputs):
if isinstance(x, ops.EagerTensor) or x.graph is not self:
inputs[i] = self.capture(x)
return super(_FuncGraph, self).create_op(op_type, inputs, data_types,
return super(_FuncGraph, self).create_op(op_type, inputs,
dtypes=data_types,
**kwargs)
def capture(self, tensor, name=None):

View File

@ -759,31 +759,19 @@ class OpDefLibrary(object):
del attrs # attrs is no longer authoritative, use attr_protos instead
# Determine output types (possibly using attrs)
output_types = []
output_structure = []
for arg in op_def.output_arg:
types = []
if arg.number_attr:
n = _AttrValue(attr_protos, arg.number_attr).i
if arg.type_attr:
types = [_AttrValue(attr_protos, arg.type_attr).type] * n
else:
types = [arg.type] * n
output_structure.append(n)
elif arg.type_attr:
t = _AttrValue(attr_protos, arg.type_attr)
types = [t.type]
output_structure.append(None)
elif arg.type_list_attr:
t = _AttrValue(attr_protos, arg.type_list_attr)
types = t.list.type
output_structure.append(len(types))
output_structure.append(len(t.list.type))
else:
types = [arg.type]
output_structure.append(None)
if arg.is_ref:
types = [dtypes.as_dtype(x)._as_ref for x in types] # pylint: disable=protected-access
output_types.extend(types)
if keywords:
raise TypeError("apply_op() got unexpected keyword arguments: " +
@ -795,7 +783,7 @@ class OpDefLibrary(object):
if arg.is_ref]
with _MaybeColocateWith(must_colocate_inputs):
# Add Op to graph
op = g.create_op(op_type_name, inputs, output_types, name=scope,
op = g.create_op(op_type_name, inputs, name=scope,
input_types=input_types, attrs=attr_protos,
op_def=op_def)
return output_structure, op_def.is_stateful, op

View File

@ -3402,7 +3402,7 @@ class Graph(object):
self,
op_type,
inputs,
dtypes, # pylint: disable=redefined-outer-name
dtypes=None, # pylint: disable=redefined-outer-name
input_types=None,
name=None,
attrs=None,
@ -3420,8 +3420,8 @@ class Graph(object):
op_type: The `Operation` type to create. This corresponds to the
`OpDef.name` field for the proto that defines the operation.
inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
dtypes: A list of `DType` objects that will be the types of the tensors
that the operation produces.
dtypes: (Optional) A list of `DType` objects that will be the types of the
tensors that the operation produces.
input_types: (Optional.) A list of `DType`s that will be the types of
the tensors that the operation consumes. By default, uses the base
`DType` of each input in `inputs`. Operations that expect

View File

@ -68,7 +68,7 @@ tf_class {
}
member_method {
name: "create_op"
argspec: "args=[\'self\', \'op_type\', \'inputs\', \'dtypes\', \'input_types\', \'name\', \'attrs\', \'op_def\', \'compute_shapes\', \'compute_device\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'True\', \'True\'], "
argspec: "args=[\'self\', \'op_type\', \'inputs\', \'dtypes\', \'input_types\', \'name\', \'attrs\', \'op_def\', \'compute_shapes\', \'compute_device\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'True\'], "
}
member_method {
name: "device"

View File

@ -68,7 +68,7 @@ tf_class {
}
member_method {
name: "create_op"
argspec: "args=[\'self\', \'op_type\', \'inputs\', \'dtypes\', \'input_types\', \'name\', \'attrs\', \'op_def\', \'compute_shapes\', \'compute_device\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'True\', \'True\'], "
argspec: "args=[\'self\', \'op_type\', \'inputs\', \'dtypes\', \'input_types\', \'name\', \'attrs\', \'op_def\', \'compute_shapes\', \'compute_device\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'True\'], "
}
member_method {
name: "device"