Unify signatures of Graph().create_op on subclasses.
PiperOrigin-RevId: 238470192
This commit is contained in:
parent
8e248b89c7
commit
32edfdd8e4
@ -755,13 +755,12 @@ class _FuncGraph(ops.Graph):
|
||||
return var.value()
|
||||
return var
|
||||
|
||||
def create_op(self, op_type, inputs, data_types=None, **kwargs):
|
||||
def create_op(self, op_type, inputs, dtypes=None, **kwargs): # pylint: disable=redefined-outer-name
|
||||
for i, x in enumerate(inputs):
|
||||
if isinstance(x, ops.EagerTensor) or x.graph is not self:
|
||||
inputs[i] = self.capture(x)
|
||||
return super(_FuncGraph, self).create_op(op_type, inputs,
|
||||
dtypes=data_types,
|
||||
**kwargs)
|
||||
dtypes=dtypes, **kwargs)
|
||||
|
||||
def capture(self, tensor, name=None):
|
||||
"""Adds the given tensor to this graph and returns the captured tensor."""
|
||||
|
@ -783,7 +783,7 @@ class OpDefLibrary(object):
|
||||
if arg.is_ref]
|
||||
with _MaybeColocateWith(must_colocate_inputs):
|
||||
# Add Op to graph
|
||||
op = g.create_op(op_type_name, inputs, name=scope,
|
||||
op = g.create_op(op_type_name, inputs, dtypes=None, name=scope,
|
||||
input_types=input_types, attrs=attr_protos,
|
||||
op_def=op_def)
|
||||
return output_structure, op_def.is_stateful, op
|
||||
|
Loading…
Reference in New Issue
Block a user