From 603e328a1c8aab77dcf70b58ddf0099a1e127750 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Thu, 11 Jun 2020 16:01:36 -0700
Subject: [PATCH] When constant folding Case nodes, propagate known output
 shapes to the "_output_shapes" attribute, to aid in shape inference.

PiperOrigin-RevId: 315996338
Change-Id: I2da7e935dacea11511229588bfd143ad46c1e5ac
---
 .../grappler/optimizers/constant_folding.cc   |  19 ++-
 .../optimizers/constant_folding_test.cc       | 115 +++++++++++-------
 2 files changed, 89 insertions(+), 45 deletions(-)

diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index 89cdb308992..d912eb7857b 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -2385,8 +2385,9 @@ bool ConstantFolding::SimplifyCase(GraphDef* optimized_graph, NodeDef* node) {
   if (node->op() != "Case") return false;
   const NodeDef* output_idx_node = node_map_->GetNode(node->input(0));
   if (output_idx_node == nullptr ||
-      !CheckAttrExists(*output_idx_node, "value").ok())
+      !CheckAttrExists(*output_idx_node, "value").ok()) {
     return false;
+  }
   Tensor output_idx_t;
   if (!output_idx_t.FromProto(output_idx_node->attr().at("value").tensor()))
     return false;
@@ -2401,8 +2402,22 @@ bool ConstantFolding::SimplifyCase(GraphDef* optimized_graph, NodeDef* node) {
   }
   auto* new_func = (*call_node.mutable_attr())["f"].mutable_func();
   *new_func = func_list.func(output_idx);
-  call_node.mutable_attr()->erase("branches");
+
+  // Move the output shape of the branch to _output_shapes if it is known.
+  const auto& output_shape_list =
+      (*node->mutable_attr())["output_shapes"].list();
+  if (output_shape_list.shape_size() > output_idx) {
+    TensorShapeProto* new_output_shape =
+        (*call_node.mutable_attr())["_output_shapes"]
+            .mutable_list()
+            ->add_shape();
+    *new_output_shape =
+        std::move(node->attr().at("output_shapes").list().shape(output_idx));
+  }
+
   call_node.mutable_attr()->erase("output_shapes");
