Add optional input_tensors argument to Operation._control_flow_post_processing().

Currently, this method causes the `inputs` property of an `Operation` to be eagerly evaluated via the C API, which can be expensive for large ops (such as tf.while_loop() ops with a large number of loop variables). Since in most cases we already possess a list of tensors from the `Operation.__init__()` `inputs` argument, we can avoid evaluating `inputs`.

PiperOrigin-RevId: 268051042
This commit is contained in:
Derek Murray 2019-09-09 12:25:40 -07:00 committed by TensorFlower Gardener
parent f4a7bc7f93
commit 24888a277e

View File

@ -1804,15 +1804,22 @@ class Operation(object):
self._graph._add_op(self, self._id_value, name) # pylint: disable=protected-access
if not c_op:
self._control_flow_post_processing()
self._control_flow_post_processing(input_tensors=inputs)
def _control_flow_post_processing(self):
def _control_flow_post_processing(self, input_tensors=None):
"""Add this op to its control flow context.
This may add new ops and change this op's inputs. self.inputs must be
available before calling this method.
Args:
input_tensors: (Optional.) A list of `Tensors` corresponding to the inputs
of this op, which should be equivalent to `self.inputs`. Pass this
argument to avoid evaluating `self.inputs` unnecessarily.
"""
for input_tensor in self.inputs:
if input_tensors is None:
input_tensors = self.inputs
for input_tensor in input_tensors:
control_flow_util.CheckInputFromValidContext(self, input_tensor.op)
if self._control_flow_context is not None:
self._control_flow_context.AddOp(self)