diff --git a/tensorflow/core/graph/shape_refiner.cc b/tensorflow/core/graph/shape_refiner.cc
index e45e4e0d633..66a5202a147 100644
--- a/tensorflow/core/graph/shape_refiner.cc
+++ b/tensorflow/core/graph/shape_refiner.cc
@@ -113,6 +113,36 @@ Status ShapeRefiner::AddNode(const Node* node) {
   return Status::OK();
 }
 
+Status ShapeRefiner::SetShape(const Node* node, int output_port,
+                              shape_inference::ShapeHandle shape) {
+  auto c = GetContext(node);
+  if (c == nullptr) {
+    return errors::Internal("Could not find context for ", node->name());
+  }
+
+  if (output_port < 0 || output_port >= node->num_outputs()) {
+    return errors::InvalidArgument(
+        "output_port '", output_port, "' is out of range, ", "node '",
+        node->name(), "' has ", node->num_outputs(), " outputs");
+  }
+
+  // Check compatibility
+  shape_inference::ShapeHandle existing_shape = c->output(output_port);
+  shape_inference::ShapeHandle unused;
+  TF_RETURN_IF_ERROR(c->Merge(existing_shape, shape, &unused));
+
+  c->set_output(output_port, shape);
+
+  // TODO(vrv): Do we need to propagate the new shape through all
+  // consumers that change their outputs?  At the moment, python
+  // does not do this, but this seems like a nice feature.
+
+  // TODO(vrv): We might need to keep track of the fact that the
+  // existing shape is invalidated, in case we need to propagate
+  // this information to remote workers.
+  return Status::OK();
+}
+
 Status ShapeRefiner::ConstantValue(const Node* node, Tensor* tensor_storage,
                                    const Tensor** input_tensor) const {
   *input_tensor = nullptr;
diff --git a/tensorflow/core/graph/shape_refiner.h b/tensorflow/core/graph/shape_refiner.h
index 21551903c03..63838e1cfdd 100644
--- a/tensorflow/core/graph/shape_refiner.h
+++ b/tensorflow/core/graph/shape_refiner.h
@@ -46,6 +46,14 @@ class ShapeRefiner {
   //  - The shape inference function returns an error.
   Status AddNode(const Node* node);
 
+  // Sets 'node's 'output_port' output to have shape 'shape'.
+  //
+  // Returns an error if 'node' was not previously added to this
+  // object, if 'output_port' is invalid, or if 'shape' is
+  // not compatible with the existing shape of the output.
+  Status SetShape(const Node* node, int output_port,
+                  shape_inference::ShapeHandle shape);
+
   // Returns the InferenceContext for 'node', if present.
   shape_inference::InferenceContext* GetContext(const Node* node) const {
     auto it = node_to_context_.find(node);
diff --git a/tensorflow/core/graph/shape_refiner_test.cc b/tensorflow/core/graph/shape_refiner_test.cc
index 94cd6dc74a9..ac4cf94546c 100644
--- a/tensorflow/core/graph/shape_refiner_test.cc
+++ b/tensorflow/core/graph/shape_refiner_test.cc
@@ -92,6 +92,33 @@ TEST(ShapeRefinerTest, BadShapes) {
   ASSERT_EQ("Dimensions must be equal, but are 1 and 2", s.error_message());
 }
 
+TEST(ShapeRefinerTest, SetShape) {
+  ShapeRefiner m;
+
+  Scope root = Scope::NewRootScope();
+  auto a = ops::Const(root, {{1.0f}, {2.0f}});
+
+  TF_ASSERT_OK(m.AddNode(a.node()));
+
+  auto ic = m.GetContext(a.node());
+  ASSERT_NE(nullptr, ic);
+  shape_inference::ShapeHandle h = ic->MakeShape({2, ic->UnknownDim()});
+  TF_ASSERT_OK(m.SetShape(a.node(), 0, h));
+  EXPECT_SHAPE("[2,?]", m, a, 0);
+
+  // Out of range.
+  ASSERT_FALSE(m.SetShape(a.node(), 1, h).ok());
+  ASSERT_FALSE(m.SetShape(a.node(), -1, h).ok());
+
+  auto b = ops::Const(root, {{1.0f}, {2.0f}});
+  // Forget to add node first.
+  ASSERT_FALSE(m.SetShape(b.node(), 0, h).ok());
+
+  // Set an incompatible shape (3 vs 2)
+  h = ic->MakeShape({3, ic->UnknownDim()});
+  ASSERT_FALSE(m.SetShape(a.node(), 0, h).ok());
+}
+
 TEST(ShapeRefinerTest, PropagateConstants) {
   // Reduction dimension is a variable, so we don't know its value.
   // So the output shape value is unknown (though its rank is known).