Unify signatures of Graph().create_op on subclasses.

PiperOrigin-RevId: 238470192
This commit is contained in:
Alexandre Passos 2019-03-14 10:30:14 -07:00 committed by TensorFlower Gardener
parent 8e248b89c7
commit 32edfdd8e4
2 changed files with 3 additions and 4 deletions

View File

@ -755,13 +755,12 @@ class _FuncGraph(ops.Graph):
return var.value() return var.value()
return var return var
def create_op(self, op_type, inputs, data_types=None, **kwargs): def create_op(self, op_type, inputs, dtypes=None, **kwargs): # pylint: disable=redefined-outer-name
for i, x in enumerate(inputs): for i, x in enumerate(inputs):
if isinstance(x, ops.EagerTensor) or x.graph is not self: if isinstance(x, ops.EagerTensor) or x.graph is not self:
inputs[i] = self.capture(x) inputs[i] = self.capture(x)
return super(_FuncGraph, self).create_op(op_type, inputs, return super(_FuncGraph, self).create_op(op_type, inputs,
dtypes=data_types, dtypes=dtypes, **kwargs)
**kwargs)
def capture(self, tensor, name=None): def capture(self, tensor, name=None):
"""Adds the given tensor to this graph and returns the captured tensor.""" """Adds the given tensor to this graph and returns the captured tensor."""

View File

@ -783,7 +783,7 @@ class OpDefLibrary(object):
if arg.is_ref] if arg.is_ref]
with _MaybeColocateWith(must_colocate_inputs): with _MaybeColocateWith(must_colocate_inputs):
# Add Op to graph # Add Op to graph
op = g.create_op(op_type_name, inputs, name=scope, op = g.create_op(op_type_name, inputs, dtypes=None, name=scope,
input_types=input_types, attrs=attr_protos, input_types=input_types, attrs=attr_protos,
op_def=op_def) op_def=op_def)
return output_structure, op_def.is_stateful, op return output_structure, op_def.is_stateful, op