Make Graph().create_op() dtype argument optional.
This makes the signature of create_op closer to the signature of execute, which will allow us to merge the graph and eager execution codepaths. PiperOrigin-RevId: 238253922
This commit is contained in:
parent
959cb35b46
commit
b55f61086a
tensorflow
python/framework
tools/api/golden
@ -384,7 +384,7 @@ class FuncGraph(ops.Graph):
|
||||
self,
|
||||
op_type,
|
||||
inputs,
|
||||
dtypes, # pylint: disable=redefined-outer-name
|
||||
dtypes=None, # pylint: disable=redefined-outer-name
|
||||
input_types=None,
|
||||
name=None,
|
||||
attrs=None,
|
||||
@ -401,8 +401,8 @@ class FuncGraph(ops.Graph):
|
||||
op_type: The `Operation` type to create. This corresponds to the
|
||||
`OpDef.name` field for the proto that defines the operation.
|
||||
inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
|
||||
dtypes: A list of `DType` objects that will be the types of the tensors
|
||||
that the operation produces.
|
||||
dtypes: (Optional) A list of `DType` objects that will be the types of the
|
||||
tensors that the operation produces.
|
||||
input_types: (Optional.) A list of `DType`s that will be the types of
|
||||
the tensors that the operation consumes. By default, uses the base
|
||||
`DType` of each input in `inputs`. Operations that expect
|
||||
|
@ -755,11 +755,12 @@ class _FuncGraph(ops.Graph):
|
||||
return var.value()
|
||||
return var
|
||||
|
||||
def create_op(self, op_type, inputs, data_types, **kwargs):
|
||||
def create_op(self, op_type, inputs, data_types=None, **kwargs):
|
||||
for i, x in enumerate(inputs):
|
||||
if isinstance(x, ops.EagerTensor) or x.graph is not self:
|
||||
inputs[i] = self.capture(x)
|
||||
return super(_FuncGraph, self).create_op(op_type, inputs, data_types,
|
||||
return super(_FuncGraph, self).create_op(op_type, inputs,
|
||||
dtypes=data_types,
|
||||
**kwargs)
|
||||
|
||||
def capture(self, tensor, name=None):
|
||||
|
@ -759,31 +759,19 @@ class OpDefLibrary(object):
|
||||
del attrs # attrs is no longer authoritative, use attr_protos instead
|
||||
|
||||
# Determine output types (possibly using attrs)
|
||||
output_types = []
|
||||
output_structure = []
|
||||
for arg in op_def.output_arg:
|
||||
types = []
|
||||
if arg.number_attr:
|
||||
n = _AttrValue(attr_protos, arg.number_attr).i
|
||||
if arg.type_attr:
|
||||
types = [_AttrValue(attr_protos, arg.type_attr).type] * n
|
||||
else:
|
||||
types = [arg.type] * n
|
||||
output_structure.append(n)
|
||||
elif arg.type_attr:
|
||||
t = _AttrValue(attr_protos, arg.type_attr)
|
||||
types = [t.type]
|
||||
output_structure.append(None)
|
||||
elif arg.type_list_attr:
|
||||
t = _AttrValue(attr_protos, arg.type_list_attr)
|
||||
types = t.list.type
|
||||
output_structure.append(len(types))
|
||||
output_structure.append(len(t.list.type))
|
||||
else:
|
||||
types = [arg.type]
|
||||
output_structure.append(None)
|
||||
if arg.is_ref:
|
||||
types = [dtypes.as_dtype(x)._as_ref for x in types] # pylint: disable=protected-access
|
||||
output_types.extend(types)
|
||||
|
||||
if keywords:
|
||||
raise TypeError("apply_op() got unexpected keyword arguments: " +
|
||||
@ -795,7 +783,7 @@ class OpDefLibrary(object):
|
||||
if arg.is_ref]
|
||||
with _MaybeColocateWith(must_colocate_inputs):
|
||||
# Add Op to graph
|
||||
op = g.create_op(op_type_name, inputs, output_types, name=scope,
|
||||
op = g.create_op(op_type_name, inputs, name=scope,
|
||||
input_types=input_types, attrs=attr_protos,
|
||||
op_def=op_def)
|
||||
return output_structure, op_def.is_stateful, op
|
||||
|
@ -3402,7 +3402,7 @@ class Graph(object):
|
||||
self,
|
||||
op_type,
|
||||
inputs,
|
||||
dtypes, # pylint: disable=redefined-outer-name
|
||||
dtypes=None, # pylint: disable=redefined-outer-name
|
||||
input_types=None,
|
||||
name=None,
|
||||
attrs=None,
|
||||
@ -3420,8 +3420,8 @@ class Graph(object):
|
||||
op_type: The `Operation` type to create. This corresponds to the
|
||||
`OpDef.name` field for the proto that defines the operation.
|
||||
inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
|
||||
dtypes: A list of `DType` objects that will be the types of the tensors
|
||||
that the operation produces.
|
||||
dtypes: (Optional) A list of `DType` objects that will be the types of the
|
||||
tensors that the operation produces.
|
||||
input_types: (Optional.) A list of `DType`s that will be the types of
|
||||
the tensors that the operation consumes. By default, uses the base
|
||||
`DType` of each input in `inputs`. Operations that expect
|
||||
|
@ -68,7 +68,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "create_op"
|
||||
argspec: "args=[\'self\', \'op_type\', \'inputs\', \'dtypes\', \'input_types\', \'name\', \'attrs\', \'op_def\', \'compute_shapes\', \'compute_device\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'True\', \'True\'], "
|
||||
argspec: "args=[\'self\', \'op_type\', \'inputs\', \'dtypes\', \'input_types\', \'name\', \'attrs\', \'op_def\', \'compute_shapes\', \'compute_device\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "device"
|
||||
|
@ -68,7 +68,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "create_op"
|
||||
argspec: "args=[\'self\', \'op_type\', \'inputs\', \'dtypes\', \'input_types\', \'name\', \'attrs\', \'op_def\', \'compute_shapes\', \'compute_device\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'True\', \'True\'], "
|
||||
argspec: "args=[\'self\', \'op_type\', \'inputs\', \'dtypes\', \'input_types\', \'name\', \'attrs\', \'op_def\', \'compute_shapes\', \'compute_device\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "device"
|
||||
|
Loading…
Reference in New Issue
Block a user