Introduce Operation._add_while_inputs to allow adding inputs to a While op.
This is in preparation for changing while_v2 to rewrite the forward pass to output intermediates needed by the gradient, instead of outputting all intermediates. Since While ops always have the same inputs and output types, we need to be able to add inputs in addition to adding outputs. PiperOrigin-RevId: 223812986
This commit is contained in:
parent
9c8b5bf6e7
commit
62db4a3ccf
@ -160,4 +160,17 @@ void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
|
||||
ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
|
||||
}
|
||||
|
||||
void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst,
|
||||
TF_Status* status) {
|
||||
mutex_lock l(graph->mu);
|
||||
status->status = graph->graph.AddWhileInputHack(&new_src.oper->node,
|
||||
new_src.index, &dst->node);
|
||||
if (status->status.ok()) {
|
||||
// This modification only updates the destination node for
|
||||
// the purposes of running this graph in a session. Thus, we don't
|
||||
// record the source node as being modified.
|
||||
RecordMutation(graph, *dst, "adding input tensor");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -34,6 +34,7 @@ void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
|
||||
|
||||
void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device);
|
||||
|
||||
// Updates 'dst' to consume 'new_src'.
|
||||
void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
|
||||
TF_Status* status);
|
||||
|
||||
@ -65,6 +66,13 @@ std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output);
|
||||
// because I couldn't get SWIG to work otherwise.
|
||||
void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
|
||||
size_t proto_len, TF_Status* status);
|
||||
|
||||
// This method is used to add a new input edge to 'dst', which must be a While
|
||||
// op. The While op's "T" attribute must have already been updated to include
|
||||
// the new edge. This is used to construct tf.while_loop gradients.
|
||||
void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst,
|
||||
TF_Status* status);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_PYTHON_API_H_
|
||||
|
@ -548,6 +548,22 @@ Status Graph::UpdateEdge(Node* new_src, int new_src_index, Node* dst,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Graph::AddWhileInputHack(Node* new_src, int new_src_index, Node* dst) {
|
||||
if (dst->type_string() != "While") {
|
||||
return errors::Internal(
|
||||
"dst argument to AddWhileEdgeHack should be a While op, got: ",
|
||||
dst->DebugString());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index));
|
||||
int dst_index = dst->in_edges().size();
|
||||
TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index));
|
||||
AddEdge(new_src, new_src_index, dst, dst_index);
|
||||
dst->MaybeCopyOnWrite();
|
||||
dst->props_->node_def.add_input(
|
||||
strings::StrCat(new_src->name(), ":", new_src_index));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) {
|
||||
// Need a new-enough consumer to support the functions we add to the graph.
|
||||
if (fdef_lib.function_size() > 0 && versions_->min_consumer() < 12) {
|
||||
|
@ -493,11 +493,17 @@ class Graph {
|
||||
// the corresponding NodeDef to reflect the change.
|
||||
// REQUIRES: The control edge must exist.
|
||||
void RemoveControlEdge(const Edge* e);
|
||||
|
||||
// Updates the input to a node. The existing edge to `dst` is removed and an
|
||||
// edge from `new_src` to `dst` is created. The NodeDef associated with `dst`
|
||||
// is also updated.
|
||||
Status UpdateEdge(Node* new_src, int new_src_index, Node* dst, int dst_index);
|
||||
|
||||
// Like AddEdge but updates dst's NodeDef. Used to add an input edge to a
|
||||
// "While" op during gradient construction, see AddInputWhileHack in
|
||||
// python_api.h for more details.
|
||||
Status AddWhileInputHack(Node* new_src, int new_src_index, Node* dst);
|
||||
|
||||
// Adds the function and gradient definitions in `fdef_lib` to this graph's op
|
||||
// registry. Ignores duplicate functions, and returns a bad status if an
|
||||
// imported function differs from an existing function or op with the same
|
||||
|
@ -2183,7 +2183,6 @@ py_library(
|
||||
":control_flow_util_v2",
|
||||
":dtypes",
|
||||
":framework_ops",
|
||||
":framework_test_lib",
|
||||
":function_def_to_graph",
|
||||
":functional_ops_gen",
|
||||
":gradients_impl",
|
||||
|
@ -2086,6 +2086,31 @@ class Operation(object):
|
||||
tensor._as_tf_output(), # pylint: disable=protected-access
|
||||
self._tf_input(index))
|
||||
|
||||
def _add_while_inputs(self, tensors):
|
||||
"""See AddWhileInputHack in python_api.h.
|
||||
|
||||
NOTE: This is for TF internal use only. Please don't use it.
|
||||
|
||||
Args:
|
||||
tensors: list of Tensors
|
||||
|
||||
Raises:
|
||||
TypeError: if tensor is not a Tensor,
|
||||
or if input tensor type is not convertible to dtype.
|
||||
ValueError: if the Tensor is from a different graph.
|
||||
"""
|
||||
for tensor in tensors:
|
||||
if not isinstance(tensor, Tensor):
|
||||
raise TypeError("tensor must be a Tensor: %s" % tensor)
|
||||
_assert_same_graph(self, tensor)
|
||||
|
||||
# Reset cached inputs.
|
||||
self._inputs_val = None
|
||||
c_api.AddWhileInputHack(
|
||||
self._graph._c_graph, # pylint: disable=protected-access
|
||||
tensor._as_tf_output(), # pylint: disable=protected-access
|
||||
self._c_op)
|
||||
|
||||
def _add_control_inputs(self, ops):
|
||||
"""Add a list of new control inputs to this operation.
|
||||
|
||||
|
@ -604,6 +604,31 @@ class OperationTest(test_util.TensorFlowTestCase):
|
||||
):
|
||||
x.op._update_input(1, x) # pylint: disable=protected-access
|
||||
|
||||
@test_util.enable_control_flow_v2
|
||||
def testAddWhileInput(self):
|
||||
@eager_function.defun
|
||||
def test():
|
||||
output = control_flow_ops.while_loop(lambda x: x < 3, lambda x: x + 1,
|
||||
[1])
|
||||
while_op = output.op.inputs[0].op
|
||||
self.assertEqual(while_op.type, "While")
|
||||
orig_num_inputs = len(while_op.inputs)
|
||||
|
||||
new_input1 = constant_op.constant(1.0)
|
||||
new_input2 = constant_op.constant(True)
|
||||
|
||||
while_op._set_type_list_attr("T",
|
||||
[t.dtype for t in while_op.inputs] +
|
||||
[new_input1.dtype, new_input2.dtype])
|
||||
|
||||
while_op._add_while_inputs([new_input1, new_input2])
|
||||
# Can't add an edge beyond what's specified by "T"
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
while_op._add_while_inputs([new_input2])
|
||||
self.assertEqual(len(while_op.inputs), orig_num_inputs + 2) # pylint: disable=g-deprecated-assert
|
||||
|
||||
test()
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testOpDef(self):
|
||||
x = constant_op.constant(0)
|
||||
|
Loading…
Reference in New Issue
Block a user