diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 89cdb308992..d912eb7857b 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -2385,8 +2385,9 @@ bool ConstantFolding::SimplifyCase(GraphDef* optimized_graph, NodeDef* node) { if (node->op() != "Case") return false; const NodeDef* output_idx_node = node_map_->GetNode(node->input(0)); if (output_idx_node == nullptr || - !CheckAttrExists(*output_idx_node, "value").ok()) + !CheckAttrExists(*output_idx_node, "value").ok()) { return false; + } Tensor output_idx_t; if (!output_idx_t.FromProto(output_idx_node->attr().at("value").tensor())) return false; @@ -2401,8 +2402,22 @@ bool ConstantFolding::SimplifyCase(GraphDef* optimized_graph, NodeDef* node) { } auto* new_func = (*call_node.mutable_attr())["f"].mutable_func(); *new_func = func_list.func(output_idx); - call_node.mutable_attr()->erase("branches"); + + // Move the output shape of the branch to _output_shapes if it is known. + const auto& output_shape_list = + (*node->mutable_attr())["output_shapes"].list(); + if (output_shape_list.shape_size() > output_idx) { + TensorShapeProto* new_output_shape = + (*call_node.mutable_attr())["_output_shapes"] + .mutable_list() + ->add_shape(); + *new_output_shape = + std::move(node->attr().at("output_shapes").list().shape(output_idx)); + } + call_node.mutable_attr()->erase("output_shapes"); + call_node.mutable_attr()->erase("branches"); + *node = std::move(call_node); return true; } diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 1d8899de989..87cf18548b6 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -4095,54 +4095,83 @@ TEST_F(ConstantFoldingTest, BitcastDenormalFloats) { TEST_F(ConstantFoldingTest, SimplifyCase) { using test::function::NDef; - // Build a graph to compute y = Case(1, x, XTimesTwo(x), NonZero(x)) - GrapplerItem item; - constexpr char kDevice[] = "/job:localhost/replica:0/task:0/device:CPU:0"; - AttrValue branches; - auto* f = branches.mutable_list()->add_func(); - f->set_name("XTimesTwo"); - (*f->mutable_attr())["T"].set_type(DT_FLOAT); - auto* g = branches.mutable_list()->add_func(); - *g = *f; - g->set_name("NonZero"); + for (int index = 0; index < 2; ++index) { + // Build a graph to compute y = Case(index, x, XTimesTwo(x), NonZero(x)) + GrapplerItem item; + constexpr char kDevice[] = "/job:localhost/replica:0/task:0/device:CPU:0"; + AttrValue branches; + auto* f = branches.mutable_list()->add_func(); + f->set_name("XTimesTwo"); + (*f->mutable_attr())["T"].set_type(DT_FLOAT); + auto* g = branches.mutable_list()->add_func(); + *g = *f; + g->set_name("NonZero"); - const Tensor kOne = test::AsScalar(1); - item.graph = test::function::GDef( - {NDef("one", "Const", {}, {{"value", kOne}, {"dtype", DT_INT32}}, - kDevice), - NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice), - NDef("case", "Case", {"one", "x"}, - {{"Tin", DataTypeSlice{DT_FLOAT}}, - {"Tout", DataTypeSlice{DT_FLOAT}}, - {"branches", branches}}, - kDevice), - NDef("y", "Identity", {"case"}, {{"T", DT_FLOAT}}, kDevice)}, - // FunctionLib - { - test::function::XTimesTwo(), - test::function::NonZero(), - }); - VLOG(1) << "Before: " << item.graph.DebugString(); + // Add a pair of somewhat arbitrary output shapes to + // test that they are correctly propagates to the _output_shapes + // attribute. + AttrValue output_shapes; + // The first shape is a scalar. + output_shapes.mutable_list()->add_shape(); + // The second shape is unknown. + TensorShapeProto* g_shape = output_shapes.mutable_list()->add_shape(); + g_shape->set_unknown_rank(true); - item.fetch = {"y"}; - const Tensor kTwo = test::AsScalar(2.0f); - auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {{"x", kTwo}}); + const Tensor kZero = test::AsScalar(0); + const Tensor kOne = test::AsScalar(1); + item.graph = test::function::GDef( + {NDef("one", "Const", {}, + {{"value", index == 0 ? kZero : kOne}, {"dtype", DT_INT32}}, + kDevice), + NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice), + NDef("case", "Case", {"one", "x"}, + {{"Tin", DataTypeSlice{DT_FLOAT}}, + {"Tout", DataTypeSlice{DT_FLOAT}}, + {"branches", branches}, + {"output_shapes", output_shapes}}, + kDevice), + NDef("y", "Identity", {"case"}, {{"T", DT_FLOAT}}, kDevice)}, + // FunctionLib + { + test::function::XTimesTwo(), + test::function::NonZero(), + }); + VLOG(1) << "Before: " << item.graph.DebugString(); - ConstantFolding optimizer(/*cpu_device=*/nullptr); - GraphDef optimized_graph; - TF_ASSERT_OK(optimizer.Optimize(/*cluster=*/nullptr, item, &optimized_graph)); - VLOG(1) << "After: " << optimized_graph.DebugString(); + item.fetch = {"y"}; + const Tensor kTwo = test::AsScalar(2.0f); + auto tensors_expected = + EvaluateNodes(item.graph, item.fetch, {{"x", kTwo}}); - int pco_count = 0; - for (const auto& node : optimized_graph.node()) { - EXPECT_NE(node.op(), "Case"); - if (node.op() == "PartitionedCall") ++pco_count; + ConstantFolding optimizer(/*cpu_device=*/nullptr); + GraphDef optimized_graph; + TF_ASSERT_OK( + optimizer.Optimize(/*cluster=*/nullptr, item, &optimized_graph)); + VLOG(1) << "After: " << optimized_graph.DebugString(); + + int pco_count = 0; + for (const auto& node : optimized_graph.node()) { + EXPECT_NE(node.op(), "Case"); + if (node.op() == "PartitionedCall") { + ++pco_count; + const auto& shape_list = node.attr().at("_output_shapes").list(); + ASSERT_EQ(shape_list.shape_size(), 1); + EXPECT_EQ(shape_list.shape(0).dim_size(), 0); + if (index == 0) { + EXPECT_EQ(node.attr().at("f").func().name(), "XTimesTwo"); + EXPECT_EQ(shape_list.shape(0).unknown_rank(), false); + } else { + EXPECT_EQ(node.attr().at("f").func().name(), "NonZero"); + EXPECT_EQ(shape_list.shape(0).unknown_rank(), true); + } + } + } + EXPECT_EQ(pco_count, 1); + + auto tensors = EvaluateNodes(optimized_graph, item.fetch, {{"x", kTwo}}); + ASSERT_EQ(tensors.size(), tensors_expected.size()); + test::ExpectTensorEqual(tensors[0], tensors_expected[0]); } - EXPECT_EQ(pco_count, 1); - - auto tensors = EvaluateNodes(optimized_graph, item.fetch, {{"x", kTwo}}); - ASSERT_EQ(tensors.size(), tensors_expected.size()); - test::ExpectTensorEqual(tensors[0], tensors_expected[0]); } TEST_F(ConstantFoldingTest, SimplifySelect) {