+  call_node.mutable_attr()->erase("branches");
+
   *node = std::move(call_node);
   return true;
 }
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index 1d8899de989..87cf18548b6 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -4095,54 +4095,83 @@ TEST_F(ConstantFoldingTest, BitcastDenormalFloats) {
 TEST_F(ConstantFoldingTest, SimplifyCase) {
   using test::function::NDef;
 
-  // Build a graph to compute y = Case(1, x, XTimesTwo(x), NonZero(x))
-  GrapplerItem item;
-  constexpr char kDevice[] = "/job:localhost/replica:0/task:0/device:CPU:0";
-  AttrValue branches;
-  auto* f = branches.mutable_list()->add_func();
-  f->set_name("XTimesTwo");
-  (*f->mutable_attr())["T"].set_type(DT_FLOAT);
-  auto* g = branches.mutable_list()->add_func();
-  *g = *f;
-  g->set_name("NonZero");
+  for (int index = 0; index < 2; ++index) {
+    // Build a graph to compute y = Case(index, x, XTimesTwo(x), NonZero(x))
+    GrapplerItem item;
+    constexpr char kDevice[] = "/job:localhost/replica:0/task:0/device:CPU:0";
+    AttrValue branches;
+    auto* f = branches.mutable_list()->add_func();
+    f->set_name("XTimesTwo");
+    (*f->mutable_attr())["T"].set_type(DT_FLOAT);
+    auto* g = branches.mutable_list()->add_func();
+    *g = *f;
+    g->set_name("NonZero");
 
-  const Tensor kOne = test::AsScalar<int32>(1);
-  item.graph = test::function::GDef(
-      {NDef("one", "Const", {}, {{"value", kOne}, {"dtype", DT_INT32}},
-            kDevice),
-       NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
-       NDef("case", "Case", {"one", "x"},
-            {{"Tin", DataTypeSlice{DT_FLOAT}},
-             {"Tout", DataTypeSlice{DT_FLOAT}},
-             {"branches", branches}},
-            kDevice),
-       NDef("y", "Identity", {"case"}, {{"T", DT_FLOAT}}, kDevice)},
-      // FunctionLib
-      {
-          test::function::XTimesTwo(),
-          test::function::NonZero(),
-      });
-  VLOG(1) << "Before: " << item.graph.DebugString();
+    // Add a pair of somewhat arbitrary output shapes to
+    // test that they are correctly propagates to the _output_shapes
+    // attribute.
+    AttrValue output_shapes;
+    // The first shape is a scalar.
+    output_shapes.mutable_list()->add_shape();
+    // The second shape is unknown.
+    TensorShapeProto* g_shape = output_shapes.mutable_list()->add_shape();
+    g_shape->set_unknown_rank(true);
 
-  item.fetch = {"y"};
-  const Tensor kTwo = test::AsScalar<float>(2.0f);
-  auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {{"x", kTwo}});
+    const Tensor kZero = test::AsScalar<int32>(0);
+    const Tensor kOne = test::AsScalar<int32>(1);
+    item.graph = test::function::GDef(
+        {NDef("one", "Const", {},
+              {{"value", index == 0 ? kZero : kOne}, {"dtype", DT_INT32}},
+              kDevice),
+         NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
+         NDef("case", "Case", {"one", "x"},
+              {{"Tin", DataTypeSlice{DT_FLOAT}},
+               {"Tout", DataTypeSlice{DT_FLOAT}},
+               {"branches", branches},
+               {"output_shapes", output_shapes}},
+              kDevice),
+         NDef("y", "Identity", {"case"}, {{"T", DT_FLOAT}}, kDevice)},
+        // FunctionLib
+        {
+            test::function::XTimesTwo(),
+            test::function::NonZero(),
+        });
+    VLOG(1) << "Before: " << item.graph.DebugString();
 
-  ConstantFolding optimizer(/*cpu_device=*/nullptr);
-  GraphDef optimized_graph;
-  TF_ASSERT_OK(optimizer.Optimize(/*cluster=*/nullptr, item, &optimized_graph));
-  VLOG(1) << "After: " << optimized_graph.DebugString();
+    item.fetch = {"y"};
+    const Tensor kTwo = test::AsScalar<float>(2.0f);
+    auto tensors_expected =
+        EvaluateNodes(item.graph, item.fetch, {{"x", kTwo}});
 
-  int pco_count = 0;
-  for (const auto& node : optimized_graph.node()) {
-    EXPECT_NE(node.op(), "Case");
-    if (node.op() == "PartitionedCall") ++pco_count;
+    ConstantFolding optimizer(/*cpu_device=*/nullptr);
+    GraphDef optimized_graph;
+    TF_ASSERT_OK(
+        optimizer.Optimize(/*cluster=*/nullptr, item, &optimized_graph));
+    VLOG(1) << "After: " << optimized_graph.DebugString();
+
+    int pco_count = 0;
+    for (const auto& node : optimized_graph.node()) {
+      EXPECT_NE(node.op(), "Case");
+      if (node.op() == "PartitionedCall") {
+        ++pco_count;
+        const auto& shape_list = node.attr().at("_output_shapes").list();
+        ASSERT_EQ(shape_list.shape_size(), 1);
+        EXPECT_EQ(shape_list.shape(0).dim_size(), 0);
+        if (index == 0) {
+          EXPECT_EQ(node.attr().at("f").func().name(), "XTimesTwo");
+          EXPECT_EQ(shape_list.shape(0).unknown_rank(), false);
+        } else {
+          EXPECT_EQ(node.attr().at("f").func().name(), "NonZero");
+          EXPECT_EQ(shape_list.shape(0).unknown_rank(), true);
+        }
+      }
+    }
+    EXPECT_EQ(pco_count, 1);
+
+    auto tensors = EvaluateNodes(optimized_graph, item.fetch, {{"x", kTwo}});
+    ASSERT_EQ(tensors.size(), tensors_expected.size());
+    test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
   }
-  EXPECT_EQ(pco_count, 1);
-
-  auto tensors = EvaluateNodes(optimized_graph, item.fetch, {{"x", kTwo}});
-  ASSERT_EQ(tensors.size(), tensors_expected.size());
-  test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
 }
 
 TEST_F(ConstantFoldingTest, SimplifySelect) {