Handle shape inference for loop invariant Merge node.
PiperOrigin-RevId: 252505059
This commit is contained in:
parent
19292fe67f
commit
ed0134474a
@ -452,6 +452,7 @@ tf_cc_test(
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:cc_ops_internal",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/cc:resource_variable_ops",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:test",
|
||||
|
@ -41,7 +41,15 @@ Status ShapeHandleToTensorShape(shape_inference::InferenceContext* context,
|
||||
|
||||
Status PropagateShapes(const Graph& graph,
|
||||
const std::map<int, InferredShape>& arg_shapes,
|
||||
const std::vector<BackEdgeHelper::BackEdge>& back_edges,
|
||||
ShapeRefiner* shape_refiner) {
|
||||
std::map<const Node*, const Node*> merge_to_next_iteration;
|
||||
for (const auto& e : back_edges) {
|
||||
if (e.src->IsNextIteration() && e.dst->IsMerge()) {
|
||||
merge_to_next_iteration[e.dst] = e.src;
|
||||
}
|
||||
}
|
||||
|
||||
// Visits the nodes in topological order (reverse post-order), inferring
|
||||
// shapes.
|
||||
// TODO(phawkins): handle cyclic graphs.
|
||||
@ -90,6 +98,45 @@ Status PropagateShapes(const Graph& graph,
|
||||
TF_RETURN_IF_ERROR(shape_refiner->SetShape(n, 0, handle));
|
||||
}
|
||||
}
|
||||
|
||||
// Merge node causes a loop so we remove NextIteration->Merge edge before
|
||||
// performing shape inference. But removing those edges also prevents us
|
||||
// from inferring output shape for Merge node (we need shapes for all its
|
||||
// inputs).
|
||||
// For loop invariant resource input's Merge node, we set output resource
|
||||
// shape as Enter node's resource shape.
|
||||
// TODO(b/129367850): clean this up.
|
||||
if (n->IsMerge() && n->output_type(0) == DT_RESOURCE) {
|
||||
// Check if this is a loop invariant input's Merge node. We do it by
|
||||
// checking if corresponding NextIteration node comes from Switch node
|
||||
// directly.
|
||||
auto iter = merge_to_next_iteration.find(n);
|
||||
if (iter != merge_to_next_iteration.end()) {
|
||||
const Node *next_iter = iter->second, *node = next_iter;
|
||||
do {
|
||||
TF_RETURN_IF_ERROR(node->input_node(0, &node));
|
||||
} while (node->IsIdentity());
|
||||
const Node* switch_input;
|
||||
bool is_loop_invariant = node->IsSwitch() &&
|
||||
node->input_node(0, &switch_input).ok() &&
|
||||
switch_input == n;
|
||||
if (is_loop_invariant) {
|
||||
shape_inference::InferenceContext* context =
|
||||
shape_refiner->GetContext(n);
|
||||
for (int i = 0; i < n->num_inputs(); i++) {
|
||||
const Node* input_node;
|
||||
if (n->input_node(i, &input_node).ok()) {
|
||||
auto shapes_and_types = context->input_handle_shapes_and_types(i);
|
||||
if (shapes_and_types) {
|
||||
context->set_output_handle_shapes_and_types(0,
|
||||
*shapes_and_types);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -149,7 +196,8 @@ Status InferShapes(Graph* graph, const std::map<int, InferredShape>& arg_shapes,
|
||||
// the shape inference is complete.
|
||||
BackEdgeHelper back_edge;
|
||||
TF_RETURN_IF_ERROR(back_edge.Remove(graph));
|
||||
TF_RETURN_IF_ERROR(PropagateShapes(*graph, arg_shapes, &shape_refiner));
|
||||
TF_RETURN_IF_ERROR(PropagateShapes(*graph, arg_shapes,
|
||||
back_edge.RemovedEdges(), &shape_refiner));
|
||||
TF_RETURN_IF_ERROR(back_edge.Replace());
|
||||
|
||||
// Currently information does not flow "backward" from consumers to producers
|
||||
|
@ -21,7 +21,9 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/array_ops.h"
|
||||
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
|
||||
#include "tensorflow/cc/ops/resource_variable_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/jit/test_util.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
@ -120,5 +122,53 @@ TEST(ShapeInferenceTest, WhileLoop) {
|
||||
TF_EXPECT_OK(ShapeAnnotationsMatch(graph, shape_info, expected));
|
||||
}
|
||||
|
||||
TEST(ShapeInferenceTest, WhileLoopWithResource) {
|
||||
// Graph:
|
||||
// x = resource_variable_ops.var_handle_op(dtype=dtypes.float32, shape=[2, 3])
|
||||
// y = control_flow_ops.while_loop(lambda _: true, lambda x: x, [x])
|
||||
Graph graph(OpRegistry::Global());
|
||||
{
|
||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
auto x =
|
||||
ops::VarHandleOp(scope.WithOpName("x"), DT_FLOAT, TensorShape({2, 3}));
|
||||
auto enter =
|
||||
ops::internal::Enter(scope.WithOpName("while/Enter"), x, "aloop");
|
||||
auto dummy = ops::Placeholder(scope.WithOpName("dummy"), DT_RESOURCE);
|
||||
auto merge = ops::Merge(scope.WithOpName("while/Merge"),
|
||||
std::initializer_list<Input>{enter, dummy});
|
||||
auto false_value = ops::Const<bool>(scope.WithOpName("false"), false);
|
||||
auto loop_cond =
|
||||
ops::LoopCond(scope.WithOpName("while/LoopCond"), false_value);
|
||||
auto switch_node =
|
||||
ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond);
|
||||
auto exit = ops::internal::Exit(scope.WithOpName("while/Exit"),
|
||||
switch_node.output_false);
|
||||
auto identity = ops::Identity(scope.WithOpName("while/Identity"),
|
||||
switch_node.output_true);
|
||||
auto next_iteration =
|
||||
ops::NextIteration(scope.WithOpName("while/NextIteration"), identity);
|
||||
auto sink = ops::Identity(scope.WithOpName("sink"), exit);
|
||||
|
||||
// Remove the dummy node and add the loop backedge.
|
||||
scope.graph()->RemoveNode(dummy.node());
|
||||
scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1);
|
||||
|
||||
TF_EXPECT_OK(scope.ToGraph(&graph));
|
||||
}
|
||||
|
||||
// Check that we can infer shape for "sink" node (Merge node output).
|
||||
GraphShapeInfo shape_info;
|
||||
TF_ASSERT_OK(InferShapes(&graph, /*arg_shapes=*/{}, /*fnlib_def=*/nullptr,
|
||||
&shape_info));
|
||||
auto iter = shape_info.find("sink");
|
||||
EXPECT_NE(iter, shape_info.end());
|
||||
EXPECT_EQ(iter->second.size(), 1);
|
||||
EXPECT_EQ(iter->second.at(0).handle_type, DT_FLOAT);
|
||||
TensorShape resource_shape;
|
||||
EXPECT_TRUE(iter->second.at(0).handle_shape.AsTensorShape(&resource_shape));
|
||||
EXPECT_EQ(resource_shape, TensorShape({2, 3}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
Loading…
x
Reference in New Issue
Block a user