Add non-deprecated Graph._create_op_internal() method.

Each time an op is added to a function, we were previously calling
`Graph.create_op()`, which has a deprecation wrapper. In some
circumstances, this wrapper can contribute ~15% of the cost of tracing
a function graph. This change adds an internal endpoint that code
inside the TensorFlow library can call, to avoid that overhead.

PiperOrigin-RevId: 253720858
This commit is contained in:
Derek Murray 2019-06-17 21:27:14 -07:00 committed by TensorFlower Gardener
parent f9dd464df8
commit 06a1ca8390
2 changed files with 50 additions and 6 deletions

View File

@ -430,7 +430,7 @@ class FuncGraph(ops.Graph):
compat.as_bytes(op_type), 1, uncaptured_inputs, attr_list, compat.as_bytes(op_type), 1, uncaptured_inputs, attr_list,
context.context()) context.context())
else: else:
op = ops.get_default_graph().create_op( op = ops.get_default_graph()._create_op_internal( # pylint: disable=protected-access
op_type, op_type,
uncaptured_inputs, uncaptured_inputs,
dtypes, dtypes,
@ -438,7 +438,7 @@ class FuncGraph(ops.Graph):
name, name,
attrs, attrs,
op_def, op_def,
compute_device=compute_device) compute_device)
value = op.outputs[0] value = op.outputs[0]
captured_value = self.capture(value) captured_value = self.capture(value)
return captured_value.op return captured_value.op
@ -510,9 +510,9 @@ class FuncGraph(ops.Graph):
inp = ctxt.AddValue(inp) inp = ctxt.AddValue(inp)
inp = self.capture(inp) inp = self.capture(inp)
inputs[i] = inp inputs[i] = inp
return super(FuncGraph, self).create_op( return super(FuncGraph, self)._create_op_internal( # pylint: disable=protected-access
op_type, inputs, dtypes, input_types, name, attrs, op_def, op_type, inputs, dtypes, input_types, name, attrs, op_def,
compute_device=compute_device) compute_device)
def capture(self, tensor, name=None): def capture(self, tensor, name=None):
"""Captures `tensor` if it's external to this graph. """Captures `tensor` if it's external to this graph.

View File

@ -3269,11 +3269,55 @@ class Graph(object):
An `Operation` object. An `Operation` object.
""" """
del compute_shapes del compute_shapes
self._check_not_finalized()
for idx, a in enumerate(inputs): for idx, a in enumerate(inputs):
if not isinstance(a, Tensor): if not isinstance(a, Tensor):
raise TypeError("Input #%d is not a tensor: %s" % (idx, a)) raise TypeError("Input #%d is not a tensor: %s" % (idx, a))
return self._create_op_internal(op_type, inputs, dtypes, input_types, name,
attrs, op_def, compute_device)
def _create_op_internal(
self,
op_type,
inputs,
dtypes=None, # pylint: disable=redefined-outer-name
input_types=None,
name=None,
attrs=None,
op_def=None,
compute_device=True):
"""Creates an `Operation` in this graph.
Implements `Graph.create_op()` without the overhead of the deprecation
wrapper.
Args:
op_type: The `Operation` type to create. This corresponds to the
`OpDef.name` field for the proto that defines the operation.
inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
dtypes: (Optional) A list of `DType` objects that will be the types of the
tensors that the operation produces.
input_types: (Optional.) A list of `DType`s that will be the types of the
tensors that the operation consumes. By default, uses the base `DType`
of each input in `inputs`. Operations that expect reference-typed inputs
must specify `input_types` explicitly.
name: (Optional.) A string name for the operation. If not specified, a
name is generated based on `op_type`.
attrs: (Optional.) A dictionary where the key is the attribute name (a
string) and the value is the respective `attr` attribute of the
`NodeDef` proto that will represent the operation (an `AttrValue`
proto).
op_def: (Optional.) The `OpDef` proto that describes the `op_type` that
the operation will have.
compute_device: (Optional.) If True, device functions will be executed to
compute the device property of the Operation.
Raises:
ValueError: if colocation conflicts with existing device assignment.
Returns:
An `Operation` object.
"""
self._check_not_finalized()
if name is None: if name is None:
name = op_type name = op_type
# If a names ends with a '/' it is a "name scope" and we use it as-is, # If a names ends with a '/' it is a "name scope" and we use it as-is,