Handle shape inference for loop invariant Merge node.

PiperOrigin-RevId: 252505059
This commit is contained in:
Tong Shen 2019-06-10 16:03:58 -07:00 committed by TensorFlower Gardener
parent 19292fe67f
commit ed0134474a
3 changed files with 100 additions and 1 deletions

View File

@ -452,6 +452,7 @@ tf_cc_test(
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:ops",
"//tensorflow/cc:resource_variable_ops",
"//tensorflow/core:framework",
"//tensorflow/core:ops",
"//tensorflow/core:test",

View File

@ -41,7 +41,15 @@ Status ShapeHandleToTensorShape(shape_inference::InferenceContext* context,
Status PropagateShapes(const Graph& graph,
const std::map<int, InferredShape>& arg_shapes,
const std::vector<BackEdgeHelper::BackEdge>& back_edges,
ShapeRefiner* shape_refiner) {
std::map<const Node*, const Node*> merge_to_next_iteration;
for (const auto& e : back_edges) {
if (e.src->IsNextIteration() && e.dst->IsMerge()) {
merge_to_next_iteration[e.dst] = e.src;
}
}
// Visits the nodes in topological order (reverse post-order), inferring
// shapes.
// TODO(phawkins): handle cyclic graphs.
@ -90,6 +98,45 @@ Status PropagateShapes(const Graph& graph,
TF_RETURN_IF_ERROR(shape_refiner->SetShape(n, 0, handle));
}
}
// Merge node causes a loop so we remove NextIteration->Merge edge before
// performing shape inference. But removing those edges also prevents us
// from inferring output shape for Merge node (we need shapes for all its
// inputs).
// For loop invariant resource input's Merge node, we set output resource
// shape as Enter node's resource shape.
// TODO(b/129367850): clean this up.
if (n->IsMerge() && n->output_type(0) == DT_RESOURCE) {
// Check if this is a loop invariant input's Merge node. We do it by
// checking if corresponding NextIteration node comes from Switch node
// directly.
auto iter = merge_to_next_iteration.find(n);
if (iter != merge_to_next_iteration.end()) {
const Node *next_iter = iter->second, *node = next_iter;
do {
TF_RETURN_IF_ERROR(node->input_node(0, &node));
} while (node->IsIdentity());
const Node* switch_input;
bool is_loop_invariant = node->IsSwitch() &&
node->input_node(0, &switch_input).ok() &&
switch_input == n;
if (is_loop_invariant) {
shape_inference::InferenceContext* context =
shape_refiner->GetContext(n);
for (int i = 0; i < n->num_inputs(); i++) {
const Node* input_node;
if (n->input_node(i, &input_node).ok()) {
auto shapes_and_types = context->input_handle_shapes_and_types(i);
if (shapes_and_types) {
context->set_output_handle_shapes_and_types(0,
*shapes_and_types);
}
break;
}
}
}
}
}
}
return Status::OK();
}
@ -149,7 +196,8 @@ Status InferShapes(Graph* graph, const std::map<int, InferredShape>& arg_shapes,
// the shape inference is complete.
BackEdgeHelper back_edge;
TF_RETURN_IF_ERROR(back_edge.Remove(graph));
TF_RETURN_IF_ERROR(PropagateShapes(*graph, arg_shapes, &shape_refiner));
TF_RETURN_IF_ERROR(PropagateShapes(*graph, arg_shapes,
back_edge.RemovedEdges(), &shape_refiner));
TF_RETURN_IF_ERROR(back_edge.Replace());
// Currently information does not flow "backward" from consumers to producers

View File

@ -21,7 +21,9 @@ limitations under the License.
#include <vector>
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/test_util.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
@ -120,5 +122,53 @@ TEST(ShapeInferenceTest, WhileLoop) {
TF_EXPECT_OK(ShapeAnnotationsMatch(graph, shape_info, expected));
}
TEST(ShapeInferenceTest, WhileLoopWithResource) {
// Graph:
// x = resource_variable_ops.var_handle_op(dtype=dtypes.float32, shape=[2, 3])
// y = control_flow_ops.while_loop(lambda _: true, lambda x: x, [x])
Graph graph(OpRegistry::Global());
{
Scope scope = Scope::NewRootScope().ExitOnError();
auto x =
ops::VarHandleOp(scope.WithOpName("x"), DT_FLOAT, TensorShape({2, 3}));
auto enter =
ops::internal::Enter(scope.WithOpName("while/Enter"), x, "aloop");
auto dummy = ops::Placeholder(scope.WithOpName("dummy"), DT_RESOURCE);
auto merge = ops::Merge(scope.WithOpName("while/Merge"),
std::initializer_list<Input>{enter, dummy});
auto false_value = ops::Const<bool>(scope.WithOpName("false"), false);
auto loop_cond =
ops::LoopCond(scope.WithOpName("while/LoopCond"), false_value);
auto switch_node =
ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond);
auto exit = ops::internal::Exit(scope.WithOpName("while/Exit"),
switch_node.output_false);
auto identity = ops::Identity(scope.WithOpName("while/Identity"),
switch_node.output_true);
auto next_iteration =
ops::NextIteration(scope.WithOpName("while/NextIteration"), identity);
auto sink = ops::Identity(scope.WithOpName("sink"), exit);
// Remove the dummy node and add the loop backedge.
scope.graph()->RemoveNode(dummy.node());
scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1);
TF_EXPECT_OK(scope.ToGraph(&graph));
}
// Check that we can infer shape for "sink" node (Merge node output).
GraphShapeInfo shape_info;
TF_ASSERT_OK(InferShapes(&graph, /*arg_shapes=*/{}, /*fnlib_def=*/nullptr,
&shape_info));
auto iter = shape_info.find("sink");
EXPECT_NE(iter, shape_info.end());
EXPECT_EQ(iter->second.size(), 1);
EXPECT_EQ(iter->second.at(0).handle_type, DT_FLOAT);
TensorShape resource_shape;
EXPECT_TRUE(iter->second.at(0).handle_shape.AsTensorShape(&resource_shape));
EXPECT_EQ(resource_shape, TensorShape({2, 3}));
}
} // namespace
} // namespace tensorflow