diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 6c81d4e9f82..89d92173c37 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -161,7 +161,7 @@ Status XlaCompiler::CompileFunction(
   opts.set_do_constant_folding(true);
   GraphOptimizer optimizer(opts);
   optimizer.Optimize(flib_runtime_.get(), flib_runtime_->env(),
-                     /*device=*/nullptr, &graph);
+                     /*device=*/nullptr, &graph, /*shape_map=*/nullptr);
 
   VLOG(1) << "====================================================";
   TF_RETURN_IF_ERROR(
diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc
index 914683d9fa3..8dfb8b45de2 100644
--- a/tensorflow/core/common_runtime/constant_folding.cc
+++ b/tensorflow/core/common_runtime/constant_folding.cc
@@ -43,11 +43,182 @@ namespace tensorflow {
 
 namespace {
 
-bool IsConstantFoldable(const Node* n,
-                        const std::function<bool(const Node*)>& consider) {
+// Test to see if the Op is one that turns into a constant when its
+// inputs' shapes are known.
+bool IsShapeOp(const Node* n) {
+  const auto& ts = n->type_string();
+  return ts == "Shape" || ts == "ShapeN" || ts == "Rank" || ts == "Size";
+}
+
+// Reads the partially-known shape of each of n's inputs from shape_map, and
+// stores it to input_shapes. Returns false if any input does not have a shape
+// in shape_map.
+bool ReadPartialShapesFromShapeMap(
+    const Node* n,
+    const std::unordered_map<const Node*, std::vector<PartialTensorShape>>*
+        shape_map,
+    std::vector<PartialTensorShape>* input_shapes) {
+  CHECK(shape_map != nullptr);
+  for (const Edge* in : n->in_edges()) {
+    // Don't need to check if incoming control edges have known shapes.
+    if (in->IsControlEdge()) continue;
+    if (shape_map->count(in->src()) == 0) {
+      // One of n's inputs doesn't have known shapes, so don't replace n.
+      return false;
+    }
+    const auto& known_shape = shape_map->at(in->src());
+    CHECK_GT(known_shape.size(), in->src_output());
+    input_shapes->push_back(known_shape[in->src_output()]);
+  }
+  return true;
+}
+
+// If all of n's inputs have fully-defined shapes, inserts those shapes as a
+// vector of Tensors in the shape_replacement_map.
+bool MaybeReplaceShapeOrShapeNOp(
+    const Node* n, const std::vector<PartialTensorShape>& input_shapes,
+    std::unordered_map<const Node*, std::vector<Tensor>>*
+        shape_replacement_map) {
+  std::vector<Tensor> defined_shape;
+  for (const auto& shape : input_shapes) {
+    if (!shape.IsFullyDefined()) {
+      return false;
+    }
+    const int rank = shape.dims();
+    DataType op_type = n->output_type(0);
+    Tensor t(op_type, TensorShape({rank}));
+    if (op_type == DT_INT64) {
+      auto vec = t.vec<int64>();
+      for (int i = 0; i < rank; ++i) {
+        vec(i) = shape.dim_size(i);
+      }
+    } else {
+      CHECK(op_type == DT_INT32);
+      auto vec = t.vec<int32>();
+      for (int i = 0; i < rank; ++i) {
+        if (shape.dim_size(i) > INT_MAX) {
+          VLOG(1) << "Node " << n->name() << " has input shape dimension " << i
+                  << " of " << shape.dim_size(i) << " but type INT32 "
+                  << " so not replacing as constant: this will trigger a "
+                     "runtime error later.";
+          return false;
+        }
+        vec(i) = static_cast<int32>(shape.dim_size(i));
+      }
+    }
+    defined_shape.push_back(t);
+  }
+  // All the inputs had known shapes so we can replace the node by constants
+  // later in the rewrite.
+  shape_replacement_map->insert({n, defined_shape});
+  return true;
+}
+
+// If n's input has defined rank, inserts that rank as a Tensor in the
+//  shape_replacement_map.
+bool MaybeReplaceRankOp(const Node* n,
+                        const std::vector<PartialTensorShape>& input_shapes,
+                        std::unordered_map<const Node*, std::vector<Tensor>>*
+                            shape_replacement_map) {
+  CHECK_EQ(input_shapes.size(), 1);
+  if (input_shapes[0].unknown_rank()) {
+    return false;
+  }
+  Tensor t(DT_INT32, TensorShape({}));
+  t.scalar<int32>()() = input_shapes[0].dims();
+  shape_replacement_map->insert({n, {t}});
+  return true;
+}
+
+// If n's input has defined size, inserts that size as a Tensor in the
+//  shape_replacement_map.
+bool MaybeReplaceSizeOp(const Node* n,
+                        const std::vector<PartialTensorShape>& input_shapes,
+                        std::unordered_map<const Node*, std::vector<Tensor>>*
+                            shape_replacement_map) {
+  CHECK_EQ(input_shapes.size(), 1);
+  if (!input_shapes[0].IsFullyDefined()) {
+    return false;
+  }
+  DataType op_type = n->output_type(0);
+  Tensor t(op_type, TensorShape({}));
+  int64 size = input_shapes[0].num_elements();
+  if (op_type == DT_INT64) {
+    t.scalar<int64>()() = size;
+  } else {
+    CHECK(op_type == DT_INT32);
+    if (size > INT_MAX) {
+      VLOG(1) << "Node " << n->name() << " has input shape size " << size
+              << " but type INT32 "
+              << " so not replacing as constant: this will trigger a runtime "
+                 "error later.";
+      return false;
+    }
+    t.scalar<int32>()() = static_cast<int32>(size);
+  }
+  shape_replacement_map->insert({n, {t}});
+  return true;
+}
+
+// If n is a shape Op (Shape, ShapeN, Rank, or Size) and its inputs have their
+// shapes specified in shape_map, then adds to shape_replacement_map a mapping
+// from n to a vector of Tensors, where Tensor k is the (statically known) value
+// on n's kth output edge. shape_replacement_map has an entry for n iff
+// MaybeReplaceShapeOp returns true, so it's valid to use
+// shape_replacement_map->count(n) as a test to see if n is a shape op that can
+// be replaced.
+bool MaybeReplaceShapeOp(
+    const Node* n,
+    const std::unordered_map<const Node*, std::vector<PartialTensorShape>>*
+        shape_map,
+    std::unordered_map<const Node*, std::vector<Tensor>>*
+        shape_replacement_map) {
+  if (shape_map == nullptr || !IsShapeOp(n)) {
+    return false;
+  }
+  // input_shapes will contain the shapes of each of n's inputs.
+  std::vector<PartialTensorShape> input_shapes;
+  if (!ReadPartialShapesFromShapeMap(n, shape_map, &input_shapes)) {
+    return false;
+  }
+  const auto& ts = n->type_string();
+  if (ts == "Shape" || ts == "ShapeN") {
+    if (!MaybeReplaceShapeOrShapeNOp(n, input_shapes, shape_replacement_map)) {
+      return false;
+    }
+  } else if (ts == "Rank") {
+    if (!MaybeReplaceRankOp(n, input_shapes, shape_replacement_map)) {
+      return false;
+    }
+  } else {
+    CHECK_EQ(ts, "Size");
+    if (!MaybeReplaceSizeOp(n, input_shapes, shape_replacement_map)) {
+      return false;
+    }
+  }
+  return true;
+}
+
+// Returns true if n can be evaluated as constant. shape_map maps from
+// nodes to the partially-known shapes of their outputs. consider if
+// non-null returns a bool indicating whether a given (non-Const,
+// non-Shape) node is eligible to be
+// constant-propagated. shape_replacement_map is filled in with a
+// vector of constant output tensors for constant-foldable shape nodes
+// (Shape, ShapeN, Size, or Rank).
+bool IsConstantFoldable(
+    const Node* n,
+    const std::unordered_map<const Node*, std::vector<PartialTensorShape>>*
+        shape_map,
+    const std::function<bool(const Node*)>& consider,
+    std::unordered_map<const Node*, std::vector<Tensor>>*
+        shape_replacement_map) {
   if (n->IsConstant()) {
     return true;
   }
+  if (MaybeReplaceShapeOp(n, shape_map, shape_replacement_map)) {
+    return true;
+  }
   if (n->op_def().is_stateful()) {
     return false;
   }
@@ -82,56 +253,81 @@ bool IsConstantFoldable(const Node* n,
   return true;
 }
 
+// If n is eligible for constant-folding, adds it to nodes, and places its
+// control dependencies and those transitively of its constant-foldable inputs
+// into constant_control_deps. If n is a constant-foldable shape node (Shape,
+// ShapeN, Rank, or Size), also puts its outputs into shape_replacement_map.
+void ConsiderConstantFoldableNode(
+    Node* n, const ConstantFoldingOptions& opts, std::vector<Node*>* nodes,
+    std::unordered_map<const Node*, gtl::FlatSet<Node*>>* constant_control_deps,
+    std::unordered_map<const Node*, std::vector<Tensor>>* shape_replacement_map,
+    bool* internal_node_inserted) {
+  if (IsConstantFoldable(n, opts.shape_map, opts.consider,
+                         shape_replacement_map)) {
+    // A node is constant provided all of its non-control incoming Tensors come
+    // from constant nodes, or it's a shape Op with statically known inputs in
+    // which case it is placed in shape_replacement_map.
+    //
+    // We allow control dependencies from non-constant nodes to constant nodes,
+    // but to preserve the graph structure we must transfer the control
+    // dependency onto any constant replacement.
+    bool all_parents_constant = true;
+    for (const Edge* in : n->in_edges()) {
+      // Allows non-constant -> constant control edges.
+      if (!in->IsControlEdge() &&
+          constant_control_deps->count(in->src()) == 0) {
+        all_parents_constant = false;
+        break;
+      }
+    }
+    if (all_parents_constant || shape_replacement_map->count(n) != 0) {
+      gtl::FlatSet<Node*>& control_deps = (*constant_control_deps)[n];
+      for (const Edge* e : n->in_edges()) {
+        if (constant_control_deps->count(e->src()) == 0) {
+          // This branch is taken if the incoming edge is a control dependency,
+          // in which case we want to add it to the dependencies being
+          // accumulated for this node, or the incoming edge is not
+          // constant. The latter may happen when n is a shape node and the
+          // source has known shape. In that case add a control dependency from
+          // the source node, since there was previously a data dependency and
+          // we want to preserve sequencing constraints.
+          if (!e->src()->IsSource()) {
+            control_deps.insert(e->src());
+          }
+        } else {
+          // If the parent has been accumulating control dependencies, add all
+          // of its transitive control deps.
+          const gtl::FlatSet<Node*>& parent_deps =
+              (*constant_control_deps)[e->src()];
+          control_deps.insert(parent_deps.begin(), parent_deps.end());
+        }
+      }
+      nodes->push_back(n);
+      if (!n->IsConstant()) {
+        *internal_node_inserted = true;
+      }
+    }
+  }
+}
+
 // Returns the constant foldable nodes in `nodes` in topological order.
 // Populates `constant_control_deps` with the non-constant control dependencies
 // of each constant node.
 void FindConstantFoldableNodes(
-    const Graph* graph, ConstantFoldingOptions opts, std::vector<Node*>* nodes,
-    std::unordered_map<const Node*, gtl::FlatSet<Node*>>*
-        constant_control_deps) {
+    const Graph* graph, const ConstantFoldingOptions& opts,
+    std::vector<Node*>* nodes,
+    std::unordered_map<const Node*, gtl::FlatSet<Node*>>* constant_control_deps,
+    std::unordered_map<const Node*, std::vector<Tensor>>*
+        shape_replacement_map) {
   bool internal_node_inserted = false;
-  // Walk the nodes in data flow order
-  ReverseDFS(
-      *graph, nullptr,
-      [nodes, constant_control_deps, &internal_node_inserted, opts](Node* n) {
-        if (IsConstantFoldable(n, opts.consider)) {
-          // A node is constant provided all of its non-control
-          // incoming Tensors come from constant nodes.
-          //
-          // We allow control dependencies from non-constant nodes to constant
-          // nodes, but to preserve the graph structure we must transfer the
-          // control dependency onto any constant replacement.
-          bool all_parents_constant = true;
-          for (const Edge* in : n->in_edges()) {
-            // Allows non-constant -> constant control edges.
-            if (!in->IsControlEdge() &&
-                constant_control_deps->count(in->src()) == 0) {
-              all_parents_constant = false;
-              break;
-            }
-          }
-          if (all_parents_constant) {
-            gtl::FlatSet<Node*>& control_deps = (*constant_control_deps)[n];
-            for (const Edge* e : n->in_edges()) {
-              if (constant_control_deps->count(e->src()) == 0) {
-                if (!e->src()->IsSource()) {
-                  control_deps.insert(e->src());
-                }
-              } else {
-                // If the parent is constant, add all of its transitive control
-                // deps.
-                const gtl::FlatSet<Node*>& parent_deps =
-                    (*constant_control_deps)[e->src()];
-                control_deps.insert(parent_deps.begin(), parent_deps.end());
-              }
-            }
-            nodes->push_back(n);
-            if (!n->IsConstant()) {
-              internal_node_inserted = true;
-            }
-          }
-        }
-      });
+  // Walk the nodes in data flow order.
+  ReverseDFS(*graph, nullptr,
+             [nodes, constant_control_deps, shape_replacement_map,
+              &internal_node_inserted, &opts](Node* n) {
+               ConsiderConstantFoldableNode(
+                   n, opts, nodes, constant_control_deps, shape_replacement_map,
+                   &internal_node_inserted);
+             });
   // If we have inserted just leaf level nodes, then there is nothing to fold.
   if (!internal_node_inserted) {
     nodes->clear();
@@ -141,31 +337,93 @@ void FindConstantFoldableNodes(
 
 typedef std::pair<Node*, int> NodeAndOutput;
 
+int64 UniqueConstantId() {
+  static std::atomic_int_fast64_t id;
+  return id.fetch_add(1);
+}
+
+// Adds n to constant_graph which is being built up for subsequent evaluation of
+// constant propagation. node_map is the mapping of nodes in the original graph
+// to nodes in the constant graph. The value of an entry in node_map is a vector
+// of nodes because a ShapeN node in the original graph is replaced by a vector
+// of Constant nodes in the constant graph.
+void AddNodeToConstantGraph(
+    Node* n, std::unordered_map<Node*, std::vector<Node*>>* node_map,
+    Graph* constant_graph) {
+  std::vector<Node*>& added = (*node_map)[n];
+  added.push_back(constant_graph->CopyNode(n));
+  for (const Edge* in_edge : n->in_edges()) {
+    // Don't copy control edges to the constant graph.
+    if (!in_edge->IsControlEdge()) {
+      Node* in = in_edge->src();
+      auto it = node_map->find(in);
+      CHECK(it != node_map->end())
+          << n->DebugString() << " <-" << in->DebugString();
+      if (it->second.size() == 1) {
+        constant_graph->AddEdge(it->second[0], in_edge->src_output(), added[0],
+                                in_edge->dst_input());
+      } else {
+        // The original source node had multiple outputs and was replaced by a
+        // vector of constants, so the edge comes from the 0th output of the kth
+        // added constant, rather than the kth output of the added node as in
+        // the standard case above.
+        constant_graph->AddEdge(it->second[in_edge->src_output()], 0, added[0],
+                                in_edge->dst_input());
+      }
+    }
+  }
+}
+
+// Replaces constant-foldable shape node n by a vector of constants in
+// constant_graph, which is being built up for subsequent evaluation of constant
+// propagation. node_map is the mapping of nodes in the original graph to nodes
+// in the constant graph. The value of an entry in node_map is a vector of nodes
+// because a ShapeN node in the original graph is replaced by a vector of
+// Constant nodes in the constant graph.
+void AddShapeNodeToConstantGraph(
+    Node* n,
+    const std::unordered_map<const Node*, std::vector<Tensor>>&
+        shape_replacement_map,
+    std::unordered_map<Node*, std::vector<Node*>>* node_map,
+    Graph* constant_graph) {
+  std::vector<Node*>& added = (*node_map)[n];
+  const string& node_name = n->name();
+  for (const Tensor& t : shape_replacement_map.at(n)) {
+    auto builder =
+        NodeDefBuilder(strings::StrCat(constant_graph->NewName(node_name),
+                                       "__cf__", UniqueConstantId()),
+                       "Const")
+            .Attr("dtype", t.dtype())
+            .Attr("value", t);
+    NodeDef def;
+    CHECK(builder.Finalize(&def).ok());
+    Node* constant_node;
+    CHECK(NodeBuilder(builder).Finalize(constant_graph, &constant_node).ok());
+    added.push_back(constant_node);
+  }
+  // Don't copy incoming edges to shape nodes that are being replaced.
+}
+
 // Given the constant foldable nodes in 'nodes', returns a new graph 'g'. 'g'
 // will contain copies of the nodes in 'nodes'. In addition, if there is an edge
 // going from a node 'n' in 'nodes' to another node in 'orig_graph' but not in
 // 'nodes', then 'tensors_to_fetch' will contain the mapping from the
 // corresponding copy of 'n' and the edge number in 'g' to 'n'.
-Graph* GetConstantGraph(const Graph* orig_graph,
-                        const std::vector<Node*>& nodes,
-                        std::map<NodeAndOutput, Node*>* tensors_to_fetch) {
+Graph* GetConstantGraph(
+    const Graph* orig_graph, const std::vector<Node*>& nodes,
+    const std::unordered_map<const Node*, std::vector<Tensor>>&
+        shape_replacement_map,
+    std::map<NodeAndOutput, Node*>* tensors_to_fetch) {
   Graph* constant_graph = new Graph(orig_graph->op_registry());
-  std::unordered_map<Node*, Node*> node_map;
-  node_map[orig_graph->source_node()] = constant_graph->source_node();
-  node_map[orig_graph->sink_node()] = constant_graph->sink_node();
+  std::unordered_map<Node*, std::vector<Node*>> node_map;
+  node_map[orig_graph->source_node()] = {constant_graph->source_node()};
+  node_map[orig_graph->sink_node()] = {constant_graph->sink_node()};
   for (Node* n : nodes) {
-    Node* added = constant_graph->CopyNode(n);
-    node_map[n] = added;
-    for (const Edge* in_edge : n->in_edges()) {
-      // Don't copy control edges to the constant graph.
-      if (!in_edge->IsControlEdge()) {
-        Node* in = in_edge->src();
-        auto it = node_map.find(in);
-        CHECK(it != node_map.end())
-            << n->DebugString() << " <-" << in->DebugString();
-        constant_graph->AddEdge(it->second, in_edge->src_output(), added,
-                                in_edge->dst_input());
-      }
+    if (shape_replacement_map.count(n) == 0) {
+      AddNodeToConstantGraph(n, &node_map, constant_graph);
+    } else {
+      AddShapeNodeToConstantGraph(n, shape_replacement_map, &node_map,
+                                  constant_graph);
     }
   }
 
@@ -173,8 +431,19 @@ Graph* GetConstantGraph(const Graph* orig_graph,
     for (const Edge* out_edge : added_nodes.first->out_edges()) {
       if (node_map.count(out_edge->dst()) == 0) {
         if (out_edge->IsControlEdge()) continue;
-        tensors_to_fetch->insert(
-            {{added_nodes.second, out_edge->src_output()}, added_nodes.first});
+        if (added_nodes.second.size() == 1) {
+          tensors_to_fetch->insert(
+              {{added_nodes.second[0], out_edge->src_output()},
+               added_nodes.first});
+        } else {
+          // The node had multiple outputs and was replaced by a
+          // vector of constants, so the NodeAndOutput is the 0th
+          // output of the kth added constant, rather than the kth
+          // output of the added node as in the standard case above.
+          tensors_to_fetch->insert(
+              {{added_nodes.second[out_edge->src_output()], 0},
+               added_nodes.first});
+        }
       }
     }
   }
@@ -182,11 +451,6 @@ Graph* GetConstantGraph(const Graph* orig_graph,
   return constant_graph;
 }
 
-int64 UniqueConstantId() {
-  static std::atomic_int_fast64_t id;
-  return id.fetch_add(1);
-}
-
 // Replaces the identified Tensor in 'graph' by a 'Const' node with
 // the value supplied in 'constant'. 'partition_device', if non-null
 // is the device where the graph executes. Returns true if the
@@ -291,8 +555,9 @@ Status ConstantFold(const ConstantFoldingOptions& opts,
 
   std::vector<Node*> constant_foldable_nodes;
   std::unordered_map<const Node*, gtl::FlatSet<Node*>> constant_control_deps;
+  std::unordered_map<const Node*, std::vector<Tensor>> shape_replacement_map;
   FindConstantFoldableNodes(graph, opts, &constant_foldable_nodes,
-                            &constant_control_deps);
+                            &constant_control_deps, &shape_replacement_map);
   if (constant_foldable_nodes.empty()) {
     VLOG(1) << "No constant foldable nodes found";
     *was_mutated = false;
@@ -302,7 +567,8 @@ Status ConstantFold(const ConstantFoldingOptions& opts,
 
   std::map<NodeAndOutput, Node*> tensors_to_fetch;
   std::unique_ptr<Graph> constant_graph(
-      GetConstantGraph(graph, constant_foldable_nodes, &tensors_to_fetch));
+      GetConstantGraph(graph, constant_foldable_nodes, shape_replacement_map,
+                       &tensors_to_fetch));
   DumpGraph("Constant graph", constant_graph.get());
 
   if (tensors_to_fetch.empty()) {
@@ -337,7 +603,6 @@ Status ConstantFold(const ConstantFoldingOptions& opts,
   if (!s.ok()) {
     VLOG(1) << "Could not fetch constants: " << s;
     *was_mutated = false;
-    // This is not an error, so return the status as OK.
     return s;
   }
 
diff --git a/tensorflow/core/common_runtime/constant_folding.h b/tensorflow/core/common_runtime/constant_folding.h
index 93289b875f5..07cb5def8b4 100644
--- a/tensorflow/core/common_runtime/constant_folding.h
+++ b/tensorflow/core/common_runtime/constant_folding.h
@@ -29,6 +29,11 @@ struct ConstantFoldingOptions {
   // If "consider" is not a nullptr, then only constant fold a node "n" if
   // consider(n) returns true.
   std::function<bool(const Node*)> consider = nullptr;
+  // If shape_map is not a nullptr, it is a map from node n to a
+  // vector of the (potentially partially-known) shapes of its
+  // outputs.
+  const std::unordered_map<const Node*, std::vector<PartialTensorShape>>*
+      shape_map;  // not owned
 };
 
 // Perform constant folding optimization on "graph".
diff --git a/tensorflow/core/common_runtime/constant_folding_test.cc b/tensorflow/core/common_runtime/constant_folding_test.cc
index 4a8560960ed..c76ad647a0a 100644
--- a/tensorflow/core/common_runtime/constant_folding_test.cc
+++ b/tensorflow/core/common_runtime/constant_folding_test.cc
@@ -363,6 +363,197 @@ TEST_F(ConstantFoldingTest, ControlDependencies) {
   }
 }
 
+TEST_F(ConstantFoldingTest, SimpleShapeKnown) {
+  Graph g(OpRegistry::Global());
+  {
+    Scope s = Scope::NewRootScope();
+    Output recv0 = ops::_Recv(s.WithOpName("recv0"), DT_FLOAT, "recv0",
+                              "sender", 0, "receiver");
+    auto shape = ops::Shape(s.WithOpName("shape"), recv0);
+    Output recv1 = ops::_Recv(s.WithOpName("recv1"), DT_FLOAT, "recv1",
+                              "sender", 0, "receiver");
+    auto shape_n = ops::ShapeN(s.WithOpName("shape_n"), {recv0, recv1});
+    auto rank = ops::Rank(s.WithOpName("rank"), recv0);
+    auto size = ops::Size(s.WithOpName("size"), recv1);
+    auto recv2 = ops::_Recv(s.WithOpName("recv2"), DT_FLOAT, "recv2", "sender",
+                            0, "receiver");
+    auto c = ops::Const<int>(s.WithControlDependencies(recv2), 3);
+    auto add0 = ops::Add(s.WithControlDependencies(c), rank, size);
+    auto add1 = ops::Add(s, shape, shape_n[0]);
+    auto add2 = ops::Add(s, shape_n[1], shape_n[1]);
+    auto send0 = ops::_Send(s.WithOpName("send0"), add0, "send0", "sender", 0,
+                            "receiver");
+    auto send1 = ops::_Send(s.WithOpName("send1"), add1, "send1", "sender", 0,
+                            "receiver");
+    auto send2 = ops::_Send(s.WithOpName("send2"), add2, "send2", "sender", 0,
+                            "receiver");
+    TF_ASSERT_OK(s.ToGraph(&g));
+  }
+  std::unordered_map<string, Node*> orig_index = NodeNameIndex(g);
+  Node* recv0 = orig_index.at("recv0");
+  Node* recv1 = orig_index.at("recv1");
+  PartialTensorShape ps0;
+  int r0_dims[] = {1, 2};
+  TF_EXPECT_OK(PartialTensorShape::MakePartialShape(r0_dims, 2, &ps0));
+  PartialTensorShape ps1;
+  int r1_dims[] = {2, 3, 4};
+  TF_EXPECT_OK(PartialTensorShape::MakePartialShape<int>(r1_dims, 3, &ps1));
+  std::unordered_map<const Node*, std::vector<PartialTensorShape>> map;
+  map[recv0].push_back(ps0);
+  map[recv1].push_back(ps1);
+  ConstantFoldingOptions opts;
+  opts.shape_map = &map;
+  bool was_mutated;
+  TF_EXPECT_OK(
+      ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated));
+  EXPECT_TRUE(was_mutated);
+
+  std::unordered_map<string, Node*> index = NodeNameIndex(g);
+  Node* recv2 = index.at("recv2");
+  Node* send0 = index.at("send0");
+  Node* send1 = index.at("send1");
+  Node* send2 = index.at("send2");
+
+  ASSERT_EQ(1, send0->num_inputs());
+  Node* cf0 = *(send0->in_nodes().begin());
+  ExpectNodeEqual<int>(cf0, {26}, {});
+
+  ASSERT_EQ(1, send1->num_inputs());
+  Node* cf1 = *(send1->in_nodes().begin());
+  ExpectNodeEqual<int>(cf1, {2, 4}, {2});
+
+  ASSERT_EQ(1, send2->num_inputs());
+  Node* cf2 = *(send2->in_nodes().begin());
+  ExpectNodeEqual<int>(cf2, {4, 6, 8}, {3});
+
+  ASSERT_EQ(3, cf0->in_edges().size());
+  for (const Edge* e : cf0->in_edges()) {
+    EXPECT_TRUE(e->IsControlEdge());
+    EXPECT_TRUE(e->src() == recv0 || e->src() == recv1 || e->src() == recv2)
+        << e->src()->name();
+  }
+
+  ASSERT_EQ(2, cf1->in_edges().size());
+  for (const Edge* e : cf1->in_edges()) {
+    EXPECT_TRUE(e->IsControlEdge());
+    EXPECT_TRUE(e->src() == recv0 || e->src() == recv1) << e->src()->name();
+  }
+
+  ASSERT_EQ(2, cf2->in_edges().size());
+  for (const Edge* e : cf2->in_edges()) {
+    EXPECT_TRUE(e->IsControlEdge());
+    EXPECT_TRUE(e->src() == recv0 || e->src() == recv1) << e->src()->name();
+  }
+}
+
+TEST_F(ConstantFoldingTest, PartialShape) {
+  Graph g(OpRegistry::Global());
+  {
+    Scope s = Scope::NewRootScope();
+    Output recv0 = ops::_Recv(s.WithOpName("recv0"), DT_FLOAT, "recv0",
+                              "sender", 0, "receiver");
+    Output recv1 = ops::_Recv(s.WithOpName("recv1"), DT_FLOAT, "recv1",
+                              "sender", 0, "receiver");
+    auto shape = ops::Shape(s.WithOpName("shape"), recv0);
+    auto rank0 = ops::Rank(s.WithOpName("rank0"), recv0);
+    auto rank1 = ops::Rank(s.WithOpName("rank1"), recv1);
+    auto size = ops::Size(s.WithOpName("size"), recv0);
+    auto send0 = ops::_Send(s.WithOpName("send0"), rank0, "send0", "sender", 0,
+                            "receiver");
+    auto send1 = ops::_Send(s.WithOpName("send1"), shape, "send1", "sender", 0,
+                            "receiver");
+    auto send2 = ops::_Send(s.WithOpName("send2"), size, "send2", "sender", 0,
+                            "receiver");
+    auto send3 = ops::_Send(s.WithOpName("send3"), rank1, "send3", "sender", 0,
+                            "receiver");
+    TF_ASSERT_OK(s.ToGraph(&g));
+  }
+  std::unordered_map<string, Node*> orig_index = NodeNameIndex(g);
+  Node* recv0 = orig_index.at("recv0");
+  Node* recv1 = orig_index.at("recv1");
+  PartialTensorShape ps0;
+  int r0_dims[] = {-1, -1};
+  TF_EXPECT_OK(PartialTensorShape::MakePartialShape(r0_dims, 2, &ps0));
+  PartialTensorShape ps1;
+  std::unordered_map<const Node*, std::vector<PartialTensorShape>> map;
+  map[recv0].push_back(ps0);
+  map[recv1].push_back(ps1);
+  ConstantFoldingOptions opts;
+  opts.shape_map = &map;
+  bool was_mutated;
+  TF_EXPECT_OK(
+      ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated));
+  EXPECT_TRUE(was_mutated);
+
+  std::unordered_map<string, Node*> index = NodeNameIndex(g);
+  Node* shape = index.at("shape");
+  Node* size = index.at("size");
+  Node* rank1 = index.at("rank1");
+  Node* send0 = index.at("send0");
+  Node* send1 = index.at("send1");
+  Node* send2 = index.at("send2");
+  Node* send3 = index.at("send3");
+
+  ASSERT_EQ(1, send0->num_inputs());
+  Node* cf0 = *(send0->in_nodes().begin());
+  ExpectNodeEqual<int>(cf0, {2}, {});
+
+  ASSERT_EQ(1, send1->num_inputs());
+  Node* ncf1 = *(send1->in_nodes().begin());
+  EXPECT_EQ(ncf1, shape);
+
+  ASSERT_EQ(1, send2->num_inputs());
+  Node* ncf2 = *(send2->in_nodes().begin());
+  EXPECT_EQ(ncf2, size);
+
+  ASSERT_EQ(1, send3->num_inputs());
+  Node* ncf3 = *(send3->in_nodes().begin());
+  EXPECT_EQ(ncf3, rank1);
+}
+
+TEST_F(ConstantFoldingTest, ConstShapeKnown) {
+  Graph g(OpRegistry::Global());
+  {
+    Scope s = Scope::NewRootScope();
+    auto recv0 = ops::_Recv(s.WithOpName("recv0"), DT_FLOAT, "recv0", "sender",
+                            0, "receiver");
+    auto c0 =
+        ops::Const<int>(s.WithOpName("c0").WithControlDependencies(recv0), 1);
+    auto rank = ops::Rank(s.WithOpName("rank"), c0);
+    auto add0 = ops::Add(s, rank, rank);
+    auto send0 = ops::_Send(s.WithOpName("send0"), add0, "send0", "sender", 0,
+                            "receiver");
+    TF_ASSERT_OK(s.ToGraph(&g));
+  }
+  std::unordered_map<string, Node*> orig_index = NodeNameIndex(g);
+  Node* c0 = orig_index.at("c0");
+  PartialTensorShape ps0;
+  int c0_dims[] = {};
+  TF_EXPECT_OK(PartialTensorShape::MakePartialShape(c0_dims, 0, &ps0));
+  std::unordered_map<const Node*, std::vector<PartialTensorShape>> map;
+  map[c0].push_back(ps0);
+  ConstantFoldingOptions opts;
+  opts.shape_map = &map;
+  bool was_mutated;
+  TF_EXPECT_OK(
+      ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated));
+  EXPECT_TRUE(was_mutated);
+
+  std::unordered_map<string, Node*> index = NodeNameIndex(g);
+  Node* recv0 = index.at("recv0");
+  Node* send0 = index.at("send0");
+
+  ASSERT_EQ(1, send0->num_inputs());
+  Node* cf0 = *(send0->in_nodes().begin());
+  ExpectNodeEqual<int>(cf0, {0}, {});
+
+  ASSERT_EQ(1, cf0->in_edges().size());
+  for (const Edge* e : cf0->in_edges()) {
+    EXPECT_TRUE(e->IsControlEdge());
+    EXPECT_TRUE(e->src() == recv0) << e->src()->name();
+  }
+}
+
 namespace {
 
 const char kTestMemRegionName[] = "test://test";
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 294ad13a0a4..aaba2aa7875 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -1194,7 +1194,8 @@ Status DirectSession::GetOrCreateExecutors(
     };
     params.node_outputs_cb = node_outputs_callback_;
 
-    optimizer.Optimize(lib, options_.env, device, &iter->second);
+    optimizer.Optimize(lib, options_.env, device, &iter->second,
+                       /*shape_map=*/nullptr);
 
     // EXPERIMENTAL: tfdbg inserts debug nodes in the graph.
     if (!options.debug_options.debug_tensor_watch_opts().empty()) {
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 64c3747ce1c..b7fad68de76 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -461,7 +461,7 @@ void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g) {
   opts.set_do_function_inlining(true);
   opts.set_do_constant_folding(true);
   GraphOptimizer optimizer(opts);
-  optimizer.Optimize(lib, lib->env(), lib->device(), g);
+  optimizer.Optimize(lib, lib->env(), lib->device(), g, /*shape_map=*/nullptr);
 }
 
 Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) {
@@ -470,7 +470,7 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) {
   std::unique_ptr<Graph> g(new Graph(lib_def_));
   CopyGraph(*fbody->graph, g.get());
 
-  optimizer_.Optimize(this, env(), device(), &g);
+  optimizer_.Optimize(this, env(), device(), &g, /*shape_map=*/nullptr);
   TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device()->device_type()),
                                        device()->name(), g.get()));
 
diff --git a/tensorflow/core/common_runtime/graph_optimizer.cc b/tensorflow/core/common_runtime/graph_optimizer.cc
index edfecfae06e..c32d9f8a2e9 100644
--- a/tensorflow/core/common_runtime/graph_optimizer.cc
+++ b/tensorflow/core/common_runtime/graph_optimizer.cc
@@ -33,8 +33,11 @@ GraphOptimizer::GraphOptimizer(const OptimizerOptions& opts) : opts_(opts) {
 
 GraphOptimizer::~GraphOptimizer() {}
 
-void GraphOptimizer::Optimize(FunctionLibraryRuntime* runtime, Env* env,
-                              Device* device, std::unique_ptr<Graph>* graph) {
+void GraphOptimizer::Optimize(
+    FunctionLibraryRuntime* runtime, Env* env, Device* device,
+    std::unique_ptr<Graph>* graph,
+    const std::unordered_map<const Node*, std::vector<PartialTensorShape>>*
+        shape_map) {
   Graph* g = graph->get();
   DumpGraph("Initial", g);
 
@@ -57,6 +60,7 @@ void GraphOptimizer::Optimize(FunctionLibraryRuntime* runtime, Env* env,
 
     if (opts_.do_constant_folding()) {
       ConstantFoldingOptions cf_opts;
+      cf_opts.shape_map = shape_map;
       bool was_mutated;
       ConstantFold(cf_opts, runtime, env, device, g, &was_mutated)
           .IgnoreError();
diff --git a/tensorflow/core/common_runtime/graph_optimizer.h b/tensorflow/core/common_runtime/graph_optimizer.h
index a6b10356ce1..c145adde829 100644
--- a/tensorflow/core/common_runtime/graph_optimizer.h
+++ b/tensorflow/core/common_runtime/graph_optimizer.h
@@ -30,12 +30,17 @@ class GraphOptimizer {
   ~GraphOptimizer();
 
   // Applies optimization passes specified in 'opts' to 'graph'.
-  // Maybe replace *graph with a new graph object.
-  // 'device' is device on which the 'graph' will execute. It's passed to the
-  // optimizers so that they can respect constraints if any, that should be
-  // respected.
-  void Optimize(FunctionLibraryRuntime* runtime, Env* env, Device* device,
-                std::unique_ptr<Graph>* graph);
+  // Maybe replace *graph with a new graph object.  'device' is device
+  // on which the 'graph' will execute. It's passed to the optimizers
+  // so that they can respect constraints if any, that should be
+  // respected. If shape_map is not null it maps from nodes in graph
+  // to partially-known shapes of their outputs, and may be used,
+  // e.g., in the constant folding pass.
+  void Optimize(
+      FunctionLibraryRuntime* runtime, Env* env, Device* device,
+      std::unique_ptr<Graph>* graph,
+      const std::unordered_map<const Node*, std::vector<PartialTensorShape>>*
+          shape_map);
 
  private:
   OptimizerOptions opts_;
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc
index 8d4b1be0261..ce186700b33 100644
--- a/tensorflow/core/distributed_runtime/graph_mgr.cc
+++ b/tensorflow/core/distributed_runtime/graph_mgr.cc
@@ -244,7 +244,8 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
       }
     };
 
-    optimizer.Optimize(lib, worker_env_->env, params.device, &subgraph);
+    optimizer.Optimize(lib, worker_env_->env, params.device, &subgraph,
+                       /*shape_map=*/nullptr);
 
     // EXPERIMENTAL: tfdbg inserts debug nodes (i.e., probes) to the graph.
     if (!debug_options.debug_tensor_watch_opts().empty()) {
diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc
index 0c2801e8bc3..ed8bbe58c46 100644
--- a/tensorflow/core/grappler/grappler_item_builder.cc
+++ b/tensorflow/core/grappler/grappler_item_builder.cc
@@ -127,7 +127,8 @@ Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def,
 
   // Optimize the graph.
   GraphOptimizer optimizer(*optimizer_opts);
-  optimizer.Optimize(flib.get(), env, devices[0], &graphptr);
+  optimizer.Optimize(flib.get(), env, devices[0], &graphptr,
+                     /*shape_map=*/nullptr);
   graphptr->ToGraphDef(output_graph_def);
 
   return Status::OK();