Introduce Operation._add_while_inputs to allow adding inputs to a While op.

This is in preparation for changing while_v2 to rewrite the forward
pass to output intermediates needed by the gradient, instead of
outputting all intermediates. Since While ops always have the same
inputs and output types, we need to be able to add inputs in addition
to adding outputs.

PiperOrigin-RevId: 223812986
This commit is contained in:
Skye Wanderman-Milne 2018-12-03 09:56:03 -08:00 committed by TensorFlower Gardener
parent 9c8b5bf6e7
commit 62db4a3ccf
7 changed files with 93 additions and 1 deletions

View File

@ -160,4 +160,17 @@ void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
}
void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst,
TF_Status* status) {
mutex_lock l(graph->mu);
status->status = graph->graph.AddWhileInputHack(&new_src.oper->node,
new_src.index, &dst->node);
if (status->status.ok()) {
// This modification only updates the destination node for
// the purposes of running this graph in a session. Thus, we don't
// record the source node as being modified.
RecordMutation(graph, *dst, "adding input tensor");
}
}
} // namespace tensorflow

View File

@ -34,6 +34,7 @@ void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device);
// Updates 'dst' to consume 'new_src'.
void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
TF_Status* status);
@ -65,6 +66,13 @@ std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output);
// because I couldn't get SWIG to work otherwise.
void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
size_t proto_len, TF_Status* status);
// This method is used to add a new input edge to 'dst', which must be a While
// op. The While op's "T" attribute must have already been updated to include
// the new edge. This is used to construct tf.while_loop gradients.
void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst,
TF_Status* status);
} // namespace tensorflow
#endif // TENSORFLOW_C_PYTHON_API_H_

View File

@ -548,6 +548,22 @@ Status Graph::UpdateEdge(Node* new_src, int new_src_index, Node* dst,
return Status::OK();
}
Status Graph::AddWhileInputHack(Node* new_src, int new_src_index, Node* dst) {
if (dst->type_string() != "While") {
return errors::Internal(
"dst argument to AddWhileEdgeHack should be a While op, got: ",
dst->DebugString());
}
TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index));
int dst_index = dst->in_edges().size();
TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index));
AddEdge(new_src, new_src_index, dst, dst_index);
dst->MaybeCopyOnWrite();
dst->props_->node_def.add_input(
strings::StrCat(new_src->name(), ":", new_src_index));
return Status::OK();
}
Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) {
// Need a new-enough consumer to support the functions we add to the graph.
if (fdef_lib.function_size() > 0 && versions_->min_consumer() < 12) {

View File

@ -493,11 +493,17 @@ class Graph {
// the corresponding NodeDef to reflect the change.
// REQUIRES: The control edge must exist.
void RemoveControlEdge(const Edge* e);
// Updates the input to a node. The existing edge to `dst` is removed and an
// edge from `new_src` to `dst` is created. The NodeDef associated with `dst`
// is also updated.
Status UpdateEdge(Node* new_src, int new_src_index, Node* dst, int dst_index);
// Like AddEdge but updates dst's NodeDef. Used to add an input edge to a
// "While" op during gradient construction, see AddInputWhileHack in
// python_api.h for more details.
Status AddWhileInputHack(Node* new_src, int new_src_index, Node* dst);
// Adds the function and gradient definitions in `fdef_lib` to this graph's op
// registry. Ignores duplicate functions, and returns a bad status if an
// imported function differs from an existing function or op with the same

View File

@ -2183,7 +2183,6 @@ py_library(
":control_flow_util_v2",
":dtypes",
":framework_ops",
":framework_test_lib",
":function_def_to_graph",
":functional_ops_gen",
":gradients_impl",

View File

@ -2086,6 +2086,31 @@ class Operation(object):
tensor._as_tf_output(), # pylint: disable=protected-access
self._tf_input(index))
def _add_while_inputs(self, tensors):
"""See AddWhileInputHack in python_api.h.
NOTE: This is for TF internal use only. Please don't use it.
Args:
tensors: list of Tensors
Raises:
TypeError: if tensor is not a Tensor,
or if input tensor type is not convertible to dtype.
ValueError: if the Tensor is from a different graph.
"""
for tensor in tensors:
if not isinstance(tensor, Tensor):
raise TypeError("tensor must be a Tensor: %s" % tensor)
_assert_same_graph(self, tensor)
# Reset cached inputs.
self._inputs_val = None
c_api.AddWhileInputHack(
self._graph._c_graph, # pylint: disable=protected-access
tensor._as_tf_output(), # pylint: disable=protected-access
self._c_op)
def _add_control_inputs(self, ops):
"""Add a list of new control inputs to this operation.

View File

@ -604,6 +604,31 @@ class OperationTest(test_util.TensorFlowTestCase):
):
x.op._update_input(1, x) # pylint: disable=protected-access
@test_util.enable_control_flow_v2
def testAddWhileInput(self):
@eager_function.defun
def test():
output = control_flow_ops.while_loop(lambda x: x < 3, lambda x: x + 1,
[1])
while_op = output.op.inputs[0].op
self.assertEqual(while_op.type, "While")
orig_num_inputs = len(while_op.inputs)
new_input1 = constant_op.constant(1.0)
new_input2 = constant_op.constant(True)
while_op._set_type_list_attr("T",
[t.dtype for t in while_op.inputs] +
[new_input1.dtype, new_input2.dtype])
while_op._add_while_inputs([new_input1, new_input2])
# Can't add an edge beyond what's specified by "T"
with self.assertRaises(errors.OutOfRangeError):
while_op._add_while_inputs([new_input2])
self.assertEqual(len(while_op.inputs), orig_num_inputs + 2) # pylint: disable=g-deprecated-assert
test()
@test_util.run_deprecated_v1
def testOpDef(self):
x = constant_op.constant(0)