Add non-deprecated Graph._create_op_internal()
method.
Each time an op is added to a function, we were previously calling `Graph.create_op()`, which has a deprecation wrapper. In some circumstances, this wrapper can contribute ~15% of the cost of tracing a function graph. This change adds an internal endpoint that code inside the TensorFlow library can call, to avoid that overhead. PiperOrigin-RevId: 253720858
This commit is contained in:
parent
f9dd464df8
commit
06a1ca8390
@ -430,7 +430,7 @@ class FuncGraph(ops.Graph):
|
||||
compat.as_bytes(op_type), 1, uncaptured_inputs, attr_list,
|
||||
context.context())
|
||||
else:
|
||||
op = ops.get_default_graph().create_op(
|
||||
op = ops.get_default_graph()._create_op_internal( # pylint: disable=protected-access
|
||||
op_type,
|
||||
uncaptured_inputs,
|
||||
dtypes,
|
||||
@ -438,7 +438,7 @@ class FuncGraph(ops.Graph):
|
||||
name,
|
||||
attrs,
|
||||
op_def,
|
||||
compute_device=compute_device)
|
||||
compute_device)
|
||||
value = op.outputs[0]
|
||||
captured_value = self.capture(value)
|
||||
return captured_value.op
|
||||
@ -510,9 +510,9 @@ class FuncGraph(ops.Graph):
|
||||
inp = ctxt.AddValue(inp)
|
||||
inp = self.capture(inp)
|
||||
inputs[i] = inp
|
||||
return super(FuncGraph, self).create_op(
|
||||
return super(FuncGraph, self)._create_op_internal( # pylint: disable=protected-access
|
||||
op_type, inputs, dtypes, input_types, name, attrs, op_def,
|
||||
compute_device=compute_device)
|
||||
compute_device)
|
||||
|
||||
def capture(self, tensor, name=None):
|
||||
"""Captures `tensor` if it's external to this graph.
|
||||
|
@ -3269,11 +3269,55 @@ class Graph(object):
|
||||
An `Operation` object.
|
||||
"""
|
||||
del compute_shapes
|
||||
|
||||
self._check_not_finalized()
|
||||
for idx, a in enumerate(inputs):
|
||||
if not isinstance(a, Tensor):
|
||||
raise TypeError("Input #%d is not a tensor: %s" % (idx, a))
|
||||
return self._create_op_internal(op_type, inputs, dtypes, input_types, name,
|
||||
attrs, op_def, compute_device)
|
||||
|
||||
def _create_op_internal(
|
||||
self,
|
||||
op_type,
|
||||
inputs,
|
||||
dtypes=None, # pylint: disable=redefined-outer-name
|
||||
input_types=None,
|
||||
name=None,
|
||||
attrs=None,
|
||||
op_def=None,
|
||||
compute_device=True):
|
||||
"""Creates an `Operation` in this graph.
|
||||
|
||||
Implements `Graph.create_op()` without the overhead of the deprecation
|
||||
wrapper.
|
||||
|
||||
Args:
|
||||
op_type: The `Operation` type to create. This corresponds to the
|
||||
`OpDef.name` field for the proto that defines the operation.
|
||||
inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
|
||||
dtypes: (Optional) A list of `DType` objects that will be the types of the
|
||||
tensors that the operation produces.
|
||||
input_types: (Optional.) A list of `DType`s that will be the types of the
|
||||
tensors that the operation consumes. By default, uses the base `DType`
|
||||
of each input in `inputs`. Operations that expect reference-typed inputs
|
||||
must specify `input_types` explicitly.
|
||||
name: (Optional.) A string name for the operation. If not specified, a
|
||||
name is generated based on `op_type`.
|
||||
attrs: (Optional.) A dictionary where the key is the attribute name (a
|
||||
string) and the value is the respective `attr` attribute of the
|
||||
`NodeDef` proto that will represent the operation (an `AttrValue`
|
||||
proto).
|
||||
op_def: (Optional.) The `OpDef` proto that describes the `op_type` that
|
||||
the operation will have.
|
||||
compute_device: (Optional.) If True, device functions will be executed to
|
||||
compute the device property of the Operation.
|
||||
|
||||
Raises:
|
||||
ValueError: if colocation conflicts with existing device assignment.
|
||||
|
||||
Returns:
|
||||
An `Operation` object.
|
||||
"""
|
||||
self._check_not_finalized()
|
||||
if name is None:
|
||||
name = op_type
|
||||
# If a names ends with a '/' it is a "name scope" and we use it as-is,
|
||||
|
Loading…
Reference in New Issue
Block a user