Make Graph().create_op() dtype argument optional.
This makes the signature of create_op closer to the signature of execute, which will allow us to merge the graph and eager execution codepaths. PiperOrigin-RevId: 238253922
This commit is contained in:
parent
959cb35b46
commit
b55f61086a
@ -384,7 +384,7 @@ class FuncGraph(ops.Graph):
|
|||||||
self,
|
self,
|
||||||
op_type,
|
op_type,
|
||||||
inputs,
|
inputs,
|
||||||
dtypes, # pylint: disable=redefined-outer-name
|
dtypes=None, # pylint: disable=redefined-outer-name
|
||||||
input_types=None,
|
input_types=None,
|
||||||
name=None,
|
name=None,
|
||||||
attrs=None,
|
attrs=None,
|
||||||
@ -401,8 +401,8 @@ class FuncGraph(ops.Graph):
|
|||||||
op_type: The `Operation` type to create. This corresponds to the
|
op_type: The `Operation` type to create. This corresponds to the
|
||||||
`OpDef.name` field for the proto that defines the operation.
|
`OpDef.name` field for the proto that defines the operation.
|
||||||
inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
|
inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
|
||||||
dtypes: A list of `DType` objects that will be the types of the tensors
|
dtypes: (Optional) A list of `DType` objects that will be the types of the
|
||||||
that the operation produces.
|
tensors that the operation produces.
|
||||||
input_types: (Optional.) A list of `DType`s that will be the types of
|
input_types: (Optional.) A list of `DType`s that will be the types of
|
||||||
the tensors that the operation consumes. By default, uses the base
|
the tensors that the operation consumes. By default, uses the base
|
||||||
`DType` of each input in `inputs`. Operations that expect
|
`DType` of each input in `inputs`. Operations that expect
|
||||||
|
|||||||
@ -755,11 +755,12 @@ class _FuncGraph(ops.Graph):
|
|||||||
return var.value()
|
return var.value()
|
||||||
return var
|
return var
|
||||||
|
|
||||||
def create_op(self, op_type, inputs, data_types, **kwargs):
|
def create_op(self, op_type, inputs, data_types=None, **kwargs):
|
||||||
for i, x in enumerate(inputs):
|
for i, x in enumerate(inputs):
|
||||||
if isinstance(x, ops.EagerTensor) or x.graph is not self:
|
if isinstance(x, ops.EagerTensor) or x.graph is not self:
|
||||||
inputs[i] = self.capture(x)
|
inputs[i] = self.capture(x)
|
||||||
return super(_FuncGraph, self).create_op(op_type, inputs, data_types,
|
return super(_FuncGraph, self).create_op(op_type, inputs,
|
||||||
|
dtypes=data_types,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
def capture(self, tensor, name=None):
|
def capture(self, tensor, name=None):
|
||||||
|
|||||||
@ -759,31 +759,19 @@ class OpDefLibrary(object):
|
|||||||
del attrs # attrs is no longer authoritative, use attr_protos instead
|
del attrs # attrs is no longer authoritative, use attr_protos instead
|
||||||
|
|
||||||
# Determine output types (possibly using attrs)
|
# Determine output types (possibly using attrs)
|
||||||
output_types = []
|
|
||||||
output_structure = []
|
output_structure = []
|
||||||
for arg in op_def.output_arg:
|
for arg in op_def.output_arg:
|
||||||
types = []
|
|
||||||
if arg.number_attr:
|
if arg.number_attr:
|
||||||
n = _AttrValue(attr_protos, arg.number_attr).i
|
n = _AttrValue(attr_protos, arg.number_attr).i
|
||||||
if arg.type_attr:
|
|
||||||
types = [_AttrValue(attr_protos, arg.type_attr).type] * n
|
|
||||||
else:
|
|
||||||
types = [arg.type] * n
|
|
||||||
output_structure.append(n)
|
output_structure.append(n)
|
||||||
elif arg.type_attr:
|
elif arg.type_attr:
|
||||||
t = _AttrValue(attr_protos, arg.type_attr)
|
t = _AttrValue(attr_protos, arg.type_attr)
|
||||||
types = [t.type]
|
|
||||||
output_structure.append(None)
|
output_structure.append(None)
|
||||||
elif arg.type_list_attr:
|
elif arg.type_list_attr:
|
||||||
t = _AttrValue(attr_protos, arg.type_list_attr)
|
t = _AttrValue(attr_protos, arg.type_list_attr)
|
||||||
types = t.list.type
|
output_structure.append(len(t.list.type))
|
||||||
output_structure.append(len(types))
|
|
||||||
else:
|
else:
|
||||||
types = [arg.type]
|
|
||||||
output_structure.append(None)
|
output_structure.append(None)
|
||||||
if arg.is_ref:
|
|
||||||
types = [dtypes.as_dtype(x)._as_ref for x in types] # pylint: disable=protected-access
|
|
||||||
output_types.extend(types)
|
|
||||||
|
|
||||||
if keywords:
|
if keywords:
|
||||||
raise TypeError("apply_op() got unexpected keyword arguments: " +
|
raise TypeError("apply_op() got unexpected keyword arguments: " +
|
||||||
@ -795,7 +783,7 @@ class OpDefLibrary(object):
|
|||||||
if arg.is_ref]
|
if arg.is_ref]
|
||||||
with _MaybeColocateWith(must_colocate_inputs):
|
with _MaybeColocateWith(must_colocate_inputs):
|
||||||
# Add Op to graph
|
# Add Op to graph
|
||||||
op = g.create_op(op_type_name, inputs, output_types, name=scope,
|
op = g.create_op(op_type_name, inputs, name=scope,
|
||||||
input_types=input_types, attrs=attr_protos,
|
input_types=input_types, attrs=attr_protos,
|
||||||
op_def=op_def)
|
op_def=op_def)
|
||||||
return output_structure, op_def.is_stateful, op
|
return output_structure, op_def.is_stateful, op
|
||||||
|
|||||||
@ -3402,7 +3402,7 @@ class Graph(object):
|
|||||||
self,
|
self,
|
||||||
op_type,
|
op_type,
|
||||||
inputs,
|
inputs,
|
||||||
dtypes, # pylint: disable=redefined-outer-name
|
dtypes=None, # pylint: disable=redefined-outer-name
|
||||||
input_types=None,
|
input_types=None,
|
||||||
name=None,
|
name=None,
|
||||||
attrs=None,
|
attrs=None,
|
||||||
@ -3420,8 +3420,8 @@ class Graph(object):
|
|||||||
op_type: The `Operation` type to create. This corresponds to the
|
op_type: The `Operation` type to create. This corresponds to the
|
||||||
`OpDef.name` field for the proto that defines the operation.
|
`OpDef.name` field for the proto that defines the operation.
|
||||||
inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
|
inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
|
||||||
dtypes: A list of `DType` objects that will be the types of the tensors
|
dtypes: (Optional) A list of `DType` objects that will be the types of the
|
||||||
that the operation produces.
|
tensors that the operation produces.
|
||||||
input_types: (Optional.) A list of `DType`s that will be the types of
|
input_types: (Optional.) A list of `DType`s that will be the types of
|
||||||
the tensors that the operation consumes. By default, uses the base
|
the tensors that the operation consumes. By default, uses the base
|
||||||
`DType` of each input in `inputs`. Operations that expect
|
`DType` of each input in `inputs`. Operations that expect
|
||||||
|
|||||||
@ -68,7 +68,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "create_op"
|
name: "create_op"
|
||||||
argspec: "args=[\'self\', \'op_type\', \'inputs\', \'dtypes\', \'input_types\', \'name\', \'attrs\', \'op_def\', \'compute_shapes\', \'compute_device\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'True\', \'True\'], "
|
argspec: "args=[\'self\', \'op_type\', \'inputs\', \'dtypes\', \'input_types\', \'name\', \'attrs\', \'op_def\', \'compute_shapes\', \'compute_device\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'True\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "device"
|
name: "device"
|
||||||
|
|||||||
@ -68,7 +68,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "create_op"
|
name: "create_op"
|
||||||
argspec: "args=[\'self\', \'op_type\', \'inputs\', \'dtypes\', \'input_types\', \'name\', \'attrs\', \'op_def\', \'compute_shapes\', \'compute_device\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'True\', \'True\'], "
|
argspec: "args=[\'self\', \'op_type\', \'inputs\', \'dtypes\', \'input_types\', \'name\', \'attrs\', \'op_def\', \'compute_shapes\', \'compute_device\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'True\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "device"
|
name: "device"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user