Unify signatures of Graph().create_op on subclasses.
PiperOrigin-RevId: 238470192
This commit is contained in:
parent
8e248b89c7
commit
32edfdd8e4
@ -755,13 +755,12 @@ class _FuncGraph(ops.Graph):
|
|||||||
return var.value()
|
return var.value()
|
||||||
return var
|
return var
|
||||||
|
|
||||||
def create_op(self, op_type, inputs, data_types=None, **kwargs):
|
def create_op(self, op_type, inputs, dtypes=None, **kwargs): # pylint: disable=redefined-outer-name
|
||||||
for i, x in enumerate(inputs):
|
for i, x in enumerate(inputs):
|
||||||
if isinstance(x, ops.EagerTensor) or x.graph is not self:
|
if isinstance(x, ops.EagerTensor) or x.graph is not self:
|
||||||
inputs[i] = self.capture(x)
|
inputs[i] = self.capture(x)
|
||||||
return super(_FuncGraph, self).create_op(op_type, inputs,
|
return super(_FuncGraph, self).create_op(op_type, inputs,
|
||||||
dtypes=data_types,
|
dtypes=dtypes, **kwargs)
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
def capture(self, tensor, name=None):
|
def capture(self, tensor, name=None):
|
||||||
"""Adds the given tensor to this graph and returns the captured tensor."""
|
"""Adds the given tensor to this graph and returns the captured tensor."""
|
||||||
|
@ -783,7 +783,7 @@ class OpDefLibrary(object):
|
|||||||
if arg.is_ref]
|
if arg.is_ref]
|
||||||
with _MaybeColocateWith(must_colocate_inputs):
|
with _MaybeColocateWith(must_colocate_inputs):
|
||||||
# Add Op to graph
|
# Add Op to graph
|
||||||
op = g.create_op(op_type_name, inputs, name=scope,
|
op = g.create_op(op_type_name, inputs, dtypes=None, name=scope,
|
||||||
input_types=input_types, attrs=attr_protos,
|
input_types=input_types, attrs=attr_protos,
|
||||||
op_def=op_def)
|
op_def=op_def)
|
||||||
return output_structure, op_def.is_stateful, op
|
return output_structure, op_def.is_stateful, op
|
||||||
|
Loading…
Reference in New Issue
Block a user