Add non-deprecated Graph._create_op_internal()
method.
Each time an op is added to a function, we were previously calling `Graph.create_op()`, which has a deprecation wrapper. In some circumstances, this wrapper can contribute ~15% of the cost of tracing a function graph. This change adds an internal endpoint that code inside the TensorFlow library can call, to avoid that overhead. PiperOrigin-RevId: 253720858
This commit is contained in:
parent
f9dd464df8
commit
06a1ca8390
@ -430,7 +430,7 @@ class FuncGraph(ops.Graph):
|
|||||||
compat.as_bytes(op_type), 1, uncaptured_inputs, attr_list,
|
compat.as_bytes(op_type), 1, uncaptured_inputs, attr_list,
|
||||||
context.context())
|
context.context())
|
||||||
else:
|
else:
|
||||||
op = ops.get_default_graph().create_op(
|
op = ops.get_default_graph()._create_op_internal( # pylint: disable=protected-access
|
||||||
op_type,
|
op_type,
|
||||||
uncaptured_inputs,
|
uncaptured_inputs,
|
||||||
dtypes,
|
dtypes,
|
||||||
@ -438,7 +438,7 @@ class FuncGraph(ops.Graph):
|
|||||||
name,
|
name,
|
||||||
attrs,
|
attrs,
|
||||||
op_def,
|
op_def,
|
||||||
compute_device=compute_device)
|
compute_device)
|
||||||
value = op.outputs[0]
|
value = op.outputs[0]
|
||||||
captured_value = self.capture(value)
|
captured_value = self.capture(value)
|
||||||
return captured_value.op
|
return captured_value.op
|
||||||
@ -510,9 +510,9 @@ class FuncGraph(ops.Graph):
|
|||||||
inp = ctxt.AddValue(inp)
|
inp = ctxt.AddValue(inp)
|
||||||
inp = self.capture(inp)
|
inp = self.capture(inp)
|
||||||
inputs[i] = inp
|
inputs[i] = inp
|
||||||
return super(FuncGraph, self).create_op(
|
return super(FuncGraph, self)._create_op_internal( # pylint: disable=protected-access
|
||||||
op_type, inputs, dtypes, input_types, name, attrs, op_def,
|
op_type, inputs, dtypes, input_types, name, attrs, op_def,
|
||||||
compute_device=compute_device)
|
compute_device)
|
||||||
|
|
||||||
def capture(self, tensor, name=None):
|
def capture(self, tensor, name=None):
|
||||||
"""Captures `tensor` if it's external to this graph.
|
"""Captures `tensor` if it's external to this graph.
|
||||||
|
@ -3269,11 +3269,55 @@ class Graph(object):
|
|||||||
An `Operation` object.
|
An `Operation` object.
|
||||||
"""
|
"""
|
||||||
del compute_shapes
|
del compute_shapes
|
||||||
|
|
||||||
self._check_not_finalized()
|
|
||||||
for idx, a in enumerate(inputs):
|
for idx, a in enumerate(inputs):
|
||||||
if not isinstance(a, Tensor):
|
if not isinstance(a, Tensor):
|
||||||
raise TypeError("Input #%d is not a tensor: %s" % (idx, a))
|
raise TypeError("Input #%d is not a tensor: %s" % (idx, a))
|
||||||
|
return self._create_op_internal(op_type, inputs, dtypes, input_types, name,
|
||||||
|
attrs, op_def, compute_device)
|
||||||
|
|
||||||
|
def _create_op_internal(
|
||||||
|
self,
|
||||||
|
op_type,
|
||||||
|
inputs,
|
||||||
|
dtypes=None, # pylint: disable=redefined-outer-name
|
||||||
|
input_types=None,
|
||||||
|
name=None,
|
||||||
|
attrs=None,
|
||||||
|
op_def=None,
|
||||||
|
compute_device=True):
|
||||||
|
"""Creates an `Operation` in this graph.
|
||||||
|
|
||||||
|
Implements `Graph.create_op()` without the overhead of the deprecation
|
||||||
|
wrapper.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
op_type: The `Operation` type to create. This corresponds to the
|
||||||
|
`OpDef.name` field for the proto that defines the operation.
|
||||||
|
inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
|
||||||
|
dtypes: (Optional) A list of `DType` objects that will be the types of the
|
||||||
|
tensors that the operation produces.
|
||||||
|
input_types: (Optional.) A list of `DType`s that will be the types of the
|
||||||
|
tensors that the operation consumes. By default, uses the base `DType`
|
||||||
|
of each input in `inputs`. Operations that expect reference-typed inputs
|
||||||
|
must specify `input_types` explicitly.
|
||||||
|
name: (Optional.) A string name for the operation. If not specified, a
|
||||||
|
name is generated based on `op_type`.
|
||||||
|
attrs: (Optional.) A dictionary where the key is the attribute name (a
|
||||||
|
string) and the value is the respective `attr` attribute of the
|
||||||
|
`NodeDef` proto that will represent the operation (an `AttrValue`
|
||||||
|
proto).
|
||||||
|
op_def: (Optional.) The `OpDef` proto that describes the `op_type` that
|
||||||
|
the operation will have.
|
||||||
|
compute_device: (Optional.) If True, device functions will be executed to
|
||||||
|
compute the device property of the Operation.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if colocation conflicts with existing device assignment.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An `Operation` object.
|
||||||
|
"""
|
||||||
|
self._check_not_finalized()
|
||||||
if name is None:
|
if name is None:
|
||||||
name = op_type
|
name = op_type
|
||||||
# If a names ends with a '/' it is a "name scope" and we use it as-is,
|
# If a names ends with a '/' it is a "name scope" and we use it as-is,
|
||||||
|
Loading…
Reference in New Issue
Block a user